]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: allow for msgUserAuthBanner during authentication
authorGustav Paul <gustav.paul@gmail.com>
Fri, 2 Dec 2011 15:34:42 +0000 (10:34 -0500)
committerAdam Langley <agl@golang.org>
Fri, 2 Dec 2011 15:34:42 +0000 (10:34 -0500)
The SSH spec allows for the server to send a banner message to the client at any point during the authentication process. Currently the ssh client auth types all assume that the first response from the server after issuing a userAuthRequestMsg will be one of a couple of possible authentication success/failure messages. This means that client authentication breaks if the ssh server being connected to has a banner message configured.

This changeset refactors the noneAuth, passwordAuth and publickeyAuth types' auth() function and allows for msgUserAuthBanner during authentication.

R=golang-dev, rsc, dave, agl
CC=golang-dev
https://golang.org/cl/5432065

src/pkg/exp/ssh/client_auth.go

index 25f9e21622530b40ef687fcf5133fc3066cd80ad..1a382357b46e7100e3657d8d26d03dde93269134 100644 (file)
@@ -9,7 +9,7 @@ import (
        "io"
 )
 
-// authenticate authenticates with the remote server. See RFC 4252. 
+// authenticate authenticates with the remote server. See RFC 4252.
 func (c *ClientConn) authenticate(session []byte) error {
        // initiate user auth session
        if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
@@ -24,7 +24,7 @@ func (c *ClientConn) authenticate(session []byte) error {
                return err
        }
        // during the authentication phase the client first attempts the "none" method
-       // then any untried methods suggested by the server. 
+       // then any untried methods suggested by the server.
        tried, remain := make(map[string]bool), make(map[string]bool)
        for auth := ClientAuth(new(noneAuth)); auth != nil; {
                ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand())
@@ -57,9 +57,9 @@ func (c *ClientConn) authenticate(session []byte) error {
 
 // A ClientAuth represents an instance of an RFC 4252 authentication method.
 type ClientAuth interface {
-       // auth authenticates user over transport t. 
+       // auth authenticates user over transport t.
        // Returns true if authentication is successful.
-       // If authentication is not successful, a []string of alternative 
+       // If authentication is not successful, a []string of alternative
        // method names is returned.
        auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
 
@@ -79,19 +79,7 @@ func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reade
                return false, nil, err
        }
 
-       packet, err := t.readPacket()
-       if err != nil {
-               return false, nil, err
-       }
-
-       switch packet[0] {
-       case msgUserAuthSuccess:
-               return true, nil, nil
-       case msgUserAuthFailure:
-               msg := decode(packet).(*userAuthFailureMsg)
-               return false, msg.Methods, nil
-       }
-       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+       return handleAuthResponse(t)
 }
 
 func (n *noneAuth) method() string {
@@ -127,19 +115,7 @@ func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.R
                return false, nil, err
        }
 
-       packet, err := t.readPacket()
-       if err != nil {
-               return false, nil, err
-       }
-
-       switch packet[0] {
-       case msgUserAuthSuccess:
-               return true, nil, nil
-       case msgUserAuthFailure:
-               msg := decode(packet).(*userAuthFailureMsg)
-               return false, msg.Methods, nil
-       }
-       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+       return handleAuthResponse(t)
 }
 
 func (p *passwordAuth) method() string {
@@ -159,7 +135,7 @@ func ClientAuthPassword(impl ClientPassword) ClientAuth {
 
 // ClientKeyring implements access to a client key ring.
 type ClientKeyring interface {
-       // Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if 
+       // Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if
        // no key exists at i.
        Key(i int) (key interface{}, err error)
 
@@ -173,27 +149,28 @@ type publickeyAuth struct {
        ClientKeyring
 }
 
+type publickeyAuthMsg struct {
+       User    string
+       Service string
+       Method  string
+       // HasSig indicates to the reciver packet that the auth request is signed and
+       // should be used for authentication of the request.
+       HasSig   bool
+       Algoname string
+       Pubkey   string
+       // Sig is defined as []byte so marshal will exclude it during validateKey
+       Sig []byte `ssh:"rest"`
+}
+
 func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
-       type publickeyAuthMsg struct {
-               User    string
-               Service string
-               Method  string
-               // HasSig indicates to the reciver packet that the auth request is signed and
-               // should be used for authentication of the request.
-               HasSig   bool
-               Algoname string
-               Pubkey   string
-               // Sig is defined as []byte so marshal will exclude it during the query phase
-               Sig []byte `ssh:"rest"`
-       }
 
        // Authentication is performed in two stages. The first stage sends an
        // enquiry to test if each key is acceptable to the remote. The second
-       // stage attempts to authenticate with the valid keys obtained in the 
+       // stage attempts to authenticate with the valid keys obtained in the
        // first stage.
 
        var index int
-       // a map of public keys to their index in the keyring 
+       // a map of public keys to their index in the keyring
        validKeys := make(map[int]interface{})
        for {
                key, err := p.Key(index)
@@ -204,33 +181,13 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
                        // no more keys in the keyring
                        break
                }
-               pubkey := serializePublickey(key)
-               algoname := algoName(key)
-               msg := publickeyAuthMsg{
-                       User:     user,
-                       Service:  serviceSSH,
-                       Method:   p.method(),
-                       HasSig:   false,
-                       Algoname: algoname,
-                       Pubkey:   string(pubkey),
-               }
-               if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
-                       return false, nil, err
-               }
-               packet, err := t.readPacket()
-               if err != nil {
-                       return false, nil, err
-               }
-               switch packet[0] {
-               case msgUserAuthPubKeyOk:
-                       msg := decode(packet).(*userAuthPubKeyOkMsg)
-                       if msg.Algo != algoname || msg.PubKey != string(pubkey) {
-                               continue
-                       }
+
+               if ok, err := p.validateKey(key, user, t); ok {
                        validKeys[index] = key
-               case msgUserAuthFailure:
-               default:
-                       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+               } else {
+                       if err != nil {
+                               return false, nil, err
+                       }
                }
                index++
        }
@@ -265,24 +222,61 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
                if err := t.writePacket(p); err != nil {
                        return false, nil, err
                }
-               packet, err := t.readPacket()
+               success, methods, err := handleAuthResponse(t)
                if err != nil {
                        return false, nil, err
                }
+               if success {
+                       return success, methods, err
+               }
+       }
+       return false, methods, nil
+}
+
+// validateKey validates the key provided it is acceptable to the server.
+func (p *publickeyAuth) validateKey(key interface{}, user string, t *transport) (bool, error) {
+       pubkey := serializePublickey(key)
+       algoname := algoName(key)
+       msg := publickeyAuthMsg{
+               User:     user,
+               Service:  serviceSSH,
+               Method:   p.method(),
+               HasSig:   false,
+               Algoname: algoname,
+               Pubkey:   string(pubkey),
+       }
+       if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+               return false, err
+       }
+
+       return p.confirmKeyAck(key, t)
+}
+
+func (p *publickeyAuth) confirmKeyAck(key interface{}, t *transport) (bool, error) {
+       pubkey := serializePublickey(key)
+       algoname := algoName(key)
+
+       for {
+               packet, err := t.readPacket()
+               if err != nil {
+                       return false, err
+               }
                switch packet[0] {
-               case msgUserAuthSuccess:
-                       return true, nil, nil
+               case msgUserAuthBanner:
+                       // TODO(gpaul): add callback to present the banner to the user
+               case msgUserAuthPubKeyOk:
+                       msg := decode(packet).(*userAuthPubKeyOkMsg)
+                       if msg.Algo != algoname || msg.PubKey != string(pubkey) {
+                               return false, nil
+                       }
+                       return true, nil
                case msgUserAuthFailure:
-                       msg := decode(packet).(*userAuthFailureMsg)
-                       methods = msg.Methods
-                       continue
-               case msgDisconnect:
-                       return false, nil, io.EOF
+                       return false, nil
                default:
-                       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+                       return false, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
                }
        }
-       return false, methods, nil
+       panic("unreachable")
 }
 
 func (p *publickeyAuth) method() string {
@@ -293,3 +287,30 @@ func (p *publickeyAuth) method() string {
 func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
        return &publickeyAuth{impl}
 }
+
+// handleAuthResponse returns whether the preceding authentication request succeeded
+// along with a list of remaining authentication methods to try next and
+// an error if an unexpected response was received.
+func handleAuthResponse(t *transport) (bool, []string, error) {
+       for {
+               packet, err := t.readPacket()
+               if err != nil {
+                       return false, nil, err
+               }
+
+               switch packet[0] {
+               case msgUserAuthBanner:
+                       // TODO: add callback to present the banner to the user
+               case msgUserAuthFailure:
+                       msg := decode(packet).(*userAuthFailureMsg)
+                       return false, msg.Methods, nil
+               case msgUserAuthSuccess:
+                       return true, nil, nil
+               case msgDisconnect:
+                       return false, nil, io.EOF
+               default:
+                       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+               }
+       }
+       panic("unreachable")
+}