]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rsa: port Validate to bigmod
authorFilippo Valsorda <filippo@golang.org>
Thu, 21 Nov 2024 12:51:21 +0000 (13:51 +0100)
committerGopher Robot <gobot@golang.org>
Fri, 22 Nov 2024 01:50:41 +0000 (01:50 +0000)
This is quite a bit slower (almost entirely in the e * d reductions,
which could be optimized), but the slowdown is only 12% of a signature
operation.

Also, call Validate at the end of GenerateKey as a backstop. Key
generation is so incredibly slow that the extra time is negligible.

goos: darwin
goarch: arm64
pkg: crypto/rsa
cpu: Apple M2
                            │  ec9643bbed  │           ec9643bbed-dirty            │
                            │    sec/op    │    sec/op      vs base                │
SignPSS/2048-8                869.8µ ±  1%    870.2µ ±  0%         ~ (p=0.937 n=6)
GenerateKey/2048-8            104.2m ± 17%    106.9m ± 10%         ~ (p=0.589 n=6)
ParsePKCS8PrivateKey/2048-8   28.54µ ±  2%   136.78µ ±  8%  +379.23% (p=0.002 n=6)

Fixes #57751

Co-authored-by: Derek Parker <parkerderek86@gmail.com>
Change-Id: Ifb476859207925a018b433c16dd62fb767afd2d5
Reviewed-on: https://go-review.googlesource.com/c/go/+/630517
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/crypto/internal/fips140/bigmod/nat.go
src/crypto/rsa/rsa.go
src/crypto/rsa/rsa_test.go

index 0a95928536380d0d853a674a73308e4db40a7706..13a4ba6e967099bc9eeefcfa1ff5f578f6d5bc6c 100644 (file)
@@ -202,6 +202,19 @@ func (x *Nat) setBytes(b []byte) error {
        return nil
 }
 
+// SetUint assigns x = y, and returns an error if y >= m.
+//
+// The output will be resized to the size of m and overwritten.
+func (x *Nat) SetUint(y uint, m *Modulus) (*Nat, error) {
+       x.resetFor(m)
+       // Modulus is never zero, so always at least one limb.
+       x.limbs[0] = y
+       if x.cmpGeq(m.nat) == yes {
+               return nil, errors.New("input overflows the modulus")
+       }
+       return x, nil
+}
+
 // Equal returns 1 if x == y, and 0 otherwise.
 //
 // Both operands must have the same announced length.
index 9a57056f03aa14c9f44d3dec4f40d90d5de4608f..eb6ce73e0f1a97ec8419ca98a925d6890109ab14 100644 (file)
@@ -20,8 +20,7 @@
 // Decrypter and Signer interfaces from the crypto package.
 //
 // Operations involving private keys are implemented using constant-time
-// algorithms, except for [GenerateKey], [PrivateKey.Precompute], and
-// [PrivateKey.Validate].
+// algorithms, except for [GenerateKey] and [PrivateKey.Precompute].
 //
 // # Minimum key size
 //
@@ -236,34 +235,67 @@ func (priv *PrivateKey) Validate() error {
                return errors.New("crypto/rsa: public exponent too large")
        }
 
-       // Check that Πprimes == n.
-       modulus := new(big.Int).Set(bigOne)
-       for _, prime := range priv.Primes {
-               // Any primes ≤ 1 will cause divide-by-zero panics later.
-               if prime.Cmp(bigOne) <= 0 {
-                       return errors.New("crypto/rsa: invalid prime value")
-               }
-               modulus.Mul(modulus, prime)
+       N, err := bigmod.NewModulus(pub.N.Bytes())
+       if err != nil {
+               return fmt.Errorf("crypto/rsa: invalid public modulus: %v", err)
        }
-       if modulus.Cmp(priv.N) != 0 {
-               return errors.New("crypto/rsa: invalid modulus")
+       d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), N)
+       if err != nil {
+               return fmt.Errorf("crypto/rsa: invalid private exponent: %v", err)
+       }
+       one, err := bigmod.NewNat().SetUint(1, N)
+       if err != nil {
+               return fmt.Errorf("crypto/rsa: internal error: %v", err)
        }
 
