From: Filippo Valsorda Date: Sat, 16 Nov 2024 13:10:02 +0000 (+0100) Subject: crypto/internal/bigmod: drop math/big dependency X-Git-Tag: go1.24rc1~223 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=e24eb3ab36d0bd0772d650e4c7f5b2dd261d7970;p=gostls13.git crypto/internal/bigmod: drop math/big dependency If when the dust settles the Bytes and SetBytes round-trip is visible in profiles (only plausible in RSA), then we can add a SetBits method like in CL 511375. Change-Id: I3e6677e849d7a3786fa7297437b119a47715225f Reviewed-on: https://go-review.googlesource.com/c/go/+/628675 LUCI-TryBot-Result: Go LUCI Auto-Submit: Filippo Valsorda Reviewed-by: Dmitri Shuralyov Reviewed-by: Russ Cox --- diff --git a/src/crypto/ecdsa/ecdsa.go b/src/crypto/ecdsa/ecdsa.go index 95a4b4be69..45215abed0 100644 --- a/src/crypto/ecdsa/ecdsa.go +++ b/src/crypto/ecdsa/ecdsa.go @@ -669,7 +669,7 @@ func precomputeParams[Point nistPoint[Point]](c *nistCurve[Point], curve ellipti params := curve.Params() c.curve = curve var err error - c.N, err = bigmod.NewModulusFromBig(params.N) + c.N, err = bigmod.NewModulus(params.N.Bytes()) if err != nil { panic(err) } diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go index 5cbae40efe..7bd09b37ac 100644 --- a/src/crypto/internal/bigmod/nat.go +++ b/src/crypto/internal/bigmod/nat.go @@ -7,7 +7,6 @@ package bigmod import ( "errors" "internal/byteorder" - "math/big" "math/bits" ) @@ -92,26 +91,34 @@ func (x *Nat) reset(n int) *Nat { return x } -// set assigns x = y, optionally resizing x to the appropriate size. -func (x *Nat) set(y *Nat) *Nat { - x.reset(len(y.limbs)) - copy(x.limbs, y.limbs) - return x -} - -// setBig assigns x = n, optionally resizing n to the appropriate size. +// resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing +// n to the appropriate size. // // The announced length of x is set based on the actual bit size of the input, // ignoring leading zeroes. -func (x *Nat) setBig(n *big.Int) *Nat { - limbs := n.Bits() - x.reset(len(limbs)) - for i := range limbs { - x.limbs[i] = uint(limbs[i]) +func (x *Nat) resetToBytes(b []byte) *Nat { + x.reset((len(b) + _S - 1) / _S) + if err := x.setBytes(b); err != nil { + panic("bigmod: internal error: bad arithmetic") + } + // Trim most significant (trailing in little-endian) zero limbs. + // We assume comparison with zero (but not the branch) is constant time. + for i := len(x.limbs) - 1; i >= 0; i-- { + if x.limbs[i] != 0 { + break + } + x.limbs = x.limbs[:i] } return x } +// set assigns x = y, optionally resizing x to the appropriate size. +func (x *Nat) set(y *Nat) *Nat { + x.reset(len(y.limbs)) + copy(x.limbs, y.limbs) + return x +} + // Bytes returns x as a zero-extended big-endian byte slice. The size of the // slice will match the size of m. // @@ -140,7 +147,8 @@ func (x *Nat) Bytes(m *Modulus) []byte { // // The output will be resized to the size of m and overwritten. func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { - if err := x.setBytes(b, m); err != nil { + x.resetFor(m) + if err := x.setBytes(b); err != nil { return nil, err } if x.cmpGeq(m.nat) == yes { @@ -155,7 +163,8 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { // // The output will be resized to the size of m and overwritten. func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { - if err := x.setBytes(b, m); err != nil { + x.resetFor(m) + if err := x.setBytes(b); err != nil { return nil, err } leading := _W - bitLen(x.limbs[len(x.limbs)-1]) @@ -175,8 +184,7 @@ func bigEndianUint(buf []byte) uint { return uint(byteorder.BeUint32(buf)) } -func (x *Nat) setBytes(b []byte, m *Modulus) error { - x.resetFor(m) +func (x *Nat) setBytes(b []byte) error { i, k := len(b), 0 for k < len(x.limbs) && i >= _S { x.limbs[k] = bigEndianUint(b[i-_S : i]) @@ -369,18 +377,16 @@ func minusInverseModW(x uint) uint { return -y } -// NewModulusFromBig creates a new Modulus from a [big.Int]. +// NewModulus creates a new Modulus from a slice of big-endian bytes. // -// The Int must be odd. The number of significant bits (and nothing else) is +// The value must be odd. The number of significant bits (and nothing else) is // leaked through timing side-channels. -func NewModulusFromBig(n *big.Int) (*Modulus, error) { - if b := n.Bits(); len(b) == 0 { - return nil, errors.New("modulus must be >= 0") - } else if b[0]&1 != 1 { - return nil, errors.New("modulus must be odd") +func NewModulus(b []byte) (*Modulus, error) { + if len(b) == 0 || b[len(b)-1]&1 != 1 { + return nil, errors.New("modulus must be > 0 and odd") } m := &Modulus{} - m.nat = NewNat().setBig(n) + m.nat = NewNat().resetToBytes(b) m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) m.m0inv = minusInverseModW(m.nat.limbs[0]) m.rr = rr(m) diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go index 79b143ab02..2b1c22ddf0 100644 --- a/src/crypto/internal/bigmod/nat_test.go +++ b/src/crypto/internal/bigmod/nat_test.go @@ -16,6 +16,19 @@ import ( "testing/quick" ) +// setBig assigns x = n, optionally resizing n to the appropriate size. +// +// The announced length of x is set based on the actual bit size of the input, +// ignoring leading zeroes. +func (x *Nat) setBig(n *big.Int) *Nat { + limbs := n.Bits() + x.reset(len(limbs)) + for i := range limbs { + x.limbs[i] = uint(limbs[i]) + } + return x +} + func (n *Nat) String() string { var limbs []string for i := range n.limbs { @@ -71,7 +84,7 @@ func TestMontgomeryRoundtrip(t *testing.T) { one.limbs[0] = 1 aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne.Add(aPlusOne, big.NewInt(1)) - m, _ := NewModulusFromBig(aPlusOne) + m, _ := NewModulus(aPlusOne.Bytes()) monty := new(Nat).set(a) monty.montgomeryRepresentation(m) aAgain := new(Nat).set(monty) @@ -320,7 +333,7 @@ func TestMulReductions(t *testing.T) { b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) n := new(big.Int).Mul(a, b) - N, _ := NewModulusFromBig(n) + N, _ := NewModulus(n.Bytes()) A := NewNat().setBig(a).ExpandFor(N) B := NewNat().setBig(b).ExpandFor(N) @@ -329,7 +342,7 @@ func TestMulReductions(t *testing.T) { } i := new(big.Int).ModInverse(a, b) - N, _ = NewModulusFromBig(b) + N, _ = NewModulus(b.Bytes()) A = NewNat().setBig(a).ExpandFor(N) I := NewNat().setBig(i).ExpandFor(N) one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) @@ -351,7 +364,7 @@ func natFromBytes(b []byte) *Nat { func modulusFromBytes(b []byte) *Modulus { bb := new(big.Int).SetBytes(b) - m, _ := NewModulusFromBig(bb) + m, _ := NewModulus(bb.Bytes()) return m } @@ -360,7 +373,7 @@ func maxModulus(n uint) *Modulus { b := big.NewInt(1) b.Lsh(b, n*_W) b.Sub(b, big.NewInt(1)) - m, _ := NewModulusFromBig(b) + m, _ := NewModulus(b.Bytes()) return m } @@ -466,17 +479,23 @@ func BenchmarkExp(b *testing.B) { } } -func TestNewModFromBigZero(t *testing.T) { - expected := "modulus must be >= 0" - _, err := NewModulusFromBig(big.NewInt(0)) +func TestNewModulus(t *testing.T) { + expected := "modulus must be > 0 and odd" + _, err := NewModulus([]byte{}) if err == nil || err.Error() != expected { - t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) + t.Errorf("NewModulus(0) got %q, want %q", err, expected) } - - expected = "modulus must be odd" - _, err = NewModulusFromBig(big.NewInt(2)) + _, err = NewModulus([]byte{0}) + if err == nil || err.Error() != expected { + t.Errorf("NewModulus(0) got %q, want %q", err, expected) + } + _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + if err == nil || err.Error() != expected { + t.Errorf("NewModulus(0) got %q, want %q", err, expected) + } + _, err = NewModulus([]byte{1, 1, 1, 1, 2}) if err == nil || err.Error() != expected { - t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) + t.Errorf("NewModulus(2) got %q, want %q", err, expected) } } diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 4d78d1eaaa..3764e02127 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -316,15 +316,15 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey return nil, errors.New("crypto/rsa: generated key exponent too large") } - mn, err := bigmod.NewModulusFromBig(N) + mn, err := bigmod.NewModulus(N.Bytes()) if err != nil { return nil, err } - mp, err := bigmod.NewModulusFromBig(P) + mp, err := bigmod.NewModulus(P.Bytes()) if err != nil { return nil, err } - mq, err := bigmod.NewModulusFromBig(Q) + mq, err := bigmod.NewModulus(Q.Bytes()) if err != nil { return nil, err } @@ -481,7 +481,7 @@ var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key siz func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) { boring.Unreachable() - N, err := bigmod.NewModulusFromBig(pub.N) + N, err := bigmod.NewModulus(pub.N.Bytes()) if err != nil { return nil, err } @@ -584,17 +584,17 @@ func (priv *PrivateKey) Precompute() { // Precomputed values _should_ always be valid, but if they aren't // just return. We could also panic. var err error - priv.Precomputed.n, err = bigmod.NewModulusFromBig(priv.N) + priv.Precomputed.n, err = bigmod.NewModulus(priv.N.Bytes()) if err != nil { return } - priv.Precomputed.p, err = bigmod.NewModulusFromBig(priv.Primes[0]) + priv.Precomputed.p, err = bigmod.NewModulus(priv.Primes[0].Bytes()) if err != nil { // Unset previous values, so we either have everything or nothing priv.Precomputed.n = nil return } - priv.Precomputed.q, err = bigmod.NewModulusFromBig(priv.Primes[1]) + priv.Precomputed.q, err = bigmod.NewModulus(priv.Primes[1].Bytes()) if err != nil { // Unset previous values, so we either have everything or nothing priv.Precomputed.n, priv.Precomputed.p = nil, nil @@ -649,7 +649,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { t0 = bigmod.NewNat() ) if priv.Precomputed.n == nil { - N, err = bigmod.NewModulusFromBig(priv.N) + N, err = bigmod.NewModulus(priv.N.Bytes()) if err != nil { return nil, ErrDecryption }