]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: fix duplicate calls to VerifyConnection
authorKatie Hockman <katie@golang.org>
Wed, 13 May 2020 21:44:20 +0000 (17:44 -0400)
committerKatie Hockman <katie@golang.org>
Wed, 3 Jun 2020 19:01:50 +0000 (19:01 +0000)
Also add a test that could reproduce this error and
ensure it doesn't occur in other configurations.

Fixes #39012

Change-Id: If792b5131f312c269fd2c5f08c9ed5c00188d1af
Reviewed-on: https://go-review.googlesource.com/c/go/+/233957
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_tls13.go

index de93e1b63f3a92c36b815a8a624b1545a0e01128..88c974f83d456092d83125e938eecb3d3b32eea1 100644 (file)
@@ -1464,6 +1464,225 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
        }
 }
 
+func TestVerifyConnection(t *testing.T) {
+       t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
+       t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
+}
+
+func testVerifyConnection(t *testing.T, version uint16) {
+       checkFields := func(c ConnectionState, called *int) error {
+               if c.Version != version {
+                       return fmt.Errorf("got Version %v, want %v", c.Version, version)
+               }
+               if c.HandshakeComplete {
+                       return fmt.Errorf("got HandshakeComplete, want false")
+               }
+               if c.ServerName != "example.golang" {
+                       return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
+               }
+               if c.NegotiatedProtocol != "protocol1" {
+                       return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
+               }
+               wantDidResume := false
+               if *called == 2 { // if this is the second time, then it should be a resumption
+                       wantDidResume = true
+               }
+               if c.DidResume != wantDidResume {
+                       return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
+               }
+               return nil
+       }
+
+       tests := []struct {
+               name            string
+               configureServer func(*Config, *int)
+               configureClient func(*Config, *int)
+       }{
+               {
+                       name: "RequireAndVerifyClientCert",
+                       configureServer: func(config *Config, called *int) {
+                               config.ClientAuth = RequireAndVerifyClientCert
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       if l := len(c.PeerCertificates); l != 1 {
+                                               return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+                                       }
+                                       if len(c.VerifiedChains) == 0 {
+                                               return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
+                                       }
+                                       return checkFields(c, called)
+                               }
+                       },
+                       configureClient: func(config *Config, called *int) {
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       if l := len(c.PeerCertificates); l != 1 {
+                                               return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+                                       }
+                                       if len(c.VerifiedChains) == 0 {
+                                               return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+                                       }
+                                       if c.DidResume {
+                                               return nil
+                                               // The SCTs and OCSP Responce are dropped on resumption.
+                                               // See http://golang.org/issue/39075.
+                                       }
+                                       if len(c.OCSPResponse) == 0 {
+                                               return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+                                       }
+                                       if len(c.SignedCertificateTimestamps) == 0 {
+                                               return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+                                       }
+                                       return checkFields(c, called)
+                               }
+                       },
+               },
+               {
+                       name: "InsecureSkipVerify",
+                       configureServer: func(config *Config, called *int) {
+                               config.ClientAuth = RequireAnyClientCert
+                               config.InsecureSkipVerify = true
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       if l := len(c.PeerCertificates); l != 1 {
+                                               return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+                                       }
+                                       if c.VerifiedChains != nil {
+                                               return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+                                       }
+                                       return checkFields(c, called)
+                               }
+                       },
+                       configureClient: func(config *Config, called *int) {
+                               config.InsecureSkipVerify = true
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       if l := len(c.PeerCertificates); l != 1 {
+                                               return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+                                       }
+                                       if c.VerifiedChains != nil {
+                                               return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+                                       }
+                                       if c.DidResume {
+                                               return nil
+                                               // The SCTs and OCSP Responce are dropped on resumption.
+                                               // See http://golang.org/issue/39075.
+                                       }
+                                       if len(c.OCSPResponse) == 0 {
+                                               return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+                                       }
+                                       if len(c.SignedCertificateTimestamps) == 0 {
+                                               return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+                                       }
+                                       return checkFields(c, called)
+                               }
+                       },
+               },
+               {
+                       name: "NoClientCert",
+                       configureServer: func(config *Config, called *int) {
+                               config.ClientAuth = NoClientCert
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       return checkFields(c, called)
+                               }
+                       },
+                       configureClient: func(config *Config, called *int) {
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       return checkFields(c, called)
+                               }
+                       },
+               },
+               {
+                       name: "RequestClientCert",
+                       configureServer: func(config *Config, called *int) {
+                               config.ClientAuth = RequestClientCert
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       return checkFields(c, called)
+                               }
+                       },
+                       configureClient: func(config *Config, called *int) {
+                               config.Certificates = nil // clear the client cert
+                               config.VerifyConnection = func(c ConnectionState) error {
+                                       *called++
+                                       if l := len(c.PeerCertificates); l != 1 {
+                                               return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+                                       }
+                                       if len(c.VerifiedChains) == 0 {
+                                               return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+                                       }
+                                       if c.DidResume {
+                                               return nil
+                                               // The SCTs and OCSP Responce are dropped on resumption.
+                                               // See http://golang.org/issue/39075.
+                                       }
+                                       if len(c.OCSPResponse) == 0 {
+                                               return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+                                       }
+                                       if len(c.SignedCertificateTimestamps) == 0 {
+                                               return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+                                       }
+                                       return checkFields(c, called)
+                               }
+                       },
+               },
+       }
+       for _, test := range tests {
+               issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+               if err != nil {
+                       panic(err)
+               }
+               rootCAs := x509.NewCertPool()
+               rootCAs.AddCert(issuer)
+
+               var serverCalled, clientCalled int
+
+               serverConfig := &Config{
+                       MaxVersion:   version,
+                       Certificates: []Certificate{testConfig.Certificates[0]},
+                       ClientCAs:    rootCAs,
+                       NextProtos:   []string{"protocol1"},
+               }
+               serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+               serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
+               test.configureServer(serverConfig, &serverCalled)
+
+               clientConfig := &Config{
+                       MaxVersion:         version,
+                       ClientSessionCache: NewLRUClientSessionCache(32),
+                       RootCAs:            rootCAs,
+                       ServerName:         "example.golang",
+                       Certificates:       []Certificate{testConfig.Certificates[0]},
+                       NextProtos:         []string{"protocol1"},
+               }
+               test.configureClient(clientConfig, &clientCalled)
+
+               testHandshakeState := func(name string, didResume bool) {
+                       _, hs, err := testHandshake(t, clientConfig, serverConfig)
+                       if err != nil {
+                               t.Fatalf("%s: handshake failed: %s", name, err)
+                       }
+                       if hs.DidResume != didResume {
+                               t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
+                       }
+                       wantCalled := 1
+                       if didResume {
+                               wantCalled = 2 // resumption would mean this is the second time it was called in this test
+                       }
+                       if clientCalled != wantCalled {
+                               t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
+                       }
+                       if serverCalled != wantCalled {
+                               t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
+                       }
+               }
+               testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
+               testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
+       }
+}
+
 func TestVerifyPeerCertificate(t *testing.T) {
        t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
        t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
index 57fba108a72420a8ec5570c08d9221789a0c54be..2c2f0a48797e47b85191718e0ab9fb0ff460a96e 100644 (file)
@@ -425,6 +425,13 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
                return err
        }
 
+       if c.config.VerifyConnection != nil {
+               if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return err
+               }
+       }
+
        hs.masterSecret = hs.sessionState.masterSecret
 
        return nil
@@ -548,14 +555,11 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                if err != nil {
                        return err
                }
-       } else {
-               // Make sure the connection is still being verified whether or not
-               // the server requested a client certificate.
-               if c.config.VerifyConnection != nil {
-                       if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
-                               c.sendAlert(alertBadCertificate)
-                               return err
-                       }
+       }
+       if c.config.VerifyConnection != nil {
+               if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return err
                }
        }
 
@@ -805,13 +809,6 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
                }
        }
 
-       if c.config.VerifyConnection != nil {
-               if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
-                       c.sendAlert(alertBadCertificate)
-                       return err
-               }
-       }
-
        return nil
 }
 
index fb7f871390e5b9d2183ef5d5938d297a6bba9f68..92d55e0293a46e9d362fc7ec34db9fd990fe3932 100644 (file)
@@ -783,6 +783,13 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
                return err
        }
 
+       if c.config.VerifyConnection != nil {
+               if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return err
+               }
+       }
+
        if len(certMsg.certificate.Certificate) != 0 {
                msg, err = c.readHandshake()
                if err != nil {