]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: Added dynamic alternative to NameToCertificate map for SNI
authorPercy Wegmann <ox.to.a.cart@gmail.com>
Wed, 6 Aug 2014 18:22:00 +0000 (11:22 -0700)
committerAdam Langley <agl@golang.org>
Wed, 6 Aug 2014 18:22:00 +0000 (11:22 -0700)
Revised version of https://golang.org/cl/81260045/

LGTM=agl
R=golang-codereviews, gobot, agl, ox
CC=golang-codereviews
https://golang.org/cl/107400043

src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/conn_test.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/handshake_server_test.go

index ce30385917efedbe983efc7e44646c245959ce7e..2b59136e654c54e43b5ec6cad4522719c6446ef4 100644 (file)
@@ -202,6 +202,32 @@ type ClientSessionCache interface {
        Put(sessionKey string, cs *ClientSessionState)
 }
 
+// ClientHelloInfo contains information from a ClientHello message in order to
+// guide certificate selection in the GetCertificate callback.
+type ClientHelloInfo struct {
+       // CipherSuites lists the CipherSuites supported by the client (e.g.
+       // TLS_RSA_WITH_RC4_128_SHA).
+       CipherSuites []uint16
+
+       // ServerName indicates the name of the server requested by the client
+       // in order to support virtual hosting. ServerName is only set if the
+       // client is using SNI (see
+       // http://tools.ietf.org/html/rfc4366#section-3.1).
+       ServerName string
+
+       // SupportedCurves lists the elliptic curves supported by the client.
+       // SupportedCurves is set only if the Supported Elliptic Curves
+       // Extension is being used (see
+       // http://tools.ietf.org/html/rfc4492#section-5.1.1).
+       SupportedCurves []CurveID
+
+       // SupportedPoints lists the point formats supported by the client.
+       // SupportedPoints is set only if the Supported Point Formats Extension
+       // is being used (see
+       // http://tools.ietf.org/html/rfc4492#section-5.1.2).
+       SupportedPoints []uint8
+}
+
 // A Config structure is used to configure a TLS client or server.
 // After one has been passed to a TLS function it must not be
 // modified. A Config may be reused; the tls package will also not
@@ -230,6 +256,13 @@ type Config struct {
        // for all connections.
        NameToCertificate map[string]*Certificate
 
+       // GetCertificate returns a Certificate based on the given
+       // ClientHelloInfo. If GetCertificate is nil or returns nil, then the
+       // certificate is retrieved from NameToCertificate. If
+       // NameToCertificate is nil, the first element of Certificates will be
+       // used.
+       GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
+
        // RootCAs defines the set of root certificate authorities
        // that clients use when verifying server certificates.
        // If RootCAs is nil, TLS uses the host's root CA set.
@@ -384,22 +417,28 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) {
        return vers, true
 }
 
-// getCertificateForName returns the best certificate for the given name,
-// defaulting to the first element of c.Certificates if there are no good
-// options.
-func (c *Config) getCertificateForName(name string) *Certificate {
+// getCertificate returns the best certificate for the given ClientHelloInfo,
+// defaulting to the first element of c.Certificates.
+func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) {
+       if c.GetCertificate != nil {
+               cert, err := c.GetCertificate(clientHello)
+               if cert != nil || err != nil {
+                       return cert, err
+               }
+       }
+
        if len(c.Certificates) == 1 || c.NameToCertificate == nil {
                // There's only one choice, so no point doing any work.
-               return &c.Certificates[0]
+               return &c.Certificates[0], nil
        }
 
-       name = strings.ToLower(name)
+       name := strings.ToLower(clientHello.ServerName)
        for len(name) > 0 && name[len(name)-1] == '.' {
                name = name[:len(name)-1]
        }
 
        if cert, ok := c.NameToCertificate[name]; ok {
-               return cert
+               return cert, nil
        }
 
        // try replacing labels in the name with wildcards until we get a
@@ -409,12 +448,12 @@ func (c *Config) getCertificateForName(name string) *Certificate {
                labels[i] = "*"
                candidate := strings.Join(labels, ".")
                if cert, ok := c.NameToCertificate[candidate]; ok {
-                       return cert
+                       return cert, nil
                }
        }
 
        // If nothing matches, return the first certificate.
-       return &c.Certificates[0]
+       return &c.Certificates[0], nil
 }
 
 // BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
index 5c555147ca893fb043613ccf9532f7819f964750..ec802cad70fc74d599dbe0ad48a7eb8dbaaf116a 100644 (file)
@@ -88,19 +88,31 @@ func TestCertificateSelection(t *testing.T) {
                return -1
        }
 
