]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: improve client channel close behavior
authorDave Cheney <dave@cheney.net>
Tue, 13 Dec 2011 15:27:17 +0000 (10:27 -0500)
committerAdam Langley <agl@golang.org>
Tue, 13 Dec 2011 15:27:17 +0000 (10:27 -0500)
R=gustav.paul
CC=golang-dev
https://golang.org/cl/5480062

src/pkg/exp/ssh/client.go

index 0ce8bcaf4ffdd41aa41708be59fc080a1d37d220..7c862078b7e95a46cdfd19149a6063b31c0c5663 100644 (file)
@@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() {
                        peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
                        if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
                                packet = packet[9:]
-                               c.getChan(peersId).stdout.data <- packet[:length]
+                               c.getChan(peersId).stdout.handleData(packet[:length])
                        }
                case msgChannelExtendedData:
                        if len(packet) < 13 {
@@ -215,7 +215,7 @@ func (c *ClientConn) mainLoop() {
                                // for stderr on interactive sessions. Other data types are
                                // silently discarded.
                                if datatype == 1 {
-                                       c.getChan(peersId).stderr.data <- packet[:length]
+                                       c.getChan(peersId).stderr.handleData(packet[:length])
                                }
                        }
                default:
@@ -228,13 +228,22 @@ func (c *ClientConn) mainLoop() {
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelCloseMsg:
                                ch := c.getChan(msg.PeersId)
+                               ch.theyClosed = true
                                close(ch.stdin.win)
-                               close(ch.stdout.data)
-                               close(ch.stderr.data)
+                               ch.stdout.eof()
+                               ch.stderr.eof()
                                close(ch.msg)
+                               if !ch.weClosed {
+                                       ch.weClosed = true
+                                       ch.sendClose()
+                               }
                                c.chanlist.remove(msg.PeersId)
                        case *channelEOFMsg:
-                               c.getChan(msg.PeersId).sendEOF()
+                               ch := c.getChan(msg.PeersId)
+                               ch.stdout.eof()
+                               // RFC 4254 is mute on how EOF affects dataExt messages but
+                               // it is logical to signal EOF at the same time.
+                               ch.stderr.eof()
                        case *channelRequestSuccessMsg:
                                c.getChan(msg.PeersId).msg <- msg
                        case *channelRequestFailureMsg:
@@ -243,6 +252,8 @@ func (c *ClientConn) mainLoop() {
                                c.getChan(msg.PeersId).msg <- msg
                        case *windowAdjustMsg:
                                c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
+                       case *disconnectMsg:
+                               break
                        default:
                                fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
                        }
@@ -295,6 +306,9 @@ type clientChan struct {
        stdout      *chanReader      // receives the payload of channelData messages
        stderr      *chanReader      // receives the payload of channelExtendedData messages
        msg         chan interface{} // incoming messages
+
+       theyClosed bool // indicates the close msg has been received from the remote side
+       weClosed   bool // incidates the close msg has been sent from our side
 }
 
 // newClientChan returns a partially constructed *clientChan
@@ -336,20 +350,29 @@ func (c *clientChan) waitForChannelOpenResponse() error {
        return errors.New("unexpected packet")
 }
 
-// sendEOF Sends EOF to the server. RFC 4254 Section 5.3
+// sendEOF sends EOF to the server. RFC 4254 Section 5.3
 func (c *clientChan) sendEOF() error {
        return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
                PeersId: c.peersId,
        }))
 }
 
-// Close closes the channel. This does not close the underlying connection.
-func (c *clientChan) Close() error {
+// sendClose signals the intent to close the channel.
+func (c *clientChan) sendClose() error {
        return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
                PeersId: c.peersId,
        }))
 }
 
+// Close closes the channel. This does not close the underlying connection.
+func (c *clientChan) Close() error {
+       if !c.weClosed {
+               c.weClosed = true
+               return c.sendClose()
+       }
+       return nil
+}
+
 // Thread safe channel list.
 type chanlist struct {
        // protects concurrent access to chans
@@ -421,7 +444,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
 }
 
 func (w *chanWriter) Close() error {
-       return w.clientChan.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.clientChan.peersId}))
+       return w.clientChan.sendEOF()
 }
 
 // A chanReader represents stdout or stderr of a remote process.
@@ -430,10 +453,27 @@ type chanReader struct {
        // If writes to this channel block, they will block mainLoop, making
        // it unable to receive new messages from the remote side.
        data       chan []byte // receives data from remote
+       dataClosed bool        // protects data from being closed twice
        clientChan *clientChan // the channel backing this reader
        buf        []byte
 }
 
+// eof signals to the consumer that there is no more data to be received.
+func (r *chanReader) eof() {
+       if !r.dataClosed {
+               r.dataClosed = true
+               close(r.data)
+       }
+}
+
+// handleData sends buf to the reader's consumer. If r.data is closed
+// the data will be silently discarded
+func (r *chanReader) handleData(buf []byte) {
+       if !r.dataClosed {
+               r.data <- buf
+       }
+}
+
 // Read reads data from the remote process's stdout or stderr.
 func (r *chanReader) Read(data []byte) (int, error) {
        var ok bool