}
}
+func TestVerifyConnection(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
+}
+
+func testVerifyConnection(t *testing.T, version uint16) {
+ checkFields := func(c ConnectionState, called *int) error {
+ if c.Version != version {
+ return fmt.Errorf("got Version %v, want %v", c.Version, version)
+ }
+ if c.HandshakeComplete {
+ return fmt.Errorf("got HandshakeComplete, want false")
+ }
+ if c.ServerName != "example.golang" {
+ return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
+ }
+ if c.NegotiatedProtocol != "protocol1" {
+ return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
+ }
+ wantDidResume := false
+ if *called == 2 { // if this is the second time, then it should be a resumption
+ wantDidResume = true
+ }
+ if c.DidResume != wantDidResume {
+ return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
+ }
+ return nil
+ }
+
+ tests := []struct {
+ name string
+ configureServer func(*Config, *int)
+ configureClient func(*Config, *int)
+ }{
+ {
+ name: "RequireAndVerifyClientCert",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = RequireAndVerifyClientCert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ return checkFields(c, called)
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ if c.DidResume {
+ return nil
+ // The SCTs and OCSP Responce are dropped on resumption.
+ // See http://golang.org/issue/39075.
+ }
+ if len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ if len(c.SignedCertificateTimestamps) == 0 {
+ return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+ }
+ return checkFields(c, called)
+ }
+ },
+ },
+ {
+ name: "InsecureSkipVerify",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = RequireAnyClientCert
+ config.InsecureSkipVerify = true
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if c.VerifiedChains != nil {
+ return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+ }
+ return checkFields(c, called)
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.InsecureSkipVerify = true
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if c.VerifiedChains != nil {
+ return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+ }
+ if c.DidResume {
+ return nil
+ // The SCTs and OCSP Responce are dropped on resumption.
+ // See http://golang.org/issue/39075.
+ }
+ if len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ if len(c.SignedCertificateTimestamps) == 0 {
+ return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+ }
+ return checkFields(c, called)
+ }
+ },
+ },
+ {
+ name: "NoClientCert",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = NoClientCert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ return checkFields(c, called)
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ return checkFields(c, called)
+ }
+ },
+ },
+ {
+ name: "RequestClientCert",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = RequestClientCert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ return checkFields(c, called)
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.Certificates = nil // clear the client cert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ if c.DidResume {
+ return nil
+ // The SCTs and OCSP Responce are dropped on resumption.
+ // See http://golang.org/issue/39075.
+ }
+ if len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ if len(c.SignedCertificateTimestamps) == 0 {
+ return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+ }
+ return checkFields(c, called)
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ panic(err)
+ }
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ var serverCalled, clientCalled int
+
+ serverConfig := &Config{
+ MaxVersion: version,
+ Certificates: []Certificate{testConfig.Certificates[0]},
+ ClientCAs: rootCAs,
+ NextProtos: []string{"protocol1"},
+ }
+ serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+ serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
+ test.configureServer(serverConfig, &serverCalled)
+
+ clientConfig := &Config{
+ MaxVersion: version,
+ ClientSessionCache: NewLRUClientSessionCache(32),
+ RootCAs: rootCAs,
+ ServerName: "example.golang",
+ Certificates: []Certificate{testConfig.Certificates[0]},
+ NextProtos: []string{"protocol1"},
+ }
+ test.configureClient(clientConfig, &clientCalled)
+
+ testHandshakeState := func(name string, didResume bool) {
+ _, hs, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("%s: handshake failed: %s", name, err)
+ }
+ if hs.DidResume != didResume {
+ t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
+ }
+ wantCalled := 1
+ if didResume {
+ wantCalled = 2 // resumption would mean this is the second time it was called in this test
+ }
+ if clientCalled != wantCalled {
+ t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
+ }
+ if serverCalled != wantCalled {
+ t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
+ }
+ }
+ testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
+ testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
+ }
+}
+
func TestVerifyPeerCertificate(t *testing.T) {
t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })