]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rsa: move precomputation to crypto/internal/fips140/rsa
authorFilippo Valsorda <filippo@golang.org>
Fri, 29 Nov 2024 12:15:11 +0000 (13:15 +0100)
committerGopher Robot <gobot@golang.org>
Sat, 30 Nov 2024 01:49:31 +0000 (01:49 +0000)
We are severely limited by the crypto/rsa API in a few ways:

 - Precompute doesn't return an error, but is the only function allowed
   to modify a PrivateKey.

 - Clients presumably expect the PrecomputedValues big.Ints to be
   populated after Precompute.

 - MarshalPKCS1PrivateKey requires the precomputed values, and doesn't
   have an error return.

 - PrivateKeys with only N, e, and D have worked so far, so they might
   have to keep working.

To move precomputation to the FIPS module, we focus on the happy path of
a PrivateKey with two primes where Precompute is called before anything
else, which match ParsePKCS1PrivateKey and GenerateKey.

There is a significant slowdown in the Parse benchmark due to the
constant-time inversion of qInv. This will be addressed in a follow-up
CL that will use (and check) the value in the ASN.1.

Note that the prime product check now moved to checkPrivateKey is broken
(Π should start at 1 not 0) and fixed in CL 632478.

Updates #69799
For #69536

Change-Id: I95a8bc1244755c6d15d7c4eb179135a15608ddd6
Reviewed-on: https://go-review.googlesource.com/c/go/+/632476
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
doc/next/6-stdlib/99-minor/crypto/x509/69799.md [new file with mode: 0644]
src/crypto/internal/fips140/bigmod/nat.go
src/crypto/internal/fips140/rsa/rsa.go
src/crypto/rsa/rsa.go
src/crypto/x509/pkcs1.go
src/crypto/x509/pkcs8.go
src/crypto/x509/x509_test.go

