]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/fips/bigmod: add support for even moduli
authorFilippo Valsorda <filippo@golang.org>
Tue, 19 Nov 2024 11:57:55 +0000 (12:57 +0100)
committerGopher Robot <gobot@golang.org>
Fri, 22 Nov 2024 01:50:35 +0000 (01:50 +0000)
It doesn't need to be fast because we will only use it for RSA key
generation / precomputation / validation.

Change-Id: If4f5d0d4ac350939b69561e75dec5791db77f68c
Reviewed-on: https://go-review.googlesource.com/c/go/+/630515
Reviewed-by: Russ Cox <rsc@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>

src/crypto/internal/fips140/bigmod/nat.go
src/crypto/internal/fips140/bigmod/nat_test.go

index 0a305b4ce6e72d36dd8f6fc16b835291315ff1d9..0a95928536380d0d853a674a73308e4db40a7706 100644 (file)
@@ -297,19 +297,20 @@ func (x *Nat) sub(y *Nat) (c uint) {
 
 // Modulus is used for modular arithmetic, precomputing relevant constants.
 //
-// Moduli are assumed to be odd numbers. Moduli can also leak the exact
-// number of bits needed to store their value, and are stored without padding.
-//
-// Their actual value is still kept secret.
+// A Modulus can leak the exact number of bits needed to store its value
+// and is stored without padding. Its actual value is still kept secret.
 type Modulus struct {
        // The underlying natural number for this modulus.
        //
        // This will be stored without any padding, and shouldn't alias with any
        // other natural number being used.
        nat     *Nat
-       leading int  // number of leading zeros in the modulus
-       m0inv   uint // -nat.limbs[0]⁻¹ mod _W
-       rr      *Nat // R*R for montgomeryRepresentation
+       leading int // number of leading zeros in the modulus
+
+       // If m is even, the following fields are not set.
+       odd   bool
+       m0inv uint // -nat.limbs[0]⁻¹ mod _W
+       rr    *Nat // R*R for montgomeryRepresentation
 }
 
 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
@@ -380,17 +381,20 @@ func minusInverseModW(x uint) uint {
 
 // NewModulus creates a new Modulus from a slice of big-endian bytes.
 //
-// The value must be odd. The number of significant bits (and nothing else) is
-// leaked through timing side-channels.
+// The number of significant bits and whether the modulus is even is leaked
+// through timing side-channels.
 func NewModulus(b []byte) (*Modulus, error) {
-       if len(b) == 0 || b[len(b)-1]&1 != 1 {
-               return nil, errors.New("modulus must be > 0 and odd")
-       }
        m := &Modulus{}
        m.nat = NewNat().resetToBytes(b)
+       if len(m.nat.limbs) == 0 {
+               return nil, errors.New("modulus must be > 0")
+       }
        m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
-       m.m0inv = minusInverseModW(m.nat.limbs[0])
-       m.rr = rr(m)
+       if m.nat.limbs[0]&1 == 1 {
+               m.odd = true
+               m.m0inv = minusInverseModW(m.nat.limbs[0])
+               m.rr = rr(m)
+       }
        return m, nil
 }
 
@@ -719,17 +723,71 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
 // The length of both operands must be the same as the modulus. Both operands
 // must already be reduced modulo m.
 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
-       // A Montgomery multiplication by a value out of the Montgomery domain
-       // takes the result out of Montgomery representation.
-       xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
-       return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
+       if m.odd {
+               // A Montgomery multiplication by a value out of the Montgomery domain
+               // takes the result out of Montgomery representation.
+               xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+               return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
+       }
+
+       n := len(m.nat.limbs)
+       xLimbs := x.limbs[:n]
+       yLimbs := y.limbs[:n]
+
+       switch n {
+       default:
+               // Attempt to use a stack-allocated backing array.
+               T := make([]uint, 0, preallocLimbs*2)
+               if cap(T) < n*2 {
+                       T = make([]uint, 0, n*2)
+               }
+               T = T[:n*2]
+
+               // T = x * y
+               for i := 0; i < n; i++ {
+                       T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i])
+               }
+
+               // x = T mod m
+               return x.Mod(&Nat{limbs: T}, m)
+
+       // The following specialized cases follow the exact same algorithm, but
+       // optimized for the sizes most used in RSA. See montgomeryMul for details.
+       case 1024 / _W:
+               const n = 1024 / _W // compiler hint
+               T := make([]uint, n*2)
+               for i := 0; i < n; i++ {
+                       T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i])
+               }
+               return x.Mod(&Nat{limbs: T}, m)
+       case 1536 / _W:
+               const n = 1536 / _W // compiler hint
+               T := make([]uint, n*2)
+               for i := 0; i < n; i++ {
+                       T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i])
+               }
+               return x.Mod(&Nat{limbs: T}, m)
+       case 2048 / _W:
+               const n = 2048 / _W // compiler hint
+               T := make([]uint, n*2)
+               for i := 0; i < n; i++ {
+                       T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i])
+               }
+               return x.Mod(&Nat{limbs: T}, m)
+       }
 }
 
 // Exp calculates out = x^e mod m.
 //
 // The exponent e is represented in big-endian order. The output will be resized
 // to the size of m and overwritten. x must already be reduced modulo m.
