From: Roland Shoemaker Date: Sat, 25 Jan 2025 18:28:02 +0000 (-0800) Subject: crypto/tls: replace custom intern cache with weak cache X-Git-Tag: go1.25rc1~149 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=360600b1d20bc4b421217435d921a2437be07893;p=gostls13.git crypto/tls: replace custom intern cache with weak cache Uses the new weak package to replace the existing custom intern cache with a map of weak.Pointers instead. This simplifies the cache, and means we don't need to store a slice of handles on the Conn anymore. Change-Id: I5c2bf6ef35fac4255e140e184f4e48574b34174c Reviewed-on: https://go-review.googlesource.com/c/go/+/644176 TryBot-Bypass: Roland Shoemaker Reviewed-by: Michael Knyszek Auto-Submit: Roland Shoemaker --- diff --git a/src/crypto/tls/cache.go b/src/crypto/tls/cache.go index 807f522947..a2c255af88 100644 --- a/src/crypto/tls/cache.go +++ b/src/crypto/tls/cache.go @@ -8,78 +8,19 @@ import ( "crypto/x509" "runtime" "sync" - "sync/atomic" + "weak" ) -type cacheEntry struct { - refs atomic.Int64 - cert *x509.Certificate -} - -// certCache implements an intern table for reference counted x509.Certificates, -// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This -// allows for a single x509.Certificate to be kept in memory and referenced from -// multiple Conns. Returned references should not be mutated by callers. Certificates -// are still safe to use after they are removed from the cache. -// -// Certificates are returned wrapped in an activeCert struct that should be held by -// the caller. When references to the activeCert are freed, the number of references -// to the certificate in the cache is decremented. Once the number of references -// reaches zero, the entry is evicted from the cache. -// -// The main difference between this implementation and CRYPTO_BUFFER_POOL is that -// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data, -// rather than specific structures. Since we only care about x509.Certificates, -// certCache is implemented as a specific cache, rather than a generic one. -// -// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h -// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c -// for the BoringSSL reference. -type certCache struct { - sync.Map -} - -var globalCertCache = new(certCache) - -// activeCert is a handle to a certificate held in the cache. Once there are -// no alive activeCerts for a given certificate, the certificate is removed -// from the cache by a cleanup. -type activeCert struct { - cert *x509.Certificate -} +// weakCertCache provides a cache of *x509.Certificates, allowing multiple +// connections to reuse parsed certificates, instead of re-parsing the +// certificate for every connection, which is an expensive operation. +type weakCertCache struct{ sync.Map } -// active increments the number of references to the entry, wraps the -// certificate in the entry in an activeCert, and sets the cleanup. -// -// Note that there is a race between active and the cleanup set on the -// returned activeCert, triggered if active is called after the ref count is -// decremented such that refs may be > 0 when evict is called. We consider this -// safe, since the caller holding an activeCert for an entry that is no longer -// in the cache is fine, with the only side effect being the memory overhead of -// there being more than one distinct reference to a certificate alive at once. -func (cc *certCache) active(e *cacheEntry) *activeCert { - e.refs.Add(1) - a := &activeCert{e.cert} - runtime.AddCleanup(a, func(ce *cacheEntry) { - if ce.refs.Add(-1) == 0 { - cc.evict(ce) +func (wcc *weakCertCache) newCert(der []byte) (*x509.Certificate, error) { + if entry, ok := wcc.Load(string(der)); ok { + if v := entry.(weak.Pointer[x509.Certificate]).Value(); v != nil { + return v, nil } - }, e) - return a -} - -// evict removes a cacheEntry from the cache. -func (cc *certCache) evict(e *cacheEntry) { - cc.Delete(string(e.cert.Raw)) -} - -// newCert returns a x509.Certificate parsed from der. If there is already a copy -// of the certificate in the cache, a reference to the existing certificate will -// be returned. Otherwise, a fresh certificate will be added to the cache, and -// the reference returned. The returned reference should not be mutated. -func (cc *certCache) newCert(der []byte) (*activeCert, error) { - if entry, ok := cc.Load(string(der)); ok { - return cc.active(entry.(*cacheEntry)), nil } cert, err := x509.ParseCertificate(der) @@ -87,9 +28,17 @@ func (cc *certCache) newCert(der []byte) (*activeCert, error) { return nil, err } - entry := &cacheEntry{cert: cert} - if entry, loaded := cc.LoadOrStore(string(der), entry); loaded { - return cc.active(entry.(*cacheEntry)), nil + wp := weak.Make(cert) + if entry, loaded := wcc.LoadOrStore(string(der), wp); !loaded { + runtime.AddCleanup(cert, func(_ any) { wcc.CompareAndDelete(string(der), entry) }, any(string(der))) + } else if v := entry.(weak.Pointer[x509.Certificate]).Value(); v != nil { + return v, nil + } else { + if wcc.CompareAndSwap(string(der), entry, wp) { + runtime.AddCleanup(cert, func(_ any) { wcc.CompareAndDelete(string(der), wp) }, any(string(der))) + } } - return cc.active(entry), nil + return cert, nil } + +var globalCertCache = new(weakCertCache) diff --git a/src/crypto/tls/cache_test.go b/src/crypto/tls/cache_test.go index ea6b726d5e..75a0508ec0 100644 --- a/src/crypto/tls/cache_test.go +++ b/src/crypto/tls/cache_test.go @@ -1,45 +1,39 @@ -// Copyright 2022 The Go Authors. All rights reserved. +// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. - package tls import ( "encoding/pem" - "fmt" "runtime" "testing" "time" ) -func TestCertCache(t *testing.T) { - cc := certCache{} +func TestWeakCertCache(t *testing.T) { + wcc := weakCertCache{} p, _ := pem.Decode([]byte(rsaCertPEM)) if p == nil { t.Fatal("Failed to decode certificate") } - certA, err := cc.newCert(p.Bytes) + certA, err := wcc.newCert(p.Bytes) if err != nil { t.Fatalf("newCert failed: %s", err) } - certB, err := cc.newCert(p.Bytes) + certB, err := wcc.newCert(p.Bytes) if err != nil { t.Fatalf("newCert failed: %s", err) } - if certA.cert != certB.cert { + if certA != certB { t.Fatal("newCert returned a unique reference for a duplicate certificate") } - if entry, ok := cc.Load(string(p.Bytes)); !ok { + if _, ok := wcc.Load(string(p.Bytes)); !ok { t.Fatal("cache does not contain expected entry") - } else { - if refs := entry.(*cacheEntry).refs.Load(); refs != 2 { - t.Fatalf("unexpected number of references: got %d, want 2", refs) - } } - timeoutRefCheck := func(t *testing.T, key string, count int64) { + timeoutRefCheck := func(t *testing.T, key string, present bool) { t.Helper() timeout := time.After(4 * time.Second) for { @@ -47,14 +41,8 @@ func TestCertCache(t *testing.T) { case <-timeout: t.Fatal("timed out waiting for expected ref count") default: - e, ok := cc.Load(key) - if !ok && count != 0 { - t.Fatal("cache does not contain expected key") - } else if count == 0 && !ok { - return - } - - if e.(*cacheEntry).refs.Load() == count { + _, ok := wcc.Load(key) + if ok == present { return } } @@ -77,7 +65,7 @@ func TestCertCache(t *testing.T) { certA = nil runtime.GC() - timeoutRefCheck(t, string(p.Bytes), 1) + timeoutRefCheck(t, string(p.Bytes), true) // Keep certB alive until at least now, so that we can // purposefully nil it and force the finalizer to be @@ -86,41 +74,5 @@ func TestCertCache(t *testing.T) { certB = nil runtime.GC() - timeoutRefCheck(t, string(p.Bytes), 0) -} - -func BenchmarkCertCache(b *testing.B) { - p, _ := pem.Decode([]byte(rsaCertPEM)) - if p == nil { - b.Fatal("Failed to decode certificate") - } - - cc := certCache{} - b.ReportAllocs() - b.ResetTimer() - // We expect that calling newCert additional times after - // the initial call should not cause additional allocations. - for extra := 0; extra < 4; extra++ { - b.Run(fmt.Sprint(extra), func(b *testing.B) { - actives := make([]*activeCert, extra+1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - var err error - actives[0], err = cc.newCert(p.Bytes) - if err != nil { - b.Fatal(err) - } - for j := 0; j < extra; j++ { - actives[j+1], err = cc.newCert(p.Bytes) - if err != nil { - b.Fatal(err) - } - } - for j := 0; j < extra+1; j++ { - actives[j] = nil - } - runtime.GC() - } - }) - } + timeoutRefCheck(t, string(p.Bytes), false) } diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 1276665a2f..141175c801 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -54,9 +54,6 @@ type Conn struct { ocspResponse []byte // stapled OCSP response scts [][]byte // signed certificate timestamps from server peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []*activeCert // verifiedChains contains the certificate chains that we built, as // opposed to the ones presented by the server. verifiedChains [][]*x509.Certificate diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index bb5b7a042a..55790a11b6 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -956,7 +956,6 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { hs.masterSecret = hs.session.secret c.extMasterSecret = hs.session.extMasterSecret c.peerCertificates = hs.session.peerCertificates - c.activeCertHandles = hs.c.activeCertHandles c.verifiedChains = hs.session.verifiedChains c.ocspResponse = hs.session.ocspResponse // Let the ServerHello SCTs override the session SCTs from the original @@ -1107,7 +1106,6 @@ func checkKeySize(n int) (max int, ok bool) { // verifyServerCertificate parses and verifies the provided chain, setting // c.verifiedChains and c.peerCertificates or sending the appropriate alert. func (c *Conn) verifyServerCertificate(certificates [][]byte) error { - activeHandles := make([]*activeCert, len(certificates)) certs := make([]*x509.Certificate, len(certificates)) for i, asn1Data := range certificates { cert, err := globalCertCache.newCert(asn1Data) @@ -1115,15 +1113,14 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { c.sendAlert(alertDecodeError) return errors.New("tls: failed to parse certificate from server: " + err.Error()) } - if cert.cert.PublicKeyAlgorithm == x509.RSA { - n := cert.cert.PublicKey.(*rsa.PublicKey).N.BitLen() + if cert.PublicKeyAlgorithm == x509.RSA { + n := cert.PublicKey.(*rsa.PublicKey).N.BitLen() if max, ok := checkKeySize(n); !ok { c.sendAlert(alertBadCertificate) return fmt.Errorf("tls: server sent certificate containing RSA key larger than %d bits", max) } } - activeHandles[i] = cert - certs[i] = cert.cert + certs[i] = cert } echRejected := c.config.EncryptedClientHelloConfigList != nil && !c.echAccepted @@ -1188,7 +1185,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) } - c.activeCertHandles = activeHandles c.peerCertificates = certs if c.config.VerifyPeerCertificate != nil && !echRejected { diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 444c6f311c..461a0e6962 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -466,7 +466,6 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { hs.usingPSK = true c.didResume = true c.peerCertificates = hs.session.peerCertificates - c.activeCertHandles = hs.session.activeCertHandles c.verifiedChains = hs.session.verifiedChains c.ocspResponse = hs.session.ocspResponse c.scts = hs.session.scts diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go index aafb889b30..448bc31d3a 100644 --- a/src/crypto/tls/handshake_messages_test.go +++ b/src/crypto/tls/handshake_messages_test.go @@ -72,10 +72,6 @@ func TestMarshalUnmarshal(t *testing.T) { break } - if m, ok := m.(*SessionState); ok { - m.activeCertHandles = nil - } - if ch, ok := m.(*clientHelloMsg); ok { // extensions is special cased, as it is only populated by the // server-side of a handshake and is not expected to roundtrip diff --git a/src/crypto/tls/ticket.go b/src/crypto/tls/ticket.go index dbbcef7637..c56898c6f7 100644 --- a/src/crypto/tls/ticket.go +++ b/src/crypto/tls/ticket.go @@ -84,15 +84,14 @@ type SessionState struct { // createdAt is the generation time of the secret on the sever (which for // TLS 1.0–1.2 might be earlier than the current session) and the time at // which the ticket was received on the client. - createdAt uint64 // seconds since UNIX epoch - secret []byte // master secret for TLS 1.2, or the PSK for TLS 1.3 - extMasterSecret bool - peerCertificates []*x509.Certificate - activeCertHandles []*activeCert - ocspResponse []byte - scts [][]byte - verifiedChains [][]*x509.Certificate - alpnProtocol string // only set if EarlyData is true + createdAt uint64 // seconds since UNIX epoch + secret []byte // master secret for TLS 1.2, or the PSK for TLS 1.3 + extMasterSecret bool + peerCertificates []*x509.Certificate + ocspResponse []byte + scts [][]byte + verifiedChains [][]*x509.Certificate + alpnProtocol string // only set if EarlyData is true // Client-side TLS 1.3-only fields. useBy uint64 // seconds since UNIX epoch @@ -239,8 +238,7 @@ func ParseSessionState(data []byte) (*SessionState, error) { if err != nil { return nil, err } - ss.activeCertHandles = append(ss.activeCertHandles, c) - ss.peerCertificates = append(ss.peerCertificates, c.cert) + ss.peerCertificates = append(ss.peerCertificates, c) } if ss.isClient && len(ss.peerCertificates) == 0 { return nil, errors.New("tls: no server certificates in client session") @@ -270,8 +268,7 @@ func ParseSessionState(data []byte) (*SessionState, error) { if err != nil { return nil, err } - ss.activeCertHandles = append(ss.activeCertHandles, c) - chain = append(chain, c.cert) + chain = append(chain, c) } ss.verifiedChains = append(ss.verifiedChains, chain) } @@ -300,18 +297,17 @@ func ParseSessionState(data []byte) (*SessionState, error) { // from the current connection. func (c *Conn) sessionState() *SessionState { return &SessionState{ - version: c.vers, - cipherSuite: c.cipherSuite, - createdAt: uint64(c.config.time().Unix()), - alpnProtocol: c.clientProtocol, - peerCertificates: c.peerCertificates, - activeCertHandles: c.activeCertHandles, - ocspResponse: c.ocspResponse, - scts: c.scts, - isClient: c.isClient, - extMasterSecret: c.extMasterSecret, - verifiedChains: c.verifiedChains, - curveID: c.curveID, + version: c.vers, + cipherSuite: c.cipherSuite, + createdAt: uint64(c.config.time().Unix()), + alpnProtocol: c.clientProtocol, + peerCertificates: c.peerCertificates, + ocspResponse: c.ocspResponse, + scts: c.scts, + isClient: c.isClient, + extMasterSecret: c.extMasterSecret, + verifiedChains: c.verifiedChains, + curveID: c.curveID, } }