]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: replace custom intern cache with weak cache
authorRoland Shoemaker <roland@golang.org>
Sat, 25 Jan 2025 18:28:02 +0000 (10:28 -0800)
committerGopher Robot <gobot@golang.org>
Wed, 21 May 2025 17:31:41 +0000 (10:31 -0700)
Uses the new weak package to replace the existing custom intern cache
with a map of weak.Pointers instead. This simplifies the cache, and
means we don't need to store a slice of handles on the Conn anymore.

Change-Id: I5c2bf6ef35fac4255e140e184f4e48574b34174c
Reviewed-on: https://go-review.googlesource.com/c/go/+/644176
TryBot-Bypass: Roland Shoemaker <roland@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>

src/crypto/tls/cache.go
src/crypto/tls/cache_test.go
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/ticket.go

index 807f522947f3b53e847e4b3d58e2bb89e1eb3731..a2c255af88af2d82b467fadba645974732cd6798 100644 (file)
@@ -8,78 +8,19 @@ import (
        "crypto/x509"
        "runtime"
        "sync"
-       "sync/atomic"
+       "weak"
 )
 
-type cacheEntry struct {
-       refs atomic.Int64
-       cert *x509.Certificate
-}
-
-// certCache implements an intern table for reference counted x509.Certificates,
-// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This
-// allows for a single x509.Certificate to be kept in memory and referenced from
-// multiple Conns. Returned references should not be mutated by callers. Certificates
-// are still safe to use after they are removed from the cache.
-//
-// Certificates are returned wrapped in an activeCert struct that should be held by
-// the caller. When references to the activeCert are freed, the number of references
-// to the certificate in the cache is decremented. Once the number of references
-// reaches zero, the entry is evicted from the cache.
-//
-// The main difference between this implementation and CRYPTO_BUFFER_POOL is that
-// CRYPTO_BUFFER_POOL is a more  generic structure which supports blobs of data,
-// rather than specific structures. Since we only care about x509.Certificates,
-// certCache is implemented as a specific cache, rather than a generic one.
-//
-// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h
-// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c
-// for the BoringSSL reference.
-type certCache struct {
-       sync.Map
-}
-
-var globalCertCache = new(certCache)
-
-// activeCert is a handle to a certificate held in the cache. Once there are
-// no alive activeCerts for a given certificate, the certificate is removed
-// from the cache by a cleanup.
-type activeCert struct {
-       cert *x509.Certificate
-}
+// weakCertCache provides a cache of *x509.Certificates, allowing multiple
+// connections to reuse parsed certificates, instead of re-parsing the
+// certificate for every connection, which is an expensive operation.
+type weakCertCache struct{ sync.Map }
 
