]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: simplify client channel open logic
authorDave Cheney <dave@cheney.net>
Tue, 6 Dec 2011 14:33:23 +0000 (09:33 -0500)
committerAdam Langley <agl@golang.org>
Tue, 6 Dec 2011 14:33:23 +0000 (09:33 -0500)
This is part one of a small set of CL's that aim to resolve
the outstanding TODOs relating to channel close and blocking
behavior.

Firstly, the hairy handling of assigning the peersId is now
done in one place. The cost of this change is the slightly
paradoxical construction of the partially created clientChan.

Secondly, by creating clientChan.stdin/out/err when the channel
is opened, the creation of consumers like tcpchan and Session
is simplified; they just have to wire themselves up to the
relevant readers/writers.

R=agl, gustav.paul, rsc
CC=golang-dev
https://golang.org/cl/5448073

src/pkg/exp/ssh/client.go
src/pkg/exp/ssh/session.go
src/pkg/exp/ssh/tcpip.go

index 429dee975bce754d9876a5b1ea12476bca4837ee..d89b908cdcf47ddb952da9a35c5a05c3bccb96c1 100644 (file)
@@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() {
                        peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
                        if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
                                packet = packet[9:]
-                               c.getChan(peersId).data <- packet[:length]
+                               c.getChan(peersId).stdout.data <- packet[:length]
                        }
                case msgChannelExtendedData:
                        if len(packet) < 13 {
@@ -215,7 +215,7 @@ func (c *ClientConn) mainLoop() {
                                // for stderr on interactive sessions. Other data types are
                                // silently discarded.
                                if datatype == 1 {
-                                       c.getChan(peersId).dataExt <- packet[:length]
+                                       c.getChan(peersId).stderr.data <- packet[:length]
                                }
                        }
                default:
@@ -228,9 +228,9 @@ func (c *ClientConn) mainLoop() {
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelCloseMsg:
                                ch := c.getChan(msg.PeersId)
-                               close(ch.win)
-                               close(ch.data)
-                               close(ch.dataExt)
+                               close(ch.stdin.win)
+                               close(ch.stdout.data)
+                               close(ch.stderr.data)
                                c.chanlist.remove(msg.PeersId)
                        case *channelEOFMsg:
                                c.getChan(msg.PeersId).msg <- msg
@@ -241,7 +241,7 @@ func (c *ClientConn) mainLoop() {
                        case *channelRequestMsg:
                                c.getChan(msg.PeersId).msg <- msg
                        case *windowAdjustMsg:
-                               c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
+                               c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
                        default:
                                fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
                        }
@@ -290,21 +290,49 @@ func (c *ClientConfig) rand() io.Reader {
 type clientChan struct {
        packetWriter
        id, peersId uint32
-       data        chan []byte      // receives the payload of channelData messages
-       dataExt     chan []byte      // receives the payload of channelExtendedData messages
-       win         chan int         // receives window adjustments
+       stdin       *chanWriter      // receives window adjustments
+       stdout      *chanReader      // receives the payload of channelData messages
+       stderr      *chanReader      // receives the payload of channelExtendedData messages
        msg         chan interface{} // incoming messages
 }
 
+// newClientChan returns a partially constructed *clientChan
+// using the local id provided. To be usable clientChan.peersId 
+// needs to be assigned once known.
 func newClientChan(t *transport, id uint32) *clientChan {
-       return &clientChan{
+       c := &clientChan{
                packetWriter: t,
                id:           id,
-               data:         make(chan []byte, 16),
-               dataExt:      make(chan []byte, 16),
-               win:          make(chan int, 16),
                msg:          make(chan interface{}, 16),
        }
+       c.stdin = &chanWriter{
+               win:        make(chan int, 16),
+               clientChan: c,
+       }
+       c.stdout = &chanReader{
+               data:       make(chan []byte, 16),
+               clientChan: c,
+       }
+       c.stderr = &chanReader{
+               data:       make(chan []byte, 16),
+               clientChan: c,
+       }
+       return c
+}
+
+// waitForChannelOpenResponse, if successful, fills out 
+// the peerId and records any initial window advertisement. 
+func (c *clientChan) waitForChannelOpenResponse() error {
+       switch msg := (<-c.msg).(type) {
+       case *channelOpenConfirmMsg:
+               // fixup peersId field
+               c.peersId = msg.MyId
+               c.stdin.win <- int(msg.MyWindow)
+               return nil
+       case *channelOpenFailureMsg:
+               return errors.New(safeString(msg.Message))
+       }
+       return errors.New("unexpected packet")
 }
 
 // Close closes the channel. This does not close the underlying connection.
@@ -355,10 +383,9 @@ func (c *chanlist) remove(id uint32) {
 
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
-       win          chan int // receives window adjustments
-       peersId      uint32   // the peer's id
-       rwin         int      // current rwin size
-       packetWriter          // for sending channelDataMsg
+       win        chan int    // receives window adjustments
+       rwin       int         // current rwin size
+       clientChan *clientChan // the channel backing this writer
 }
 
 // Write writes data to the remote process's standard input.
@@ -372,12 +399,13 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
                        w.rwin += win
                        continue
                }
+               peersId := w.clientChan.peersId
                n = len(data)
                packet := make([]byte, 0, 9+n)
                packet = append(packet, msgChannelData,
-                       byte(w.peersId>>24), byte(w.peersId>>16), byte(w.peersId>>8), byte(w.peersId),
+                       byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId),
                        byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
-               err = w.writePacket(append(packet, data...))
+               err = w.clientChan.writePacket(append(packet, data...))
                w.rwin -= n
                return
        }
@@ -385,7 +413,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
 }
 
 func (w *chanWriter) Close() error {
-       return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.peersId}))
+       return w.clientChan.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.clientChan.peersId}))
 }
 
 // A chanReader represents stdout or stderr of a remote process.