diff --git a/doc/next/6-stdlib/99-minor/crypto/x509/69799.md b/doc/next/6-stdlib/99-minor/crypto/x509/69799.md
new file mode 100644 (file)
index 0000000..6eb000b
--- /dev/null
@@ -0,0 +1,3 @@
+[MarshalPKCS8PrivateKey] now returns an error instead of marshaling an invalid
+RSA key. ([MarshalPKCS1PrivateKey] doesn't have an error return, and its behavior
+when provided invalid keys continues to be undefined.)
index dd2cd3690bdb76ee7b2f92113a1cb8ce3b222d55..18e1203c241bb9abfb76c9b70720e524e02d5737 100644 (file)
@@ -123,7 +123,7 @@ func (x *Nat) set(y *Nat) *Nat {
 // Bytes returns x as a zero-extended big-endian byte slice. The size of the
 // slice will match the size of m.
 //
-// x must have the same size as m and it must be reduced modulo m.
+// x must have the same size as m and it must be less than or equal to m.
 func (x *Nat) Bytes(m *Modulus) []byte {
        i := m.Size()
        bytes := make([]byte, i)
@@ -202,17 +202,13 @@ func (x *Nat) setBytes(b []byte) error {
        return nil
 }
 
-// SetUint assigns x = y, and returns an error if y >= m.
+// SetUint assigns x = y.
 //
-// The output will be resized to the size of m and overwritten.
-func (x *Nat) SetUint(y uint, m *Modulus) (*Nat, error) {
-       x.resetFor(m)
-       // Modulus is never zero, so always at least one limb.
+// The output will be resized to a single limb and overwritten.
+func (x *Nat) SetUint(y uint) *Nat {
+       x.reset(1)
        x.limbs[0] = y
-       if x.cmpGeq(m.nat) == yes {
-               return nil, errors.New("input overflows the modulus")
-       }
-       return x, nil
+       return x
 }
 
 // Equal returns 1 if x == y, and 0 otherwise.
@@ -641,11 +637,12 @@ func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
 
 // SubOne computes x = x - 1 mod m.
 //
-// The length of x must be the same as the modulus. x must already be reduced
-// modulo m.
+// The length of x must be the same as the modulus.
 func (x *Nat) SubOne(m *Modulus) *Nat {
        one := NewNat().ExpandFor(m)
        one.limbs[0] = 1
+       // Sub asks for x to be reduced modulo m, while SubOne doesn't, but when
+       // y = 1, it works, and this is an internal use.
        return x.Sub(one, m)
 }
 
index 692cd3b1ad99dcb74e9a603d0b6d3998b4c787b9..59951f838b9147dadca68256d0a3e5ac0c4153ea 100644 (file)
@@ -5,6 +5,7 @@
 package rsa
 
 import (
+       "bytes"
        "crypto/internal/fips140"
        "crypto/internal/fips140/bigmod"
        "errors"
@@ -45,10 +46,7 @@ func (priv *PrivateKey) PublicKey() *PublicKey {
 //
 // All values are in big-endian byte slice format, and may have leading zeros
 // or be shorter if leading zeroes were trimmed.
-//
-// N, e, d, P, and Q are required. dP, dQ, and qInv can be nil and will be
-// precomputed if missing.
-func NewPrivateKey(N []byte, e int, d, P, Q, dP, dQ, qInv []byte) (*PrivateKey, error) {
+func NewPrivateKey(N []byte, e int, d, P, Q []byte) (*PrivateKey, error) {
        n, err := bigmod.NewModulus(N)
        if err != nil {
                return nil, err
@@ -65,23 +63,42 @@ func NewPrivateKey(N []byte, e int, d, P, Q, dP, dQ, qInv []byte) (*PrivateKey,
        if err != nil {
                return nil, err
        }
-       // TODO(filippo): implement CRT computation. For now, NewPrivateKey is
-       // always called with CRT values.
-       if dP == nil || dQ == nil || qInv == nil {
-               panic("crypto/internal/fips140/rsa: internal error: missing CRT parameters")
+       return newPrivateKey(n, e, dN, p, q)
+}
+
+func newPrivateKey(n *bigmod.Modulus, e int, d *bigmod.Nat, p, q *bigmod.Modulus) (*PrivateKey, error) {
+       pMinusOne := p.Nat().SubOne(p)
+       pMinusOneMod, err := bigmod.NewModulus(pMinusOne.Bytes(p))
+       if err != nil {
+               return nil, err
        }
-       qInvN, err := bigmod.NewNat().SetBytes(qInv, p)
+       dP := bigmod.NewNat().Mod(d, pMinusOneMod).Bytes(pMinusOneMod)
+
+       qMinusOne := q.Nat().SubOne(q)
+       qMinusOneMod, err := bigmod.NewModulus(qMinusOne.Bytes(q))
        if err != nil {
                return nil, err
        }
+       dQ := bigmod.NewNat().Mod(d, qMinusOneMod).Bytes(qMinusOneMod)
+
+       // Constant-time modular inversion with prime modulus by Fermat's Little
+       // Theorem: qInv = q⁻¹ mod p = q^(p-2) mod p.
+       if p.Nat().IsOdd() == 0 {
+               // [bigmod.Nat.Exp] requires an odd modulus.
+               return nil, errors.New("crypto/rsa: p is even")
+       }
+       pMinusTwo := p.Nat().SubOne(p).SubOne(p).Bytes(p)
+       qInv := bigmod.NewNat().Mod(q.Nat(), p)
+       qInv.Exp(qInv, pMinusTwo, p)
+
        pk := &PrivateKey{
                pub: PublicKey{
                        N: n, E: e,
                },
-               d: dN, p: p, q: q,
-               dP: dP, dQ: dQ, qInv: qInvN,
+               d: d, p: p, q: q,
+               dP: dP, dQ: dQ, qInv: qInv,
        }
-       if err := checkPublicKey(&pk.pub); err != nil {
+       if err := checkPrivateKey(pk); err != nil {
                return nil, err
        }
        return pk, nil
@@ -105,12 +122,77 @@ func NewPrivateKeyWithoutCRT(N []byte, e int, d []byte) (*PrivateKey, error) {
                },
                d: dN,
        }
-       if err := checkPublicKey(&pk.pub); err != nil {
+       if err := checkPrivateKey(pk); err != nil {
                return nil, err
        }
        return pk, nil
 }
 
+// Export returns the key parameters in big-endian byte slice format.
+//
+// P, Q, dP, dQ, and qInv may be nil if the key was created with
+// NewPrivateKeyWithoutCRT.
+func (priv *PrivateKey) Export() (N []byte, e int, d, P, Q, dP, dQ, qInv []byte) {
+       N = priv.pub.N.Nat().Bytes(priv.pub.N)
+       e = priv.pub.E
+       d = priv.d.Bytes(priv.pub.N)
+       if priv.dP == nil {
+               return
+       }
+       P = priv.p.Nat().Bytes(priv.p)
+       Q = priv.q.Nat().Bytes(priv.q)
+       dP = bytes.Clone(priv.dP)
+       dQ = bytes.Clone(priv.dQ)
+       qInv = priv.qInv.Bytes(priv.p)
+       return
+}
+
+func checkPrivateKey(priv *PrivateKey) error {
+       if err := checkPublicKey(&priv.pub); err != nil {
+               return err
+       }
+
+       if priv.dP == nil {
+               return nil
+       }
+
+       N := priv.pub.N
+       Π := bigmod.NewNat().ExpandFor(N)
+       for _, prime := range []*bigmod.Modulus{priv.p, priv.q} {
+               p := prime.Nat().ExpandFor(N)
+               if p.IsZero() == 1 || p.IsOne() == 1 {
+                       return errors.New("crypto/rsa: invalid prime")
+               }
+               Π.Mul(p, N)
+
+               // Check that de ≡ 1 mod p-1, for each prime.
+               // This implies that e is coprime to each p-1 as e has a multiplicative
+               // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
+               // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
+               // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
+
+               pMinus1, err := bigmod.NewModulus(p.SubOne(N).Bytes(N))
+               if err != nil {
+                       return errors.New("crypto/rsa: invalid prime")
+               }
+
+               e := bigmod.NewNat().SetUint(uint(priv.pub.E)).ExpandFor(pMinus1)
+
+               de := bigmod.NewNat()
+               de.Mod(priv.d, pMinus1)
+               de.Mul(e, pMinus1)
+               if de.IsOne() != 1 {
+                       return errors.New("crypto/rsa: invalid exponents")
+               }
+       }
+       // Check that Πprimes == n.
+       if Π.IsZero() != 1 {
+               return errors.New("crypto/rsa: invalid modulus")
+       }
+
+       return nil
+}
+
 func checkPublicKey(pub *PublicKey) error {
        if pub.N == nil {
                return errors.New("crypto/rsa: missing public modulus")
index 9051f176f7a9f3679d53c737257459b45227ddef..8cca6a8cdd5d7f0d970996a9ba87f317da893a8c 100644 (file)
@@ -20,7 +20,8 @@
 // Decrypter and Signer interfaces from the crypto package.
 //
 // Operations involving private keys are implemented using constant-time
-// algorithms, except for [GenerateKey] and [PrivateKey.Precompute].
+// algorithms, except for [GenerateKey] and for some operations involving
+// deprecated multi-prime keys.
 //
 // # Minimum key size
 //
@@ -223,73 +224,22 @@ type CRTValue struct {
 
 // Validate performs basic sanity checks on the key.
 // It returns nil if the key is valid, or else an error describing a problem.
+//
+// It runs faster on valid keys if run after [Precompute].
 func (priv *PrivateKey) Validate() error {
-       pub := &priv.PublicKey
-       if pub.N == nil {
-               return errors.New("crypto/rsa: missing public modulus")
-       }
-       if pub.N.Bit(0) == 0 {
-               return errors.New("crypto/rsa: public modulus is even")
-       }
-       if pub.E < 2 {
-               return errors.New("crypto/rsa: public exponent is less than 2")
-       }
-       if pub.E&1 == 0 {
-               return errors.New("crypto/rsa: public exponent is even")
-       }
-       if pub.E > 1<<31-1 {
-               return errors.New("crypto/rsa: public exponent too large")
-       }
-
-       N, err := bigmod.NewModulus(pub.N.Bytes())
-       if err != nil {
-               return fmt.Errorf("crypto/rsa: invalid public modulus: %v", err)
-       }
-       d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), N)
-       if err != nil {
-               return fmt.Errorf("crypto/rsa: invalid private exponent: %v", err)
-       }
-
-       Π := bigmod.NewNat().ExpandFor(N)
-       for _, prime := range priv.Primes {
-               p, err := bigmod.NewNat().SetBytes(prime.Bytes(), N)
-               if err != nil {
-                       return fmt.Errorf("crypto/rsa: invalid prime: %v", err)
-               }
-               if p.IsZero() == 1 || p.IsOne() == 1 {
-                       return errors.New("crypto/rsa: invalid prime")
-               }
-               Π.Mul(p, N)
-
-               // Check that de ≡ 1 mod p-1, for each prime.
-               // This implies that e is coprime to each p-1 as e has a multiplicative
-               // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
-               // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
-               // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
-
-               pMinus1, err := bigmod.NewModulus(p.SubOne(N).Bytes(N))
-               if err != nil {
-                       return fmt.Errorf("crypto/rsa: internal error: %v", err)
-               }
-
-               e, err := bigmod.NewNat().SetUint(uint(pub.E), pMinus1)
-               if err != nil {
-                       return fmt.Errorf("crypto/rsa: invalid public exponent: %v", err)
-               }
-
-               de := bigmod.NewNat()
-               de.Mod(d, pMinus1)
-               de.Mul(e, pMinus1)
-               if de.IsOne() != 1 {
-                       return errors.New("crypto/rsa: invalid exponents")
-               }
+       // We can operate on keys based on d alone, but it isn't possible to encode
+       // with [crypto/x509.MarshalPKCS1PrivateKey], which unfortunately doesn't
+       // return an error.
+       if len(priv.Primes) < 2 {
+               return errors.New("crypto/rsa: missing primes")
        }
-       // Check that Πprimes == n.
-       if Π.IsZero() != 1 {
-               return errors.New("crypto/rsa: invalid modulus")
+       // If Precomputed.fips is set, then the key has been validated by
+       // [rsa.NewPrivateKey] or [rsa.NewPrivateKeyWithoutCRT].
+       if priv.Precomputed.fips != nil {
+               return nil
        }
-
-       return nil
+       _, err := priv.precompute()
+       return err
 }
 
 // rsa1024min is a GODEBUG that re-enables weak RSA keys if set to "0".
@@ -496,53 +446,108 @@ var ErrDecryption = errors.New("crypto/rsa: decryption error")
 var ErrVerification = errors.New("crypto/rsa: verification error")
 
 // Precompute performs some calculations that speed up private key operations
-// in the future.
+// in the future. It is safe to run on non-validated private keys.
 func (priv *PrivateKey) Precompute() {
        if priv.Precomputed.fips != nil {
                return
        }
 
-       if len(priv.Primes) < 2 {
-               priv.Precomputed.fips, _ = rsa.NewPrivateKeyWithoutCRT(
-                       priv.N.Bytes(), priv.E, priv.D.Bytes())
+       precomputed, err := priv.precompute()
+       if err != nil {
+               // We don't have a way to report errors, so just leave the key
+               // unmodified. Validate will re-run precompute.
                return
        }
+       priv.Precomputed = precomputed
+}
+
+func (priv *PrivateKey) precompute() (PrecomputedValues, error) {
+       var precomputed PrecomputedValues
+
+       if priv.N == nil {
+               return precomputed, errors.New("crypto/rsa: missing public modulus")
+       }
+       if priv.D == nil {
+               return precomputed, errors.New("crypto/rsa: missing private exponent")
+       }
+       if len(priv.Primes) != 2 {
+               return priv.precomputeLegacy()
+       }
+       if priv.Primes[0] == nil {
+               return precomputed, errors.New("crypto/rsa: prime P is nil")
+       }
+       if priv.Primes[1] == nil {
+               return precomputed, errors.New("crypto/rsa: prime Q is nil")
+       }
+
+       k, err := rsa.NewPrivateKey(priv.N.Bytes(), priv.E, priv.D.Bytes(),
+               priv.Primes[0].Bytes(), priv.Primes[1].Bytes())
+       if err != nil {
+               return precomputed, err
+       }
 
-       priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
-       priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
+       precomputed.fips = k
+       _, _, _, _, _, dP, dQ, qInv := k.Export()
+       precomputed.Dp = new(big.Int).SetBytes(dP)
+       precomputed.Dq = new(big.Int).SetBytes(dQ)
+       precomputed.Qinv = new(big.Int).SetBytes(qInv)
+       precomputed.CRTValues = make([]CRTValue, 0)
+       return precomputed, nil
+}
 
-       priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
-       priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
+func (priv *PrivateKey) precomputeLegacy() (PrecomputedValues, error) {
+       var precomputed PrecomputedValues
 
-       priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
+       k, err := rsa.NewPrivateKeyWithoutCRT(priv.N.Bytes(), priv.E, priv.D.Bytes())
+       if err != nil {
+               return precomputed, err
+       }
+       precomputed.fips = k
+
+       if len(priv.Primes) < 2 {
+               return precomputed, nil
+       }
+
+       // Ensure the Mod and ModInverse calls below don't panic.
+       for _, prime := range priv.Primes {
+               if prime == nil {
+                       return precomputed, errors.New("crypto/rsa: prime factor is nil")
+               }
+               if prime.Cmp(bigOne) <= 0 {
+                       return precomputed, errors.New("crypto/rsa: prime factor is <= 1")
+               }
+       }
+
+       precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
+       precomputed.Dp.Mod(priv.D, precomputed.Dp)
+
+       precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
+       precomputed.Dq.Mod(priv.D, precomputed.Dq)
+
+       precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
+       if precomputed.Qinv == nil {
+               return precomputed, errors.New("crypto/rsa: prime factors are not relatively prime")
+       }
 
        r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
-       priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
+       precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
        for i := 2; i < len(priv.Primes); i++ {
                prime := priv.Primes[i]
-               values := &priv.Precomputed.CRTValues[i-2]
+               values := &precomputed.CRTValues[i-2]
 
                values.Exp = new(big.Int).Sub(prime, bigOne)
                values.Exp.Mod(priv.D, values.Exp)
 
                values.R = new(big.Int).Set(r)
                values.Coeff = new(big.Int).ModInverse(r, prime)
+               if values.Coeff == nil {
+                       return precomputed, errors.New("crypto/rsa: prime factors are not relatively prime")
+               }
 
                r.Mul(r, prime)
        }
 
-       // Errors are discarded because we don't have a way to report them.
-       // Anything that relies on Precomputed.fips will need to check for nil.
-       if len(priv.Primes) == 2 {
-               priv.Precomputed.fips, _ = rsa.NewPrivateKey(
-                       priv.N.Bytes(), priv.E, priv.D.Bytes(),
-                       priv.Primes[0].Bytes(), priv.Primes[1].Bytes(),
-                       priv.Precomputed.Dp.Bytes(), priv.Precomputed.Dq.Bytes(),
-                       priv.Precomputed.Qinv.Bytes())
-       } else {
-               priv.Precomputed.fips, _ = rsa.NewPrivateKeyWithoutCRT(
-                       priv.N.Bytes(), priv.E, priv.D.Bytes())
-       }
+       return precomputed, nil
 }
 
 func fipsPublicKey(pub *PublicKey) (*rsa.PublicKey, error) {
@@ -557,11 +562,9 @@ func fipsPrivateKey(priv *PrivateKey) (*rsa.PrivateKey, error) {
        if priv.Precomputed.fips != nil {
                return priv.Precomputed.fips, nil
        }
-       // Make a copy of the private key to avoid modifying the original.
-       k := *priv
-       k.Precompute()
-       if k.Precomputed.fips == nil {
-               return nil, errors.New("crypto/rsa: invalid private key")
+       precomputed, err := priv.precompute()
+       if err != nil {
+               return nil, err
        }
-       return k.Precomputed.fips, nil
+       return precomputed.fips, nil
 }
index 94c7bbb2304fd116418f6d4d49b20416bcafb6db..7929867ac6d2e722762288e7c24d61400993fba3 100644 (file)
@@ -87,11 +87,10 @@ func ParsePKCS1PrivateKey(der []byte) (*rsa.PrivateKey, error) {
                // them as needed.
        }
 
-       err = key.Validate()
-       if err != nil {
+       key.Precompute()
+       if err := key.Validate(); err != nil {
                return nil, err
        }
-       key.Precompute()
 
        return key, nil
 }
@@ -101,6 +100,10 @@ func ParsePKCS1PrivateKey(der []byte) (*rsa.PrivateKey, error) {
 // This kind of key is commonly encoded in PEM blocks of type "RSA PRIVATE KEY".
 // For a more flexible key format which is not [RSA] specific, use
 // [MarshalPKCS8PrivateKey].
+//
+// The key must have passed validation by calling [rsa.PrivateKey.Validate]
+// first. MarshalPKCS1PrivateKey calls [rsa.PrivateKey.Precompute], which may
+// modify the key if not already precomputed.
 func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
        key.Precompute()
 
index 08e9da404c33e96503c9226082e66da8da505762..6268c367575025fff96781d5722ece209cb80a24 100644 (file)
@@ -98,6 +98,8 @@ func ParsePKCS8PrivateKey(der []byte) (key any, err error) {
 // Unsupported key types result in an error.
 //
 // This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY".
+//
+// MarshalPKCS8PrivateKey runs [rsa.PrivateKey.Precompute] on RSA keys.
 func MarshalPKCS8PrivateKey(key any) ([]byte, error) {
        var privKey pkcs8
 
@@ -107,6 +109,10 @@ func MarshalPKCS8PrivateKey(key any) ([]byte, error) {
                        Algorithm:  oidPublicKeyRSA,
                        Parameters: asn1.NullRawValue,
                }
+               k.Precompute()
+               if err := k.Validate(); err != nil {
+                       return nil, err
+               }
                privKey.PrivateKey = MarshalPKCS1PrivateKey(k)
 
        case *ecdsa.PrivateKey:
index 3eeeb0212878aef67169e42738ddc1fa5afe6712..1a714e7b62cf4775f68f53f4a7d2eae1792fb849 100644 (file)
@@ -251,7 +251,7 @@ func TestMarshalRSAPrivateKey(t *testing.T) {
                priv.Primes[0].Cmp(priv2.Primes[0]) != 0 ||
                priv.Primes[1].Cmp(priv2.Primes[1]) != 0 ||
                priv.Primes[2].Cmp(priv2.Primes[2]) != 0 {
-               t.Errorf("got:%+v want:%+v", priv, priv2)
+               t.Errorf("wrong priv:\ngot  %+v\nwant %+v", priv2, priv)
        }
 }