]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: alter Session to match the exec.Cmd API
authorDave Cheney <dave@cheney.net>
Sun, 20 Nov 2011 16:46:35 +0000 (11:46 -0500)
committerAdam Langley <agl@golang.org>
Sun, 20 Nov 2011 16:46:35 +0000 (11:46 -0500)
This CL inverts the direction of the Stdin/out/err members of the
Session struct so they reflect the API of the exec.Cmd. In doing so
it borrows heavily from the exec package.

Additionally Shell now returns immediately, wait for completion using
Wait. Exec calls Wait internally and so blocks until the remote
command is complete.

Credit to Gustavo Niemeyer for the impetus for this CL.

R=rsc, agl, n13m3y3r, huin, bradfitz
CC=cw, golang-dev
https://golang.org/cl/5322055

src/pkg/exp/ssh/client.go
src/pkg/exp/ssh/session.go

index 24569ad9389bd65d6c026a208606660ecd75d032..9721723488bbbcc7d08a25ee6c68e0d53397f5b0 100644 (file)
@@ -342,17 +342,6 @@ func (c *clientChan) Close() error {
        }))
 }
 
-func (c *clientChan) sendChanReq(req channelRequestMsg) error {
-       if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
-               return err
-       }
-       msg := <-c.msg
-       if _, ok := msg.(*channelRequestSuccessMsg); ok {
-               return nil
-       }
-       return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
-}
-
 // Thread safe channel list.
 type chanlist struct {
        // protects concurrent access to chans
index 77154f2c3c30f544dbc0ac5764b0d3d26377f8a9..181a89688396ec6f48626300ec95946c00ead890 100644 (file)
@@ -8,66 +8,104 @@ package ssh
 // "RFC 4254, section 6".
 
 import (
-       "encoding/binary"
+       "bytes"
        "errors"
+       "fmt"
        "io"
+       "io/ioutil"
 )
 
 // A Session represents a connection to a remote command or shell.
 type Session struct {
-       // Writes to Stdin are made available to the remote command's standard input.
-       // Closing Stdin causes the command to observe an EOF on its standard input.
-       Stdin io.WriteCloser
-
-       // Reads from Stdout and Stderr consume from the remote command's standard
-       // output and error streams, respectively.
-       // There is a fixed amount of buffering that is shared for the two streams.
-       // Failing to read from either may eventually cause the command to block.
-       // Closing Stdout unblocks such writes and causes them to return errors.
-       Stdout io.ReadCloser
-       Stderr io.Reader
+       // Stdin specifies the remote process's standard input.
+       // If Stdin is nil, the remote process reads from an empty 
+       // bytes.Buffer.
+       Stdin io.Reader
+
+       // Stdout and Stderr specify the remote process's standard 
+       // output and error.
+       //
+       // If either is nil, Run connects the corresponding file 
+       // descriptor to an instance of ioutil.Discard. There is a 
+       // fixed amount of buffering that is shared for the two streams. 
+       // If either blocks it may eventually cause the remote 
+       // command to block.
+       Stdout io.Writer
+       Stderr io.Writer
 
        *clientChan // the channel backing this session
 
-       started bool // started is set to true once a Shell or Exec is invoked.
+       started   bool // true once a Shell or Exec is invoked.
+       copyFuncs []func() error
+       errch     chan error // one send per copyFunc
+}
+
+// RFC 4254 Section 6.4.
+type setenvRequest struct {
+       PeersId   uint32
+       Request   string
+       WantReply bool
+       Name      string
+       Value     string
 }
 
 // Setenv sets an environment variable that will be applied to any
 // command executed by Shell or Exec.
 func (s *Session) Setenv(name, value string) error {
-       n, v := []byte(name), []byte(value)
-       nlen, vlen := stringLength(n), stringLength(v)
-       payload := make([]byte, nlen+vlen)
-       marshalString(payload[:nlen], n)
-       marshalString(payload[nlen:], v)
-
-       return s.sendChanReq(channelRequestMsg{
-               PeersId:             s.id,
-               Request:             "env",
-               WantReply:           true,
-               RequestSpecificData: payload,
-       })
+       req := setenvRequest{
+               PeersId:   s.id,
+               Request:   "env",
+               WantReply: true,
+               Name:      name,
+               Value:     value,
+       }
+       if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+               return err
+       }
+       return s.waitForResponse()
 }
 
-// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8.
-var emptyModeList = []byte{0, 0, 0, 1, 0}
+// An empty mode list, see RFC 4254 Section 8.
+var emptyModelist = "\x00"
+
+// RFC 4254 Section 6.2.
+type ptyRequestMsg struct {
+       PeersId   uint32
+       Request   string
+       WantReply bool
+       Term      string
+       Columns   uint32
+       Rows      uint32
+       Width     uint32
+       Height    uint32
+       Modelist  string
+}
 
 // RequestPty requests the association of a pty with the session on the remote host.
 func (s *Session) RequestPty(term string, h, w int) error {
-       buf := make([]byte, 4+len(term)+16+len(emptyModeList))
-       b := marshalString(buf, []byte(term))
-       binary.BigEndian.PutUint32(b, uint32(h))
-       binary.BigEndian.PutUint32(b[4:], uint32(w))
-       binary.BigEndian.PutUint32(b[8:], uint32(h*8))
-       binary.BigEndian.PutUint32(b[12:], uint32(w*8))
-       copy(b[16:], emptyModeList)
-
-       return s.sendChanReq(channelRequestMsg{
-               PeersId:             s.id,
-               Request:             "pty-req",
-               WantReply:           true,
-               RequestSpecificData: buf,
-       })
+       req := ptyRequestMsg{
+               PeersId:   s.id,
+               Request:   "pty-req",
+               WantReply: true,
+               Term:      term,
+               Columns:   uint32(w),
+               Rows:      uint32(h),
+               Width:     uint32(w * 8),
+               Height:    uint32(h * 8),
+               Modelist:  emptyModelist,
+       }
+       if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+               return err
+       }
+       return s.waitForResponse()
+}
+
+// RFC 4254 Section 6.5.
+type execMsg struct {
+       PeersId   uint32
+       Request   string
+       WantReply bool
+       Command   string
 }
 
 // Exec runs cmd on the remote host. Typically, the remote 