-       if n := pointerToIndex(config.getCertificateForName("example.com")); n != 0 {
+       certificateForName := func(name string) *Certificate {
+               clientHello := &ClientHelloInfo{
+                       ServerName: name,
+               }
+               if cert, err := config.getCertificate(clientHello); err != nil {
+                       t.Errorf("unable to get certificate for name '%s': %s", name, err)
+                       return nil
+               } else {
+                       return cert
+               }
+       }
+
+       if n := pointerToIndex(certificateForName("example.com")); n != 0 {
                t.Errorf("example.com returned certificate %d, not 0", n)
        }
-       if n := pointerToIndex(config.getCertificateForName("bar.example.com")); n != 1 {
+       if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
                t.Errorf("bar.example.com returned certificate %d, not 1", n)
        }
-       if n := pointerToIndex(config.getCertificateForName("foo.example.com")); n != 2 {
+       if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
                t.Errorf("foo.example.com returned certificate %d, not 2", n)
        }
-       if n := pointerToIndex(config.getCertificateForName("foo.bar.example.com")); n != 3 {
+       if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 {
                t.Errorf("foo.bar.example.com returned certificate %d, not 3", n)
        }
-       if n := pointerToIndex(config.getCertificateForName("foo.bar.baz.example.com")); n != 0 {
+       if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 {
                t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n)
        }
 }
index 910f3d6ef0f82b52f88ab836bc9d3176a700aa23..39eeb363cdfdc3fa174cc08fbcee7bf87583289d 100644 (file)
@@ -186,7 +186,16 @@ Curves:
        }
        hs.cert = &config.Certificates[0]
        if len(hs.clientHello.serverName) > 0 {
-               hs.cert = config.getCertificateForName(hs.clientHello.serverName)
+               chi := &ClientHelloInfo{
+                       CipherSuites:    hs.clientHello.cipherSuites,
+                       ServerName:      hs.clientHello.serverName,
+                       SupportedCurves: hs.clientHello.supportedCurves,
+                       SupportedPoints: hs.clientHello.supportedPoints,
+               }
+               if hs.cert, err = config.getCertificate(chi); err != nil {
+                       c.sendAlert(alertInternalError)
+                       return false, err
+               }
        }
 
        _, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
index e94dc9f9952c9d39019353ca8aef5469929b9518..36c79a9b0c534df65edd659604eb9dfe96f77392 100644 (file)
@@ -257,6 +257,9 @@ type serverTest struct {
        expectedPeerCerts []string
        // config, if not nil, contains a custom Config to use for this test.
        config *Config
+       // expectAlert, if true, indicates that a fatal alert should be returned
+       // when handshaking with the server.
+       expectAlert bool
        // validate, if not nil, is a function that will be called with the
        // ConnectionState of the resulting connection. It returns false if the
        // ConnectionState is unacceptable.
@@ -370,7 +373,9 @@ func (test *serverTest) run(t *testing.T, write bool) {
        if !write {
                flows, err := test.loadData()
                if err != nil {
-                       t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
+                       if !test.expectAlert {
+                               t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
+                       }
                }
                for i, b := range flows {
                        if i%2 == 0 {
@@ -379,11 +384,17 @@ func (test *serverTest) run(t *testing.T, write bool) {
                        }
                        bb := make([]byte, len(b))
                        n, err := io.ReadFull(clientConn, bb)
-                       if err != nil {
-                               t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
-                       }
-                       if !bytes.Equal(b, bb) {
-                               t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
+                       if test.expectAlert {
+                               if err == nil {
+                                       t.Fatal("Expected read failure but read succeeded")
+                               }
+                       } else {
+                               if err != nil {
+                                       t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
+                               }
+                               if !bytes.Equal(b, bb) {
+                                       t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
+                               }
                        }
                }
                clientConn.Close()
@@ -562,6 +573,61 @@ func TestHandshakeServerSNI(t *testing.T) {
        runServerTestTLS12(t, test)
 }
 
+// TestHandshakeServerSNICertForName is similar to TestHandshakeServerSNI, but
+// tests the dynamic GetCertificate method
+func TestHandshakeServerSNIGetCertificate(t *testing.T) {
+       config := *testConfig
+
+       // Replace the NameToCertificate map with a GetCertificate function
+       nameToCert := config.NameToCertificate
+       config.NameToCertificate = nil
+       config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+               cert, _ := nameToCert[clientHello.ServerName]
+               return cert, nil
+       }
+       test := &serverTest{
+               name:    "SNI",
+               command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
+               config:  &config,
+       }
+       runServerTestTLS12(t, test)
+}
+
+// TestHandshakeServerSNICertForNameNotFound is similar to
+// TestHandshakeServerSNICertForName, but tests to make sure that when the
+// GetCertificate method doesn't return a cert, we fall back to what's in
+// the NameToCertificate map.
+func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) {
+       config := *testConfig
+
+       config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+               return nil, nil
+       }
+       test := &serverTest{
+               name:    "SNI",
+               command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
+               config:  &config,
+       }
+       runServerTestTLS12(t, test)
+}
+
+// TestHandshakeServerSNICertForNameError tests to make sure that errors in
+// GetCertificate result in a tls alert.
+func TestHandshakeServerSNIGetCertificateError(t *testing.T) {
+       config := *testConfig
+
+       config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+               return nil, fmt.Errorf("Test error in GetCertificate")
+       }
+       test := &serverTest{
+               name:        "SNI",
+               command:     []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
+               config:      &config,
+               expectAlert: true,
+       }
+       runServerTestTLS12(t, test)
+}
+
 // TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with
 // an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate.
 func TestCipherSuiteCertPreferenceECDSA(t *testing.T) {