+//
+// m must be odd, or Exp will panic.
 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
+       if !m.odd {
+               panic("bigmod: modulus for Exp must be odd")
+       }
+
        // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
        // than 2 bit windows, but use an extra 12 nats worth of scratch space.
        // Using bit sizes that don't divide 8 are more complex to implement, but
@@ -778,7 +836,12 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
 //
 // The output will be resized to the size of m and overwritten. x must already
 // be reduced modulo m. This leaks the exponent through timing side-channels.
+//
+// m must be odd, or ExpShortVarTime will panic.
 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
+       if !m.odd {
+               panic("bigmod: modulus for ExpShortVarTime must be odd")
+       }
        // For short exponents, precomputing a table and using a window like in Exp
        // doesn't pay off. Instead, we do a simple conditional square-and-multiply
        // chain, skipping the initial run of zeroes.
index 2b1c22ddf029defe20d50bc6e8c5815df31bb60a..6ee0dd48da20ddc0c338de7a6e07a561ba70e3fa 100644 (file)
@@ -5,6 +5,8 @@
 package bigmod
 
 import (
+       "bytes"
+       cryptorand "crypto/rand"
        "fmt"
        "math/big"
        "math/bits"
@@ -352,6 +354,56 @@ func TestMulReductions(t *testing.T) {
        }
 }
 
+func TestMul(t *testing.T) {
+       t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
+       t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
+       t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
+       t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
+}
+
+func testMul(t *testing.T, n int) {
+       a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
+       cryptorand.Read(a)
+       cryptorand.Read(b)
+       cryptorand.Read(m)
+
+       // Pick the highest as the modulus.
+       if bytes.Compare(a, m) > 0 {
+               a, m = m, a
+       }
+       if bytes.Compare(b, m) > 0 {
+               b, m = m, b
+       }
+
+       M, err := NewModulus(m)
+       if err != nil {
+               t.Fatal(err)
+       }
+       A, err := NewNat().SetBytes(a, M)
+       if err != nil {
+               t.Fatal(err)
+       }
+       B, err := NewNat().SetBytes(b, M)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       A.Mul(B, M)
+       ABytes := A.Bytes(M)
+
+       mBig := new(big.Int).SetBytes(m)
+       aBig := new(big.Int).SetBytes(a)
+       bBig := new(big.Int).SetBytes(b)
+       nBig := new(big.Int).Mul(aBig, bBig)
+       nBig.Mod(nBig, mBig)
+       nBigBytes := make([]byte, len(ABytes))
+       nBig.FillBytes(nBigBytes)
+
+       if !bytes.Equal(ABytes, nBigBytes) {
+               t.Errorf("got %x, want %x", ABytes, nBigBytes)
+       }
+}
+
 func natBytes(n *Nat) []byte {
        return n.Bytes(maxModulus(uint(len(n.limbs))))
 }
@@ -480,7 +532,7 @@ func BenchmarkExp(b *testing.B) {
 }
 
 func TestNewModulus(t *testing.T) {
-       expected := "modulus must be > 0 and odd"
+       expected := "modulus must be > 0"
        _, err := NewModulus([]byte{})
        if err == nil || err.Error() != expected {
                t.Errorf("NewModulus(0) got %q, want %q", err, expected)
@@ -493,10 +545,6 @@ func TestNewModulus(t *testing.T) {
        if err == nil || err.Error() != expected {
                t.Errorf("NewModulus(0) got %q, want %q", err, expected)
        }
-       _, err = NewModulus([]byte{1, 1, 1, 1, 2})
-       if err == nil || err.Error() != expected {
-               t.Errorf("NewModulus(2) got %q, want %q", err, expected)
-       }
 }
 
 func makeTestValue(nbits int) []uint {