@@ -75,34 +113,166 @@ func (s *Session) RequestPty(term string, h, w int) error {
 // A Session only accepts one call to Exec or Shell.
 func (s *Session) Exec(cmd string) error {
        if s.started {
-               return errors.New("session already started")
+               return errors.New("ssh: session already started")
        }
-       cmdLen := stringLength([]byte(cmd))
-       payload := make([]byte, cmdLen)
-       marshalString(payload, []byte(cmd))
-       s.started = true
-
-       return s.sendChanReq(channelRequestMsg{
-               PeersId:             s.id,
-               Request:             "exec",
-               WantReply:           true,
-               RequestSpecificData: payload,
-       })
+       req := execMsg{
+               PeersId:   s.id,
+               Request:   "exec",
+               WantReply: true,
+               Command:   cmd,
+       }
+       if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+               return err
+       }
+       if err := s.waitForResponse(); err != nil {
+               return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err)
+       }
+       if err := s.start(); err != nil {
+               return err
+       }
+       return s.Wait()
 }
 
 // Shell starts a login shell on the remote host. A Session only 
 // accepts one call to Exec or Shell.
 func (s *Session) Shell() error {
        if s.started {
-               return errors.New("session already started")
+               return errors.New("ssh: session already started")
        }
-       s.started = true
-
-       return s.sendChanReq(channelRequestMsg{
+       req := channelRequestMsg{
                PeersId:   s.id,
                Request:   "shell",
                WantReply: true,
+       }
+       if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+               return err
+       }
+       if err := s.waitForResponse(); err != nil {
+               return fmt.Errorf("ssh: cound not execute shell: %v", err)
+       }
+       return s.start()
+}
+
+func (s *Session) waitForResponse() error {
+       msg := <-s.msg
+       switch msg.(type) {
+       case *channelRequestSuccessMsg:
+               return nil
+       case *channelRequestFailureMsg:
+               return errors.New("request failed")
+       }
+       return fmt.Errorf("unknown packet %T received: %v", msg, msg)
+}
+
+func (s *Session) start() error {
+       s.started = true
+
+       type F func(*Session) error
+       for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
+               if err := setupFd(s); err != nil {
+                       return err
+               }
+       }
+
+       s.errch = make(chan error, len(s.copyFuncs))
+       for _, fn := range s.copyFuncs {
+               go func(fn func() error) {
+                       s.errch <- fn()
+               }(fn)
+       }
+       return nil
+}
+
+// Wait waits for the remote command to exit. 
+func (s *Session) Wait() error {
+       if !s.started {
+               return errors.New("ssh: session not started")
+       }
+       waitErr := s.wait()
+
+       var copyError error
+       for _ = range s.copyFuncs {
+               if err := <-s.errch; err != nil && copyError == nil {
+                       copyError = err
+               }
+       }
+
+       if waitErr != nil {
+               return waitErr
+       }
+
+       return copyError
+}
+
+func (s *Session) wait() error {
+       for {
+               switch msg := (<-s.msg).(type) {
+               case *channelRequestMsg:
+                       // TODO(dfc) improve this behavior to match os.Waitmsg
+                       switch msg.Request {
+                       case "exit-status":
+                               d := msg.RequestSpecificData
+                               status := int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
+                               if status > 0 {
+                                       return fmt.Errorf("remote process exited with %d", status)
+                               }
+                               return nil
+                       case "exit-signal":
+                               // TODO(dfc) make a more readable error message
+                               return fmt.Errorf("%v", msg.RequestSpecificData)
+                       default:
+                               return fmt.Errorf("wait: unexpected channel request: %v", msg)
+                       }
+               default:
+                       return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg)
+               }
+       }
+       panic("unreachable")
+}
+
+func (s *Session) stdin() error {
+       if s.Stdin == nil {
+               s.Stdin = new(bytes.Buffer)
+       }
+       s.copyFuncs = append(s.copyFuncs, func() error {
+               _, err := io.Copy(&chanWriter{
+                       packetWriter: s,
+                       id:           s.id,
+                       win:          s.win,
+               }, s.Stdin)
+               return err
+       })
+       return nil
+}
+
+func (s *Session) stdout() error {
+       if s.Stdout == nil {
+               s.Stdout = ioutil.Discard
+       }
+       s.copyFuncs = append(s.copyFuncs, func() error {
+               _, err := io.Copy(s.Stdout, &chanReader{
+                       packetWriter: s,
+                       id:           s.id,
+                       data:         s.data,
+               })
+               return err
+       })
+       return nil
+}
+
+func (s *Session) stderr() error {
+       if s.Stderr == nil {
+               s.Stderr = ioutil.Discard
+       }
+       s.copyFuncs = append(s.copyFuncs, func() error {
+               _, err := io.Copy(s.Stderr, &chanReader{
+                       packetWriter: s,
+                       id:           s.id,
+                       data:         s.dataExt,
+               })
+               return err
        })
+       return nil
 }
 
 // NewSession returns a new interactive session on the remote host.
@@ -112,21 +282,6 @@ func (c *ClientConn) NewSession() (*Session, error) {
                return nil, err
        }
        return &Session{
-               Stdin: &chanWriter{
-                       packetWriter: ch,
-                       id:           ch.id,
-                       win:          ch.win,
-               },
-               Stdout: &chanReader{
-                       packetWriter: ch,
-                       id:           ch.id,
-                       data:         ch.data,
-               },
-               Stderr: &chanReader{
-                       packetWriter: ch,
-                       id:           ch.id,
-                       data:         ch.dataExt,
-               },
                clientChan: ch,
        }, nil
 }