]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/x509: add support for CertPool to load certs lazily
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 24 Apr 2020 15:04:16 +0000 (08:04 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 7 Nov 2020 16:59:40 +0000 (16:59 +0000)
This will allow building CertPools that consume less memory. (Most
certs are never accessed. Different users/programs access different
ones, but not many.)

This CL only adds the new internal mechanism (and uses it for the
old AddCert) but does not modify any existing root pool behavior.
(That is, the default Unix roots are still all slurped into memory as
of this CL)

Change-Id: Ib3a42e4050627b5e34413c595d8ced839c7bfa14
Reviewed-on: https://go-review.googlesource.com/c/go/+/229917
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
src/crypto/x509/cert_pool.go
src/crypto/x509/name_constraints_test.go
src/crypto/x509/root_cgo_darwin.go
src/crypto/x509/root_darwin_test.go
src/crypto/x509/root_unix.go
src/crypto/x509/root_unix_test.go
src/crypto/x509/root_windows.go
src/crypto/x509/verify.go
src/crypto/x509/x509_test.go

index 167390da9f638c958ac23a62fd9e3e38a04e53a7..2cfaeb2d9ef6117ee60c59ac2b298195c874e097 100644 (file)
@@ -6,35 +6,87 @@ package x509
 
 import (
        "bytes"
+       "crypto/sha256"
        "encoding/pem"
        "errors"
        "runtime"
 )
 
+type sum224 [sha256.Size224]byte
+
 // CertPool is a set of certificates.
 type CertPool struct {
-       byName map[string][]int
-       certs  []*Certificate
+       byName map[string][]int // cert.RawSubject => index into lazyCerts
+
+       // lazyCerts contains funcs that return a certificate,
+       // lazily parsing/decompressing it as needed.
+       lazyCerts []lazyCert
+
+       // haveSum maps from sum224(cert.Raw) to true. It's used only
+       // for AddCert duplicate detection, to avoid CertPool.contains
+       // calls in the AddCert path (because the contains method can
+       // call getCert and otherwise negate savings from lazy getCert
+       // funcs).
+       haveSum map[sum224]bool
+}
+
+// lazyCert is minimal metadata about a Cert and a func to retrieve it
+// in its normal expanded *Certificate form.
+type lazyCert struct {
+       // rawSubject is the Certificate.RawSubject value.
+       // It's the same as the CertPool.byName key, but in []byte
+       // form to make CertPool.Subjects (as used by crypto/tls) do
+       // fewer allocations.
+       rawSubject []byte
+
+       // getCert returns the certificate.
+       //
+       // It is not meant to do network operations or anything else
+       // where a failure is likely; the func is meant to lazily
+       // parse/decompress data that is already known to be good. The
+       // error in the signature primarily is meant for use in the
+       // case where a cert file existed on local disk when the program
+       // started up is deleted later before it's read.
+       getCert func() (*Certificate, error)
 }
 
 // NewCertPool returns a new, empty CertPool.
 func NewCertPool() *CertPool {
        return &CertPool{
-               byName: make(map[string][]int),
+               byName:  make(map[string][]int),
+               haveSum: make(map[sum224]bool),
        }
 }
 
+// len returns the number of certs in the set.
+// A nil set is a valid empty set.
+func (s *CertPool) len() int {
+       if s == nil {
+               return 0
+       }
+       return len(s.lazyCerts)
+}
+
+// cert returns cert index n in s.
+func (s *CertPool) cert(n int) (*Certificate, error) {
+       return s.lazyCerts[n].getCert()
+}
+
 func (s *CertPool) copy() *CertPool {
        p := &CertPool{
-               byName: make(map[string][]int, len(s.byName)),
-               certs:  make([]*Certificate, len(s.certs)),
+               byName:    make(map[string][]int, len(s.byName)),
+               lazyCerts: make([]lazyCert, len(s.lazyCerts)),
+               haveSum:   make(map[sum224]bool, len(s.haveSum)),
        }
        for k, v := range s.byName {
                indexes := make([]int, len(v))
                copy(indexes, v)
                p.byName[k] = indexes
        }
-       copy(p.certs, s.certs)
+       for k := range s.haveSum {
+               p.haveSum[k] = true
+       }
+       copy(p.lazyCerts, s.lazyCerts)
        return p
 }
 
@@ -64,7 +116,7 @@ func SystemCertPool() (*CertPool, error) {
 
 // findPotentialParents returns the indexes of certificates in s which might
 // have signed cert.
-func (s *CertPool) findPotentialParents(cert *Certificate) []int {
+func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate {
        if s == nil {
                return nil
        }
@@ -75,18 +127,21 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int {
        //   AKID and SKID match
        //   AKID present, SKID missing / AKID missing, SKID present
        //   AKID and SKID don't match
-       var matchingKeyID, oneKeyID, mismatchKeyID []int
+       var matchingKeyID, oneKeyID, mismatchKeyID []*Certificate
        for _, c := range s.byName[string(cert.RawIssuer)] {
-               candidate := s.certs[c]
+               candidate, err := s.cert(c)
+               if err != nil {
+                       continue
+               }
                kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId)
                switch {
                case kidMatch:
-                       matchingKeyID = append(matchingKeyID, c)
+                       matchingKeyID = append(matchingKeyID, candidate)
                case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) ||
                        (len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0):
-                       oneKeyID = append(oneKeyID, c)
+                       oneKeyID = append(oneKeyID, candidate)
                default:
-                       mismatchKeyID = append(mismatchKeyID, c)
+                       mismatchKeyID = append(mismatchKeyID, candidate)
                }
        }
 
@@ -94,11 +149,10 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int {
        if found == 0 {
                return nil
        }
-       candidates := make([]int, 0, found)
+       candidates := make([]*Certificate, 0, found)
        candidates = append(candidates, matchingKeyID...)
        candidates = append(candidates, oneKeyID...)
        candidates = append(candidates, mismatchKeyID...)
-
        return candidates
 }
 
@@ -106,10 +160,13 @@ func (s *CertPool) contains(cert *Certificate) bool {
        if s == nil {
                return false
        }
-
        candidates := s.byName[string(cert.RawSubject)]
-       for _, c := range candidates {
-               if s.certs[c].Equal(cert) {
+       for _, i := range candidates {
+               c, err := s.cert(i)
+               if err != nil {
+                       return false
+               }
+               if c.Equal(cert) {
                        return true
                }
        }
@@ -122,17 +179,32 @@ func (s *CertPool) AddCert(cert *Certificate) {
        if cert == nil {
                panic("adding nil Certificate to CertPool")
        }
+       s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
+               return cert, nil
+       })
+}
+
+// addCertFunc adds metadata about a certificate to a pool, along with
+// a func to fetch that certificate later when needed.
+//
+// The rawSubject is Certificate.RawSubject and must be non-empty.
+// The getCert func may be called 0 or more times.
+func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error)) {
+       if getCert == nil {
+               panic("getCert can't be nil")
+       }
 
        // Check that the certificate isn't being added twice.
-       if s.contains(cert) {
+       if s.haveSum[rawSum224] {
                return
        }
 
-       n := len(s.certs)
-       s.certs = append(s.certs, cert)
-
-       name := string(cert.RawSubject)
-       s.byName[name] = append(s.byName[name], n)
+       s.haveSum[rawSum224] = true
+       s.lazyCerts = append(s.lazyCerts, lazyCert{
+               rawSubject: []byte(rawSubject),
+               getCert:    getCert,
+       })
+       s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1)
 }
 
 // AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
@@ -167,9 +239,9 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
 // Subjects returns a list of the DER-encoded subjects of
 // all of the certificates in the pool.
 func (s *CertPool) Subjects() [][]byte {
-       res := make([][]byte, len(s.certs))
-       for i, c := range s.certs {
-               res[i] = c.RawSubject
+       res := make([][]byte, s.len())
+       for i, lc := range s.lazyCerts {
+               res[i] = lc.rawSubject
        }
        return res
 }
index 5469e28de24aa18493fc9800dea85e4ea5ad5e7a..34055d07b538e65b0f5d0192ef1f2a2c5b3abc53 100644 (file)
@@ -1941,7 +1941,7 @@ func TestConstraintCases(t *testing.T) {
                // Skip tests with CommonName set because OpenSSL will try to match it
                // against name constraints, while we ignore it when it's not hostname-looking.
                if !test.noOpenSSL && testNameConstraintsAgainstOpenSSL && test.leaf.cn == "" {
-                       output, err := testChainAgainstOpenSSL(leafCert, intermediatePool, rootPool)
+                       output, err := testChainAgainstOpenSSL(t, leafCert, intermediatePool, rootPool)
                        if err == nil && len(test.expectedError) > 0 {
                                t.Errorf("#%d: unexpectedly succeeded against OpenSSL", i)
                                if debugOpenSSLFailure {
@@ -1993,7 +1993,7 @@ func TestConstraintCases(t *testing.T) {
                                pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
                                return buf.String()
                        }
-                       t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.certs[0]))
+                       t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.mustCert(t, 0)))
                        t.Errorf("#%d: leaf:\n%s", i, certAsPEM(leafCert))
                }
 
@@ -2019,10 +2019,10 @@ func writePEMsToTempFile(certs []*Certificate) *os.File {
        return file
 }
 
-func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) (string, error) {
+func testChainAgainstOpenSSL(t *testing.T, leaf *Certificate, intermediates, roots *CertPool) (string, error) {
        args := []string{"verify", "-no_check_time"}
 
-       rootsFile := writePEMsToTempFile(roots.certs)
+       rootsFile := writePEMsToTempFile(allCerts(t, roots))
        if debugOpenSSLFailure {
                println("roots file:", rootsFile.Name())
        } else {
@@ -2030,8 +2030,8 @@ func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool)
        }
        args = append(args, "-CAfile", rootsFile.Name())
 
-       if len(intermediates.certs) > 0 {
-               intermediatesFile := writePEMsToTempFile(intermediates.certs)
+       if intermediates.len() > 0 {
+               intermediatesFile := writePEMsToTempFile(allCerts(t, intermediates))
                if debugOpenSSLFailure {
                        println("intermediates file:", intermediatesFile.Name())
                } else {
index 15c72cc0c837bf9a3a36f0ea757c11d301b0bb39..825e8d4812b9cf1f0f0ea182e9b6d3619c7afd54 100644 (file)
@@ -313,7 +313,11 @@ func _loadSystemRootsWithCgo() (*CertPool, error) {
        untrustedRoots.AppendCertsFromPEM(buf)
 
        trustedRoots := NewCertPool()
-       for _, c := range roots.certs {
+       for _, lc := range roots.lazyCerts {
+               c, err := lc.getCert()
+               if err != nil {
+                       return nil, err
+               }
                if !untrustedRoots.contains(c) {
                        trustedRoots.AddCert(c)
                }
index 2c773b91203bd055c312b28cb9ed8684f7be8cbe..69f181c2d482198dbc6ccb95878e31b1e7d64e98 100644 (file)
@@ -24,7 +24,7 @@ func TestSystemRoots(t *testing.T) {
 
        // There are 174 system roots on Catalina, and 163 on iOS right now, require
        // at least 100 to make sure this is not completely broken.
-       if want, have := 100, len(sysRoots.certs); have < want {
+       if want, have := 100, sysRoots.len(); have < want {
                t.Errorf("want at least %d system roots, have %d", want, have)
        }
 
@@ -43,11 +43,14 @@ func TestSystemRoots(t *testing.T) {
        t.Logf("loadSystemRootsWithCgo: %v", cgoSysRootsDuration)
 
        // Check that the two cert pools are the same.
-       sysPool := make(map[string]*Certificate, len(sysRoots.certs))
-       for _, c := range sysRoots.certs {
+       sysPool := make(map[string]*Certificate, sysRoots.len())
+       for i := 0; i < sysRoots.len(); i++ {
+               c := sysRoots.mustCert(t, i)
                sysPool[string(c.Raw)] = c
        }
-       for _, c := range cgoRoots.certs {
+       for i := 0; i < cgoRoots.len(); i++ {
+               c := cgoRoots.mustCert(t, i)
+
                if _, ok := sysPool[string(c.Raw)]; ok {
                        delete(sysPool, string(c.Raw))
                } else {
index ae72f025c30f6bed1fe0ac8033cb2cc48be35bb5..1090b69f839bebefdf399eaab22bd5e3b1c01f5e 100644 (file)
@@ -75,7 +75,7 @@ func loadSystemRoots() (*CertPool, error) {
                }
        }
 
-       if len(roots.certs) > 0 || firstErr == nil {
+       if roots.len() > 0 || firstErr == nil {
                return roots, nil
        }
 
index 5a8015429c0091e80a5c2bf10715319aa1686afe..b2e832ff3685ff979af99513a3e61826c896b96f 100644 (file)
@@ -113,15 +113,15 @@ func TestEnvVars(t *testing.T) {
 
                        // Verify that the returned certs match, otherwise report where the mismatch is.
                        for i, cn := range tc.cns {
-                               if i >= len(r.certs) {
+                               if i >= r.len() {
                                        t.Errorf("missing cert %v @ %v", cn, i)
-                               } else if r.certs[i].Subject.CommonName != cn {
-                                       fmt.Printf("%#v\n", r.certs[0].Subject)
-                                       t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn)
+                               } else if r.mustCert(t, i).Subject.CommonName != cn {
+                                       fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
+                                       t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
                                }
                        }
-                       if len(r.certs) > len(tc.cns) {
-                               t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns))
+                       if r.len() > len(tc.cns) {
+                               t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
                        }
                })
        }
@@ -197,7 +197,8 @@ func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
        strCertPool := func(p *CertPool) string {
                return string(bytes.Join(p.Subjects(), []byte("\n")))
        }
-       if !reflect.DeepEqual(gotPool, wantPool) {
+
+       if !certPoolEqual(gotPool, wantPool) {
                g, w := strCertPool(gotPool), strCertPool(wantPool)
                t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
        }
index 1e0f3acb6700e66d9167a2e056983a4335e23abf..22e5a9382bef9087076d79b96e7c962c9bedd1b4 100644 (file)
@@ -38,7 +38,11 @@ func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertCo
        }
 
        if opts.Intermediates != nil {
-               for _, intermediate := range opts.Intermediates.certs {
+               for i := 0; i < opts.Intermediates.len(); i++ {
+                       intermediate, err := opts.Intermediates.cert(i)
+                       if err != nil {
+                               return nil, err
+                       }
                        ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw)))
                        if err != nil {
                                return nil, err
index 5fdd4cb9fe1bbbecf86d43dd42076060ed20e818..46afb2698a919ddb64131ee9d5df0958f2063309 100644 (file)
@@ -761,11 +761,13 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
        if len(c.Raw) == 0 {
                return nil, errNotParsed
        }
-       if opts.Intermediates != nil {
-               for _, intermediate := range opts.Intermediates.certs {
-                       if len(intermediate.Raw) == 0 {
-                               return nil, errNotParsed
-                       }
+       for i := 0; i < opts.Intermediates.len(); i++ {
+               c, err := opts.Intermediates.cert(i)
+               if err != nil {
+                       return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err)
+               }
+               if len(c.Raw) == 0 {
+                       return nil, errNotParsed
                }
        }
 
@@ -891,11 +893,11 @@ func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, curre
                }
        }
 
-       for _, rootNum := range opts.Roots.findPotentialParents(c) {
-               considerCandidate(rootCertificate, opts.Roots.certs[rootNum])
+       for _, root := range opts.Roots.findPotentialParents(c) {
+               considerCandidate(rootCertificate, root)
        }
-       for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) {
-               considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum])
+       for _, intermediate := range opts.Intermediates.findPotentialParents(c) {
+               considerCandidate(intermediateCertificate, intermediate)
        }
 
        if len(chains) > 0 {
index 47d78cf02afa5561dd740edcc6a1ab3b4e3294cc..1ba31aeff3225f974cd7603be1bc000cb8cbbd8c 100644 (file)
@@ -1960,7 +1960,7 @@ func TestSystemCertPool(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       if !reflect.DeepEqual(a, b) {
+       if !certPoolEqual(a, b) {
                t.Fatal("two calls to SystemCertPool had different results")
        }
        if ok := b.AppendCertsFromPEM([]byte(`
@@ -2912,3 +2912,53 @@ func TestCreateCertificateMD5(t *testing.T) {
                t.Fatalf("CreateCertificate failed when SignatureAlgorithm = MD5WithRSA: %s", err)
        }
 }
+
+func (s *CertPool) mustCert(t *testing.T, n int) *Certificate {
+       c, err := s.lazyCerts[n].getCert()
+       if err != nil {
+               t.Fatalf("failed to load cert %d: %v", n, err)
+       }
+       return c
+}
+
+func allCerts(t *testing.T, p *CertPool) []*Certificate {
+       all := make([]*Certificate, p.len())
+       for i := range all {
+               all[i] = p.mustCert(t, i)
+       }
+       return all
+}
+
+// certPoolEqual reports whether a and b are equal, except for the
+// function pointers.
+func certPoolEqual(a, b *CertPool) bool {
+       if (a != nil) != (b != nil) {
+               return false
+       }
+       if a == nil {
+               return true
+       }
+       if !reflect.DeepEqual(a.byName, b.byName) ||
+               len(a.lazyCerts) != len(b.lazyCerts) {
+               return false
+       }
+       for i := range a.lazyCerts {
+               la, lb := a.lazyCerts[i], b.lazyCerts[i]
+               if !bytes.Equal(la.rawSubject, lb.rawSubject) {
+                       return false
+               }
+               ca, err := la.getCert()
+               if err != nil {
+                       panic(err)
+               }
+               cb, err := la.getCert()
+               if err != nil {
+                       panic(err)
+               }
+               if !ca.Equal(cb) {
+                       return false
+               }
+       }
+
+       return true
+}