@@ -393,10 +421,9 @@ type chanReader struct {
        // TODO(dfc) a fixed size channel may not be the right data structure.
        // If writes to this channel block, they will block mainLoop, making
        // it unable to receive new messages from the remote side.
-       data         chan []byte // receives data from remote
-       peersId      uint32      // the peer's id
-       packetWriter             // for sending windowAdjustMsg
-       buf          []byte
+       data       chan []byte // receives data from remote
+       clientChan *clientChan // the channel backing this reader
+       buf        []byte
 }
 
 // Read reads data from the remote process's stdout or stderr.
@@ -407,10 +434,10 @@ func (r *chanReader) Read(data []byte) (int, error) {
                        n := copy(data, r.buf)
                        r.buf = r.buf[n:]
                        msg := windowAdjustMsg{
-                               PeersId:         r.peersId,
+                               PeersId:         r.clientChan.peersId,
                                AdditionalBytes: uint32(n),
                        }
-                       return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
+                       return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg))
                }
                r.buf, ok = <-r.data
                if !ok {
index 5f98a8d58c683956b701cef5df1c1c9f9ace1fa9..23ea18c29ae697e01f3651dfca0df7f3978b7f26 100644 (file)
@@ -285,13 +285,8 @@ func (s *Session) stdin() error {
                s.Stdin = new(bytes.Buffer)
        }
        s.copyFuncs = append(s.copyFuncs, func() error {
-               w := &chanWriter{
-                       packetWriter: s,
-                       peersId:      s.peersId,
-                       win:          s.win,
-               }
-               _, err := io.Copy(w, s.Stdin)
-               if err1 := w.Close(); err == nil {
+               _, err := io.Copy(s.clientChan.stdin, s.Stdin)
+               if err1 := s.clientChan.stdin.Close(); err == nil {
                        err = err1
                }
                return err
@@ -304,12 +299,7 @@ func (s *Session) stdout() error {
                s.Stdout = ioutil.Discard
        }
        s.copyFuncs = append(s.copyFuncs, func() error {
-               r := &chanReader{
-                       packetWriter: s,
-                       peersId:      s.peersId,
-                       data:         s.data,
-               }
-               _, err := io.Copy(s.Stdout, r)
+               _, err := io.Copy(s.Stdout, s.clientChan.stdout)
                return err
        })
        return nil
@@ -320,12 +310,7 @@ func (s *Session) stderr() error {
                s.Stderr = ioutil.Discard
        }
        s.copyFuncs = append(s.copyFuncs, func() error {
-               r := &chanReader{
-                       packetWriter: s,
-                       peersId:      s.peersId,
-                       data:         s.dataExt,
-               }
-               _, err := io.Copy(s.Stderr, r)
+               _, err := io.Copy(s.Stderr, s.clientChan.stderr)
                return err
        })
        return nil
@@ -398,19 +383,11 @@ func (c *ClientConn) NewSession() (*Session, error) {
                c.chanlist.remove(ch.id)
                return nil, err
        }
-       // wait for response
-       msg := <-ch.msg
-       switch msg := msg.(type) {
-       case *channelOpenConfirmMsg:
-               ch.peersId = msg.MyId
-               ch.win <- int(msg.MyWindow)
-               return &Session{
-                       clientChan: ch,
-               }, nil
-       case *channelOpenFailureMsg:
+       if err := ch.waitForChannelOpenResponse(); err != nil {
                c.chanlist.remove(ch.id)
-               return nil, fmt.Errorf("ssh: channel open failed: %s", msg.Message)
+               return nil, fmt.Errorf("ssh: unable to open session: %v", err)
        }
-       c.chanlist.remove(ch.id)
-       return nil, fmt.Errorf("ssh: unexpected message %T: %v", msg, msg)
+       return &Session{
+               clientChan: ch,
+       }, nil
 }
index f3bbac5d19e1071a07b8e86029692b96b0bfb13b..a85044ace9cecc2c6ee2b3b29ae292de5590989b 100644 (file)
@@ -6,6 +6,7 @@ package ssh
 
 import (
        "errors"
+       "fmt"
        "io"
        "net"
 )
@@ -42,20 +43,21 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err
        }, nil
 }
 
+// RFC 4254 7.2
+type channelOpenDirectMsg struct {
+       ChanType      string
+       PeersId       uint32
+       PeersWindow   uint32
+       MaxPacketSize uint32
+       raddr         string
+       rport         uint32
+       laddr         string
+       lport         uint32
+}
+
 // dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
 // strings and are expected to be resolveable at the remote end.
 func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
-       // RFC 4254 7.2
-       type channelOpenDirectMsg struct {
-               ChanType      string
-               PeersId       uint32
-               PeersWindow   uint32
-               MaxPacketSize uint32
-               raddr         string
-               rport         uint32
-               laddr         string
-               lport         uint32
-       }
        ch := c.newChan(c.transport)
        if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
                ChanType:      "direct-tcpip",
@@ -70,30 +72,14 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc
                c.chanlist.remove(ch.id)
                return nil, err
        }
-       // wait for response
-       switch msg := (<-ch.msg).(type) {
-       case *channelOpenConfirmMsg:
-               ch.peersId = msg.MyId
-               ch.win <- int(msg.MyWindow)
-       case *channelOpenFailureMsg:
-               c.chanlist.remove(ch.id)
-               return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
-       default:
+       if err := ch.waitForChannelOpenResponse(); err != nil {
                c.chanlist.remove(ch.id)
-               return nil, errors.New("ssh: unexpected packet")
+               return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err)
        }
        return &tcpchan{
                clientChan: ch,
-               Reader: &chanReader{
-                       packetWriter: ch,
-                       peersId:      ch.peersId,
-                       data:         ch.data,
-               },
-               Writer: &chanWriter{
-                       packetWriter: ch,
-                       peersId:      ch.peersId,
-                       win:          ch.win,
-               },
+               Reader:     ch.stdout,
+               Writer:     ch.stdin,
        }, nil
 }