CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...)
NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos)
NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only)
- ServerName string // server name requested by client, if any (server side only)
+ ServerName string // server name requested by client, if any
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
SignedCertificateTimestamps [][]byte // SCTs from the peer, if any
// be considered but the verifiedChains argument will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
+ // VerifyConnection, if not nil, is called after normal certificate
+ // verification and after VerifyPeerCertificate by either a TLS client
+ // or server. If it returns a non-nil error, the handshake is aborted
+ // and that error results.
+ //
+ // If normal verification fails then the handshake will abort before
+ // considering this callback. This callback will run for all connections
+ // regardless of InsecureSkipVerify or ClientAuth settings.
+ VerifyConnection func(ConnectionState) error
+
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: c.GetConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
+ VerifyConnection: c.VerifyConnection,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
func (c *Conn) ConnectionState() ConnectionState {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
+ return c.connectionStateLocked()
+}
+func (c *Conn) connectionStateLocked() ConnectionState {
var state ConnectionState
state.HandshakeComplete = c.handshakeComplete()
+ state.Version = c.vers
+ state.NegotiatedProtocol = c.clientProtocol
+ state.DidResume = c.didResume
+ state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
state.ServerName = c.serverName
-
- if state.HandshakeComplete {
- state.Version = c.vers
- state.NegotiatedProtocol = c.clientProtocol
- state.DidResume = c.didResume
- state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
- state.CipherSuite = c.cipherSuite
- state.PeerCertificates = c.peerCertificates
- state.VerifiedChains = c.verifiedChains
- state.SignedCertificateTimestamps = c.scts
- state.OCSPResponse = c.ocspResponse
- if !c.didResume && c.vers != VersionTLS13 {
- if c.clientFinishedIsFirst {
- state.TLSUnique = c.clientFinished[:]
- } else {
- state.TLSUnique = c.serverFinished[:]
- }
- }
- if c.config.Renegotiation != RenegotiateNever {
- state.ekm = noExportedKeyingMaterial
+ state.CipherSuite = c.cipherSuite
+ state.PeerCertificates = c.peerCertificates
+ state.VerifiedChains = c.verifiedChains
+ state.SignedCertificateTimestamps = c.scts
+ state.OCSPResponse = c.ocspResponse
+ if !c.didResume && c.vers != VersionTLS13 {
+ if c.clientFinishedIsFirst {
+ state.TLSUnique = c.clientFinished[:]
} else {
- state.ekm = c.ekm
+ state.TLSUnique = c.serverFinished[:]
}
}
-
+ if c.config.Renegotiation != RenegotiateNever {
+ state.ekm = noExportedKeyingMaterial
+ } else {
+ state.ekm = c.ekm
+ }
return state
}
if err != nil {
return err
}
+ c.serverName = hello.serverName
cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
if cacheKey != "" && session != nil {
hs.finishedHash.Write(hs.serverHello.marshal())
c.buffering = true
+ c.didResume = isResume
if isResume {
if err := hs.establishKeys(); err != nil {
return err
return err
}
c.clientFinishedIsFirst = false
+ // Make sure the connection is still being verified whether or not this
+ // is a resumption. Resumptions currently don't reverify certificates so
+ // they don't call verifyServerCertificate. See Issue 31641.
+ if c.config.VerifyConnection != nil {
+ if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
+ }
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err
}
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
- c.didResume = isResume
atomic.StoreUint32(&c.handshakeStatus, 1)
return nil
}
hs.finishedHash.Write(certMsg.marshal())
- if c.handshakes == 0 {
- // If this is the first handshake on a connection, process and
- // (optionally) verify the server's certificates.
- if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
- return err
- }
- } else {
- // This is a renegotiation handshake. We require that the
- // server's identity (i.e. leaf certificate) is unchanged and
- // thus any previous trust decision is still valid.
- //
- // See https://mitls.org/pages/attacks/3SHAKE for the
- // motivation behind this requirement.
- if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
- c.sendAlert(alertBadCertificate)
- return errors.New("tls: server's identity changed during renegotiation")
- }
- }
-
msg, err = c.readHandshake()
if err != nil {
return err
}
}
+ if c.handshakes == 0 {
+ // If this is the first handshake on a connection, process and
+ // (optionally) verify the server's certificates.
+ if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
+ return err
+ }
+ } else {
+ // This is a renegotiation handshake. We require that the
+ // server's identity (i.e. leaf certificate) is unchanged and
+ // thus any previous trust decision is still valid.
+ //
+ // See https://mitls.org/pages/attacks/3SHAKE for the
+ // motivation behind this requirement.
+ if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
+ c.sendAlert(alertBadCertificate)
+ return errors.New("tls: server's identity changed during renegotiation")
+ }
+ }
+
keyAgreement := hs.suite.ka(c.vers)
skx, ok := msg.(*serverKeyExchangeMsg)
}
}
- if c.config.VerifyPeerCertificate != nil {
- if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
- c.sendAlert(alertBadCertificate)
- return err
- }
- }
-
switch certs[0].PublicKey.(type) {
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
break
c.peerCertificates = certs
+ if c.config.VerifyPeerCertificate != nil {
+ if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
+ }
+
+ if c.config.VerifyConnection != nil {
+ if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
+ }
+
return nil
}
if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
}
+ if got, want := hs.ServerName, clientConfig.ServerName; got != want {
+ t.Errorf("%s: server name %s, want %s", test, got, want)
+ }
}
getTicket := func() []byte {
sentinelErr := errors.New("TestVerifyPeerCertificate")
- verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
if l := len(rawCerts); l != 1 {
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
}
*called = true
return nil
}
+ verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ if isClient && len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ *called = true
+ return nil
+ }
tests := []struct {
configureServer func(*Config, *bool)
configureServer: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
- return verifyCallback(called, rawCerts, validatedChains)
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
},
configureClient: func(config *Config, called *bool) {
config.InsecureSkipVerify = false
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
- return verifyCallback(called, rawCerts, validatedChains)
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
}
},
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
}
},
},
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return verifyConnectionCallback(called, false, c)
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return verifyConnectionCallback(called, true, c)
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != nil {
+ t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
+ }
+ if serverErr != nil {
+ t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
+ }
+ if !clientCalled {
+ t.Errorf("test[%d]: client did not call callback", testNo)
+ }
+ if !serverCalled {
+ t.Errorf("test[%d]: server did not call callback", testNo)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = nil
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if serverErr != sentinelErr {
+ t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = nil
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != sentinelErr {
+ t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+ }
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = nil
+ config.VerifyConnection = nil
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if serverErr != sentinelErr {
+ t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+ }
+ if !serverCalled {
+ t.Errorf("test[%d]: server did not call callback", testNo)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = nil
+ config.VerifyConnection = nil
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+ }
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != sentinelErr {
+ t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+ }
+ if !clientCalled {
+ t.Errorf("test[%d]: client did not call callback", testNo)
+ }
+ },
+ },
}
for i, test := range tests {
config.ClientCAs = rootCAs
config.Time = now
config.MaxVersion = version
+ config.Certificates = make([]Certificate, 1)
+ config.Certificates[0].Certificate = [][]byte{testRSACertificate}
+ config.Certificates[0].PrivateKey = testRSAPrivateKey
+ config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+ config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
test.configureServer(config, &serverCalled)
err = Server(s, config).Handshake()
// Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1.
if hs.usingPSK {
+ // Make sure the connection is still being verified whether or not this
+ // is a resumption. Resumptions currently don't reverify certificates so
+ // they don't call verifyServerCertificate. See Issue 31641.
+ if c.config.VerifyConnection != nil {
+ if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
+ }
return nil
}
c.buffering = true
if hs.checkForResumption() {
// The client has included a session ticket and so we do an abbreviated handshake.
+ c.didResume = true
if err := hs.doResumeHandshake(); err != nil {
return err
}
if err := hs.readFinished(nil); err != nil {
return err
}
- c.didResume = true
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
if err != nil {
return err
}
+ } else {
+ // Make sure the connection is still being verified whether or not
+ // the server requested a client certificate.
+ if c.config.VerifyConnection != nil {
+ if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
+ }
}
// Get client key exchange
c.verifiedChains = chains
}
+ c.peerCertificates = certs
+ c.ocspResponse = certificate.OCSPStaple
+ c.scts = certificate.SignedCertificateTimestamps
+
+ if len(certs) > 0 {
+ switch certs[0].PublicKey.(type) {
+ case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
+ default:
+ c.sendAlert(alertUnsupportedCertificate)
+ return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
+ }
+ }
+
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
}
}
- if len(certs) == 0 {
- return nil
- }
-
- switch certs[0].PublicKey.(type) {
- case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
- default:
- c.sendAlert(alertUnsupportedCertificate)
- return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
+ if c.config.VerifyConnection != nil {
+ if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
}
- c.peerCertificates = certs
- c.ocspResponse = certificate.OCSPStaple
- c.scts = certificate.SignedCertificateTimestamps
return nil
}
return errors.New("tls: invalid PSK binder")
}
+ c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err
}
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
- c.didResume = true
return nil
}
c := hs.c
if !hs.requestClientCert() {
+ // Make sure the connection is still being verified whether or not
+ // the server requested a client certificate.
+ if c.config.VerifyConnection != nil {
+ if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+ c.sendAlert(alertBadCertificate)
+ return err
+ }
+ }
return nil
}
}
func TestCloneFuncFields(t *testing.T) {
- const expectedCount = 5
+ const expectedCount = 6
called := 0
c1 := Config{
called |= 1 << 4
return nil
},
+ VerifyConnection: func(ConnectionState) error {
+ called |= 1 << 5
+ return nil
+ },
}
c2 := c1.Clone()
c2.GetClientCertificate(nil)
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
+ c2.VerifyConnection(ConnectionState{})
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
- case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
+ case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
if ss.ServerName != serverName {
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
}
- if cs.ServerName != "" {
- t.Errorf("Got unexpected server name on the client side")
+ if cs.ServerName != serverName {
+ t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
}
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {