]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add Config.VerifyConnection callback
authorKatie Hockman <katie@golang.org>
Mon, 20 Apr 2020 21:55:37 +0000 (17:55 -0400)
committerKatie Hockman <katie@golang.org>
Fri, 8 May 2020 02:17:26 +0000 (02:17 +0000)
Since the ConnectionState will now be available during
verification, some code was moved around in order to
initialize and make available as much of the fields on
Conn as possible before the ConnectionState is verified.

Fixes #36736

Change-Id: I0e3efa97565ead7de5c48bb8a87e3ea54fbde140
Reviewed-on: https://go-review.googlesource.com/c/go/+/229122
Run-TryBot: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/crypto/tls/common.go
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/tls_test.go

index 90846a3659ceaefe5c77069ad55bdb8d2e505167..fd21ae8fb10e3925b33b47ba8eb76e64e2b50116 100644 (file)
@@ -219,7 +219,7 @@ type ConnectionState struct {
        CipherSuite                 uint16                // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...)
        NegotiatedProtocol          string                // negotiated next protocol (not guaranteed to be from Config.NextProtos)
        NegotiatedProtocolIsMutual  bool                  // negotiated protocol was advertised by server (client side only)
-       ServerName                  string                // server name requested by client, if any (server side only)
+       ServerName                  string                // server name requested by client, if any
        PeerCertificates            []*x509.Certificate   // certificate chain presented by remote peer
        VerifiedChains              [][]*x509.Certificate // verified chains built from PeerCertificates
        SignedCertificateTimestamps [][]byte              // SCTs from the peer, if any
@@ -520,6 +520,16 @@ type Config struct {
        // be considered but the verifiedChains argument will always be nil.
        VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
 
+       // VerifyConnection, if not nil, is called after normal certificate
+       // verification and after VerifyPeerCertificate by either a TLS client
+       // or server. If it returns a non-nil error, the handshake is aborted
+       // and that error results.
+       //
+       // If normal verification fails then the handshake will abort before
+       // considering this callback. This callback will run for all connections
+       // regardless of InsecureSkipVerify or ClientAuth settings.
+       VerifyConnection func(ConnectionState) error
+
        // RootCAs defines the set of root certificate authorities
        // that clients use when verifying server certificates.
        // If RootCAs is nil, TLS uses the host's root CA set.
@@ -685,6 +695,7 @@ func (c *Config) Clone() *Config {
                GetClientCertificate:        c.GetClientCertificate,
                GetConfigForClient:          c.GetConfigForClient,
                VerifyPeerCertificate:       c.VerifyPeerCertificate,
+               VerifyConnection:            c.VerifyConnection,
                RootCAs:                     c.RootCAs,
                NextProtos:                  c.NextProtos,
                ServerName:                  c.ServerName,
index d759986bb94a2c75c9791322908ba2c0d8b07f94..edcfecf81d77c63d9b3a14eee29a85bf0fed3324 100644 (file)
@@ -1379,35 +1379,34 @@ func (c *Conn) Handshake() error {
 func (c *Conn) ConnectionState() ConnectionState {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
+       return c.connectionStateLocked()
+}
 
+func (c *Conn) connectionStateLocked() ConnectionState {
        var state ConnectionState
        state.HandshakeComplete = c.handshakeComplete()
+       state.Version = c.vers
+       state.NegotiatedProtocol = c.clientProtocol
+       state.DidResume = c.didResume
+       state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
        state.ServerName = c.serverName
-
-       if state.HandshakeComplete {
-               state.Version = c.vers
-               state.NegotiatedProtocol = c.clientProtocol
-               state.DidResume = c.didResume
-               state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
-               state.CipherSuite = c.cipherSuite
-               state.PeerCertificates = c.peerCertificates
-               state.VerifiedChains = c.verifiedChains
-               state.SignedCertificateTimestamps = c.scts
-               state.OCSPResponse = c.ocspResponse
-               if !c.didResume && c.vers != VersionTLS13 {
-                       if c.clientFinishedIsFirst {
-                               state.TLSUnique = c.clientFinished[:]
-                       } else {
-                               state.TLSUnique = c.serverFinished[:]
-                       }
-               }
-               if c.config.Renegotiation != RenegotiateNever {
-                       state.ekm = noExportedKeyingMaterial
+       state.CipherSuite = c.cipherSuite
+       state.PeerCertificates = c.peerCertificates
+       state.VerifiedChains = c.verifiedChains
+       state.SignedCertificateTimestamps = c.scts
+       state.OCSPResponse = c.ocspResponse
+       if !c.didResume && c.vers != VersionTLS13 {
+               if c.clientFinishedIsFirst {
+                       state.TLSUnique = c.clientFinished[:]
                } else {
-                       state.ekm = c.ekm
+                       state.TLSUnique = c.serverFinished[:]
                }
        }
-
+       if c.config.Renegotiation != RenegotiateNever {
+               state.ekm = noExportedKeyingMaterial
+       } else {
+               state.ekm = c.ekm
+       }
        return state
 }
 
index 210eece26dd4f9e0826e6a762978e72f5b96dccd..40c8e02c53dc7f980581f66b755e38dddf4fbaf5 100644 (file)
@@ -146,6 +146,7 @@ func (c *Conn) clientHandshake() (err error) {
        if err != nil {
                return err
        }
+       c.serverName = hello.serverName
 
        cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
        if cacheKey != "" && session != nil {
@@ -388,6 +389,7 @@ func (hs *clientHandshakeState) handshake() error {
        hs.finishedHash.Write(hs.serverHello.marshal())
 
        c.buffering = true
+       c.didResume = isResume
        if isResume {
                if err := hs.establishKeys(); err != nil {
                        return err
@@ -399,6 +401,15 @@ func (hs *clientHandshakeState) handshake() error {
                        return err
                }
                c.clientFinishedIsFirst = false
+               // Make sure the connection is still being verified whether or not this
+               // is a resumption. Resumptions currently don't reverify certificates so
+               // they don't call verifyServerCertificate. See Issue 31641.
+               if c.config.VerifyConnection != nil {
+                       if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                               c.sendAlert(alertBadCertificate)
+                               return err
+                       }
+               }
                if err := hs.sendFinished(c.clientFinished[:]); err != nil {
                        return err
                }
@@ -428,7 +439,6 @@ func (hs *clientHandshakeState) handshake() error {
        }
 
        c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
-       c.didResume = isResume
        atomic.StoreUint32(&c.handshakeStatus, 1)
 
        return nil
@@ -458,25 +468,6 @@ func (hs *clientHandshakeState) doFullHandshake() error {
        }
        hs.finishedHash.Write(certMsg.marshal())
 
-       if c.handshakes == 0 {
-               // If this is the first handshake on a connection, process and
-               // (optionally) verify the server's certificates.
-               if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
-                       return err
-               }
-       } else {
-               // This is a renegotiation handshake. We require that the
-               // server's identity (i.e. leaf certificate) is unchanged and
-               // thus any previous trust decision is still valid.
-               //
-               // See https://mitls.org/pages/attacks/3SHAKE for the
-               // motivation behind this requirement.
-               if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
-                       c.sendAlert(alertBadCertificate)
-                       return errors.New("tls: server's identity changed during renegotiation")
-               }
-       }
-
        msg, err = c.readHandshake()
        if err != nil {
                return err
@@ -505,6 +496,25 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                }
        }
 
+       if c.handshakes == 0 {
+               // If this is the first handshake on a connection, process and
+               // (optionally) verify the server's certificates.
+               if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
+                       return err
+               }
+       } else {
+               // This is a renegotiation handshake. We require that the
+               // server's identity (i.e. leaf certificate) is unchanged and
+               // thus any previous trust decision is still valid.
+               //
+               // See https://mitls.org/pages/attacks/3SHAKE for the
+               // motivation behind this requirement.
+               if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
+                       c.sendAlert(alertBadCertificate)
+                       return errors.New("tls: server's identity changed during renegotiation")
+               }
+       }
+
        keyAgreement := hs.suite.ka(c.vers)
 
        skx, ok := msg.(*serverKeyExchangeMsg)
@@ -831,13 +841,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
                }
        }
 
-       if c.config.VerifyPeerCertificate != nil {
-               if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
-                       c.sendAlert(alertBadCertificate)
-                       return err
-               }
-       }
-
        switch certs[0].PublicKey.(type) {
        case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
                break
@@ -848,6 +851,20 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
 
        c.peerCertificates = certs
 
+       if c.config.VerifyPeerCertificate != nil {
+               if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return err
+               }
+       }
+
+       if c.config.VerifyConnection != nil {
+               if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return err
+               }
+       }
+
        return nil
 }
 
index cd387dcc6cb1c28766156b8c9a4344c637e4ee4f..313872ca76c6d37dc8ceedbece479b02c666833a 100644 (file)
@@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) {
                if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
                        t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
                }
+               if got, want := hs.ServerName, clientConfig.ServerName; got != want {
+                       t.Errorf("%s: server name %s, want %s", test, got, want)
+               }
        }
 
        getTicket := func() []byte {
@@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
 
        sentinelErr := errors.New("TestVerifyPeerCertificate")
 
-       verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+       verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
                if l := len(rawCerts); l != 1 {
                        return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
                }
@@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
                *called = true
                return nil
        }
+       verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
+               if l := len(c.PeerCertificates); l != 1 {
+                       return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
+               }
+               if len(c.VerifiedChains) == 0 {
+                       return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
+               }
+               if isClient && len(c.OCSPResponse) == 0 {
+                       return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
+               }
+               *called = true
+               return nil
+       }
 
        tests := []struct {
                configureServer func(*Config, *bool)
@@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
                        configureServer: func(config *Config, called *bool) {
                                config.InsecureSkipVerify = false
                                config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
-                                       return verifyCallback(called, rawCerts, validatedChains)
+                                       return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
                                }
                        },
                        configureClient: func(config *Config, called *bool) {
                                config.InsecureSkipVerify = false
                                config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
-                                       return verifyCallback(called, rawCerts, validatedChains)
+                                       return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
                                }
                        },
                        validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