-       // Check that de ≡ 1 mod p-1, for each prime.
-       // This implies that e is coprime to each p-1 as e has a multiplicative
-       // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
-       // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
-       // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
-       congruence := new(big.Int)
-       de := new(big.Int).SetInt64(int64(priv.E))
-       de.Mul(de, priv.D)
+       Π := bigmod.NewNat().ExpandFor(N)
        for _, prime := range priv.Primes {
-               pminus1 := new(big.Int).Sub(prime, bigOne)
-               congruence.Mod(de, pminus1)
-               if congruence.Cmp(bigOne) != 0 {
+               p, err := bigmod.NewNat().SetBytes(prime.Bytes(), N)
+               if err != nil {
+                       return fmt.Errorf("crypto/rsa: invalid prime: %v", err)
+               }
+               if p.IsZero() == 1 {
+                       return errors.New("crypto/rsa: invalid prime")
+               }
+               Π.Mul(p, N)
+
+               // Check that de ≡ 1 mod p-1, for each prime.
+               // This implies that e is coprime to each p-1 as e has a multiplicative
+               // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
+               // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
+               // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
+
+               p.Sub(one, N)
+               if p.IsZero() == 1 {
+                       return errors.New("crypto/rsa: invalid prime")
+               }
+               pMinus1, err := bigmod.NewModulus(p.Bytes(N))
+               if err != nil {
+                       return fmt.Errorf("crypto/rsa: internal error: %v", err)
+               }
+
+               e, err := bigmod.NewNat().SetUint(uint(pub.E), pMinus1)
+               if err != nil {
+                       return fmt.Errorf("crypto/rsa: invalid public exponent: %v", err)
+               }
+               one, err := bigmod.NewNat().SetUint(1, pMinus1)
+               if err != nil {
+                       return fmt.Errorf("crypto/rsa: internal error: %v", err)
+               }
+
+               de := bigmod.NewNat()
+               de.Mod(d, pMinus1)
+               de.Mul(e, pMinus1)
+               de.Sub(one, pMinus1)
+               if de.IsZero() != 1 {
                        return errors.New("crypto/rsa: invalid exponents")
                }
        }
+       // Check that Πprimes == n.
+       if Π.IsZero() != 1 {
+               return errors.New("crypto/rsa: invalid modulus")
+       }
+
        return nil
 }
 
@@ -450,6 +482,10 @@ NextSetOfPrimes:
        }
 
        priv.Precompute()
+       if err := priv.Validate(); err != nil {
+               return nil, err
+       }
+
        return priv, nil
 }
 
index dbf5e0a52aabe904b3b9240f69b3943befa353ac..7a3e02f09c92a2d9bdce0bbdcb5025ba1fd082d0 100644 (file)
@@ -98,10 +98,10 @@ func TestNPrimeKeyGeneration(t *testing.T) {
 }
 
 func TestImpossibleKeyGeneration(t *testing.T) {
-       // This test ensures that trying to generate toy RSA keys doesn't enter
-       // an infinite loop.
+       // This test ensures that trying to generate or validate toy RSA keys
+       // doesn't enter an infinite loop or panic.
        t.Setenv("GODEBUG", "rsa1024min=0")
-       for i := 0; i < 32; i++ {
+       for i := 0; i < 128; i++ {
                GenerateKey(rand.Reader, i)
                GenerateMultiPrimeKey(rand.Reader, 3, i)
                GenerateMultiPrimeKey(rand.Reader, 4, i)
@@ -184,7 +184,7 @@ func TestEverything(t *testing.T) {
        }
 
        t.Setenv("GODEBUG", "rsa1024min=0")
-       min := 32
+       min := 128
        max := 560 // any smaller than this and not all tests will run
        if *allFlag {
                max = 2048