]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: move common code to common.go
authorDave Cheney <dave@cheney.net>
Mon, 26 Sep 2011 14:25:13 +0000 (10:25 -0400)
committerAdam Langley <agl@golang.org>
Mon, 26 Sep 2011 14:25:13 +0000 (10:25 -0400)
R=agl
CC=golang-dev
https://golang.org/cl/5132041

src/pkg/exp/ssh/Makefile
src/pkg/exp/ssh/common.go
src/pkg/exp/ssh/messages.go
src/pkg/exp/ssh/server.go
src/pkg/exp/ssh/transport.go

index e8f33b708c3ba5c3db383bdbfca494cbddd7ce3e..1a100e9b69bb830f701be0d8ef662c287010894f 100644 (file)
@@ -6,11 +6,11 @@ include ../../../Make.inc
 
 TARG=exp/ssh
 GOFILES=\
+       channel.go\
        common.go\
        messages.go\
-       server.go\
        transport.go\
-        channel.go\
-        server_shell.go\
+       server.go\
+       server_shell.go\
 
 include ../../../Make.pkg
index 698db60b8e601fa71a1635d3a63084c2504322cc..441c5b17d381ef287c73b1e7cd507748759d43d5 100644 (file)
@@ -5,7 +5,9 @@
 package ssh
 
 import (
+       "big"
        "strconv"
+       "sync"
 )
 
 // These are string constants in the SSH protocol.
@@ -19,6 +21,32 @@ const (
        serviceSSH      = "ssh-connection"
 )
 
