From d65a41329ee87f46f35719129d1d4b03d4a07cc8 Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Wed, 13 Apr 2022 08:58:01 -0700 Subject: [PATCH] crypto/x509: support nil pools in CertPool.Equal Otherwise we panic if either pool is nil. Change-Id: I8598e3c0f3a5294135f1c330e319128d552ebb67 Reviewed-on: https://go-review.googlesource.com/c/go/+/399161 Reviewed-by: Damien Neil Run-TryBot: Roland Shoemaker Auto-Submit: Roland Shoemaker Reviewed-by: Roland Shoemaker TryBot-Result: Gopher Robot --- src/crypto/x509/cert_pool.go | 3 + src/crypto/x509/cert_pool_test.go | 124 +++++++++++++++++++++--------- 2 files changed, 90 insertions(+), 37 deletions(-) diff --git a/src/crypto/x509/cert_pool.go b/src/crypto/x509/cert_pool.go index ae43c84424..266d1ea04a 100644 --- a/src/crypto/x509/cert_pool.go +++ b/src/crypto/x509/cert_pool.go @@ -252,6 +252,9 @@ func (s *CertPool) Subjects() [][]byte { // Equal reports whether s and other are equal. func (s *CertPool) Equal(other *CertPool) bool { + if s == nil || other == nil { + return s == other + } if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) { return false } diff --git a/src/crypto/x509/cert_pool_test.go b/src/crypto/x509/cert_pool_test.go index d1ec9aaefd..a12beda83d 100644 --- a/src/crypto/x509/cert_pool_test.go +++ b/src/crypto/x509/cert_pool_test.go @@ -7,52 +7,102 @@ package x509 import "testing" func TestCertPoolEqual(t *testing.T) { - a, b := NewCertPool(), NewCertPool() - if !a.Equal(b) { - t.Error("two empty pools not equal") - } - tc := &Certificate{Raw: []byte{1, 2, 3}, RawSubject: []byte{2}} - a.AddCert(tc) - if a.Equal(b) { - t.Error("empty pool equals non-empty pool") - } - - b.AddCert(tc) - if !a.Equal(b) { - t.Error("two non-empty pools not equal") - } - otherTC := &Certificate{Raw: []byte{9, 8, 7}, RawSubject: []byte{8}} - a.AddCert(otherTC) - if a.Equal(b) { - t.Error("non-equal pools equal") - } - systemA, err := SystemCertPool() + emptyPool := NewCertPool() + nonSystemPopulated := NewCertPool() + nonSystemPopulated.AddCert(tc) + nonSystemPopulatedAlt := NewCertPool() + nonSystemPopulatedAlt.AddCert(otherTC) + emptySystem, err := SystemCertPool() if err != nil { - t.Fatalf("unable to load system cert pool: %s", err) + t.Fatal(err) } - systemB, err := SystemCertPool() + populatedSystem, err := SystemCertPool() if err != nil { - t.Fatalf("unable to load system cert pool: %s", err) - } - if !systemA.Equal(systemB) { - t.Error("two empty system pools not equal") + t.Fatal(err) } - - systemA.AddCert(tc) - if systemA.Equal(systemB) { - t.Error("empty system pool equals non-empty system pool") + populatedSystem.AddCert(tc) + populatedSystemAlt, err := SystemCertPool() + if err != nil { + t.Fatal(err) } - - systemB.AddCert(tc) - if !systemA.Equal(systemB) { - t.Error("two non-empty system pools not equal") + populatedSystemAlt.AddCert(otherTC) + tests := []struct { + name string + a *CertPool + b *CertPool + equal bool + }{ + { + name: "two empty pools", + a: emptyPool, + b: emptyPool, + equal: true, + }, + { + name: "one empty pool, one populated pool", + a: emptyPool, + b: nonSystemPopulated, + equal: false, + }, + { + name: "two populated pools", + a: nonSystemPopulated, + b: nonSystemPopulated, + equal: true, + }, + { + name: "two populated pools, different content", + a: nonSystemPopulated, + b: nonSystemPopulatedAlt, + equal: false, + }, + { + name: "two empty system pools", + a: emptySystem, + b: emptySystem, + equal: true, + }, + { + name: "one empty system pool, one populated system pool", + a: emptySystem, + b: populatedSystem, + equal: false, + }, + { + name: "two populated system pools", + a: populatedSystem, + b: populatedSystem, + equal: true, + }, + { + name: "two populated pools, different content", + a: populatedSystem, + b: populatedSystemAlt, + equal: false, + }, + { + name: "two nil pools", + a: nil, + b: nil, + equal: true, + }, + { + name: "one nil pool, one empty pool", + a: nil, + b: emptyPool, + equal: false, + }, } - systemA.AddCert(otherTC) - if systemA.Equal(systemB) { - t.Error("non-equal system pools equal") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + equal := tc.a.Equal(tc.b) + if equal != tc.equal { + t.Errorf("Unexpected Equal result: got %t, want %t", equal, tc.equal) + } + }) } } -- 2.50.0