natTen = nat{10}
)
+func (z nat) String() string {
+ return "0x" + string(z.itoa(false, 16))
+}
+
func (z nat) clear() {
for i := range z {
z[i] = 0
z = new(nat)
}
*z = z.make(n)
+ if n > 0 {
+ (*z)[0] = 0xfedcb // break code expecting zero
+ }
return z
}
var natPool sync.Pool
-// Length of x in bits. x must be normalized.
+// bitLen returns the length of x in bits.
+// Unlike most methods, it works even if x is not normalized.
func (x nat) bitLen() int {
if i := len(x) - 1; i >= 0 {
return i*_W + bits.Len(uint(x[i]))
return i*_W + uint(bits.TrailingZeros(uint(x[i])))
}
+// isPow2 returns i, true when x == 2**i and 0, false otherwise.
+func (x nat) isPow2() (uint, bool) {
+ var i uint
+ for x[i] == 0 {
+ i++
+ }
+ if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 {
+ return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true
+ }
+ return 0, false
+}
+
func same(x, y nat) bool {
return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0]
}
return z.norm()
}
+// trunc returns z = x mod 2ⁿ.
+func (z nat) trunc(x nat, n uint) nat {
+ w := (n + _W - 1) / _W
+ if uint(len(x)) < w {
+ return z.set(x)
+ }
+ z = z.make(int(w))
+ copy(z, x)
+ if n%_W != 0 {
+ z[len(z)-1] &= 1<<(n%_W) - 1
+ }
+ return z.norm()
+}
+
func (z nat) andNot(x, y nat) nat {
m := len(x)
n := len(y)
// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
// otherwise it sets z to x**y. The result is the value of z.
-func (z nat) expNN(x, y, m nat) nat {
+func (z nat) expNN(x, y, m nat, slow bool) nat {
if alias(z, x) || alias(z, y) {
// We cannot allow in-place modification of x or y.
z = nil
}
// y > 0
- // x**1 mod m == x mod m
- if len(y) == 1 && y[0] == 1 && len(m) != 0 {
- _, z = nat(nil).div(z, x, m)
- return z
+ // 0**y = 0
+ if len(x) == 0 {
+ return z.setWord(0)
+ }
+ // x > 0
+
+ // 1**y = 1
+ if len(x) == 1 && x[0] == 1 {
+ return z.setWord(1)
+ }
+ // x > 1
+
+ // x**1 == x
+ if len(y) == 1 && y[0] == 1 {
+ if len(m) != 0 {
+ return z.rem(x, m)
+ }
+ return z.set(x)
}
// y > 1
if len(m) != 0 {
// We likely end up being as long as the modulus.
z = z.make(len(m))
- }
- z = z.set(x)
- // If the base is non-trivial and the exponent is large, we use
- // 4-bit, windowed exponentiation. This involves precomputing 14 values
- // (x^2...x^15) but then reduces the number of multiply-reduces by a
- // third. Even for a 32-bit exponent, this reduces the number of
- // operations. Uses Montgomery method for odd moduli.
- if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 {
- if m[0]&1 == 1 {
- return z.expNNMontgomery(x, y, m)
+ // If the exponent is large, we use the Montgomery method for odd values,
+ // and a 4-bit, windowed exponentiation for powers of two,
+ // and a CRT-decomposed Montgomery method for the remaining values
+ // (even values times non-trivial odd values, which decompose into one
+ // instance of each of the first two cases).
+ if len(y) > 1 && !slow {
+ if m[0]&1 == 1 {
+ return z.expNNMontgomery(x, y, m)
+ }
+ if logM, ok := m.isPow2(); ok {
+ return z.expNNWindowed(x, y, logM)
+ }
+ return z.expNNMontgomeryEven(x, y, m)
}
- return z.expNNWindowed(x, y, m)
}
+ z = z.set(x)
v := y[len(y)-1] // v > 0 because y is normalized and y > 0
shift := nlz(v) + 1
v <<= shift
return z.norm()
}
-// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
-func (z nat) expNNWindowed(x, y, m nat) nat {
- // zz and r are used to avoid allocating in mul and div as otherwise
+// expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd.
+// It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2
+// and then uses the Chinese Remainder Theorem to combine the results.
+// The recursive call using m1 will use expNNWindowed,
+// while the recursive call using m2 will use expNNMontgomery.
+// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
+// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
+// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
+func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
+ // Split m = m₁ × m₂ where m₁ = 2ⁿ
+ n := m.trailingZeroBits()
+ m1 := nat(nil).shl(natOne, n)
+ m2 := nat(nil).shr(m, n)
+
+ // We want z = x**y mod m.
+ // z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
+ // z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
+ // (We are using the math/big convention for names here,
+ // where the computation is z = x**y mod m, so its parts are z1 and z2.
+ // The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
+ z1 := nat(nil).expNN(x, y, m1, false)
+ z2 := nat(nil).expNN(x, y, m2, false)
+
+ // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
+ // which uses only a single modInverse (and an easy one at that).
+ // p = (z₁ - z₂) × m₂⁻¹ (mod m₁)
+ // z = z₂ + p × m₂
+ // The final addition is in range because:
+ // z = z₂ + p × m₂
+ // ≤ z₂ + (m₁-1) × m₂
+ // < m₂ + (m₁-1) × m₂
+ // = m₁ × m₂
+ // = m.
+ z = z.set(z2)
+
+ // Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1.
+ z1 = z1.subMod2N(z1, z2, n)
+
+ // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
+ m2inv := nat(nil).modInverse(m2, m1)
+ z2 = z2.mul(z1, m2inv)
+ z2 = z2.trunc(z2, n)
+
+ // Reuse z1 for p * m2.
+ z = z.add(z, z1.mul(z2, m2))
+
+ return z
+}
+
+// expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
+// where m = 2**logM.
+func (z nat) expNNWindowed(x, y nat, logM uint) nat {
+ if len(y) <= 1 {
+ panic("big: misuse of expNNWindowed")
+ }
+ if x[0]&1 == 0 {
+ // len(y) > 1, so y > logM.
+ // x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM.
+ return z.setWord(0)
+ }
+ if logM == 1 {
+ return z.setWord(1)
+ }
+
+ // zz is used to avoid allocating in mul as otherwise
// the arguments would alias.
- var zz, r nat
+ w := int((logM + _W - 1) / _W)
+ zzp := getNat(w)
+ zz := *zzp
const n = 4
// powers[i] contains x^i.
- var powers [1 << n]nat
- powers[0] = natOne
- powers[1] = x
+ var powers [1 << n]*nat
+ for i := range powers {
+ powers[i] = getNat(w)
+ }
+ *powers[0] = powers[0].set(natOne)
+ *powers[1] = powers[1].trunc(x, logM)
for i := 2; i < 1<<n; i += 2 {
- p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+ p2, p, p1 := powers[i/2], powers[i], powers[i+1]
*p = p.sqr(*p2)
- zz, r = zz.div(r, *p, m)
- *p, r = r, *p
+ *p = p.trunc(*p, logM)
*p1 = p1.mul(*p, x)
- zz, r = zz.div(r, *p1, m)
- *p1, r = r, *p1
+ *p1 = p1.trunc(*p1, logM)
}
+ // Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1,
+ // so we can compute x**(y mod 2**(logM-1)) instead of x**y.
+ // That is, we can throw away all but the bottom logM-1 bits of y.
+ // Instead of allocating a new y, we start reading y at the right word
+ // and truncate it appropriately at the start of the loop.
+ i := len(y) - 1
+ mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word.
+ mmask := ^Word(0)
+ if mbits := (logM - 1) & (_W - 1); mbits != 0 {
+ mmask = (1 << mbits) - 1
+ }
+ if i > mtop {
+ i = mtop
+ }
+ advance := false
z = z.setWord(1)
-
- for i := len(y) - 1; i >= 0; i-- {
+ for ; i >= 0; i-- {
yi := y[i]
+ if i == mtop {
+ yi &= mmask
+ }
for j := 0; j < _W; j += n {
- if i != len(y)-1 || j != 0 {
+ if advance {
+ // Account for use of 4 bits in previous iteration.
// Unrolled loop for significant performance
// gain. Use go test -bench=".*" in crypto/rsa
// to check performance before making changes.
zz = zz.sqr(z)
zz, z = z, zz
- zz, r = zz.div(r, z, m)
- z, r = r, z
+ z = z.trunc(z, logM)
zz = zz.sqr(z)
zz, z = z, zz
- zz, r = zz.div(r, z, m)
- z, r = r, z
+ z = z.trunc(z, logM)
zz = zz.sqr(z)
zz, z = z, zz
- zz, r = zz.div(r, z, m)
- z, r = r, z
+ z = z.trunc(z, logM)
zz = zz.sqr(z)
zz, z = z, zz
- zz, r = zz.div(r, z, m)
- z, r = r, z
+ z = z.trunc(z, logM)
}
- zz = zz.mul(z, powers[yi>>(_W-n)])
+ zz = zz.mul(z, *powers[yi>>(_W-n)])
zz, z = z, zz
- zz, r = zz.div(r, z, m)
- z, r = r, z
+ z = z.trunc(z, logM)
yi <<= n
+ advance = true
}
}
+ *zzp = zz
+ putNat(zzp)
+ for i := range powers {
+ putNat(powers[i])
+ }
+
return z.norm()
}
z1, z2 = z2, z1
}
}
+
+// subMod2N returns z = (x - y) mod 2ⁿ.
+func (z nat) subMod2N(x, y nat, n uint) nat {
+ if uint(x.bitLen()) > n {
+ if alias(z, x) {
+ // ok to overwrite x in place
+ x = x.trunc(x, n)
+ } else {
+ x = nat(nil).trunc(x, n)
+ }
+ }
+ if uint(y.bitLen()) > n {
+ if alias(z, y) {
+ // ok to overwrite y in place
+ y = y.trunc(y, n)
+ } else {
+ y = nat(nil).trunc(y, n)
+ }
+ }
+ if x.cmp(y) >= 0 {
+ return z.sub(x, y)
+ }
+ // x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x).
+ z = z.sub(y, x)
+ for uint(len(z))*_W < n {
+ z = append(z, 0)
+ }
+ for i := range z {
+ z[i] = ^z[i]
+ }
+ z = z.trunc(z, n)
+ return z.add(z, natOne)
+}
"444747819283133684179",
"42",
},
+ {"375", "249", "388", "175"},
+ {"375", "18446744073709551801", "388", "175"},
+ {"0", "0x40000000000000", "0x200", "0"},
+ {"0xeffffff900002f00", "0x40000000000000", "0x200", "0"},
+ {"5", "1435700818", "72", "49"},
+ {"0xffff", "0x300030003000300030003000300030003000302a3000300030003000300030003000300030003000300030003000300030003030623066307f3030783062303430383064303630343036", "0x300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", "0xa3f94c08b0b90e87af637cacc9383f7ea032352b8961fc036a52b659b6c9b33491b335ffd74c927f64ddd62cfca0001"},
}
func TestExpNN(t *testing.T) {
m = natFromString(test.m)
}
- z := nat(nil).expNN(x, y, m)
+ z := nat(nil).expNN(x, y, m, false)
if z.cmp(out) != 0 {
t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10))
}
}
}
+func FuzzExpMont(f *testing.F) {
+ f.Fuzz(func(t *testing.T, x1, x2, x3, y1, y2, y3, m1, m2, m3 uint) {
+ if m1 == 0 && m2 == 0 && m3 == 0 {
+ return
+ }
+ x := new(Int).SetBits([]Word{Word(x1), Word(x2), Word(x3)})
+ y := new(Int).SetBits([]Word{Word(y1), Word(y2), Word(y3)})
+ m := new(Int).SetBits([]Word{Word(m1), Word(m2), Word(m3)})
+ out := new(Int).Exp(x, y, m)
+ want := new(Int).expSlow(x, y, m)
+ if out.Cmp(want) != 0 {
+ t.Errorf("x = %#x\ny=%#x\nz=%#x\nout=%#x\nwant=%#x\ndc: 16o 16i %X %X %X |p", x, y, m, out, want, x, y, m)
+ }
+ })
+}
+
func BenchmarkExp3Power(b *testing.B) {
const x = 3
for _, y := range []Word{
}
}
+var subMod2NTests = []struct {
+ x string
+ y string
+ n uint
+ z string
+}{
+ {"1", "2", 0, "0"},
+ {"1", "0", 1, "1"},
+ {"0", "1", 1, "1"},
+ {"3", "5", 3, "6"},
+ {"5", "3", 3, "2"},
+ // 2^65, 2^66-1, 2^65 - (2^66-1) + 2^67
+ {"36893488147419103232", "73786976294838206463", 67, "110680464442257309697"},
+ // 2^66-1, 2^65, 2^65-1
+ {"73786976294838206463", "36893488147419103232", 67, "36893488147419103231"},
+}
+
+func TestNatSubMod2N(t *testing.T) {
+ for _, mode := range []string{"noalias", "aliasX", "aliasY"} {
+ t.Run(mode, func(t *testing.T) {
+ for _, tt := range subMod2NTests {
+ x0 := natFromString(tt.x)
+ y0 := natFromString(tt.y)
+ want := natFromString(tt.z)
+ x := nat(nil).set(x0)
+ y := nat(nil).set(y0)
+ var z nat
+ switch mode {
+ case "aliasX":
+ z = x
+ case "aliasY":
+ z = y
+ }
+ z = z.subMod2N(x, y, tt.n)
+ if z.cmp(want) != 0 {
+ t.Fatalf("subMod2N(%d, %d, %d) = %d, want %d", x0, y0, tt.n, z, want)
+ }
+ if mode != "aliasX" && x.cmp(x0) != 0 {
+ t.Fatalf("subMod2N(%d, %d, %d) modified x", x0, y0, tt.n)
+ }
+ if mode != "aliasY" && y.cmp(y0) != 0 {
+ t.Fatalf("subMod2N(%d, %d, %d) modified y", x0, y0, tt.n)
+ }
+ }
+ })
+ }
+}
+
func BenchmarkNatSetBytes(b *testing.B) {
const maxLength = 128
lengths := []int{