@@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
                                }
                        },
                },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       return verifyConnectionCallback(called, false, c)
+                               }
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       return verifyConnectionCallback(called, true, c)
+                               }
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if clientErr != nil {
+                                       t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
+                               }
+                               if serverErr != nil {
+                                       t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
+                               }
+                               if !clientCalled {
+                                       t.Errorf("test[%d]: client did not call callback", testNo)
+                               }
+                               if !serverCalled {
+                                       t.Errorf("test[%d]: server did not call callback", testNo)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       return sentinelErr
+                               }
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyConnection = nil
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if serverErr != sentinelErr {
+                                       t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyConnection = nil
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       return sentinelErr
+                               }
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if clientErr != sentinelErr {
+                                       t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+                               }
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       return sentinelErr
+                               }
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = nil
+                               config.VerifyConnection = nil
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if serverErr != sentinelErr {
+                                       t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+                               }
+                               if !serverCalled {
+                                       t.Errorf("test[%d]: server did not call callback", testNo)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = nil
+                               config.VerifyConnection = nil
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+                               }
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       return sentinelErr
+                               }
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if clientErr != sentinelErr {
+                                       t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+                               }
+                               if !clientCalled {
+                                       t.Errorf("test[%d]: client did not call callback", testNo)
+                               }
+                       },
+               },
        }
 
        for i, test := range tests {
@@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) {
                        config.ClientCAs = rootCAs
                        config.Time = now
                        config.MaxVersion = version
+                       config.Certificates = make([]Certificate, 1)
+                       config.Certificates[0].Certificate = [][]byte{testRSACertificate}
+                       config.Certificates[0].PrivateKey = testRSAPrivateKey
+                       config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+                       config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
                        test.configureServer(config, &serverCalled)
 
                        err = Server(s, config).Handshake()
index 97122bd2202f45e0bdae81aa8e73ac89a83991b5..35a00f2f3a4c87376f9cdc8e2518541eb721f6f6 100644 (file)
@@ -407,6 +407,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
        // Either a PSK or a certificate is always used, but not both.
        // See RFC 8446, Section 4.1.1.
        if hs.usingPSK {
+               // Make sure the connection is still being verified whether or not this
+               // is a resumption. Resumptions currently don't reverify certificates so
+               // they don't call verifyServerCertificate. See Issue 31641.
+               if c.config.VerifyConnection != nil {
+                       if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                               c.sendAlert(alertBadCertificate)
+                               return err
+                       }
+               }
                return nil
        }
 
