// mainloop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
+ // TODO(dfc) signal the underlying close to all channels
+ defer c.Close()
for {
packet, err := c.readPacket()
if err != nil {
- // TODO(dfc) signal the underlying close to all channels
- c.Close()
- return
+ break
}
// TODO(dfc) A note on blocking channel use.
// The msg, win, data and dataExt channels of a clientChan can
// cause this loop to block indefinately if the consumer does
// not service them.
- switch msg := decode(packet).(type) {
- case *channelOpenMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *channelOpenConfirmMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *channelOpenFailureMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *channelCloseMsg:
- ch := c.getChan(msg.PeersId)
- close(ch.win)
- close(ch.data)
- close(ch.dataExt)
- c.chanlist.remove(msg.PeersId)
- case *channelEOFMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *channelRequestSuccessMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *channelRequestFailureMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *channelRequestMsg:
- c.getChan(msg.PeersId).msg <- msg
- case *windowAdjustMsg:
- c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
- case *channelData:
- c.getChan(msg.PeersId).data <- msg.Payload
- case *channelExtendedData:
- // RFC 4254 5.2 defines data_type_code 1 to be data destined
- // for stderr on interactive sessions. Other data types are
- // silently discarded.
- if msg.Datatype == 1 {
- c.getChan(msg.PeersId).dataExt <- msg.Payload
+ switch packet[0] {
+ case msgChannelData:
+ if len(packet) < 9 {
+ // malformed data packet
+ break
+ }
+ peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+ if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
+ packet = packet[9:]
+ c.getChan(peersId).data <- packet[:length]
+ }
+ case msgChannelExtendedData:
+ if len(packet) < 13 {
+ // malformed data packet
+ break
+ }
+ peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+ datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
+ if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
+ packet = packet[13:]
+ // RFC 4254 5.2 defines data_type_code 1 to be data destined
+ // for stderr on interactive sessions. Other data types are
+ // silently discarded.
+ if datatype == 1 {
+ c.getChan(peersId).dataExt <- packet[:length]
+ }
}
default:
- fmt.Printf("mainLoop: unhandled %#v\n", msg)
+ switch msg := decode(packet).(type) {
+ case *channelOpenMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelOpenConfirmMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelOpenFailureMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelCloseMsg:
+ ch := c.getChan(msg.PeersId)
+ close(ch.win)
+ close(ch.data)
+ close(ch.dataExt)
+ c.chanlist.remove(msg.PeersId)
+ case *channelEOFMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelRequestSuccessMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelRequestFailureMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelRequestMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *windowAdjustMsg:
+ c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
+ default:
+ fmt.Printf("mainLoop: unhandled %#v\n", msg)
+ }
}
}
}
Language string
}
-// See RFC 4254, section 5.2.
-type channelData struct {
- PeersId uint32
- Payload []byte `ssh:"rest"`
-}
-
-// See RFC 4254, section 5.2.
-type channelExtendedData struct {
- PeersId uint32
- Datatype uint32
- Payload []byte `ssh:"rest"`
-}
-
type channelRequestMsg struct {
PeersId uint32
Request string
msg = new(channelOpenFailureMsg)
case msgChannelWindowAdjust:
msg = new(windowAdjustMsg)
- case msgChannelData:
- msg = new(channelData)
- case msgChannelExtendedData:
- msg = new(channelExtendedData)
case msgChannelEOF:
msg = new(channelEOFMsg)
case msgChannelClose:
return nil, err
}
- switch msg := decode(packet).(type) {
- case *channelOpenMsg:
- c := new(channel)
- c.chanType = msg.ChanType
- c.theirId = msg.PeersId
- c.theirWindow = msg.PeersWindow
- c.maxPacketSize = msg.MaxPacketSize
- c.extraData = msg.TypeSpecificData
- c.myWindow = defaultWindowSize
- c.serverConn = s
- c.cond = sync.NewCond(&c.lock)
- c.pendingData = make([]byte, c.myWindow)
-
- s.lock.Lock()
- c.myId = s.nextChanId
- s.nextChanId++
- s.channels[c.myId] = c
- s.lock.Unlock()
- return c, nil
-
- case *channelRequestMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- continue
+ switch packet[0] {
+ case msgChannelData:
+ if len(packet) < 9 {
+ // malformed data packet
+ return nil, ParseError{msgChannelData}
}
- c.handlePacket(msg)
- s.lock.Unlock()
-
- case *channelData:
+ peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
+ c, ok := s.channels[peersId]
if !ok {
+ s.lock.Unlock()
continue
}
- c.handleData(msg.Payload)
- s.lock.Unlock()
-
- case *channelEOFMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- continue
+ if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
+ packet = packet[9:]
+ c.handleData(packet[:length])
}
- c.handlePacket(msg)
s.lock.Unlock()
+ default:
+ switch msg := decode(packet).(type) {
+ case *channelOpenMsg:
+ c := new(channel)
+ c.chanType = msg.ChanType
+ c.theirId = msg.PeersId
+ c.theirWindow = msg.PeersWindow
+ c.maxPacketSize = msg.MaxPacketSize
+ c.extraData = msg.TypeSpecificData
+ c.myWindow = defaultWindowSize
+ c.serverConn = s
+ c.cond = sync.NewCond(&c.lock)
+ c.pendingData = make([]byte, c.myWindow)
+
+ s.lock.Lock()
+ c.myId = s.nextChanId
+ s.nextChanId++
+ s.channels[c.myId] = c
+ s.lock.Unlock()
+ return c, nil
+
+ case *channelRequestMsg:
+ s.lock.Lock()
+ c, ok := s.channels[msg.PeersId]
+ if !ok {
+ s.lock.Unlock()
+ continue
+ }
+ c.handlePacket(msg)
+ s.lock.Unlock()
- case *channelCloseMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- continue
- }
- c.handlePacket(msg)
- s.lock.Unlock()
+ case *channelEOFMsg:
+ s.lock.Lock()
+ c, ok := s.channels[msg.PeersId]
+ if !ok {
+ s.lock.Unlock()
+ continue
+ }
+ c.handlePacket(msg)
+ s.lock.Unlock()
- case *globalRequestMsg:
- if msg.WantReply {
- if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
- return nil, err
+ case *channelCloseMsg:
+ s.lock.Lock()
+ c, ok := s.channels[msg.PeersId]
+ if !ok {
+ s.lock.Unlock()
+ continue
}
- }
+ c.handlePacket(msg)
+ s.lock.Unlock()
- case UnexpectedMessageError:
- return nil, msg
- case *disconnectMsg:
- return nil, os.EOF
- default:
- // Unknown message. Ignore.
+ case *globalRequestMsg:
+ if msg.WantReply {
+ if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+ return nil, err
+ }
+ }
+
+ case UnexpectedMessageError:
+ return nil, msg
+ case *disconnectMsg:
+ return nil, os.EOF
+ default:
+ // Unknown message. Ignore.
+ }
}
}