// Modulus is used for modular arithmetic, precomputing relevant constants.
//
-// Moduli are assumed to be odd numbers. Moduli can also leak the exact
-// number of bits needed to store their value, and are stored without padding.
-//
-// Their actual value is still kept secret.
+// A Modulus can leak the exact number of bits needed to store its value
+// and is stored without padding. Its actual value is still kept secret.
type Modulus struct {
// The underlying natural number for this modulus.
//
// This will be stored without any padding, and shouldn't alias with any
// other natural number being used.
nat *Nat
- leading int // number of leading zeros in the modulus
- m0inv uint // -nat.limbs[0]⁻¹ mod _W
- rr *Nat // R*R for montgomeryRepresentation
+ leading int // number of leading zeros in the modulus
+
+ // If m is even, the following fields are not set.
+ odd bool
+ m0inv uint // -nat.limbs[0]⁻¹ mod _W
+ rr *Nat // R*R for montgomeryRepresentation
}
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
// NewModulus creates a new Modulus from a slice of big-endian bytes.
//
-// The value must be odd. The number of significant bits (and nothing else) is
-// leaked through timing side-channels.
+// The number of significant bits and whether the modulus is even is leaked
+// through timing side-channels.
func NewModulus(b []byte) (*Modulus, error) {
- if len(b) == 0 || b[len(b)-1]&1 != 1 {
- return nil, errors.New("modulus must be > 0 and odd")
- }
m := &Modulus{}
m.nat = NewNat().resetToBytes(b)
+ if len(m.nat.limbs) == 0 {
+ return nil, errors.New("modulus must be > 0")
+ }
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
- m.m0inv = minusInverseModW(m.nat.limbs[0])
- m.rr = rr(m)
+ if m.nat.limbs[0]&1 == 1 {
+ m.odd = true
+ m.m0inv = minusInverseModW(m.nat.limbs[0])
+ m.rr = rr(m)
+ }
return m, nil
}
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
- // A Montgomery multiplication by a value out of the Montgomery domain
- // takes the result out of Montgomery representation.
- xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
- return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
+ if m.odd {
+ // A Montgomery multiplication by a value out of the Montgomery domain
+ // takes the result out of Montgomery representation.
+ xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+ return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
+ }
+
+ n := len(m.nat.limbs)
+ xLimbs := x.limbs[:n]
+ yLimbs := y.limbs[:n]
+
+ switch n {
+ default:
+ // Attempt to use a stack-allocated backing array.
+ T := make([]uint, 0, preallocLimbs*2)
+ if cap(T) < n*2 {
+ T = make([]uint, 0, n*2)
+ }
+ T = T[:n*2]
+
+ // T = x * y
+ for i := 0; i < n; i++ {
+ T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i])
+ }
+
+ // x = T mod m
+ return x.Mod(&Nat{limbs: T}, m)
+
+ // The following specialized cases follow the exact same algorithm, but
+ // optimized for the sizes most used in RSA. See montgomeryMul for details.
+ case 1024 / _W:
+ const n = 1024 / _W // compiler hint
+ T := make([]uint, n*2)
+ for i := 0; i < n; i++ {
+ T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i])
+ }
+ return x.Mod(&Nat{limbs: T}, m)
+ case 1536 / _W:
+ const n = 1536 / _W // compiler hint
+ T := make([]uint, n*2)
+ for i := 0; i < n; i++ {
+ T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i])
+ }
+ return x.Mod(&Nat{limbs: T}, m)
+ case 2048 / _W:
+ const n = 2048 / _W // compiler hint
+ T := make([]uint, n*2)
+ for i := 0; i < n; i++ {
+ T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i])
+ }
+ return x.Mod(&Nat{limbs: T}, m)
+ }
}
// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
// to the size of m and overwritten. x must already be reduced modulo m.
+//
+// m must be odd, or Exp will panic.
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
+ if !m.odd {
+ panic("bigmod: modulus for Exp must be odd")
+ }
+
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement, but
//
// The output will be resized to the size of m and overwritten. x must already
// be reduced modulo m. This leaks the exponent through timing side-channels.
+//
+// m must be odd, or ExpShortVarTime will panic.
func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
+ if !m.odd {
+ panic("bigmod: modulus for ExpShortVarTime must be odd")
+ }
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
// chain, skipping the initial run of zeroes.
package bigmod
import (
+ "bytes"
+ cryptorand "crypto/rand"
"fmt"
"math/big"
"math/bits"
}
}
+func TestMul(t *testing.T) {
+ t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
+ t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
+ t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
+ t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
+}
+
+func testMul(t *testing.T, n int) {
+ a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
+ cryptorand.Read(a)
+ cryptorand.Read(b)
+ cryptorand.Read(m)
+
+ // Pick the highest as the modulus.
+ if bytes.Compare(a, m) > 0 {
+ a, m = m, a
+ }
+ if bytes.Compare(b, m) > 0 {
+ b, m = m, b
+ }
+
+ M, err := NewModulus(m)
+ if err != nil {
+ t.Fatal(err)
+ }
+ A, err := NewNat().SetBytes(a, M)
+ if err != nil {
+ t.Fatal(err)
+ }
+ B, err := NewNat().SetBytes(b, M)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ A.Mul(B, M)
+ ABytes := A.Bytes(M)
+
+ mBig := new(big.Int).SetBytes(m)
+ aBig := new(big.Int).SetBytes(a)
+ bBig := new(big.Int).SetBytes(b)
+ nBig := new(big.Int).Mul(aBig, bBig)
+ nBig.Mod(nBig, mBig)
+ nBigBytes := make([]byte, len(ABytes))
+ nBig.FillBytes(nBigBytes)
+
+ if !bytes.Equal(ABytes, nBigBytes) {
+ t.Errorf("got %x, want %x", ABytes, nBigBytes)
+ }
+}
+
func natBytes(n *Nat) []byte {
return n.Bytes(maxModulus(uint(len(n.limbs))))
}
}
func TestNewModulus(t *testing.T) {
- expected := "modulus must be > 0 and odd"
+ expected := "modulus must be > 0"
_, err := NewModulus([]byte{})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
}
- _, err = NewModulus([]byte{1, 1, 1, 1, 2})
- if err == nil || err.Error() != expected {
- t.Errorf("NewModulus(2) got %q, want %q", err, expected)
- }
}
func makeTestValue(nbits int) []uint {