]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: add client side support for publickey auth
authorDave Cheney <dave@cheney.net>
Sun, 13 Nov 2011 19:48:22 +0000 (14:48 -0500)
committerAdam Langley <agl@golang.org>
Sun, 13 Nov 2011 19:48:22 +0000 (14:48 -0500)
client.go/client_auth.go:
* add support for publickey key auth using the interface
  outlined by rsc in the previous auth CL

client_auth_test.go:
* password and publickey tests against server.go

common.go/server.go:
* move some helper methods from server.go into common.go
* generalise serializeRSASignature

R=rsc, agl, huin
CC=cw, golang-dev, n13m3y3r
https://golang.org/cl/5373055

src/pkg/exp/ssh/client.go
src/pkg/exp/ssh/client_auth.go
src/pkg/exp/ssh/client_auth_test.go [new file with mode: 0644]
src/pkg/exp/ssh/common.go
src/pkg/exp/ssh/server.go

index 669182b2c80b0bee2a05d40142824e52328c9513..0ea48437b6ad1517d9b6c20153a64d2e4cf64ea1 100644 (file)
@@ -35,10 +35,6 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
                conn.Close()
                return nil, err
        }
-       if err := conn.authenticate(); err != nil {
-               conn.Close()
-               return nil, err
-       }
        go conn.mainLoop()
        return conn, nil
 }
@@ -128,7 +124,10 @@ func (c *ClientConn) handshake() error {
        if packet[0] != msgNewKeys {
                return UnexpectedMessageError{msgNewKeys, packet[0]}
        }
-       return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
+       if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+               return err
+       }
+       return c.authenticate(H)
 }
 
 // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
index 0089d0c769e33b7702a7324c94dbbb434f0eff67..bb43eafff517605da12a46e2b135d871d52a0819 100644 (file)
@@ -5,11 +5,13 @@
 package ssh
 
 import (
+       "crypto/rand"
        "errors"
+       "io"
 )
 
 // authenticate authenticates with the remote server. See RFC 4252. 
