]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/bigmod: don't panic on NewModulusFromBig(0)
authorRoland Shoemaker <roland@golang.org>
Wed, 24 May 2023 17:53:47 +0000 (10:53 -0700)
committerGopher Robot <gobot@golang.org>
Thu, 25 May 2023 01:01:41 +0000 (01:01 +0000)
Return an error instead. Makes usages of NewModulusFromBig a bit more
verbose, but better than returning nil or something and just moving the
panic down the road.

Fixes #60411

Change-Id: I10732c6ce56ccd9e4769281cea049dd4beb60a6e
Reviewed-on: https://go-review.googlesource.com/c/go/+/498035
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
src/crypto/ecdsa/ecdsa.go
src/crypto/internal/bigmod/nat.go
src/crypto/internal/bigmod/nat_test.go
src/crypto/rsa/rsa.go

index 68272af41fc1378cf423e1a6888bad0356d6847d..1b04b2cb99b4da8c9865522dfcad11396b1a4e28 100644 (file)
@@ -655,6 +655,10 @@ func p521() *nistCurve[*nistec.P521Point] {
 func precomputeParams[Point nistPoint[Point]](c *nistCurve[Point], curve elliptic.Curve) {
        params := curve.Params()
        c.curve = curve
-       c.N = bigmod.NewModulusFromBig(params.N)
+       var err error
+       c.N, err = bigmod.NewModulusFromBig(params.N)
+       if err != nil {
+               panic(err)
+       }
        c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
 }
index a08c12b76e65a0f61e5cb3cddd0bfa8846f927e0..5605e9f1c34bb94c1cd3af1af4909b4e27388436 100644 (file)
@@ -351,13 +351,18 @@ func minusInverseModW(x uint) uint {
 //
 // The Int must be odd. The number of significant bits (and nothing else) is
 // leaked through timing side-channels.
-func NewModulusFromBig(n *big.Int) *Modulus {
+func NewModulusFromBig(n *big.Int) (*Modulus, error) {
+       if b := n.Bits(); len(b) == 0 {
+               return nil, errors.New("modulus must be >= 0")
+       } else if b[0]&1 != 1 {
+               return nil, errors.New("modulus must be odd")
+       }
        m := &Modulus{}
        m.nat = NewNat().setBig(n)
        m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
        m.m0inv = minusInverseModW(m.nat.limbs[0])
        m.rr = rr(m)
-       return m
+       return m, nil
 }
 
 // bitLen is a version of bits.Len that only leaks the bit length of n, but not
index 1c615b988827ae040ae5e91fb509c5a20684507d..76e557048c38c267d279ec7758586dccda0bb446 100644 (file)
@@ -70,7 +70,7 @@ func TestMontgomeryRoundtrip(t *testing.T) {
                one.limbs[0] = 1
                aPlusOne := new(big.Int).SetBytes(natBytes(a))
                aPlusOne.Add(aPlusOne, big.NewInt(1))
-               m := NewModulusFromBig(aPlusOne)
+               m, _ := NewModulusFromBig(aPlusOne)
                monty := new(Nat).set(a)
                monty.montgomeryRepresentation(m)
                aAgain := new(Nat).set(monty)
@@ -319,7 +319,7 @@ func TestMulReductions(t *testing.T) {
        b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
        n := new(big.Int).Mul(a, b)
 
-       N := NewModulusFromBig(n)
+       N, _ := NewModulusFromBig(n)
        A := NewNat().setBig(a).ExpandFor(N)
        B := NewNat().setBig(b).ExpandFor(N)
 
@@ -328,7 +328,7 @@ func TestMulReductions(t *testing.T) {
        }
 
        i := new(big.Int).ModInverse(a, b)
-       N = NewModulusFromBig(b)
+       N, _ = NewModulusFromBig(b)
        A = NewNat().setBig(a).ExpandFor(N)
        I := NewNat().setBig(i).ExpandFor(N)
        one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
@@ -350,15 +350,17 @@ func natFromBytes(b []byte) *Nat {
 
 func modulusFromBytes(b []byte) *Modulus {
        bb := new(big.Int).SetBytes(b)
-       return NewModulusFromBig(bb)
+       m, _ := NewModulusFromBig(bb)
+       return m
 }
 
 // maxModulus returns the biggest modulus that can fit in n limbs.
 func maxModulus(n uint) *Modulus {
-       m := big.NewInt(1)
-       m.Lsh(m, n*_W)
-       m.Sub(m, big.NewInt(1))
-       return NewModulusFromBig(m)
+       b := big.NewInt(1)
+       b.Lsh(b, n*_W)
+       b.Sub(b, big.NewInt(1))
+       m, _ := NewModulusFromBig(b)
+       return m
 }
 
 func makeBenchmarkModulus() *Modulus {
@@ -462,3 +464,17 @@ func BenchmarkExp(b *testing.B) {
                out.Exp(x, e, m)
        }
 }
+
+func TestNewModFromBigZero(t *testing.T) {
+       expected := "modulus must be >= 0"
+       _, err := NewModulusFromBig(big.NewInt(0))
+       if err == nil || err.Error() != expected {
+               t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
+       }
+
+       expected = "modulus must be odd"
+       _, err = NewModulusFromBig(big.NewInt(2))
+       if err == nil || err.Error() != expected {
+               t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
+       }
+}
index 1d01ff3ed1e282603a905a2e1e3c6d6ae16aeacf..88e44508cdf2227fb6438da2ab41391ade13d41b 100644 (file)
@@ -309,6 +309,20 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey
                if !E.IsInt64() || int64(int(e64)) != e64 {
                        return nil, errors.New("crypto/rsa: generated key exponent too large")
                }
+
+               mn, err := bigmod.NewModulusFromBig(N)
+               if err != nil {
+                       return nil, err
+               }
+               mp, err := bigmod.NewModulusFromBig(P)
+               if err != nil {
+                       return nil, err
+               }
+               mq, err := bigmod.NewModulusFromBig(Q)
+               if err != nil {
+                       return nil, err
+               }
+
                key := &PrivateKey{
                        PublicKey: PublicKey{
                                N: N,
@@ -321,9 +335,9 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey
                                Dq:        Dq,
                                Qinv:      Qinv,
                                CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
-                               n:         bigmod.NewModulusFromBig(N),
-                               p:         bigmod.NewModulusFromBig(P),
-                               q:         bigmod.NewModulusFromBig(Q),
+                               n:         mn,
+                               p:         mp,
+                               q:         mq,
                        },
                }
                return key, nil
@@ -465,7 +479,10 @@ func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) {
        // NewModulusFromBig call, because PublicKey doesn't have a Precomputed
        // field. If performance becomes an issue, consider placing a private
        // sync.Once on PublicKey to compute this.
-       N := bigmod.NewModulusFromBig(pub.N)
+       N, err := bigmod.NewModulusFromBig(pub.N)
+       if err != nil {
+               return nil, err
+       }
        m, err := bigmod.NewNat().SetBytes(plaintext, N)
        if err != nil {
                return nil, err
@@ -555,9 +572,25 @@ var ErrVerification = errors.New("crypto/rsa: verification error")
 // in the future.
 func (priv *PrivateKey) Precompute() {
        if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
-               priv.Precomputed.n = bigmod.NewModulusFromBig(priv.N)
-               priv.Precomputed.p = bigmod.NewModulusFromBig(priv.Primes[0])
-               priv.Precomputed.q = bigmod.NewModulusFromBig(priv.Primes[1])
+               // Precomputed values _should_ always be valid, but if they aren't
+               // just return. We could also panic.
+               var err error
+               priv.Precomputed.n, err = bigmod.NewModulusFromBig(priv.N)
+               if err != nil {
+                       return
+               }
+               priv.Precomputed.p, err = bigmod.NewModulusFromBig(priv.Primes[0])
+               if err != nil {
+                       // Unset previous values, so we either have everything or nothing
+                       priv.Precomputed.n = nil
+                       return
+               }
+               priv.Precomputed.q, err = bigmod.NewModulusFromBig(priv.Primes[1])
+               if err != nil {
+                       // Unset previous values, so we either have everything or nothing
+                       priv.Precomputed.n, priv.Precomputed.p = nil, nil
+                       return
+               }
        }
 
        // Fill in the backwards-compatibility *big.Int values.
@@ -607,7 +640,10 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
                t0   = bigmod.NewNat()
        )
        if priv.Precomputed.n == nil {
-               N = bigmod.NewModulusFromBig(priv.N)
+               N, err = bigmod.NewModulusFromBig(priv.N)
+               if err != nil {
+                       return nil, ErrDecryption
+               }
                c, err = bigmod.NewNat().SetBytes(ciphertext, N)
                if err != nil {
                        return nil, ErrDecryption