]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: server cleanups
authorDave Cheney <dave@cheney.net>
Fri, 21 Oct 2011 15:04:28 +0000 (11:04 -0400)
committerAdam Langley <agl@golang.org>
Fri, 21 Oct 2011 15:04:28 +0000 (11:04 -0400)
server.go/channel.go:
* rename Server to ServerConfig to match Client.
* rename ServerConnection to ServeConn to match Client.
* add Listen/Listener.
* ServerConn.Handshake(), general cleanups.

client.go:
* fix bug where fmt.Error was not assigned to err

R=rsc, agl
CC=golang-dev
https://golang.org/cl/5265049

src/pkg/exp/ssh/channel.go
src/pkg/exp/ssh/client.go
src/pkg/exp/ssh/server.go

index 922584f63175ac7862ce15f1cb225cf04db67ae0..f69b735fd47f1f53b9941378ddf9cc316f52c2b5 100644 (file)
@@ -68,7 +68,7 @@ type channel struct {
        weClosed    bool
        dead        bool
 
-       serverConn            *ServerConnection
+       serverConn            *ServerConn
        myId, theirId         uint32
        myWindow, theirWindow uint32
        maxPacketSize         uint32
index edb95eccc644caa3c97234a423cfdb784de2da8b..b3d7708a265ffd02f8aa21bb307a9812e9cfbc77 100644 (file)
@@ -115,7 +115,7 @@ func (c *ClientConn) handshake() os.Error {
                dhGroup14Once.Do(initDHGroup14)
                H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
        default:
-               fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+               err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
        }
        if err != nil {
                return err
index b5a5e017d39d6f15f7c29193d8b437df3aa3d1a3..3a640fc081ec432f633cb874f9cf49f4e139384b 100644 (file)
@@ -10,19 +10,23 @@ import (
        "crypto"
        "crypto/rand"
        "crypto/rsa"
-       _ "crypto/sha1"
        "crypto/x509"
        "encoding/pem"
+       "io"
        "net"
        "os"
        "sync"
 )
 
-// Server represents an SSH server. A Server may have several ServerConnections.
-type Server struct {
+type ServerConfig struct {
        rsa           *rsa.PrivateKey
        rsaSerialized []byte
 
+       // Rand provides the source of entropy for key exchange. If Rand is 
+       // nil, the cryptographic random reader in package crypto/rand will 
+       // be used.
+       Rand io.Reader
+
        // NoClientAuth is true if clients are allowed to connect without
        // authenticating.
        NoClientAuth bool
@@ -38,11 +42,18 @@ type Server struct {
        PubKeyCallback func(user, algo string, pubkey []byte) bool
 }
 
+func (c *ServerConfig) rand() io.Reader {
+       if c.Rand == nil {
+               return rand.Reader
+       }
+       return c.Rand
+}
+
 // SetRSAPrivateKey sets the private key for a Server. A Server must have a
 // private key configured in order to accept connections. The private key must
 // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
 // typically contains such a key.
-func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error {
+func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) os.Error {
        block, _ := pem.Decode(pemBytes)
        if block == nil {
                return os.NewError("ssh: no key found")
@@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) {
 }
 
 // cachedPubKey contains the results of querying whether a public key is
-// acceptable for a user. The cache only applies to a single ServerConnection.
+// acceptable for a user. The cache only applies to a single ServerConn.
 type cachedPubKey struct {
        user, algo string
        pubKey     []byte
@@ -118,11 +129,10 @@ type cachedPubKey struct {
 
 const maxCachedPubKeys = 16
 
-// ServerConnection represents an incomming connection to a Server.
-type ServerConnection struct {
-       Server *Server
-
+// A ServerConn represents an incomming connection.
+type ServerConn struct {
        *transport
+       config *ServerConfig
 
        channels   map[uint32]*channel
        nextChanId uint32
@@ -139,9 +149,20 @@ type ServerConnection struct {
        cachedPubKeys []cachedPubKey
 }
 
+// Server returns a new SSH server connection
+// using c as the underlying transport.
+func Server(c net.Conn, config *ServerConfig) *ServerConn {
+       conn := &ServerConn{
+               transport: newTransport(c, config.rand()),
+               channels:  make(map[uint32]*channel),
+               config:    config,
+       }
+       return conn
+}
+
 // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
 // returned values are given the same names as in RFC 4253, section 8.
-func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
+func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
        packet, err := s.readPacket()
        if err != nil {
                return
@@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
                return nil, nil, os.NewError("client DH parameter out of bounds")
        }
 
-       y, err := rand.Int(rand.Reader, group.p)
+       y, err := rand.Int(s.config.rand(), group.p)
        if err != nil {
                return
        }
@@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
        var serializedHostKey []byte
        switch hostKeyAlgo {
        case hostAlgoRSA:
-               serializedHostKey = s.Server.rsaSerialized
+               serializedHostKey = s.config.rsaSerialized
        default:
                return nil, nil, os.NewError("internal error")
        }
@@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
        var sig []byte
        switch hostKeyAlgo {
        case hostAlgoRSA:
-               sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh)
+               sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh)
                if err != nil {
                        return
                }
@@ -257,17 +278,18 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
        return ret
 }
 
-// Handshake performs an SSH transport and client authentication on the given ServerConnection.
-func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
+// Handshake performs an SSH transport and client authentication on the given ServerConn.
+func (s *ServerConn) Handshake() os.Error {
        var magics handshakeMagics
-       s.transport = newTransport(conn, rand.Reader)
-
-       if _, err := conn.Write(serverVersion); err != nil {
+       if _, err := s.Write(serverVersion); err != nil {
+               return err
+       }
+       if err := s.Flush(); err != nil {
                return err
        }
        magics.serverVersion = serverVersion[:len(serverVersion)-2]
 
-       version, err := readVersion(s.transport)
+       version, err := readVersion(s)
        if err != nil {
                return err
        }
@@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
        if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
                // The client sent a Kex message for the wrong algorithm,
                // which we have to ignore.
-               _, err := s.readPacket()
-               if err != nil {
+               if _, err := s.readPacket(); err != nil {
                        return err
                }
        }
@@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
                dhGroup14Once.Do(initDHGroup14)
                H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
        default:
-               err = os.NewError("ssh: internal error")
+               err = os.NewError("ssh: unexpected key exchange algorithm " + kexAlgo)
        }
-
        if err != nil {
                return err
        }
 
-       packet = []byte{msgNewKeys}
-       if err = s.writePacket(packet); err != nil {
+       if err = s.writePacket([]byte{msgNewKeys}); err != nil {
                return err
        }
        if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
                return err
        }
-
        if packet, err = s.readPacket(); err != nil {
                return err
        }
+
        if packet[0] != msgNewKeys {
                return UnexpectedMessageError{msgNewKeys, packet[0]}
        }
-
        s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
-
-       packet, err = s.readPacket()
-       if err != nil {
+       if packet, err = s.readPacket(); err != nil {
                return err
        }
 
@@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
        if serviceRequest.Service != serviceUserAuth {
                return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
        }
-
        serviceAccept := serviceAcceptMsg{
                Service: serviceUserAuth,
        }
-       packet = marshal(msgServiceAccept, serviceAccept)
-       if err = s.writePacket(packet); err != nil {
+       if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
                return err
        }
 
        if err = s.authenticate(H); err != nil {
                return err
        }
-
-       s.channels = make(map[uint32]*channel)
        return nil
 }
 
@@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool {
 }
 
 // testPubKey returns true if the given public key is acceptable for the user.
-func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
-       if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
+func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
+       if s.config.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
                return false
        }
 
@@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
                }
        }
 
-       result := s.Server.PubKeyCallback(user, algo, pubKey)
+       result := s.config.PubKeyCallback(user, algo, pubKey)
        if len(s.cachedPubKeys) < maxCachedPubKeys {
                c := cachedPubKey{
                        user:   user,
@@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
        return result
 }
 
-func (s *ServerConnection) authenticate(H []byte) os.Error {
+func (s *ServerConn) authenticate(H []byte) os.Error {
        var userAuthReq userAuthRequestMsg
        var err os.Error
        var packet []byte
@@ -428,11 +440,11 @@ userAuthLoop:
 
                switch userAuthReq.Method {
                case "none":
-                       if s.Server.NoClientAuth {
+                       if s.config.NoClientAuth {
                                break userAuthLoop
                        }
                case "password":
-                       if s.Server.PasswordCallback == nil {
+                       if s.config.PasswordCallback == nil {
                                break
                        }
                        payload := userAuthReq.Payload
@@ -445,11 +457,11 @@ userAuthLoop:
                                return ParseError{msgUserAuthRequest}
                        }
 
-                       if s.Server.PasswordCallback(userAuthReq.User, string(password)) {
+                       if s.config.PasswordCallback(userAuthReq.User, string(password)) {
                                break userAuthLoop
                        }
                case "publickey":
-                       if s.Server.PubKeyCallback == nil {
+                       if s.config.PubKeyCallback == nil {
                                break
                        }
                        payload := userAuthReq.Payload
@@ -520,10 +532,10 @@ userAuthLoop:
                }
 
                var failureMsg userAuthFailureMsg
-               if s.Server.PasswordCallback != nil {
+               if s.config.PasswordCallback != nil {
                        failureMsg.Methods = append(failureMsg.Methods, "password")
                }
-               if s.Server.PubKeyCallback != nil {
+               if s.config.PubKeyCallback != nil {
                        failureMsg.Methods = append(failureMsg.Methods, "publickey")
                }
 
@@ -546,9 +558,9 @@ userAuthLoop:
 
 const defaultWindowSize = 32768
 
-// Accept reads and processes messages on a ServerConnection. It must be called
+// Accept reads and processes messages on a ServerConn. It must be called
 // in order to demultiplex messages to any resulting Channels.
-func (s *ServerConnection) Accept() (Channel, os.Error) {
+func (s *ServerConn) Accept() (Channel, os.Error) {
        if s.err != nil {
                return nil, s.err
        }
@@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) {
 
        panic("unreachable")
 }
+
+// A Listener implements a network listener (net.Listener) for SSH connections.
+type Listener struct {
+       listener net.Listener
+       config   *ServerConfig
+}
+
+// Accept waits for and returns the next incoming SSH connection.
+// The receiver should call Handshake() in another goroutine 
+// to avoid blocking the accepter.
+func (l *Listener) Accept() (*ServerConn, os.Error) {
+       c, err := l.listener.Accept()
+       if err != nil {
+               return nil, err
+       }
+       conn := Server(c, l.config)
+       return conn, nil
+}
+
+// Addr returns the listener's network address.
+func (l *Listener) Addr() net.Addr {
+       return l.listener.Addr()
+}
+
+// Close closes the listener.
+func (l *Listener) Close() os.Error {
+       return l.listener.Close()
+}
+
+// Listen creates an SSH listener accepting connections on
+// the given network address using net.Listen.
+func Listen(network, addr string, config *ServerConfig) (*Listener, os.Error) {
+       l, err := net.Listen(network, addr)
+       if err != nil {
+               return nil, err
+       }
+       return &Listener{
+               l,
+               config,
+       }, nil
+}