-func (c *ClientConn) authenticate() error {
+func (c *ClientConn) authenticate(session []byte) error {
        // initiate user auth session
        if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
                return err
@@ -26,7 +28,7 @@ func (c *ClientConn) authenticate() error {
        // then any untried methods suggested by the server. 
        tried, remain := make(map[string]bool), make(map[string]bool)
        for auth := ClientAuth(new(noneAuth)); auth != nil; {
-               ok, methods, err := auth.auth(c.config.User, c.transport)
+               ok, methods, err := auth.auth(session, c.config.User, c.transport)
                if err != nil {
                        return err
                }
@@ -60,7 +62,7 @@ type ClientAuth interface {
        // Returns true if authentication is successful.
        // If authentication is not successful, a []string of alternative 
        // method names is returned.
-       auth(user string, t *transport) (bool, []string, error)
+       auth(session []byte, user string, t *transport) (bool, []string, error)
 
        // method returns the RFC 4252 method name.
        method() string
@@ -69,7 +71,7 @@ type ClientAuth interface {
 // "none" authentication, RFC 4252 section 5.2.
 type noneAuth int
 
-func (n *noneAuth) auth(user string, t *transport) (bool, []string, error) {
+func (n *noneAuth) auth(session []byte, user string, t *transport) (bool, []string, error) {
        if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
                User:    user,
                Service: serviceSSH,
@@ -102,7 +104,7 @@ type passwordAuth struct {
        ClientPassword
 }
 
-func (p *passwordAuth) auth(user string, t *transport) (bool, []string, error) {
+func (p *passwordAuth) auth(session []byte, user string, t *transport) (bool, []string, error) {
        type passwordAuthMsg struct {
                User     string
                Service  string
@@ -155,3 +157,141 @@ type ClientPassword interface {
 func ClientAuthPassword(impl ClientPassword) ClientAuth {
        return &passwordAuth{impl}
 }
+
+// ClientKeyring implements access to a client key ring.
+type ClientKeyring interface {
+       // Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if 
+       // no key exists at i.
+       Key(i int) (key interface{}, err error)
+
+       // Sign returns a signature of the given data using the i'th key
+       // and the supplied random source.
+       Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
+}
+
+// "publickey" authentication, RFC 4252 Section 7.
+type publickeyAuth struct {
+       ClientKeyring
+}
+
+func (p *publickeyAuth) auth(session []byte, user string, t *transport) (bool, []string, error) {
+       type publickeyAuthMsg struct {
+               User    string
+               Service string
+               Method  string
+               // HasSig indicates to the reciver packet that the auth request is signed and
+               // should be used for authentication of the request.
+               HasSig   bool
+               Algoname string
+               Pubkey   string
+               // Sig is defined as []byte so marshal will exclude it during the query phase
+               Sig []byte `ssh:"rest"`
+       }
+
+       // Authentication is performed in two stages. The first stage sends an
+       // enquiry to test if each key is acceptable to the remote. The second
+       // stage attempts to authenticate with the valid keys obtained in the 
+       // first stage.
+
+       var index int
+       // a map of public keys to their index in the keyring 
+       validKeys := make(map[int]interface{})
+       for {
+               key, err := p.Key(index)
+               if err != nil {
+                       return false, nil, err
+               }
+               if key == nil {
+                       // no more keys in the keyring
+                       break
+               }
+               pubkey := serializePublickey(key)
+               algoname := algoName(key)
+               msg := publickeyAuthMsg{
+                       User:     user,
+                       Service:  serviceSSH,
+                       Method:   p.method(),
+                       HasSig:   false,
+                       Algoname: algoname,
+                       Pubkey:   string(pubkey),
+               }
+               if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+                       return false, nil, err
+               }
+               packet, err := t.readPacket()
+               if err != nil {
+                       return false, nil, err
+               }
+               switch packet[0] {
+               case msgUserAuthPubKeyOk:
+                       msg := decode(packet).(*userAuthPubKeyOkMsg)
+                       if msg.Algo != algoname || msg.PubKey != string(pubkey) {
+                               continue
+                       }
+                       validKeys[index] = key
+               case msgUserAuthFailure:
+               default:
+                       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+               }
+               index++
+       }
+
+       // methods that may continue if this auth is not successful.
+       var methods []string
+       for i, key := range validKeys {
+               pubkey := serializePublickey(key)
+               algoname := algoName(key)
+               // TODO(dfc) use random source from the ClientConfig
+               sign, err := p.Sign(i, rand.Reader, buildDataSignedForAuth(session, userAuthRequestMsg{
+                       User:    user,
+                       Service: serviceSSH,
+                       Method:  p.method(),
+               }, []byte(algoname), pubkey))
+               if err != nil {
+                       return false, nil, err
+               }
+               // manually wrap the serialized signature in a string
+               s := serializeSignature(algoname, sign)
+               sig := make([]byte, stringLength(s))
+               marshalString(sig, s)
+               msg := publickeyAuthMsg{
+                       User:     user,
+                       Service:  serviceSSH,
+                       Method:   p.method(),
+                       HasSig:   true,
+                       Algoname: algoname,
+                       Pubkey:   string(pubkey),
+                       Sig:      sig,
+               }
+               p := marshal(msgUserAuthRequest, msg)
+               if err := t.writePacket(p); err != nil {
+                       return false, nil, err
+               }
+               packet, err := t.readPacket()
+               if err != nil {
+                       return false, nil, err
+               }
+               switch packet[0] {
+               case msgUserAuthSuccess:
+                       return true, nil, nil
+               case msgUserAuthFailure:
+                       msg := decode(packet).(*userAuthFailureMsg)
+                       methods = msg.Methods
+                       continue
+               case msgDisconnect:
+                       return false, nil, io.EOF
+               default:
+                       return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+               }
+       }
+       return false, methods, nil
+}
+
+func (p *publickeyAuth) method() string {
+       return "publickey"
+}
+
+// ClientAuthPublickey returns a ClientAuth using public key authentication.
+func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
+       return &publickeyAuth{impl}
+}
diff --git a/src/pkg/exp/ssh/client_auth_test.go b/src/pkg/exp/ssh/client_auth_test.go
new file mode 100644 (file)
index 0000000..ccd6cd2
--- /dev/null
@@ -0,0 +1,248 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "bytes"
+       "crypto"
+       "crypto/rand"
+       "crypto/rsa"
+       "crypto/x509"
+       "encoding/pem"
+       "errors"
+       "io"
+       "io/ioutil"
+       "testing"
+)
+
+const _pem = `-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
+70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
+9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
+tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
+s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
+qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
++IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
+riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
+D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
+atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
+b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
+ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
+MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
+KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
+e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
+D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
+3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
+orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
+64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
+XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
+QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
+/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
+I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
+gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
+NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
+-----END RSA PRIVATE KEY-----`
+
+// reused internally by tests
+var serverConfig = new(ServerConfig)
+
+func init() {
+       if err := serverConfig.SetRSAPrivateKey([]byte(_pem)); err != nil {
+               panic("unable to set private key: " + err.Error())
+       }
+}
+
+// keychain implements the ClientPublickey interface
+type keychain struct {
+       keys []*rsa.PrivateKey
+}
+
+func (k *keychain) Key(i int) (interface{}, error) {
+       if i < 0 || i >= len(k.keys) {
+               return nil, nil
+       }
+       return k.keys[i].PublicKey, nil
+}
+
+func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
+       hashFunc := crypto.SHA1
+       h := hashFunc.New()
+       h.Write(data)
+       digest := h.Sum()
+       return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
+}
+
+func (k *keychain) loadPEM(file string) error {
+       buf, err := ioutil.ReadFile(file)
+       if err != nil {
+               return err
+       }
+       block, _ := pem.Decode(buf)
+       if block == nil {
+               return errors.New("ssh: no key found")
+       }
+       r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+       if err != nil {
+               return err
+       }
+       k.keys = append(k.keys, r)
+       return nil
+}
+
+var pkey *rsa.PrivateKey
+
+func init() {
+       var err error
+       pkey, err = rsa.GenerateKey(rand.Reader, 512)
+       if err != nil {
+               panic("unable to generate public key")
+       }
+}
+
+func TestClientAuthPublickey(t *testing.T) {
+       k := new(keychain)
+       k.keys = append(k.keys, pkey)
+
+       serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
+               expected := []byte(serializePublickey(k.keys[0].PublicKey))
+               algoname := algoName(k.keys[0].PublicKey)
+               return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
+       }
+       serverConfig.PasswordCallback = nil
+
+       l, err := Listen("tcp", "0.0.0.0:0", serverConfig)
+       if err != nil {
+               t.Fatalf("unable to listen: %s", err)
+       }
+       defer l.Close()
+
+       done := make(chan bool)
+       go func() {
+               c, err := l.Accept()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := c.Handshake(); err != nil {
+                       t.Error(err)
+               }
+               defer c.Close()
+               done <- true
+       }()
+
+       config := &ClientConfig{
+               User: "testuser",
+               Auth: []ClientAuth{
+                       ClientAuthPublickey(k),
+               },
+       }
+
+       c, err := Dial("tcp", l.Addr().String(), config)
+       if err != nil {
+               t.Errorf("unable to dial remote side: %s", err)
+       }
+       defer c.Close()
+       <-done
+}
+
+// password implements the ClientPassword interface
+type password string
+
+func (p password) Password(user string) (string, error) {
+       return string(p), nil
+}
+
+func TestClientAuthPassword(t *testing.T) {
+       pw := password("tiger")
+
+       serverConfig.PasswordCallback = func(user, pass string) bool {
+               return user == "testuser" && pass == string(pw)
+       }
+       serverConfig.PubKeyCallback = nil
+
+       l, err := Listen("tcp", "0.0.0.0:0", serverConfig)
+       if err != nil {
+               t.Fatalf("unable to listen: %s", err)
+       }
+       defer l.Close()
+
+       done := make(chan bool)
+       go func() {
+               c, err := l.Accept()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := c.Handshake(); err != nil {
+                       t.Error(err)
+               }
+               defer c.Close()
+               done <- true
+       }()
+
+       config := &ClientConfig{
+               User: "testuser",
+               Auth: []ClientAuth{
+                       ClientAuthPassword(pw),
+               },
+       }
+
+       c, err := Dial("tcp", l.Addr().String(), config)
+       if err != nil {
+               t.Errorf("unable to dial remote side: %s", err)
+       }
+       defer c.Close()
+       <-done
+}
+
+func TestClientAuthPasswordAndPublickey(t *testing.T) {
+       pw := password("tiger")
+
+       serverConfig.PasswordCallback = func(user, pass string) bool {
+               return user == "testuser" && pass == string(pw)
+       }
+
+       k := new(keychain)
+       k.keys = append(k.keys, pkey)
+
+       serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
+               expected := []byte(serializePublickey(k.keys[0].PublicKey))
+               algoname := algoName(k.keys[0].PublicKey)
+               return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
+       }
+
+       l, err := Listen("tcp", "0.0.0.0:0", serverConfig)
+       if err != nil {
+               t.Fatalf("unable to listen: %s", err)
+       }
+       defer l.Close()
+
+       done := make(chan bool)
+       go func() {
+               c, err := l.Accept()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := c.Handshake(); err != nil {
+                       t.Error(err)
+               }
+               defer c.Close()
+               done <- true
+       }()
+
+       wrongPw := password("wrong")
+       config := &ClientConfig{
+               User: "testuser",
+               Auth: []ClientAuth{
+                       ClientAuthPassword(wrongPw),
+                       ClientAuthPublickey(k),
+               },
+       }
+
+       c, err := Dial("tcp", l.Addr().String(), config)
+       if err != nil {
+               t.Errorf("unable to dial remote side: %s", err)
+       }
+       defer c.Close()
+       <-done
+}
index 273820b6428c62706df848f91f0cce4f83b37a7c..cc720558fc4e5297bff2b1e5843593856bb13ae9 100644 (file)
@@ -5,6 +5,8 @@
 package ssh
 
 import (
+       "crypto/dsa"
+       "crypto/rsa"
        "math/big"
        "strconv"
        "sync"
@@ -127,3 +129,86 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke
        ok = true
        return
 }
+
+// serialize a signed slice according to RFC 4254 6.6.
+func serializeSignature(algoname string, sig []byte) []byte {
+       length := stringLength([]byte(algoname))
+       length += stringLength(sig)
+
+       ret := make([]byte, length)
+       r := marshalString(ret, []byte(algoname))
+       r = marshalString(r, sig)
+
+       return ret
+}
+
+// serialize an rsa.PublicKey or dsa.PublicKey according to RFC 4253 6.6.
+func serializePublickey(key interface{}) []byte {
+       algoname := algoName(key)
+       switch key := key.(type) {
+       case rsa.PublicKey:
+               e := new(big.Int).SetInt64(int64(key.E))
+               length := stringLength([]byte(algoname))
+               length += intLength(e)
+               length += intLength(key.N)
+               ret := make([]byte, length)
+               r := marshalString(ret, []byte(algoname))
+               r = marshalInt(r, e)
+               marshalInt(r, key.N)
+               return ret
+       case dsa.PublicKey:
+               length := stringLength([]byte(algoname))
+               length += intLength(key.P)
+               length += intLength(key.Q)
+               length += intLength(key.G)
+               length += intLength(key.Y)
+               ret := make([]byte, length)
+               r := marshalString(ret, []byte(algoname))
+               r = marshalInt(r, key.P)
+               r = marshalInt(r, key.Q)
+               r = marshalInt(r, key.G)
+               marshalInt(r, key.Y)
+               return ret
+       }
+       panic("unexpected key type")
+}
+
+func algoName(key interface{}) string {
+       switch key.(type) {
+       case rsa.PublicKey:
+               return "ssh-rsa"
+       case dsa.PublicKey:
+               return "ssh-dss"
+       }
+       panic("unexpected key type")
+}
+
+// buildDataSignedForAuth returns the data that is signed in order to prove
+// posession of a private key. See RFC 4252, section 7.
+func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+       user := []byte(req.User)
+       service := []byte(req.Service)
+       method := []byte(req.Method)
+
+       length := stringLength(sessionId)
+       length += 1
+       length += stringLength(user)
+       length += stringLength(service)
+       length += stringLength(method)
+       length += 1
+       length += stringLength(algo)
+       length += stringLength(pubKey)
+
+       ret := make([]byte, length)
+       r := marshalString(ret, sessionId)
+       r[0] = msgUserAuthRequest
+       r = r[1:]
+       r = marshalString(r, user)
+       r = marshalString(r, service)
+       r = marshalString(r, method)
+       r[0] = 1
+       r = r[1:]
+       r = marshalString(r, algo)
+       r = marshalString(r, pubKey)
+       return ret
+}
index 62035d52b7dd6ba7e5539ac497ecfee19440da17..55dd5b0e0290ae46a631b19c26ad1053bdb3076f 100644 (file)
@@ -221,7 +221,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
                return nil, nil, errors.New("internal error")
        }
 
-       serializedSig := serializeRSASignature(sig)
+       serializedSig := serializeSignature(hostAlgoRSA, sig)
 
        kexDHReply := kexDHReplyMsg{
                HostKey:   serializedHostKey,
@@ -234,50 +234,9 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
        return
 }
 
-func serializeRSASignature(sig []byte) []byte {
-       length := stringLength([]byte(hostAlgoRSA))
-       length += stringLength(sig)
-
-       ret := make([]byte, length)
-       r := marshalString(ret, []byte(hostAlgoRSA))
-       r = marshalString(r, sig)
-
-       return ret
-}
-
 // serverVersion is the fixed identification string that Server will use.
 var serverVersion = []byte("SSH-2.0-Go\r\n")
 
-// buildDataSignedForAuth returns the data that is signed in order to prove
-// posession of a private key. See RFC 4252, section 7.
-func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
-       user := []byte(req.User)
-       service := []byte(req.Service)
-       method := []byte(req.Method)
-
-       length := stringLength(sessionId)
-       length += 1
-       length += stringLength(user)
-       length += stringLength(service)
-       length += stringLength(method)
-       length += 1
-       length += stringLength(algo)
-       length += stringLength(pubKey)
-
-       ret := make([]byte, length)
-       r := marshalString(ret, sessionId)
-       r[0] = msgUserAuthRequest
-       r = r[1:]
-       r = marshalString(r, user)
-       r = marshalString(r, service)
-       r = marshalString(r, method)
-       r[0] = 1
-       r = r[1:]
-       r = marshalString(r, algo)
-       r = marshalString(r, pubKey)
-       return ret
-}
-
 // Handshake performs an SSH transport and client authentication on the given ServerConn.
 func (s *ServerConn) Handshake() error {
        var magics handshakeMagics