]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/bigmod: drop math/big dependency
authorFilippo Valsorda <filippo@golang.org>
Sat, 16 Nov 2024 13:10:02 +0000 (14:10 +0100)
committerGopher Robot <gobot@golang.org>
Tue, 19 Nov 2024 23:01:40 +0000 (23:01 +0000)
If when the dust settles the Bytes and SetBytes round-trip is visible in
profiles (only plausible in RSA), then we can add a SetBits method like
in CL 511375.

Change-Id: I3e6677e849d7a3786fa7297437b119a47715225f
Reviewed-on: https://go-review.googlesource.com/c/go/+/628675
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Russ Cox <rsc@golang.org>
src/crypto/ecdsa/ecdsa.go
src/crypto/internal/bigmod/nat.go
src/crypto/internal/bigmod/nat_test.go
src/crypto/rsa/rsa.go

index 95a4b4be69f32e7fd07376ba316abdde96496b9d..45215abed0301730715fa30d4c3b6b933b0e64a9 100644 (file)
@@ -669,7 +669,7 @@ func precomputeParams[Point nistPoint[Point]](c *nistCurve[Point], curve ellipti
        params := curve.Params()
        c.curve = curve
        var err error
-       c.N, err = bigmod.NewModulusFromBig(params.N)
+       c.N, err = bigmod.NewModulus(params.N.Bytes())
        if err != nil {
                panic(err)
        }
index 5cbae40efe997d002a80a45813e25e2a24c5df5e..7bd09b37ac90216eeb82be32ce99ec7bb377fa7b 100644 (file)
@@ -7,7 +7,6 @@ package bigmod
 import (
        "errors"
        "internal/byteorder"
-       "math/big"
        "math/bits"
 )
 
@@ -92,26 +91,34 @@ func (x *Nat) reset(n int) *Nat {
        return x
 }
 
-// set assigns x = y, optionally resizing x to the appropriate size.
-func (x *Nat) set(y *Nat) *Nat {
-       x.reset(len(y.limbs))
-       copy(x.limbs, y.limbs)
-       return x
-}
-
-// setBig assigns x = n, optionally resizing n to the appropriate size.
+// resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing
+// n to the appropriate size.
 //
 // The announced length of x is set based on the actual bit size of the input,
 // ignoring leading zeroes.
-func (x *Nat) setBig(n *big.Int) *Nat {
-       limbs := n.Bits()
-       x.reset(len(limbs))
-       for i := range limbs {
-               x.limbs[i] = uint(limbs[i])
+func (x *Nat) resetToBytes(b []byte) *Nat {
+       x.reset((len(b) + _S - 1) / _S)
+       if err := x.setBytes(b); err != nil {
+               panic("bigmod: internal error: bad arithmetic")
+       }
+       // Trim most significant (trailing in little-endian) zero limbs.
+       // We assume comparison with zero (but not the branch) is constant time.
+       for i := len(x.limbs) - 1; i >= 0; i-- {
+               if x.limbs[i] != 0 {
+                       break
+               }
+               x.limbs = x.limbs[:i]
        }
        return x
 }
 
+// set assigns x = y, optionally resizing x to the appropriate size.
+func (x *Nat) set(y *Nat) *Nat {
+       x.reset(len(y.limbs))
+       copy(x.limbs, y.limbs)
+       return x
+}
+
 // Bytes returns x as a zero-extended big-endian byte slice. The size of the
 // slice will match the size of m.
 //
@@ -140,7 +147,8 @@ func (x *Nat) Bytes(m *Modulus) []byte {
 //
 // The output will be resized to the size of m and overwritten.
 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
-       if err := x.setBytes(b, m); err != nil {
+       x.resetFor(m)
+       if err := x.setBytes(b); err != nil {
                return nil, err
        }
        if x.cmpGeq(m.nat) == yes {
@@ -155,7 +163,8 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
 //
 // The output will be resized to the size of m and overwritten.
 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
-       if err := x.setBytes(b, m); err != nil {
+       x.resetFor(m)
+       if err := x.setBytes(b); err != nil {
                return nil, err
        }
        leading := _W - bitLen(x.limbs[len(x.limbs)-1])
@@ -175,8 +184,7 @@ func bigEndianUint(buf []byte) uint {
        return uint(byteorder.BeUint32(buf))
 }
 
-func (x *Nat) setBytes(b []byte, m *Modulus) error {
-       x.resetFor(m)
+func (x *Nat) setBytes(b []byte) error {
        i, k := len(b), 0
        for k < len(x.limbs) && i >= _S {
                x.limbs[k] = bigEndianUint(b[i-_S : i])
@@ -369,18 +377,16 @@ func minusInverseModW(x uint) uint {
        return -y
 }
 
-// NewModulusFromBig creates a new Modulus from a [big.Int].
+// NewModulus creates a new Modulus from a slice of big-endian bytes.
 //
-// The Int must be odd. The number of significant bits (and nothing else) is
+// The value must be odd. The number of significant bits (and nothing else) is
 // leaked through timing side-channels.
-func NewModulusFromBig(n *big.Int) (*Modulus, error) {
-       if b := n.Bits(); len(b) == 0 {
-               return nil, errors.New("modulus must be >= 0")
-       } else if b[0]&1 != 1 {
-               return nil, errors.New("modulus must be odd")
+func NewModulus(b []byte) (*Modulus, error) {
+       if len(b) == 0 || b[len(b)-1]&1 != 1 {
+               return nil, errors.New("modulus must be > 0 and odd")
        }
        m := &Modulus{}
-       m.nat = NewNat().setBig(n)
+       m.nat = NewNat().resetToBytes(b)
        m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
        m.m0inv = minusInverseModW(m.nat.limbs[0])
        m.rr = rr(m)
index 79b143ab0202561f454e878732d115e7ca6fc32f..2b1c22ddf029defe20d50bc6e8c5815df31bb60a 100644 (file)
@@ -16,6 +16,19 @@ import (
        "testing/quick"
 )
 
+// setBig assigns x = n, optionally resizing n to the appropriate size.
+//
+// The announced length of x is set based on the actual bit size of the input,
+// ignoring leading zeroes.
+func (x *Nat) setBig(n *big.Int) *Nat {
+       limbs := n.Bits()
+       x.reset(len(limbs))
+       for i := range limbs {
+               x.limbs[i] = uint(limbs[i])
+       }
+       return x
+}
+
 func (n *Nat) String() string {
        var limbs []string
        for i := range n.limbs {
@@ -71,7 +84,7 @@ func TestMontgomeryRoundtrip(t *testing.T) {
                one.limbs[0] = 1
                aPlusOne := new(big.Int).SetBytes(natBytes(a))
                aPlusOne.Add(aPlusOne, big.NewInt(1))
-               m, _ := NewModulusFromBig(aPlusOne)
+               m, _ := NewModulus(aPlusOne.Bytes())
                monty := new(Nat).set(a)
                monty.montgomeryRepresentation(m)
                aAgain := new(Nat).set(monty)
@@ -320,7 +333,7 @@ func TestMulReductions(t *testing.T) {
        b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
        n := new(big.Int).Mul(a, b)
 
-       N, _ := NewModulusFromBig(n)
+       N, _ := NewModulus(n.Bytes())
        A := NewNat().setBig(a).ExpandFor(N)
        B := NewNat().setBig(b).ExpandFor(N)
 
@@ -329,7 +342,7 @@ func TestMulReductions(t *testing.T) {
        }
 
        i := new(big.Int).ModInverse(a, b)
-       N, _ = NewModulusFromBig(b)
+       N, _ = NewModulus(b.Bytes())
        A = NewNat().setBig(a).ExpandFor(N)
        I := NewNat().setBig(i).ExpandFor(N)
        one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
@@ -351,7 +364,7 @@ func natFromBytes(b []byte) *Nat {
 
 func modulusFromBytes(b []byte) *Modulus {
        bb := new(big.Int).SetBytes(b)
-       m, _ := NewModulusFromBig(bb)
+       m, _ := NewModulus(bb.Bytes())
        return m
 }
 
@@ -360,7 +373,7 @@ func maxModulus(n uint) *Modulus {
        b := big.NewInt(1)
        b.Lsh(b, n*_W)
        b.Sub(b, big.NewInt(1))
-       m, _ := NewModulusFromBig(b)
+       m, _ := NewModulus(b.Bytes())
        return m
 }
 
@@ -466,17 +479,23 @@ func BenchmarkExp(b *testing.B) {
        }
 }
 
-func TestNewModFromBigZero(t *testing.T) {
-       expected := "modulus must be >= 0"
-       _, err := NewModulusFromBig(big.NewInt(0))
+func TestNewModulus(t *testing.T) {
+       expected := "modulus must be > 0 and odd"
+       _, err := NewModulus([]byte{})
        if err == nil || err.Error() != expected {
-               t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
+               t.Errorf("NewModulus(0) got %q, want %q", err, expected)
        }
-
-       expected = "modulus must be odd"
-       _, err = NewModulusFromBig(big.NewInt(2))
+       _, err = NewModulus([]byte{0})
+       if err == nil || err.Error() != expected {
+               t.Errorf("NewModulus(0) got %q, want %q", err, expected)
+       }
+       _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
+       if err == nil || err.Error() != expected {
+               t.Errorf("NewModulus(0) got %q, want %q", err, expected)
+       }
+       _, err = NewModulus([]byte{1, 1, 1, 1, 2})
        if err == nil || err.Error() != expected {
-               t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
+               t.Errorf("NewModulus(2) got %q, want %q", err, expected)
        }
 }
 
index 4d78d1eaaa6be014c8261123322314abc3556b19..3764e0212781804ef4b9876a63277eff0d83fb35 100644 (file)
@@ -316,15 +316,15 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey
                        return nil, errors.New("crypto/rsa: generated key exponent too large")
                }
 
-               mn, err := bigmod.NewModulusFromBig(N)
+               mn, err := bigmod.NewModulus(N.Bytes())
                if err != nil {
                        return nil, err
                }
-               mp, err := bigmod.NewModulusFromBig(P)
+               mp, err := bigmod.NewModulus(P.Bytes())
                if err != nil {
                        return nil, err
                }
-               mq, err := bigmod.NewModulusFromBig(Q)
+               mq, err := bigmod.NewModulus(Q.Bytes())
                if err != nil {
                        return nil, err
                }
@@ -481,7 +481,7 @@ var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key siz
 func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) {
        boring.Unreachable()
 
-       N, err := bigmod.NewModulusFromBig(pub.N)
+       N, err := bigmod.NewModulus(pub.N.Bytes())
        if err != nil {
                return nil, err
        }
@@ -584,17 +584,17 @@ func (priv *PrivateKey) Precompute() {
                // Precomputed values _should_ always be valid, but if they aren't
                // just return. We could also panic.
                var err error
-               priv.Precomputed.n, err = bigmod.NewModulusFromBig(priv.N)
+               priv.Precomputed.n, err = bigmod.NewModulus(priv.N.Bytes())
                if err != nil {
                        return
                }
-               priv.Precomputed.p, err = bigmod.NewModulusFromBig(priv.Primes[0])
+               priv.Precomputed.p, err = bigmod.NewModulus(priv.Primes[0].Bytes())
                if err != nil {
                        // Unset previous values, so we either have everything or nothing
                        priv.Precomputed.n = nil
                        return
                }
-               priv.Precomputed.q, err = bigmod.NewModulusFromBig(priv.Primes[1])
+               priv.Precomputed.q, err = bigmod.NewModulus(priv.Primes[1].Bytes())
                if err != nil {
                        // Unset previous values, so we either have everything or nothing
                        priv.Precomputed.n, priv.Precomputed.p = nil, nil
@@ -649,7 +649,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
                t0   = bigmod.NewNat()
        )
        if priv.Precomputed.n == nil {
-               N, err = bigmod.NewModulusFromBig(priv.N)
+               N, err = bigmod.NewModulus(priv.N.Bytes())
                if err != nil {
                        return nil, ErrDecryption
                }