+var supportedKexAlgos = []string{kexAlgoDH14SHA1}
+var supportedHostKeyAlgos = []string{hostAlgoRSA}
+var supportedCiphers = []string{cipherAES128CTR}
+var supportedMACs = []string{macSHA196}
+var supportedCompressions = []string{compressionNone}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+       g, p *big.Int
+}
+
+// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
+// Oakley Group 14 in RFC 3526.
+var dhGroup14 *dhGroup
+
+var dhGroup14Once sync.Once
+
+func initDHGroup14() {
+       p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+       dhGroup14 = &dhGroup{
+               g: new(big.Int).SetInt64(2),
+               p: p,
+       }
+}
+
 // UnexpectedMessageError results when the SSH message that we received didn't
 // match what we wanted.
 type UnexpectedMessageError struct {
@@ -38,6 +66,11 @@ func (p ParseError) String() string {
        return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
 }
 
+type handshakeMagics struct {
+       clientVersion, serverVersion []byte
+       clientKexInit, serverKexInit []byte
+}
+
 func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
        for _, clientAlgo := range clientAlgos {
                for _, serverAlgo := range serverAlgos {
index d375eafae96a778369bf6ace67a970c646a402c0..bc2333e182c8337c380d1caf98e3d540c96f0270 100644 (file)
@@ -59,6 +59,13 @@ const (
 // in this file. The only wrinkle is that a final member of type []byte with a
 // tag of "rest" receives the remainder of a packet when unmarshaling.
 
+// See RFC 4253, section 11.1.
+type disconnectMsg struct {
+       Reason   uint32
+       Message  string
+       Language string
+}
+
 // See RFC 4253, section 7.1.
 type kexInitMsg struct {
        Cookie                  [16]byte
@@ -137,6 +144,12 @@ type channelOpenFailureMsg struct {
        Language string
 }
 
+// See RFC 4254, section 5.2.
+type channelData struct {
+       PeersId uint32
+       Payload []byte "rest"
+}
+
 type channelRequestMsg struct {
        PeersId             uint32
        Request             string
@@ -555,3 +568,60 @@ func marshalString(to []byte, s []byte) []byte {
 }
 
 var bigIntType = reflect.TypeOf((*big.Int)(nil))
+
+// Decode a packet into it's corresponding message.
+func decode(packet []byte) interface{} {
+       var msg interface{}
+       switch packet[0] {
+       case msgDisconnect:
+               msg = new(disconnectMsg)
+       case msgServiceRequest:
+               msg = new(serviceRequestMsg)
+       case msgServiceAccept:
+               msg = new(serviceAcceptMsg)
+       case msgKexInit:
+               msg = new(kexInitMsg)
+       case msgKexDHInit:
+               msg = new(kexDHInitMsg)
+       case msgKexDHReply:
+               msg = new(kexDHReplyMsg)
+       case msgUserAuthRequest:
+               msg = new(userAuthRequestMsg)
+       case msgUserAuthFailure:
+               msg = new(userAuthFailureMsg)
+       case msgUserAuthPubKeyOk:
+               msg = new(userAuthPubKeyOkMsg)
+       case msgGlobalRequest:
+               msg = new(globalRequestMsg)
+       case msgRequestSuccess:
+               msg = new(channelRequestSuccessMsg)
+       case msgRequestFailure:
+               msg = new(channelRequestFailureMsg)
+       case msgChannelOpen:
+               msg = new(channelOpenMsg)
+       case msgChannelOpenConfirm:
+               msg = new(channelOpenConfirmMsg)
+       case msgChannelOpenFailure:
+               msg = new(channelOpenFailureMsg)
+       case msgChannelWindowAdjust:
+               msg = new(windowAdjustMsg)
+       case msgChannelData:
+               msg = new(channelData)
+       case msgChannelEOF:
+               msg = new(channelEOFMsg)
+       case msgChannelClose:
+               msg = new(channelCloseMsg)
+       case msgChannelRequest:
+               msg = new(channelRequestMsg)
+       case msgChannelSuccess:
+               msg = new(channelRequestSuccessMsg)
+       case msgChannelFailure:
+               msg = new(channelRequestFailureMsg)
+       default:
+               return UnexpectedMessageError{0, packet[0]}
+       }
+       if err := unmarshal(msg, packet, packet[0]); err != nil {
+               return err
+       }
+       return msg
+}
index bc0af13e821f430657bb3363bef5da3ec0d04fbf..9e7a255ae3eb31a8c5c3cb701ea8719413771b1c 100644 (file)
@@ -6,7 +6,6 @@ package ssh
 
 import (
        "big"
-       "bufio"
        "bytes"
        "crypto"
        "crypto/rand"
@@ -19,12 +18,6 @@ import (
        "sync"
 )
 
-var supportedKexAlgos = []string{kexAlgoDH14SHA1}
-var supportedHostKeyAlgos = []string{hostAlgoRSA}
-var supportedCiphers = []string{cipherAES128CTR}
-var supportedMACs = []string{macSHA196}
-var supportedCompressions = []string{compressionNone}
-
 // Server represents an SSH server. A Server may have several ServerConnections.
 type Server struct {
        rsa           *rsa.PrivateKey
@@ -146,31 +139,6 @@ type ServerConnection struct {
        cachedPubKeys []cachedPubKey
 }
 
-// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
-type dhGroup struct {
-       g, p *big.Int
-}
-
-// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
-// Oakley Group 14 in RFC 3526.
-var dhGroup14 *dhGroup
-
-var dhGroup14Once sync.Once
-
-func initDHGroup14() {
-       p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
-
-       dhGroup14 = &dhGroup{
-               g: new(big.Int).SetInt64(2),
-               p: p,
-       }
-}
-
-type handshakeMagics struct {
-       clientVersion, serverVersion []byte
-       clientKexInit, serverKexInit []byte
-}
-
 // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
 // returned values are given the same names as in RFC 4253, section 8.
 func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
@@ -292,18 +260,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
 // Handshake performs an SSH transport and client authentication on the given ServerConnection.
 func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
        var magics handshakeMagics
-       s.transport = &transport{
-               reader: reader{
-                       Reader: bufio.NewReader(conn),
-               },
-               writer: writer{
-                       Writer: bufio.NewWriter(conn),
-                       rand:   rand.Reader,
-               },
-               Close: func() os.Error {
-                       return conn.Close()
-               },
-       }
+       s.transport = newTransport(conn)
 
        if _, err := conn.Write(serverVersion); err != nil {
                return err
@@ -612,19 +569,14 @@ func (s *ServerConnection) Accept() (Channel, os.Error) {
                        return nil, err
                }
 
-               switch packet[0] {
-               case msgChannelOpen:
-                       var chanOpen channelOpenMsg
-                       if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
-                               return nil, err
-                       }
-
+               switch msg := decode(packet).(type) {
+               case *channelOpenMsg:
                        c := new(channel)
-                       c.chanType = chanOpen.ChanType
-                       c.theirId = chanOpen.PeersId
-                       c.theirWindow = chanOpen.PeersWindow
-                       c.maxPacketSize = chanOpen.MaxPacketSize
-                       c.extraData = chanOpen.TypeSpecificData
+                       c.chanType = msg.ChanType
+                       c.theirId = msg.PeersId
+                       c.theirWindow = msg.PeersWindow
+                       c.maxPacketSize = msg.MaxPacketSize
+                       c.extraData = msg.TypeSpecificData
                        c.myWindow = defaultWindowSize
                        c.serverConn = s
                        c.cond = sync.NewCond(&c.lock)
@@ -637,74 +589,53 @@ func (s *ServerConnection) Accept() (Channel, os.Error) {
                        s.lock.Unlock()
                        return c, nil
 
-               case msgChannelRequest:
-                       var chanRequest channelRequestMsg
-                       if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
-                               return nil, err
-                       }
-
+               case *channelRequestMsg:
                        s.lock.Lock()
-                       c, ok := s.channels[chanRequest.PeersId]
+                       c, ok := s.channels[msg.PeersId]
                        if !ok {
                                continue
                        }
-                       c.handlePacket(&chanRequest)
+                       c.handlePacket(msg)
                        s.lock.Unlock()
 
-               case msgChannelData:
-                       if len(packet) < 5 {
-                               return nil, ParseError{msgChannelData}
-                       }
-                       chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
-
+               case *channelData:
                        s.lock.Lock()
-                       c, ok := s.channels[chanId]
+                       c, ok := s.channels[msg.PeersId]
                        if !ok {
                                continue
                        }
-                       c.handleData(packet[9:])
+                       c.handleData(msg.Payload)
                        s.lock.Unlock()
 
-               case msgChannelEOF:
-                       var eofMsg channelEOFMsg
-                       if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
-                               return nil, err
-                       }
-
+               case *channelEOFMsg:
                        s.lock.Lock()
-                       c, ok := s.channels[eofMsg.PeersId]
+                       c, ok := s.channels[msg.PeersId]
                        if !ok {
                                continue
                        }
-                       c.handlePacket(&eofMsg)
+                       c.handlePacket(msg)
                        s.lock.Unlock()
 
-               case msgChannelClose:
-                       var closeMsg channelCloseMsg
-                       if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
-                               return nil, err
-                       }
-
+               case *channelCloseMsg:
                        s.lock.Lock()
-                       c, ok := s.channels[closeMsg.PeersId]
+                       c, ok := s.channels[msg.PeersId]
                        if !ok {
                                continue
                        }
-                       c.handlePacket(&closeMsg)
+                       c.handlePacket(msg)
                        s.lock.Unlock()
 
-               case msgGlobalRequest:
-                       var request globalRequestMsg
-                       if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
-                               return nil, err
-                       }
-
-                       if request.WantReply {
+               case *globalRequestMsg:
+                       if msg.WantReply {
                                if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
                                        return nil, err
                                }
                        }
 
+               case UnexpectedMessageError:
+                       return nil, msg
+               case *disconnectMsg:
+                       return nil, os.EOF
                default:
                        // Unknown message. Ignore.
                }
index 5a474a9f2c2e30863a90010ea4c0fd6ed29f9c18..77660a26573873cb2eb27900691b6db8f9f14dd5 100644 (file)
@@ -10,10 +10,13 @@ import (
        "crypto/aes"
        "crypto/cipher"
        "crypto/hmac"
+       "crypto/rand"
        "crypto/subtle"
        "hash"
        "io"
+       "net"
        "os"
+       "sync"
 )
 
 const (
@@ -29,7 +32,8 @@ type transport struct {
        macAlgo         string
        compressionAlgo string
 
-       Close func() os.Error
+       Close      func() os.Error
+       RemoteAddr func() net.Addr
 }
 
 // reader represents the incoming connection state.
@@ -40,6 +44,7 @@ type reader struct {
 
 // writer represnts the outgoing connection state.
 type writer struct {
+       *sync.Mutex // protects writer.Writer from concurrent writes
        *bufio.Writer
        paddingMultiple int
        rand            io.Reader
@@ -126,6 +131,9 @@ func (t *transport) readPacket() ([]byte, os.Error) {
 
 // Encrypt and send a packet of data to the remote peer.
 func (w *writer) writePacket(packet []byte) os.Error {
+       w.Mutex.Lock()
+       defer w.Mutex.Unlock()
+
        paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
        if paddingLength < 4 {
                paddingLength += paddingMultiple
@@ -190,6 +198,31 @@ func (w *writer) writePacket(packet []byte) os.Error {
        return err
 }
 
+// Send a message to the remote peer
+func (t *transport) sendMessage(typ uint8, msg interface{}) os.Error {
+       packet := marshal(typ, msg)
+       return t.writePacket(packet)
+}
+
+func newTransport(conn net.Conn) *transport {
+       return &transport{
+               reader: reader{
+                       Reader: bufio.NewReader(conn),
+               },
+               writer: writer{
+                       Writer: bufio.NewWriter(conn),
+                       rand:   rand.Reader,
+                       Mutex:  new(sync.Mutex),
+               },
+               Close: func() os.Error {
+                       return conn.Close()
+               },
+               RemoteAddr: func() net.Addr {
+                       return conn.RemoteAddr()
+               },
+       }
+}
+
 type direction struct {
        ivTag     []byte
        keyTag    []byte