]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/x509: return additional chains from Verify on Windows
authorKoen <me@koenspanjer.com>
Tue, 13 Oct 2020 23:05:43 +0000 (01:05 +0200)
committerFilippo Valsorda <filippo@golang.org>
Mon, 9 Nov 2020 18:32:41 +0000 (18:32 +0000)
Previously windows only returned the certificate-chain with the highest quality.
This change makes it so chains with a potentially lower quality
originating from other root certificates are also returned by verify.

Tests in verify_test flagged with systemLax are now allowed to pass if the system returns additional chains

Fixes #40604

Change-Id: I66edc233219f581039d47a15f2200ff627154691
Reviewed-on: https://go-review.googlesource.com/c/go/+/257257
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Trust: Tobias Klauser <tobias.klauser@gmail.com>
Run-TryBot: Tobias Klauser <tobias.klauser@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>

src/crypto/x509/root_windows.go
src/crypto/x509/verify_test.go

index 22e5a9382bef9087076d79b96e7c962c9bedd1b4..1e9be80b7dd67d35c73cfb01ec743a23c879da3a 100644 (file)
@@ -155,6 +155,44 @@ func init() {
        }
 }
 
+func verifyChain(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) (chain []*Certificate, err error) {
+       err = checkChainTrustStatus(c, chainCtx)
+       if err != nil {
+               return nil, err
+       }
+
+       if opts != nil && len(opts.DNSName) > 0 {
+               err = checkChainSSLServerPolicy(c, chainCtx, opts)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       chain, err = extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
+       if err != nil {
+               return nil, err
+       }
+       if len(chain) == 0 {
+               return nil, errors.New("x509: internal error: system verifier returned an empty chain")
+       }
+
+       // Mitigate CVE-2020-0601, where the Windows system verifier might be
+       // tricked into using custom curve parameters for a trusted root, by
+       // double-checking all ECDSA signatures. If the system was tricked into
+       // using spoofed parameters, the signature will be invalid for the correct
+       // ones we parsed. (We don't support custom curves ourselves.)
+       for i, parent := range chain[1:] {
+               if parent.PublicKeyAlgorithm != ECDSA {
+                       continue
+               }
+               if err := parent.CheckSignature(chain[i].SignatureAlgorithm,
+                       chain[i].RawTBSCertificate, chain[i].Signature); err != nil {
+                       return nil, err
+               }
+       }
+       return chain, nil
+}
+
 // systemVerify is like Verify, except that it uses CryptoAPI calls
 // to build certificate chains and verify them.
 func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
@@ -202,67 +240,41 @@ func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate
                verifyTime = &ft
        }
 
-       // CertGetCertificateChain will traverse Windows's root stores
-       // in an attempt to build a verified certificate chain. Once
-       // it has found a verified chain, it stops. MSDN docs on
-       // CERT_CHAIN_CONTEXT:
-       //
-       //   When a CERT_CHAIN_CONTEXT is built, the first simple chain
-       //   begins with an end certificate and ends with a self-signed
-       //   certificate. If that self-signed certificate is not a root
-       //   or otherwise trusted certificate, an attempt is made to
-       //   build a new chain. CTLs are used to create the new chain
-       //   beginning with the self-signed certificate from the original
-       //   chain as the end certificate of the new chain. This process
-       //   continues building additional simple chains until the first
-       //   self-signed certificate is a trusted certificate or until
-       //   an additional simple chain cannot be built.
-       //
-       // The result is that we'll only get a single trusted chain to
-       // return to our caller.
-       var chainCtx *syscall.CertChainContext
-       err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, 0, 0, &chainCtx)
-       if err != nil {
-               return nil, err
-       }
-       defer syscall.CertFreeCertificateChain(chainCtx)
+       // The default is to return only the highest quality chain,
+       // setting this flag will add additional lower quality contexts.
+       // These are returned in the LowerQualityChains field.
+       const CERT_CHAIN_RETURN_LOWER_QUALITY_CONTEXTS = 0x00000080
 
-       err = checkChainTrustStatus(c, chainCtx)
+       // CertGetCertificateChain will traverse Windows's root stores in an attempt to build a verified certificate chain
+       var topCtx *syscall.CertChainContext
+       err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, CERT_CHAIN_RETURN_LOWER_QUALITY_CONTEXTS, 0, &topCtx)
        if err != nil {
                return nil, err
        }
+       defer syscall.CertFreeCertificateChain(topCtx)
 