-// active increments the number of references to the entry, wraps the
-// certificate in the entry in an activeCert, and sets the cleanup.
-//
-// Note that there is a race between active and the cleanup set on the
-// returned activeCert, triggered if active is called after the ref count is
-// decremented such that refs may be > 0 when evict is called. We consider this
-// safe, since the caller holding an activeCert for an entry that is no longer
-// in the cache is fine, with the only side effect being the memory overhead of
-// there being more than one distinct reference to a certificate alive at once.
-func (cc *certCache) active(e *cacheEntry) *activeCert {
-       e.refs.Add(1)
-       a := &activeCert{e.cert}
-       runtime.AddCleanup(a, func(ce *cacheEntry) {
-               if ce.refs.Add(-1) == 0 {
-                       cc.evict(ce)
+func (wcc *weakCertCache) newCert(der []byte) (*x509.Certificate, error) {
+       if entry, ok := wcc.Load(string(der)); ok {
+               if v := entry.(weak.Pointer[x509.Certificate]).Value(); v != nil {
+                       return v, nil
                }
-       }, e)
-       return a
-}
-
-// evict removes a cacheEntry from the cache.
-func (cc *certCache) evict(e *cacheEntry) {
-       cc.Delete(string(e.cert.Raw))
-}
-
-// newCert returns a x509.Certificate parsed from der. If there is already a copy
-// of the certificate in the cache, a reference to the existing certificate will
-// be returned. Otherwise, a fresh certificate will be added to the cache, and
-// the reference returned. The returned reference should not be mutated.
-func (cc *certCache) newCert(der []byte) (*activeCert, error) {
-       if entry, ok := cc.Load(string(der)); ok {
-               return cc.active(entry.(*cacheEntry)), nil
        }
 
        cert, err := x509.ParseCertificate(der)
@@ -87,9 +28,17 @@ func (cc *certCache) newCert(der []byte) (*activeCert, error) {
                return nil, err
        }
 
-       entry := &cacheEntry{cert: cert}
-       if entry, loaded := cc.LoadOrStore(string(der), entry); loaded {
-               return cc.active(entry.(*cacheEntry)), nil
+       wp := weak.Make(cert)
+       if entry, loaded := wcc.LoadOrStore(string(der), wp); !loaded {
+               runtime.AddCleanup(cert, func(_ any) { wcc.CompareAndDelete(string(der), entry) }, any(string(der)))
+       } else if v := entry.(weak.Pointer[x509.Certificate]).Value(); v != nil {
+               return v, nil
+       } else {
+               if wcc.CompareAndSwap(string(der), entry, wp) {
+                       runtime.AddCleanup(cert, func(_ any) { wcc.CompareAndDelete(string(der), wp) }, any(string(der)))
+               }
        }
-       return cc.active(entry), nil
+       return cert, nil
 }
+
+var globalCertCache = new(weakCertCache)
index ea6b726d5e5da2ad87fc36d67c039908e0811410..75a0508ec0defa3f40939353d8d6795f3ac7ba0b 100644 (file)
@@ -1,45 +1,39 @@
-// Copyright 2022 The Go Authors. All rights reserved.
+// Copyright 2025 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
-
 package tls
 
 import (
        "encoding/pem"
-       "fmt"
        "runtime"
        "testing"
        "time"
 )
 
-func TestCertCache(t *testing.T) {
-       cc := certCache{}
+func TestWeakCertCache(t *testing.T) {
+       wcc := weakCertCache{}
        p, _ := pem.Decode([]byte(rsaCertPEM))
        if p == nil {
                t.Fatal("Failed to decode certificate")
        }
 
-       certA, err := cc.newCert(p.Bytes)
+       certA, err := wcc.newCert(p.Bytes)
        if err != nil {
                t.Fatalf("newCert failed: %s", err)
        }
-       certB, err := cc.newCert(p.Bytes)
+       certB, err := wcc.newCert(p.Bytes)
        if err != nil {
                t.Fatalf("newCert failed: %s", err)
        }
-       if certA.cert != certB.cert {
+       if certA != certB {
                t.Fatal("newCert returned a unique reference for a duplicate certificate")
        }
 
-       if entry, ok := cc.Load(string(p.Bytes)); !ok {
+       if _, ok := wcc.Load(string(p.Bytes)); !ok {
                t.Fatal("cache does not contain expected entry")
-       } else {
-               if refs := entry.(*cacheEntry).refs.Load(); refs != 2 {
-                       t.Fatalf("unexpected number of references: got %d, want 2", refs)
-               }
        }
 
-       timeoutRefCheck := func(t *testing.T, key string, count int64) {
+       timeoutRefCheck := func(t *testing.T, key string, present bool) {
                t.Helper()
                timeout := time.After(4 * time.Second)
                for {
@@ -47,14 +41,8 @@ func TestCertCache(t *testing.T) {
                        case <-timeout:
                                t.Fatal("timed out waiting for expected ref count")
                        default:
-                               e, ok := cc.Load(key)
-                               if !ok && count != 0 {
-                                       t.Fatal("cache does not contain expected key")
-                               } else if count == 0 && !ok {
-                                       return
-                               }
-
-                               if e.(*cacheEntry).refs.Load() == count {
+                               _, ok := wcc.Load(key)
+                               if ok == present {
                                        return
                                }
                        }
@@ -77,7 +65,7 @@ func TestCertCache(t *testing.T) {
        certA = nil
        runtime.GC()
 
-       timeoutRefCheck(t, string(p.Bytes), 1)
+       timeoutRefCheck(t, string(p.Bytes), true)
 
        // Keep certB alive until at least now, so that we can
        // purposefully nil it and force the finalizer to be
@@ -86,41 +74,5 @@ func TestCertCache(t *testing.T) {
        certB = nil
        runtime.GC()
 
-       timeoutRefCheck(t, string(p.Bytes), 0)
-}
-
-func BenchmarkCertCache(b *testing.B) {
-       p, _ := pem.Decode([]byte(rsaCertPEM))
-       if p == nil {
-               b.Fatal("Failed to decode certificate")
-       }
-
-       cc := certCache{}
-       b.ReportAllocs()
-       b.ResetTimer()
-       // We expect that calling newCert additional times after
-       // the initial call should not cause additional allocations.
-       for extra := 0; extra < 4; extra++ {
-               b.Run(fmt.Sprint(extra), func(b *testing.B) {
-                       actives := make([]*activeCert, extra+1)
-                       b.ResetTimer()
-                       for i := 0; i < b.N; i++ {
-                               var err error
-                               actives[0], err = cc.newCert(p.Bytes)
-                               if err != nil {
-                                       b.Fatal(err)
-                               }
-                               for j := 0; j < extra; j++ {
-                                       actives[j+1], err = cc.newCert(p.Bytes)
-                                       if err != nil {
-                                               b.Fatal(err)
-                                       }
-                               }
-                               for j := 0; j < extra+1; j++ {
-                                       actives[j] = nil
-                               }
-                               runtime.GC()
-                       }
-               })
-       }
+       timeoutRefCheck(t, string(p.Bytes), false)
 }
index 1276665a2feeb4bb01e02228231798116d83be52..141175c801ea7f3a7dc1f7ec5de7bdaa62acd6a0 100644 (file)
@@ -54,9 +54,6 @@ type Conn struct {
        ocspResponse     []byte   // stapled OCSP response
        scts             [][]byte // signed certificate timestamps from server
        peerCertificates []*x509.Certificate
-       // activeCertHandles contains the cache handles to certificates in
-       // peerCertificates that are used to track active references.
-       activeCertHandles []*activeCert
        // verifiedChains contains the certificate chains that we built, as
        // opposed to the ones presented by the server.
        verifiedChains [][]*x509.Certificate
index bb5b7a042a7399a9ae43450894afec54ddf0aab5..55790a11b6c8d1ea353c34c6c63659cdc75ac7fd 100644 (file)
@@ -956,7 +956,6 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
        hs.masterSecret = hs.session.secret
        c.extMasterSecret = hs.session.extMasterSecret
        c.peerCertificates = hs.session.peerCertificates
-       c.activeCertHandles = hs.c.activeCertHandles
        c.verifiedChains = hs.session.verifiedChains
        c.ocspResponse = hs.session.ocspResponse
        // Let the ServerHello SCTs override the session SCTs from the original
@@ -1107,7 +1106,6 @@ func checkKeySize(n int) (max int, ok bool) {
 // verifyServerCertificate parses and verifies the provided chain, setting
 // c.verifiedChains and c.peerCertificates or sending the appropriate alert.
 func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
-       activeHandles := make([]*activeCert, len(certificates))
        certs := make([]*x509.Certificate, len(certificates))
        for i, asn1Data := range certificates {
                cert, err := globalCertCache.newCert(asn1Data)
@@ -1115,15 +1113,14 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
                        c.sendAlert(alertDecodeError)
                        return errors.New("tls: failed to parse certificate from server: " + err.Error())
                }
-               if cert.cert.PublicKeyAlgorithm == x509.RSA {
-                       n := cert.cert.PublicKey.(*rsa.PublicKey).N.BitLen()
+               if cert.PublicKeyAlgorithm == x509.RSA {
+                       n := cert.PublicKey.(*rsa.PublicKey).N.BitLen()
                        if max, ok := checkKeySize(n); !ok {
                                c.sendAlert(alertBadCertificate)
                                return fmt.Errorf("tls: server sent certificate containing RSA key larger than %d bits", max)
                        }
                }
-               activeHandles[i] = cert
-               certs[i] = cert.cert
+               certs[i] = cert
        }
 
        echRejected := c.config.EncryptedClientHelloConfigList != nil && !c.echAccepted
@@ -1188,7 +1185,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
                return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
        }
 
-       c.activeCertHandles = activeHandles
        c.peerCertificates = certs
 
        if c.config.VerifyPeerCertificate != nil && !echRejected {
index 444c6f311ce1e3f78e5332442a069c8f7ba9f29a..461a0e69628008fff7a11fbdc07ad15176cad3ea 100644 (file)
@@ -466,7 +466,6 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
        hs.usingPSK = true
        c.didResume = true
        c.peerCertificates = hs.session.peerCertificates
-       c.activeCertHandles = hs.session.activeCertHandles
        c.verifiedChains = hs.session.verifiedChains
        c.ocspResponse = hs.session.ocspResponse
        c.scts = hs.session.scts
index aafb889b3017c27a7301a10e5b9bb8c90122e586..448bc31d3a00e31427b415c1ee7c355c4167ef1a 100644 (file)
@@ -72,10 +72,6 @@ func TestMarshalUnmarshal(t *testing.T) {
                                        break
                                }
 
-                               if m, ok := m.(*SessionState); ok {
-                                       m.activeCertHandles = nil
-                               }
-
                                if ch, ok := m.(*clientHelloMsg); ok {
                                        // extensions is special cased, as it is only populated by the
                                        // server-side of a handshake and is not expected to roundtrip
index dbbcef763710a085302c3c6c5d09c0cd1c2e5205..c56898c6f79813c957fb8497ea416f0f6400a34c 100644 (file)
@@ -84,15 +84,14 @@ type SessionState struct {
        // createdAt is the generation time of the secret on the sever (which for
        // TLS 1.0–1.2 might be earlier than the current session) and the time at
        // which the ticket was received on the client.
-       createdAt         uint64 // seconds since UNIX epoch
-       secret            []byte // master secret for TLS 1.2, or the PSK for TLS 1.3
-       extMasterSecret   bool
-       peerCertificates  []*x509.Certificate
-       activeCertHandles []*activeCert
-       ocspResponse      []byte
-       scts              [][]byte
-       verifiedChains    [][]*x509.Certificate
-       alpnProtocol      string // only set if EarlyData is true
+       createdAt        uint64 // seconds since UNIX epoch
+       secret           []byte // master secret for TLS 1.2, or the PSK for TLS 1.3
+       extMasterSecret  bool
+       peerCertificates []*x509.Certificate
+       ocspResponse     []byte
+       scts             [][]byte
+       verifiedChains   [][]*x509.Certificate
+       alpnProtocol     string // only set if EarlyData is true
 
        // Client-side TLS 1.3-only fields.
        useBy  uint64 // seconds since UNIX epoch
@@ -239,8 +238,7 @@ func ParseSessionState(data []byte) (*SessionState, error) {
                if err != nil {
                        return nil, err
                }
-               ss.activeCertHandles = append(ss.activeCertHandles, c)
-               ss.peerCertificates = append(ss.peerCertificates, c.cert)
+               ss.peerCertificates = append(ss.peerCertificates, c)
        }
        if ss.isClient && len(ss.peerCertificates) == 0 {
                return nil, errors.New("tls: no server certificates in client session")
@@ -270,8 +268,7 @@ func ParseSessionState(data []byte) (*SessionState, error) {
                        if err != nil {
                                return nil, err
                        }
-                       ss.activeCertHandles = append(ss.activeCertHandles, c)
-                       chain = append(chain, c.cert)
+                       chain = append(chain, c)
                }
                ss.verifiedChains = append(ss.verifiedChains, chain)
        }
@@ -300,18 +297,17 @@ func ParseSessionState(data []byte) (*SessionState, error) {
 // from the current connection.
 func (c *Conn) sessionState() *SessionState {
        return &SessionState{
-               version:           c.vers,
-               cipherSuite:       c.cipherSuite,
-               createdAt:         uint64(c.config.time().Unix()),
-               alpnProtocol:      c.clientProtocol,
-               peerCertificates:  c.peerCertificates,
-               activeCertHandles: c.activeCertHandles,
-               ocspResponse:      c.ocspResponse,
-               scts:              c.scts,
-               isClient:          c.isClient,
-               extMasterSecret:   c.extMasterSecret,
-               verifiedChains:    c.verifiedChains,
-               curveID:           c.curveID,
+               version:          c.vers,
+               cipherSuite:      c.cipherSuite,
+               createdAt:        uint64(c.config.time().Unix()),
+               alpnProtocol:     c.clientProtocol,
+               peerCertificates: c.peerCertificates,
+               ocspResponse:     c.ocspResponse,
+               scts:             c.scts,
+               isClient:         c.isClient,
+               extMasterSecret:  c.extMasterSecret,
+               verifiedChains:   c.verifiedChains,
+               curveID:          c.curveID,
        }
 }