]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: Have Wait() return an *ExitError
authorGustav Paul <gustav.paul@gmail.com>
Wed, 7 Dec 2011 14:58:22 +0000 (09:58 -0500)
committerAdam Langley <agl@golang.org>
Wed, 7 Dec 2011 14:58:22 +0000 (09:58 -0500)
I added the clientChan's msg channel to the list of channels that are closed in mainloop when the server sends a channelCloseMsg.

I added an ExitError type that wraps a Waitmsg similar to that of os/exec. I fill ExitStatus with the data returned in the 'exit-status' channel message and Msg with the data returned in the 'exit-signal' channel message.

Instead of having Wait() return on the first 'exit-status'/'exit-signal' I have it return an ExitError containing the status and signal when the clientChan's msg channel is closed.

I added two tests cases to session_test.go that test for exit status 0 (in which case Wait() returns nil) and exit status 1 (in which case Wait() returns an ExitError with ExitStatus 1)

R=dave, agl, rsc, golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/5452051

src/pkg/exp/ssh/client.go
src/pkg/exp/ssh/client_func_test.go
src/pkg/exp/ssh/session.go
src/pkg/exp/ssh/session_test.go

index d89b908cdcf47ddb952da9a35c5a05c3bccb96c1..0ce8bcaf4ffdd41aa41708be59fc080a1d37d220 100644 (file)
@@ -187,10 +187,10 @@ func (c *ClientConn) mainLoop() {
                if err != nil {
                        break
                }
-               // TODO(dfc) A note on blocking channel use. 
-               // The msg, win, data and dataExt channels of a clientChan can 
-               // cause this loop to block indefinately if the consumer does 
-               // not service them. 
+               // TODO(dfc) A note on blocking channel use.
+               // The msg, win, data and dataExt channels of a clientChan can
+               // cause this loop to block indefinately if the consumer does
+               // not service them.
                switch packet[0] {
                case msgChannelData:
                        if len(packet) < 9 {
@@ -211,7 +211,7 @@ func (c *ClientConn) mainLoop() {
                        datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
                        if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
                                packet = packet[13:]
-                               // RFC 4254 5.2 defines data_type_code 1 to be data destined 
+                               // RFC 4254 5.2 defines data_type_code 1 to be data destined
                                // for stderr on interactive sessions. Other data types are
                                // silently discarded.
                                if datatype == 1 {
@@ -231,9 +231,10 @@ func (c *ClientConn) mainLoop() {
                                close(ch.stdin.win)
                                close(ch.stdout.data)
                                close(ch.stderr.data)
+                               close(ch.msg)
                                c.chanlist.remove(msg.PeersId)
                        case *channelEOFMsg:
-                               c.getChan(msg.PeersId).msg <- msg
+                               c.getChan(msg.PeersId).sendEOF()
                        case *channelRequestSuccessMsg:
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelRequestFailureMsg:
@@ -249,7 +250,7 @@ func (c *ClientConn) mainLoop() {
        }
 }
 
-// Dial connects to the given network address using net.Dial and 
+// Dial connects to the given network address using net.Dial and
 // then initiates a SSH handshake, returning the resulting client connection.
 func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
        conn, err := net.Dial(network, addr)
@@ -259,18 +260,18 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
        return Client(conn, config)
 }
 
-// A ClientConfig structure is used to configure a ClientConn. After one has 
+// A ClientConfig structure is used to configure a ClientConn. After one has
 // been passed to an SSH function it must not be modified.
 type ClientConfig struct {
-       // Rand provides the source of entropy for key exchange. If Rand is 
-       // nil, the cryptographic random reader in package crypto/rand will 
+       // Rand provides the source of entropy for key exchange. If Rand is
+       // nil, the cryptographic random reader in package crypto/rand will
        // be used.
        Rand io.Reader
 
        // The username to authenticate.
        User string
 
-       // A slice of ClientAuth methods. Only the first instance 
+       // A slice of ClientAuth methods. Only the first instance
        // of a particular RFC 4252 method will be used during authentication.
        Auth []ClientAuth
 
@@ -285,7 +286,7 @@ func (c *ClientConfig) rand() io.Reader {
        return c.Rand
 }
 
-// A clientChan represents a single RFC 4254 channel that is multiplexed 
+// A clientChan represents a single RFC 4254 channel that is multiplexed
 // over a single SSH connection.
 type clientChan struct {
        packetWriter
@@ -297,7 +298,7 @@ type clientChan struct {
 }
 
 // newClientChan returns a partially constructed *clientChan
-// using the local id provided. To be usable clientChan.peersId 
+// using the local id provided. To be usable clientChan.peersId
 // needs to be assigned once known.
 func newClientChan(t *transport, id uint32) *clientChan {
        c := &clientChan{
@@ -320,8 +321,8 @@ func newClientChan(t *transport, id uint32) *clientChan {
        return c
 }
 
-// waitForChannelOpenResponse, if successful, fills out 
-// the peerId and records any initial window advertisement. 
+// waitForChannelOpenResponse, if successful, fills out
+// the peerId and records any initial window advertisement.
 func (c *clientChan) waitForChannelOpenResponse() error {
        switch msg := (<-c.msg).(type) {
        case *channelOpenConfirmMsg:
@@ -335,6 +336,13 @@ func (c *clientChan) waitForChannelOpenResponse() error {
        return errors.New("unexpected packet")
 }
 
+// sendEOF Sends EOF to the server. RFC 4254 Section 5.3
+func (c *clientChan) sendEOF() error {
+       return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
+               PeersId: c.peersId,
+       }))
+}
+
 // Close closes the channel. This does not close the underlying connection.
 func (c *clientChan) Close() error {
        return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
index 137456095a06cc13265da077c751c17e13cbd78a..24e3a6334e50ef3e735594cc54b85c1b20d8ce24 100644 (file)
@@ -6,7 +6,7 @@ package ssh
 
 // ClientConn functional tests.
 // These tests require a running ssh server listening on port 22
-// on the local host. Functional tests will be skipped unless 
+// on the local host. Functional tests will be skipped unless
 // -ssh.user and -ssh.pass must be passed to gotest.
 
 import (
index 23ea18c29ae697e01f3651dfca0df7f3978b7f26..bf9a88e97efc39ad7fd46fbd39754b3ac470fae1 100644 (file)
@@ -34,6 +34,20 @@ const (
        SIGUSR2 Signal = "USR2"
 )
 
+var signals = map[Signal]int{
+       SIGABRT: 6,
+       SIGALRM: 14,
+       SIGFPE:  8,
+       SIGHUP:  1,
+       SIGILL:  4,
+       SIGINT:  2,
+       SIGKILL: 9,
+       SIGPIPE: 13,
+       SIGQUIT: 3,
+       SIGSEGV: 11,
+       SIGTERM: 15,
+}
+
 // A Session represents a connection to a remote command or shell.
 type Session struct {
        // Stdin specifies the remote process's standard input.
@@ -170,10 +184,17 @@ func (s *Session) Start(cmd string) error {
        return s.start()
 }
 
-// Run runs cmd on the remote host and waits for it to terminate.
-// Typically, the remote server passes cmd to the shell for
-// interpretation. A Session only accepts one call to Run,
-// Start or Shell.
+// Run runs cmd on the remote host. Typically, the remote
+// server passes cmd to the shell for interpretation.
+// A Session only accepts one call to Run, Start or Shell.
+//
+// The returned error is nil if the command runs, has no problems
+// copying stdin, stdout, and stderr, and exits with a zero exit
+// status.
+//
+// If the command fails to run or doesn't complete successfully, the
+// error is of type *ExitError. Other error types may be
+// returned for I/O problems.
 func (s *Session) Run(cmd string) error {
        err := s.Start(cmd)
        if err != nil {
@@ -233,6 +254,14 @@ func (s *Session) start() error {
 }
 
 // Wait waits for the remote command to exit.
+//
+// The returned error is nil if the command runs, has no problems
+// copying stdin, stdout, and stderr, and exits with a zero exit
+// status.
+//
+// If the command fails to run or doesn't complete successfully, the
+// error is of type *ExitError. Other error types may be
+// returned for I/O problems.
 func (s *Session) Wait() error {
        if !s.started {
                return errors.New("ssh: session not started")
@@ -255,21 +284,40 @@ func (s *Session) Wait() error {
 }
 
 func (s *Session) wait() error {
-       for {
-               switch msg := (<-s.msg).(type) {
+       wm := Waitmsg{status: -1}
+
+       // Wait for msg channel to be closed before returning.
+       for msg := range s.msg {
+               switch msg := msg.(type) {
                case *channelRequestMsg:
-                       // TODO(dfc) improve this behavior to match os.Waitmsg
                        switch msg.Request {
                        case "exit-status":
                                d := msg.RequestSpecificData
-                               status := int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
-                               if status > 0 {
-                                       return fmt.Errorf("remote process exited with %d", status)
-                               }
-                               return nil
+                               wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
                        case "exit-signal":
-                               // TODO(dfc) make a more readable error message
-                               return fmt.Errorf("%v", msg.RequestSpecificData)
+                               signal, rest, ok := parseString(msg.RequestSpecificData)
+                               if !ok {
+                                       return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
+                               }
+                               wm.signal = safeString(string(signal))
+
+                               // skip coreDumped bool
+                               if len(rest) == 0 {
+                                       return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
+                               }
+                               rest = rest[1:]
+
+                               errmsg, rest, ok := parseString(rest)
+                               if !ok {
+                                       return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
+                               }
+                               wm.msg = safeString(string(errmsg))
+
+                               lang, _, ok := parseString(rest)
+                               if !ok {
+                                       return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
+                               }
+                               wm.lang = safeString(string(lang))
                        default:
                                return fmt.Errorf("wait: unexpected channel request: %v", msg)
                        }
@@ -277,7 +325,20 @@ func (s *Session) wait() error {
                        return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg)
                }
        }
-       panic("unreachable")
+       if wm.status == 0 {
+               return nil
+       }
+       if wm.status == -1 {
+               // exit-status was never sent from server
+               if wm.signal == "" {
+                       return errors.New("wait: remote command exited without exit status or exit signal")
+               }
+               wm.status = 128
+               if _, ok := signals[Signal(wm.signal)]; ok {
+                       wm.status += signals[Signal(wm.signal)]
+               }
+       }
+       return &ExitError{wm}
 }
 
 func (s *Session) stdin() error {
@@ -391,3 +452,46 @@ func (c *ClientConn) NewSession() (*Session, error) {
                clientChan: ch,
        }, nil
 }
+
+// An ExitError reports unsuccessful completion of a remote command.
+type ExitError struct {
+       Waitmsg
+}
+
+func (e *ExitError) Error() string {
+       return e.Waitmsg.String()
+}
+
+// Waitmsg stores the information about an exited remote command
+// as reported by Wait.
+type Waitmsg struct {
+       status int
+       signal string
+       msg    string
+       lang   string
+}
+
+// ExitStatus returns the exit status of the remote command.
+func (w Waitmsg) ExitStatus() int {
+       return w.status
+}
+
+// Signal returns the exit signal of the remote command if
+// it was terminated violently.
+func (w Waitmsg) Signal() string {
+       return w.signal
+}
+
+// Msg returns the exit message given by the remote command
+func (w Waitmsg) Msg() string {
+       return w.msg
+}
+
+// Lang returns the language tag. See RFC 3066
+func (w Waitmsg) Lang() string {
+       return w.lang
+}
+
+func (w Waitmsg) String() string {
+       return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal)
+}
index d4818c29f7051a8bdf6318a02c1f5e48f0f199e1..a28ead087369dfa7007b7d65d660753b1f7d79ef 100644 (file)
@@ -12,8 +12,10 @@ import (
        "testing"
 )
 
+type serverType func(*channel)
+
 // dial constructs a new test server and returns a *ClientConn.
-func dial(t *testing.T) *ClientConn {
+func dial(handler serverType, t *testing.T) *ClientConn {
        pw := password("tiger")
        serverConfig.PasswordCallback = func(user, pass string) bool {
                return user == "testuser" && pass == string(pw)
@@ -50,27 +52,7 @@ func dial(t *testing.T) *ClientConn {
                                continue
                        }
                        ch.Accept()
-                       go func() {
-                               defer ch.Close()
-                               // this string is returned to stdout
-                               shell := NewServerShell(ch, "golang")
-                               shell.ReadLine()
-                               type exitMsg struct {
-                                       PeersId   uint32
-                                       Request   string
-                                       WantReply bool
-                                       Status    uint32
-                               }
-                               // TODO(dfc) converting to the concrete type should not be
-                               // necessary to send a packet.
-                               msg := exitMsg{
-                                       PeersId:   ch.(*channel).theirId,
-                                       Request:   "exit-status",
-                                       WantReply: false,
-                                       Status:    0,
-                               }
-                               ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
-                       }()
+                       go handler(ch.(*channel))
                }
                t.Log("done")
        }()
@@ -91,7 +73,7 @@ func dial(t *testing.T) *ClientConn {
 
 // Test a simple string is returned to session.Stdout.
 func TestSessionShell(t *testing.T) {
-       conn := dial(t)
+       conn := dial(shellHandler, t)
        defer conn.Close()
        session, err := conn.NewSession()
        if err != nil {
@@ -116,7 +98,7 @@ func TestSessionShell(t *testing.T) {
 
 // Test a simple string is returned via StdoutPipe.
 func TestSessionStdoutPipe(t *testing.T) {
-       conn := dial(t)
+       conn := dial(shellHandler, t)
        defer conn.Close()
        session, err := conn.NewSession()
        if err != nil {
@@ -147,3 +129,237 @@ func TestSessionStdoutPipe(t *testing.T) {
                t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
        }
 }
+
+// Test non-0 exit status is returned correctly.
+func TestExitStatusNonZero(t *testing.T) {
+       conn := dial(exitStatusNonZeroHandler, t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       err = session.Wait()
+       if err == nil {
+               t.Fatalf("expected command to fail but it didn't")
+       }
+       e, ok := err.(*ExitError)
+       if !ok {
+               t.Fatalf("expected *ExitError but got %T", err)
+       }
+       if e.ExitStatus() != 15 {
+               t.Fatalf("expected command to exit with 15 but got %s", e.ExitStatus())
+       }
+}
+
+// Test 0 exit status is returned correctly.
+func TestExitStatusZero(t *testing.T) {
+       conn := dial(exitStatusZeroHandler, t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       err = session.Wait()
+       if err != nil {
+               t.Fatalf("expected nil but got %s", err)
+       }
+}
+
+// Test exit signal and status are both returned correctly.
+func TestExitSignalAndStatus(t *testing.T) {
+       conn := dial(exitSignalAndStatusHandler, t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       err = session.Wait()
+       if err == nil {
+               t.Fatalf("expected command to fail but it didn't")
+       }
+       e, ok := err.(*ExitError)
+       if !ok {
+               t.Fatalf("expected *ExitError but got %T", err)
+       }
+       if e.Signal() != "TERM" || e.ExitStatus() != 15 {
+               t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
+       }
+}
+
+// Test exit signal and status are both returned correctly.
+func TestKnownExitSignalOnly(t *testing.T) {
+       conn := dial(exitSignalHandler, t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       err = session.Wait()
+       if err == nil {
+               t.Fatalf("expected command to fail but it didn't")
+       }
+       e, ok := err.(*ExitError)
+       if !ok {
+               t.Fatalf("expected *ExitError but got %T", err)
+       }
+       if e.Signal() != "TERM" || e.ExitStatus() != 143 {
+               t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
+       }
+}
+
+// Test exit signal and status are both returned correctly.
+func TestUnknownExitSignal(t *testing.T) {
+       conn := dial(exitSignalUnknownHandler, t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       err = session.Wait()
+       if err == nil {
+               t.Fatalf("expected command to fail but it didn't")
+       }
+       e, ok := err.(*ExitError)
+       if !ok {
+               t.Fatalf("expected *ExitError but got %T", err)
+       }
+       if e.Signal() != "SYS" || e.ExitStatus() != 128 {
+               t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
+       }
+}
+
+// Test WaitMsg is not returned if the channel closes abruptly.
+func TestExitWithoutStatusOrSignal(t *testing.T) {
+       conn := dial(exitWithoutSignalOrStatus, t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       err = session.Wait()
+       if err == nil {
+               t.Fatalf("expected command to fail but it didn't")
+       }
+       _, ok := err.(*ExitError)
+       if ok {
+               // you can't actually test for errors.errorString
+               // because it's not exported.
+               t.Fatalf("expected *errorString but got %T", err)
+       }
+}
+
+type exitStatusMsg struct {
+       PeersId   uint32
+       Request   string
+       WantReply bool
+       Status    uint32
+}
+
+type exitSignalMsg struct {
+       PeersId    uint32
+       Request    string
+       WantReply  bool
+       Signal     string
+       CoreDumped bool
+       Errmsg     string
+       Lang       string
+}
+
+func exitStatusZeroHandler(ch *channel) {
+       defer ch.Close()
+       // this string is returned to stdout
+       shell := NewServerShell(ch, "> ")
+       shell.ReadLine()
+       sendStatus(0, ch)
+}
+
+func exitStatusNonZeroHandler(ch *channel) {
+       defer ch.Close()
+       shell := NewServerShell(ch, "> ")
+       shell.ReadLine()
+       sendStatus(15, ch)
+}
+
+func exitSignalAndStatusHandler(ch *channel) {
+       defer ch.Close()
+       shell := NewServerShell(ch, "> ")
+       shell.ReadLine()
+       sendStatus(15, ch)
+       sendSignal("TERM", ch)
+}
+
+func exitSignalHandler(ch *channel) {
+       defer ch.Close()
+       shell := NewServerShell(ch, "> ")
+       shell.ReadLine()
+       sendSignal("TERM", ch)
+}
+
+func exitSignalUnknownHandler(ch *channel) {
+       defer ch.Close()
+       shell := NewServerShell(ch, "> ")
+       shell.ReadLine()
+       sendSignal("SYS", ch)
+}
+
+func exitWithoutSignalOrStatus(ch *channel) {
+       defer ch.Close()
+       shell := NewServerShell(ch, "> ")
+       shell.ReadLine()
+}
+
+func shellHandler(ch *channel) {
+       defer ch.Close()
+       // this string is returned to stdout
+       shell := NewServerShell(ch, "golang")
+       shell.ReadLine()
+       sendStatus(0, ch)
+}
+
+func sendStatus(status uint32, ch *channel) {
+       msg := exitStatusMsg{
+               PeersId:   ch.theirId,
+               Request:   "exit-status",
+               WantReply: false,
+               Status:    status,
+       }
+       ch.serverConn.writePacket(marshal(msgChannelRequest, msg))
+}
+
+func sendSignal(signal string, ch *channel) {
+       sig := exitSignalMsg{
+               PeersId:    ch.theirId,
+               Request:    "exit-signal",
+               WantReply:  false,
+               Signal:     signal,
+               CoreDumped: false,
+               Errmsg:     "Process terminated",
+               Lang:       "en-GB-oed",
+       }
+       ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
+}