-       if opts != nil && len(opts.DNSName) > 0 {
-               err = checkChainSSLServerPolicy(c, chainCtx, opts)
-               if err != nil {
-                       return nil, err
-               }
+       chain, topErr := verifyChain(c, topCtx, opts)
+       if topErr == nil {
+               chains = append(chains, chain)
        }
 
-       chain, err := extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
-       if err != nil {
-               return nil, err
-       }
-       if len(chain) < 1 {
-               return nil, errors.New("x509: internal error: system verifier returned an empty chain")
-       }
+       if lqCtxCount := topCtx.LowerQualityChainCount; lqCtxCount > 0 {
+               lqCtxs := (*[1 << 20]*syscall.CertChainContext)(unsafe.Pointer(topCtx.LowerQualityChains))[:lqCtxCount:lqCtxCount]
 
-       // Mitigate CVE-2020-0601, where the Windows system verifier might be
-       // tricked into using custom curve parameters for a trusted root, by
-       // double-checking all ECDSA signatures. If the system was tricked into
-       // using spoofed parameters, the signature will be invalid for the correct
-       // ones we parsed. (We don't support custom curves ourselves.)
-       for i, parent := range chain[1:] {
-               if parent.PublicKeyAlgorithm != ECDSA {
-                       continue
-               }
-               if err := parent.CheckSignature(chain[i].SignatureAlgorithm,
-                       chain[i].RawTBSCertificate, chain[i].Signature); err != nil {
-                       return nil, err
+               for _, ctx := range lqCtxs {
+                       chain, err := verifyChain(c, ctx, opts)
+                       if err == nil {
+                               chains = append(chains, chain)
+                       }
                }
        }
 
-       return [][]*Certificate{chain}, nil
+       if len(chains) == 0 {
+               // Return the error from the highest quality context.
+               return nil, topErr
+       }
+
+       return chains, nil
 }
 
 func loadSystemRoots() (*CertPool, error) {
index 9cc17c7b3d89b6832fbdfe8d50f3ce54e055d84a..8e0a7bef4755026d8188bfdffc0dad5702e40a3e 100644 (file)
@@ -550,34 +550,55 @@ func testVerify(t *testing.T, test verifyTest, useSystemRoots bool) {
                }
        }
 
-       if len(chains) != len(test.expectedChains) {
-               t.Errorf("wanted %d chains, got %d", len(test.expectedChains), len(chains))
+       doesMatch := func(expectedChain []string, chain []*Certificate) bool {
+               if len(chain) != len(expectedChain) {
+                       return false
+               }
+
+               for k, cert := range chain {
+                       if !strings.Contains(nameToKey(&cert.Subject), expectedChain[k]) {
+                               return false
+                       }
+               }
+               return true
        }
 
-       // We check that each returned chain matches a chain from
-       // expectedChains but an entry in expectedChains can't match
-       // two chains.
-       seenChains := make([]bool, len(chains))
-NextOutputChain:
-       for _, chain := range chains {
-       TryNextExpected:
-               for j, expectedChain := range test.expectedChains {
-                       if seenChains[j] {
-                               continue
+       // Every expected chain should match 1 returned chain
+       for _, expectedChain := range test.expectedChains {
+               nChainMatched := 0
+               for _, chain := range chains {
+                       if doesMatch(expectedChain, chain) {
+                               nChainMatched++
+                       }
+               }
+
+               if nChainMatched != 1 {
+                       t.Errorf("Got %v matches instead of %v for expected chain %v", nChainMatched, 1, expectedChain)
+                       for _, chain := range chains {
+                               if doesMatch(expectedChain, chain) {
+                                       t.Errorf("\t matched %v", chainToDebugString(chain))
+                               }
                        }
-                       if len(chain) != len(expectedChain) {
-                               continue
+               }
+       }
+
+       // Every returned chain should match 1 expected chain (or <2 if testing against the system)
+       for _, chain := range chains {
+               nMatched := 0
+               for _, expectedChain := range test.expectedChains {
+                       if doesMatch(expectedChain, chain) {
+                               nMatched++
                        }
-                       for k, cert := range chain {
-                               if !strings.Contains(nameToKey(&cert.Subject), expectedChain[k]) {
-                                       continue TryNextExpected
+               }
+               // Allow additional unknown chains if systemLax is set
+               if nMatched == 0 && test.systemLax == false || nMatched > 1 {
+                       t.Errorf("Got %v matches for chain %v", nMatched, chainToDebugString(chain))
+                       for _, expectedChain := range test.expectedChains {
+                               if doesMatch(expectedChain, chain) {
+                                       t.Errorf("\t matched %v", expectedChain)
                                }
                        }
-                       // we matched
-                       seenChains[j] = true
-                       continue NextOutputChain
                }
-               t.Errorf("no expected chain matched %s", chainToDebugString(chain))
        }
 }