]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/x509: support nil pools in CertPool.Equal
authorRoland Shoemaker <roland@golang.org>
Wed, 13 Apr 2022 15:58:01 +0000 (08:58 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 13 Apr 2022 18:06:43 +0000 (18:06 +0000)
Otherwise we panic if either pool is nil.

Change-Id: I8598e3c0f3a5294135f1c330e319128d552ebb67
Reviewed-on: https://go-review.googlesource.com/c/go/+/399161
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Roland Shoemaker <roland@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/crypto/x509/cert_pool.go
src/crypto/x509/cert_pool_test.go

index ae43c84424bcfd77a2b546e1f71bac586e2256aa..266d1ea04a1478736f881b93bf4d90e30c71223c 100644 (file)
@@ -252,6 +252,9 @@ func (s *CertPool) Subjects() [][]byte {
 
 // Equal reports whether s and other are equal.
 func (s *CertPool) Equal(other *CertPool) bool {
+       if s == nil || other == nil {
+               return s == other
+       }
        if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) {
                return false
        }
index d1ec9aaefd8c6cb4251438d7140445034674e8fc..a12beda83d353984abdc6e9a45a3206f0d2d212a 100644 (file)
@@ -7,52 +7,102 @@ package x509
 import "testing"
 
 func TestCertPoolEqual(t *testing.T) {
-       a, b := NewCertPool(), NewCertPool()
-       if !a.Equal(b) {
-               t.Error("two empty pools not equal")
-       }
-
        tc := &Certificate{Raw: []byte{1, 2, 3}, RawSubject: []byte{2}}
-       a.AddCert(tc)
-       if a.Equal(b) {
-               t.Error("empty pool equals non-empty pool")
-       }
-
-       b.AddCert(tc)
-       if !a.Equal(b) {
-               t.Error("two non-empty pools not equal")
-       }
-
        otherTC := &Certificate{Raw: []byte{9, 8, 7}, RawSubject: []byte{8}}
-       a.AddCert(otherTC)
-       if a.Equal(b) {
-               t.Error("non-equal pools equal")
-       }
 
-       systemA, err := SystemCertPool()
+       emptyPool := NewCertPool()
+       nonSystemPopulated := NewCertPool()
+       nonSystemPopulated.AddCert(tc)
+       nonSystemPopulatedAlt := NewCertPool()
+       nonSystemPopulatedAlt.AddCert(otherTC)
+       emptySystem, err := SystemCertPool()
        if err != nil {
-               t.Fatalf("unable to load system cert pool: %s", err)
+               t.Fatal(err)
        }
-       systemB, err := SystemCertPool()
+       populatedSystem, err := SystemCertPool()
        if err != nil {
-               t.Fatalf("unable to load system cert pool: %s", err)
-       }
-       if !systemA.Equal(systemB) {
-               t.Error("two empty system pools not equal")
+               t.Fatal(err)
        }
-
-       systemA.AddCert(tc)
-       if systemA.Equal(systemB) {
-               t.Error("empty system pool equals non-empty system pool")
+       populatedSystem.AddCert(tc)
+       populatedSystemAlt, err := SystemCertPool()
+       if err != nil {
+               t.Fatal(err)
        }
-
-       systemB.AddCert(tc)
-       if !systemA.Equal(systemB) {
-               t.Error("two non-empty system pools not equal")
+       populatedSystemAlt.AddCert(otherTC)
+       tests := []struct {
+               name  string
+               a     *CertPool
+               b     *CertPool
+               equal bool
+       }{
+               {
+                       name:  "two empty pools",
+                       a:     emptyPool,
+                       b:     emptyPool,
+                       equal: true,
+               },
+               {
+                       name:  "one empty pool, one populated pool",
+                       a:     emptyPool,
+                       b:     nonSystemPopulated,
+                       equal: false,
+               },
+               {
+                       name:  "two populated pools",
+                       a:     nonSystemPopulated,
+                       b:     nonSystemPopulated,
+                       equal: true,
+               },
+               {
+                       name:  "two populated pools, different content",
+                       a:     nonSystemPopulated,
+                       b:     nonSystemPopulatedAlt,
+                       equal: false,
+               },
+               {
+                       name:  "two empty system pools",
+                       a:     emptySystem,
+                       b:     emptySystem,
+                       equal: true,
+               },
+               {
+                       name:  "one empty system pool, one populated system pool",
+                       a:     emptySystem,
+                       b:     populatedSystem,
+                       equal: false,
+               },
+               {
+                       name:  "two populated system pools",
+                       a:     populatedSystem,
+                       b:     populatedSystem,
+                       equal: true,
+               },
+               {
+                       name:  "two populated pools, different content",
+                       a:     populatedSystem,
+                       b:     populatedSystemAlt,
+                       equal: false,
+               },
+               {
+                       name:  "two nil pools",
+                       a:     nil,
+                       b:     nil,
+                       equal: true,
+               },
+               {
+                       name:  "one nil pool, one empty pool",
+                       a:     nil,
+                       b:     emptyPool,
+                       equal: false,
+               },
        }
 
-       systemA.AddCert(otherTC)
-       if systemA.Equal(systemB) {
-               t.Error("non-equal system pools equal")
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       equal := tc.a.Equal(tc.b)
+                       if equal != tc.equal {
+                               t.Errorf("Unexpected Equal result: got %t, want %t", equal, tc.equal)
+                       }
+               })
        }
 }