]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: fix length header leaking into channel data streams.
authorDave Cheney <dave@cheney.net>
Sat, 29 Oct 2011 18:22:30 +0000 (14:22 -0400)
committerAdam Langley <agl@golang.org>
Sat, 29 Oct 2011 18:22:30 +0000 (14:22 -0400)
The payload of a data message is defined as an SSH string type,
which uses the first four bytes to encode its length. When channelData
and channelExtendedData were added I defined Payload as []byte to
be able to use it directly without a string to []byte conversion. This
resulted in the length data leaking into the payload data.

This CL fixes the bug, and restores agl's original fast path code.

Additionally, a bug whereby s.lock was not released if a packet arrived
for an invalid channel has been fixed.

Finally, as they were no longer used, I have removed
the channelData and channelExtedendData structs.

R=agl, rsc
CC=golang-dev
https://golang.org/cl/5330053

src/pkg/exp/ssh/client.go
src/pkg/exp/ssh/messages.go
src/pkg/exp/ssh/server.go

index 9223b6c3cf21b60aad69586e4e40353048eeb4c4..fe76db16c4b1e978916e634e37f2bb6bfdbc2594 100644 (file)
@@ -258,51 +258,71 @@ func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) {
 // mainloop reads incoming messages and routes channel messages
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
+       // TODO(dfc) signal the underlying close to all channels
+       defer c.Close()
        for {
                packet, err := c.readPacket()
                if err != nil {
-                       // TODO(dfc) signal the underlying close to all channels
-                       c.Close()
-                       return
+                       break
                }
                // TODO(dfc) A note on blocking channel use. 
                // The msg, win, data and dataExt channels of a clientChan can 
                // cause this loop to block indefinately if the consumer does 
                // not service them. 
-               switch msg := decode(packet).(type) {
-               case *channelOpenMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *channelOpenConfirmMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *channelOpenFailureMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *channelCloseMsg:
-                       ch := c.getChan(msg.PeersId)
-                       close(ch.win)
-                       close(ch.data)
-                       close(ch.dataExt)
-                       c.chanlist.remove(msg.PeersId)
-               case *channelEOFMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *channelRequestSuccessMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *channelRequestFailureMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *channelRequestMsg:
-                       c.getChan(msg.PeersId).msg <- msg
-               case *windowAdjustMsg:
-                       c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
-               case *channelData:
-                       c.getChan(msg.PeersId).data <- msg.Payload
-               case *channelExtendedData:
-                       // RFC 4254 5.2 defines data_type_code 1 to be data destined 
-                       // for stderr on interactive sessions. Other data types are
-                       // silently discarded.
-                       if msg.Datatype == 1 {
-                               c.getChan(msg.PeersId).dataExt <- msg.Payload
+               switch packet[0] {
+               case msgChannelData:
+                       if len(packet) < 9 {
+                               // malformed data packet
+                               break
+                       }
+                       peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+                       if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
+                               packet = packet[9:]
+                               c.getChan(peersId).data <- packet[:length]
+                       }
+               case msgChannelExtendedData:
+                       if len(packet) < 13 {
+                               // malformed data packet
+                               break
+                       }
+                       peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+                       datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
+                       if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
+                               packet = packet[13:]
+                               // RFC 4254 5.2 defines data_type_code 1 to be data destined 
+                               // for stderr on interactive sessions. Other data types are
+                               // silently discarded.
+                               if datatype == 1 {
+                                       c.getChan(peersId).dataExt <- packet[:length]
+                               }
                        }
                default:
-                       fmt.Printf("mainLoop: unhandled %#v\n", msg)
+                       switch msg := decode(packet).(type) {
+                       case *channelOpenMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *channelOpenConfirmMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *channelOpenFailureMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *channelCloseMsg:
+                               ch := c.getChan(msg.PeersId)
+                               close(ch.win)
+                               close(ch.data)
+                               close(ch.dataExt)
+                               c.chanlist.remove(msg.PeersId)
+                       case *channelEOFMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *channelRequestSuccessMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *channelRequestFailureMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *channelRequestMsg:
+                               c.getChan(msg.PeersId).msg <- msg
+                       case *windowAdjustMsg:
+                               c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
+                       default:
+                               fmt.Printf("mainLoop: unhandled %#v\n", msg)
+                       }
                }
        }
 }
index 7771f2b242e662287c3d64d0b6a9826e916dbca5..5f2c447142adb328e6da7358539180b0cc3579ff 100644 (file)
@@ -144,19 +144,6 @@ type channelOpenFailureMsg struct {
        Language string
 }
 