index 4885c69568d439b8ac5d35e99f558df716ca0d1f..6aacfa1ff6622c364f1573ac3c9fd3a218ed68a2 100644 (file)
@@ -68,6 +68,7 @@ func (hs *serverHandshakeState) handshake() error {
        c.buffering = true
        if hs.checkForResumption() {
                // The client has included a session ticket and so we do an abbreviated handshake.
+               c.didResume = true
                if err := hs.doResumeHandshake(); err != nil {
                        return err
                }
@@ -92,7 +93,6 @@ func (hs *serverHandshakeState) handshake() error {
                if err := hs.readFinished(nil); err != nil {
                        return err
                }
-               c.didResume = true
        } else {
                // The client didn't include a session ticket, or it wasn't
                // valid so we do a full handshake.
@@ -553,6 +553,15 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                if err != nil {
                        return err
                }
+       } else {
+               // Make sure the connection is still being verified whether or not
+               // the server requested a client certificate.
+               if c.config.VerifyConnection != nil {
+                       if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                               c.sendAlert(alertBadCertificate)
+                               return err
+                       }
+               }
        }
 
        // Get client key exchange
@@ -771,6 +780,19 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
                c.verifiedChains = chains
        }
 
+       c.peerCertificates = certs
+       c.ocspResponse = certificate.OCSPStaple
+       c.scts = certificate.SignedCertificateTimestamps
+
+       if len(certs) > 0 {
+               switch certs[0].PublicKey.(type) {
+               case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
+               default:
+                       c.sendAlert(alertUnsupportedCertificate)
+                       return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
+               }
+       }
+
        if c.config.VerifyPeerCertificate != nil {
                if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
                        c.sendAlert(alertBadCertificate)
@@ -778,20 +800,13 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
                }
        }
 
-       if len(certs) == 0 {
-               return nil
-       }
-
-       switch certs[0].PublicKey.(type) {
-       case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
-       default:
-               c.sendAlert(alertUnsupportedCertificate)
-               return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
+       if c.config.VerifyConnection != nil {
+               if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return err
+               }
        }
 
-       c.peerCertificates = certs
-       c.ocspResponse = certificate.OCSPStaple
-       c.scts = certificate.SignedCertificateTimestamps
        return nil
 }
 
index 5432145de46292095403c8e95588c931d6e0595f..fb7f871390e5b9d2183ef5d5938d297a6bba9f68 100644 (file)
@@ -306,6 +306,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
                        return errors.New("tls: invalid PSK binder")
                }
 
+               c.didResume = true
                if err := c.processCertsFromClient(sessionState.certificate); err != nil {
                        return err
                }
@@ -313,7 +314,6 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
                hs.hello.selectedIdentityPresent = true
                hs.hello.selectedIdentity = uint16(i)
                hs.usingPSK = true
-               c.didResume = true
                return nil
        }
 
@@ -753,6 +753,14 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
        c := hs.c
 
        if !hs.requestClientCert() {
+               // Make sure the connection is still being verified whether or not
+               // the server requested a client certificate.
+               if c.config.VerifyConnection != nil {
+                       if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                               c.sendAlert(alertBadCertificate)
+                               return err
+                       }
+               }
                return nil
        }
 
index 85005d4950ae17e2687c228572df1d3ae1783abf..9e340774b679ed25e9b79cda23610dfd15efafdc 100644 (file)
@@ -734,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) {
 }
 
 func TestCloneFuncFields(t *testing.T) {
-       const expectedCount = 5
+       const expectedCount = 6
        called := 0
 
        c1 := Config{
@@ -758,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) {
                        called |= 1 << 4
                        return nil
                },
+               VerifyConnection: func(ConnectionState) error {
+                       called |= 1 << 5
+                       return nil
+               },
        }
 
        c2 := c1.Clone()
@@ -767,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) {
        c2.GetClientCertificate(nil)
        c2.GetConfigForClient(nil)
        c2.VerifyPeerCertificate(nil, nil)
+       c2.VerifyConnection(ConnectionState{})
 
        if called != (1<<expectedCount)-1 {
                t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@@ -790,7 +795,7 @@ func TestCloneNonFuncFields(t *testing.T) {
                switch fn := typ.Field(i).Name; fn {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
-               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
+               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
                        // DeepEqual can't compare functions. If you add a
                        // function field to this list, you must also change
                        // TestCloneFuncFields to ensure that the func field is
@@ -1116,8 +1121,8 @@ func TestConnectionState(t *testing.T) {
                        if ss.ServerName != serverName {
                                t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
                        }
-                       if cs.ServerName != "" {
-                               t.Errorf("Got unexpected server name on the client side")
+                       if cs.ServerName != serverName {
+                               t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
                        }
 
                        if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {