From 8cecfad2a99987a35edfbcd875bef5e894abbce7 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Thu, 21 Nov 2024 13:51:21 +0100 Subject: [PATCH] crypto/rsa: port Validate to bigmod MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit This is quite a bit slower (almost entirely in the e * d reductions, which could be optimized), but the slowdown is only 12% of a signature operation. Also, call Validate at the end of GenerateKey as a backstop. Key generation is so incredibly slow that the extra time is negligible. goos: darwin goarch: arm64 pkg: crypto/rsa cpu: Apple M2 │ ec9643bbed │ ec9643bbed-dirty │ │ sec/op │ sec/op vs base │ SignPSS/2048-8 869.8µ ± 1% 870.2µ ± 0% ~ (p=0.937 n=6) GenerateKey/2048-8 104.2m ± 17% 106.9m ± 10% ~ (p=0.589 n=6) ParsePKCS8PrivateKey/2048-8 28.54µ ± 2% 136.78µ ± 8% +379.23% (p=0.002 n=6) Fixes #57751 Co-authored-by: Derek Parker Change-Id: Ifb476859207925a018b433c16dd62fb767afd2d5 Reviewed-on: https://go-review.googlesource.com/c/go/+/630517 Auto-Submit: Filippo Valsorda Reviewed-by: Roland Shoemaker Reviewed-by: Russ Cox LUCI-TryBot-Result: Go LUCI --- src/crypto/internal/fips140/bigmod/nat.go | 13 ++++ src/crypto/rsa/rsa.go | 82 ++++++++++++++++------- src/crypto/rsa/rsa_test.go | 8 +-- 3 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/crypto/internal/fips140/bigmod/nat.go b/src/crypto/internal/fips140/bigmod/nat.go index 0a95928536..13a4ba6e96 100644 --- a/src/crypto/internal/fips140/bigmod/nat.go +++ b/src/crypto/internal/fips140/bigmod/nat.go @@ -202,6 +202,19 @@ func (x *Nat) setBytes(b []byte) error { return nil } +// SetUint assigns x = y, and returns an error if y >= m. +// +// The output will be resized to the size of m and overwritten. +func (x *Nat) SetUint(y uint, m *Modulus) (*Nat, error) { + x.resetFor(m) + // Modulus is never zero, so always at least one limb. + x.limbs[0] = y + if x.cmpGeq(m.nat) == yes { + return nil, errors.New("input overflows the modulus") + } + return x, nil +} + // Equal returns 1 if x == y, and 0 otherwise. // // Both operands must have the same announced length. diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 9a57056f03..eb6ce73e0f 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -20,8 +20,7 @@ // Decrypter and Signer interfaces from the crypto package. // // Operations involving private keys are implemented using constant-time -// algorithms, except for [GenerateKey], [PrivateKey.Precompute], and -// [PrivateKey.Validate]. +// algorithms, except for [GenerateKey] and [PrivateKey.Precompute]. // // # Minimum key size // @@ -236,34 +235,67 @@ func (priv *PrivateKey) Validate() error { return errors.New("crypto/rsa: public exponent too large") } - // Check that Πprimes == n. - modulus := new(big.Int).Set(bigOne) - for _, prime := range priv.Primes { - // Any primes ≤ 1 will cause divide-by-zero panics later. - if prime.Cmp(bigOne) <= 0 { - return errors.New("crypto/rsa: invalid prime value") - } - modulus.Mul(modulus, prime) + N, err := bigmod.NewModulus(pub.N.Bytes()) + if err != nil { + return fmt.Errorf("crypto/rsa: invalid public modulus: %v", err) } - if modulus.Cmp(priv.N) != 0 { - return errors.New("crypto/rsa: invalid modulus") + d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), N) + if err != nil { + return fmt.Errorf("crypto/rsa: invalid private exponent: %v", err) + } + one, err := bigmod.NewNat().SetUint(1, N) + if err != nil { + return fmt.Errorf("crypto/rsa: internal error: %v", err) } - // Check that de ≡ 1 mod p-1, for each prime. - // This implies that e is coprime to each p-1 as e has a multiplicative - // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = - // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 - // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required. - congruence := new(big.Int) - de := new(big.Int).SetInt64(int64(priv.E)) - de.Mul(de, priv.D) + Π := bigmod.NewNat().ExpandFor(N) for _, prime := range priv.Primes { - pminus1 := new(big.Int).Sub(prime, bigOne) - congruence.Mod(de, pminus1) - if congruence.Cmp(bigOne) != 0 { + p, err := bigmod.NewNat().SetBytes(prime.Bytes(), N) + if err != nil { + return fmt.Errorf("crypto/rsa: invalid prime: %v", err) + } + if p.IsZero() == 1 { + return errors.New("crypto/rsa: invalid prime") + } + Π.Mul(p, N) + + // Check that de ≡ 1 mod p-1, for each prime. + // This implies that e is coprime to each p-1 as e has a multiplicative + // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = + // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 + // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required. + + p.Sub(one, N) + if p.IsZero() == 1 { + return errors.New("crypto/rsa: invalid prime") + } + pMinus1, err := bigmod.NewModulus(p.Bytes(N)) + if err != nil { + return fmt.Errorf("crypto/rsa: internal error: %v", err) + } + + e, err := bigmod.NewNat().SetUint(uint(pub.E), pMinus1) + if err != nil { + return fmt.Errorf("crypto/rsa: invalid public exponent: %v", err) + } + one, err := bigmod.NewNat().SetUint(1, pMinus1) + if err != nil { + return fmt.Errorf("crypto/rsa: internal error: %v", err) + } + + de := bigmod.NewNat() + de.Mod(d, pMinus1) + de.Mul(e, pMinus1) + de.Sub(one, pMinus1) + if de.IsZero() != 1 { return errors.New("crypto/rsa: invalid exponents") } } + // Check that Πprimes == n. + if Π.IsZero() != 1 { + return errors.New("crypto/rsa: invalid modulus") + } + return nil } @@ -450,6 +482,10 @@ NextSetOfPrimes: } priv.Precompute() + if err := priv.Validate(); err != nil { + return nil, err + } + return priv, nil } diff --git a/src/crypto/rsa/rsa_test.go b/src/crypto/rsa/rsa_test.go index dbf5e0a52a..7a3e02f09c 100644 --- a/src/crypto/rsa/rsa_test.go +++ b/src/crypto/rsa/rsa_test.go @@ -98,10 +98,10 @@ func TestNPrimeKeyGeneration(t *testing.T) { } func TestImpossibleKeyGeneration(t *testing.T) { - // This test ensures that trying to generate toy RSA keys doesn't enter - // an infinite loop. + // This test ensures that trying to generate or validate toy RSA keys + // doesn't enter an infinite loop or panic. t.Setenv("GODEBUG", "rsa1024min=0") - for i := 0; i < 32; i++ { + for i := 0; i < 128; i++ { GenerateKey(rand.Reader, i) GenerateMultiPrimeKey(rand.Reader, 3, i) GenerateMultiPrimeKey(rand.Reader, 4, i) @@ -184,7 +184,7 @@ func TestEverything(t *testing.T) { } t.Setenv("GODEBUG", "rsa1024min=0") - min := 32 + min := 128 max := 560 // any smaller than this and not all tests will run if *allFlag { max = 2048 -- 2.48.1