]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: set CipherSuite for VerifyConnection
authorKatie Hockman <katie@golang.org>
Thu, 4 Jun 2020 14:52:24 +0000 (10:52 -0400)
committerKatie Hockman <katie@golang.org>
Thu, 4 Jun 2020 20:16:53 +0000 (20:16 +0000)
The ConnectionState's CipherSuite was not set prior
to the VerifyConnection callback in TLS 1.2 servers,
both for full handshakes and resumptions.

Change-Id: Iab91783eff84d1b42ca09c8df08e07861e18da30
Reviewed-on: https://go-review.googlesource.com/c/go/+/236558
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go

index 88c974f83d456092d83125e938eecb3d3b32eea1..1cda90190cec2e7e3533199bb4dcae85b8b84248 100644 (file)
@@ -1470,25 +1470,28 @@ func TestVerifyConnection(t *testing.T) {
 }
 
 func testVerifyConnection(t *testing.T, version uint16) {
-       checkFields := func(c ConnectionState, called *int) error {
+       checkFields := func(c ConnectionState, called *int, errorType string) error {
                if c.Version != version {
-                       return fmt.Errorf("got Version %v, want %v", c.Version, version)
+                       return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
                }
                if c.HandshakeComplete {
-                       return fmt.Errorf("got HandshakeComplete, want false")
+                       return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
                }
                if c.ServerName != "example.golang" {
-                       return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
+                       return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
                }
                if c.NegotiatedProtocol != "protocol1" {
-                       return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
+                       return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
+               }
+               if c.CipherSuite == 0 {
+                       return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
                }
                wantDidResume := false
                if *called == 2 { // if this is the second time, then it should be a resumption
                        wantDidResume = true
                }
                if c.DidResume != wantDidResume {
-                       return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
+                       return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
                }
                return nil
        }
@@ -1510,7 +1513,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                        if len(c.VerifiedChains) == 0 {
                                                return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
                                        }
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "server")
                                }
                        },
                        configureClient: func(config *Config, called *int) {
@@ -1533,7 +1536,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                        if len(c.SignedCertificateTimestamps) == 0 {
                                                return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
                                        }
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "client")
                                }
                        },
                },
@@ -1550,7 +1553,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                        if c.VerifiedChains != nil {
                                                return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
                                        }
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "server")
                                }
                        },
                        configureClient: func(config *Config, called *int) {
@@ -1574,7 +1577,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                        if len(c.SignedCertificateTimestamps) == 0 {
                                                return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
                                        }
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "client")
                                }
                        },
                },
@@ -1584,13 +1587,13 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                config.ClientAuth = NoClientCert
                                config.VerifyConnection = func(c ConnectionState) error {
                                        *called++
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "server")
                                }
                        },
                        configureClient: func(config *Config, called *int) {
                                config.VerifyConnection = func(c ConnectionState) error {
                                        *called++
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "client")
                                }
                        },
                },
@@ -1600,7 +1603,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                config.ClientAuth = RequestClientCert
                                config.VerifyConnection = func(c ConnectionState) error {
                                        *called++
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "server")
                                }
                        },
                        configureClient: func(config *Config, called *int) {
@@ -1624,7 +1627,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
                                        if len(c.SignedCertificateTimestamps) == 0 {
                                                return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
                                        }
-                                       return checkFields(c, called)
+                                       return checkFields(c, called, "client")
                                }
                        },
                },
index 2c2f0a48797e47b85191718e0ab9fb0ff460a96e..16d3e643f0b28ed45936142e6dacdb56385ee480 100644 (file)
@@ -308,6 +308,7 @@ func (hs *serverHandshakeState) pickCipherSuite() error {
                c.sendAlert(alertHandshakeFailure)
                return errors.New("tls: no cipher suite supported by both client and server")
        }
+       c.cipherSuite = hs.suite.id
 
        for _, id := range hs.clientHello.cipherSuites {
                if id == TLS_FALLBACK_SCSV {
@@ -407,6 +408,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
        c := hs.c
 
        hs.hello.cipherSuite = hs.suite.id
+       c.cipherSuite = hs.suite.id
        // We echo the client's session ID in the ServerHello to let it know
        // that we're doing a resumption.
        hs.hello.sessionId = hs.clientHello.sessionId
@@ -743,7 +745,6 @@ func (hs *serverHandshakeState) sendFinished(out []byte) error {
                return err
        }
 
-       c.cipherSuite = hs.suite.id
        copy(out, finished.verifyData)
 
        return nil