-// See RFC 4254, section 5.2.
-type channelData struct {
-       PeersId uint32
-       Payload []byte `ssh:"rest"`
-}
-
-// See RFC 4254, section 5.2.
-type channelExtendedData struct {
-       PeersId  uint32
-       Datatype uint32
-       Payload  []byte `ssh:"rest"`
-}
-
 type channelRequestMsg struct {
        PeersId             uint32
        Request             string
@@ -612,10 +599,6 @@ func decode(packet []byte) interface{} {
                msg = new(channelOpenFailureMsg)
        case msgChannelWindowAdjust:
                msg = new(windowAdjustMsg)
-       case msgChannelData:
-               msg = new(channelData)
-       case msgChannelExtendedData:
-               msg = new(channelExtendedData)
        case msgChannelEOF:
                msg = new(channelEOFMsg)
        case msgChannelClose:
index 3a640fc081ec432f633cb874f9cf49f4e139384b..0dd24ecd6ee7d4512e5a4a79af0250322a63c7c1 100644 (file)
@@ -581,75 +581,89 @@ func (s *ServerConn) Accept() (Channel, os.Error) {
                        return nil, err
                }
 
-               switch msg := decode(packet).(type) {
-               case *channelOpenMsg:
-                       c := new(channel)
-                       c.chanType = msg.ChanType
-                       c.theirId = msg.PeersId
-                       c.theirWindow = msg.PeersWindow
-                       c.maxPacketSize = msg.MaxPacketSize
-                       c.extraData = msg.TypeSpecificData
-                       c.myWindow = defaultWindowSize
-                       c.serverConn = s
-                       c.cond = sync.NewCond(&c.lock)
-                       c.pendingData = make([]byte, c.myWindow)
-
-                       s.lock.Lock()
-                       c.myId = s.nextChanId
-                       s.nextChanId++
-                       s.channels[c.myId] = c
-                       s.lock.Unlock()
-                       return c, nil
-
-               case *channelRequestMsg:
-                       s.lock.Lock()
-                       c, ok := s.channels[msg.PeersId]
-                       if !ok {
-                               continue
+               switch packet[0] {
+               case msgChannelData:
+                       if len(packet) < 9 {
+                               // malformed data packet
+                               return nil, ParseError{msgChannelData}
                        }
-                       c.handlePacket(msg)
-                       s.lock.Unlock()
-
-               case *channelData:
+                       peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
                        s.lock.Lock()
-                       c, ok := s.channels[msg.PeersId]
+                       c, ok := s.channels[peersId]
                        if !ok {
+                               s.lock.Unlock()
                                continue
                        }
-                       c.handleData(msg.Payload)
-                       s.lock.Unlock()
-
-               case *channelEOFMsg:
-                       s.lock.Lock()
-                       c, ok := s.channels[msg.PeersId]
-                       if !ok {
-                               continue
+                       if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
+                               packet = packet[9:]
+                               c.handleData(packet[:length])
                        }
-                       c.handlePacket(msg)
                        s.lock.Unlock()
+               default:
+                       switch msg := decode(packet).(type) {
+                       case *channelOpenMsg:
+                               c := new(channel)
+                               c.chanType = msg.ChanType
+                               c.theirId = msg.PeersId
+                               c.theirWindow = msg.PeersWindow
+                               c.maxPacketSize = msg.MaxPacketSize
+                               c.extraData = msg.TypeSpecificData
+                               c.myWindow = defaultWindowSize
+                               c.serverConn = s
+                               c.cond = sync.NewCond(&c.lock)
+                               c.pendingData = make([]byte, c.myWindow)
+
+                               s.lock.Lock()
+                               c.myId = s.nextChanId
+                               s.nextChanId++
+                               s.channels[c.myId] = c
+                               s.lock.Unlock()
+                               return c, nil
+
+                       case *channelRequestMsg:
+                               s.lock.Lock()
+                               c, ok := s.channels[msg.PeersId]
+                               if !ok {
+                                       s.lock.Unlock()
+                                       continue
+                               }
+                               c.handlePacket(msg)
+                               s.lock.Unlock()
 
-               case *channelCloseMsg:
-                       s.lock.Lock()
-                       c, ok := s.channels[msg.PeersId]
-                       if !ok {
-                               continue
-                       }
-                       c.handlePacket(msg)
-                       s.lock.Unlock()
+                       case *channelEOFMsg:
+                               s.lock.Lock()
+                               c, ok := s.channels[msg.PeersId]
+                               if !ok {
+                                       s.lock.Unlock()
+                                       continue
+                               }
+                               c.handlePacket(msg)
+                               s.lock.Unlock()
 
-               case *globalRequestMsg:
-                       if msg.WantReply {
-                               if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
-                                       return nil, err
+                       case *channelCloseMsg:
+                               s.lock.Lock()
+                               c, ok := s.channels[msg.PeersId]
+                               if !ok {
+                                       s.lock.Unlock()
+                                       continue
                                }
-                       }
+                               c.handlePacket(msg)
+                               s.lock.Unlock()
 
-               case UnexpectedMessageError:
-                       return nil, msg
-               case *disconnectMsg:
-                       return nil, os.EOF
-               default:
-                       // Unknown message. Ignore.
+                       case *globalRequestMsg:
+                               if msg.WantReply {
+                                       if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+                                               return nil, err
+                                       }
+                               }
+
+                       case UnexpectedMessageError:
+                               return nil, msg
+                       case *disconnectMsg:
+                               return nil, os.EOF
+                       default:
+                               // Unknown message. Ignore.
+                       }
                }
        }