]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: use Montgomery for z.Exp(x, y, m) even for even m
authorRuss Cox <rsc@golang.org>
Wed, 27 Jul 2022 04:38:37 +0000 (00:38 -0400)
committerGopher Robot <gobot@golang.org>
Wed, 2 Nov 2022 14:43:52 +0000 (14:43 +0000)
Montgomery multiplication can be used for Exp mod even m
by splitting it into two steps - Exp mod an odd number and
Exp mod a power of two - and then combining the two results
using the Chinese Remainder Theorem.

For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf

Thanks to Guido Vranken for suggesting that we use a faster algorithm.

name                   old time/op    new time/op    delta
ExpMont/Odd-16            240µs ± 2%     239µs ± 2%      ~     (p=0.853 n=10+10)
ExpMont/Even1-16          757µs ± 3%     249µs ± 6%   -67.17%  (p=0.000 n=10+10)
ExpMont/Even2-16          755µs ± 1%     244µs ± 4%   -67.64%  (p=0.000 n=8+10)
ExpMont/Even3-16          771µs ± 3%     240µs ± 2%   -68.90%  (p=0.000 n=10+10)
ExpMont/Even4-16          775µs ± 3%     241µs ± 2%   -68.91%  (p=0.000 n=10+10)
ExpMont/Even8-16          779µs ± 2%     241µs ± 3%   -69.06%  (p=0.000 n=9+10)
ExpMont/Even32-16         778µs ± 3%     240µs ± 4%   -69.11%  (p=0.000 n=9+9)
ExpMont/Even64-16         774µs ± 6%     186µs ± 2%   -76.00%  (p=0.000 n=10+10)
ExpMont/Even96-16         776µs ± 4%     186µs ± 6%   -76.09%  (p=0.000 n=9+10)
ExpMont/Even128-16        764µs ± 2%     143µs ± 3%   -81.24%  (p=0.000 n=10+10)
ExpMont/Even255-16        761µs ± 3%     109µs ± 2%   -85.73%  (p=0.000 n=10+10)
ExpMont/SmallEven1-16    45.6µs ± 1%    36.3µs ± 3%   -20.49%  (p=0.000 n=9+10)
ExpMont/SmallEven2-16    44.3µs ± 2%    37.5µs ± 2%   -15.26%  (p=0.000 n=9+10)
ExpMont/SmallEven3-16    44.1µs ± 5%    37.3µs ± 3%   -15.32%  (p=0.000 n=9+10)
ExpMont/SmallEven4-16    47.1µs ± 6%    38.0µs ± 5%   -19.40%  (p=0.000 n=10+9)

name                   old alloc/op   new alloc/op   delta
ExpMont/Odd-16           2.53kB ± 0%    2.53kB ± 0%      ~     (p=0.137 n=8+10)
ExpMont/Even1-16         2.57kB ± 0%    3.31kB ± 0%   +28.90%  (p=0.000 n=8+10)
ExpMont/Even2-16         2.57kB ± 0%    3.37kB ± 0%   +31.09%  (p=0.000 n=9+10)
ExpMont/Even3-16         2.57kB ± 0%    3.37kB ± 0%   +31.08%  (p=0.000 n=8+8)
ExpMont/Even4-16         2.57kB ± 0%    3.37kB ± 0%   +31.09%  (p=0.000 n=9+10)
ExpMont/Even8-16         2.57kB ± 0%    3.37kB ± 0%   +31.09%  (p=0.000 n=9+10)
ExpMont/Even32-16        2.57kB ± 0%    3.37kB ± 0%   +31.14%  (p=0.000 n=10+10)
ExpMont/Even64-16        2.57kB ± 0%    3.16kB ± 0%   +22.99%  (p=0.000 n=9+9)
ExpMont/Even96-16        2.57kB ± 0%    3.44kB ± 0%   +33.90%  (p=0.000 n=10+8)
ExpMont/Even128-16       2.57kB ± 0%    2.88kB ± 0%   +12.10%  (p=0.000 n=10+10)
ExpMont/Even255-16       2.57kB ± 0%    2.54kB ± 0%    -1.30%  (p=0.000 n=9+10)
ExpMont/SmallEven1-16      872B ± 0%     1232B ± 0%   +41.28%  (p=0.000 n=10+10)
ExpMont/SmallEven2-16      872B ± 0%     1233B ± 0%   +41.40%  (p=0.000 n=10+9)
ExpMont/SmallEven3-16      872B ± 0%     1289B ± 0%   +47.82%  (p=0.000 n=10+10)
ExpMont/SmallEven4-16      872B ± 0%     1241B ± 0%   +42.32%  (p=0.000 n=10+10)

name                   old allocs/op  new allocs/op  delta
ExpMont/Odd-16             21.0 ± 0%      21.0 ± 0%      ~     (all equal)
ExpMont/Even1-16           24.0 ± 0%      38.0 ± 0%   +58.33%  (p=0.000 n=10+10)
ExpMont/Even2-16           24.0 ± 0%      40.0 ± 0%   +66.67%  (p=0.000 n=10+10)
ExpMont/Even3-16           24.0 ± 0%      40.0 ± 0%   +66.67%  (p=0.000 n=10+10)
ExpMont/Even4-16           24.0 ± 0%      40.0 ± 0%   +66.67%  (p=0.000 n=10+10)
ExpMont/Even8-16           24.0 ± 0%      40.0 ± 0%   +66.67%  (p=0.000 n=10+10)
ExpMont/Even32-16          24.0 ± 0%      40.0 ± 0%   +66.67%  (p=0.000 n=10+10)
ExpMont/Even64-16          24.0 ± 0%      39.0 ± 0%   +62.50%  (p=0.000 n=10+10)
ExpMont/Even96-16          24.0 ± 0%      42.0 ± 0%   +75.00%  (p=0.000 n=10+10)
ExpMont/Even128-16         24.0 ± 0%      40.0 ± 0%   +66.67%  (p=0.000 n=10+10)
ExpMont/Even255-16         24.0 ± 0%      38.0 ± 0%   +58.33%  (p=0.000 n=10+10)
ExpMont/SmallEven1-16      16.0 ± 0%      35.0 ± 0%  +118.75%  (p=0.000 n=10+10)
ExpMont/SmallEven2-16      16.0 ± 0%      35.0 ± 0%  +118.75%  (p=0.000 n=10+10)
ExpMont/SmallEven3-16      16.0 ± 0%      37.0 ± 0%  +131.25%  (p=0.000 n=10+10)
ExpMont/SmallEven4-16      16.0 ± 0%      36.0 ± 0%  +125.00%  (p=0.000 n=10+10)

Change-Id: Ib7f70290f8f087b78805ec3120baf17dd7737ac3
Reviewed-on: https://go-review.googlesource.com/c/go/+/420897
Run-TryBot: Russ Cox <rsc@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Russ Cox <rsc@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/math/big/arith_test.go
src/math/big/int.go
src/math/big/nat.go
src/math/big/nat_test.go
src/math/big/natconv.go
src/math/big/natdiv.go
src/math/big/prime.go
src/math/big/ratconv.go

index e530dd9750601bee062da35c2d2287a28d0eceb3..64225bbd53661ce9ad1fd034c5d135bf6c4caafa 100644 (file)
@@ -370,7 +370,7 @@ func TestShiftOverlap(t *testing.T) {
 func TestIssue31084(t *testing.T) {
        // compute 10^n via 5^n << n.
        const n = 165
-       p := nat(nil).expNN(nat{5}, nat{n}, nil)
+       p := nat(nil).expNN(nat{5}, nat{n}, nil, false)
        p = p.shl(p, n)
        got := string(p.utoa(10))
        want := "1" + strings.Repeat("0", n)
index a26fdbb90e497efa5a85b1bbaf702a067388c9a7..411a56966b9f0634d7cad04aed0171a3bc93d25c 100644 (file)
@@ -523,6 +523,14 @@ func (x *Int) TrailingZeroBits() uint {
 // Modular exponentiation of inputs of a particular size is not a
 // cryptographically constant-time operation.
 func (z *Int) Exp(x, y, m *Int) *Int {
+       return z.exp(x, y, m, false)
+}
+
+func (z *Int) expSlow(x, y, m *Int) *Int {
+       return z.exp(x, y, m, true)
+}
+
+func (z *Int) exp(x, y, m *Int, slow bool) *Int {
        // See Knuth, volume 2, section 4.6.3.
        xWords := x.abs
        if y.neg {
@@ -546,7 +554,7 @@ func (z *Int) Exp(x, y, m *Int) *Int {
                mWords = m.abs // m.abs may be nil for m == 0
        }
 
-       z.abs = z.abs.expNN(xWords, yWords, mWords)
+       z.abs = z.abs.expNN(xWords, yWords, mWords, slow)
        z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign
        if z.neg && len(mWords) > 0 {
                // make modulus result positive
@@ -879,6 +887,11 @@ func (z *Int) ModInverse(g, n *Int) *Int {
        return z
 }
 
+func (z nat) modInverse(g, n nat) nat {
+       // TODO(rsc): ModInverse should be implemented in terms of this function.
+       return (&Int{abs: z}).ModInverse(&Int{abs: g}, &Int{abs: n}).abs
+}
+
 // Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0.
 // The y argument must be an odd integer.
 func Jacobi(x, y *Int) int {
index 5cc42b80dc772addae7d6a6ba9e0b7445d505487..a7f4dc6999addccbf70b5152238c7e925137a5d4 100644 (file)
@@ -40,6 +40,10 @@ var (
        natTen  = nat{10}
 )
 
+func (z nat) String() string {
+       return "0x" + string(z.itoa(false, 16))
+}
+
 func (z nat) clear() {
        for i := range z {
                z[i] = 0
@@ -642,6 +646,9 @@ func getNat(n int) *nat {
                z = new(nat)
        }
        *z = z.make(n)
+       if n > 0 {
+               (*z)[0] = 0xfedcb // break code expecting zero
+       }
        return z
 }
 
@@ -651,7 +658,8 @@ func putNat(x *nat) {
 
 var natPool sync.Pool
 
-// Length of x in bits. x must be normalized.
+// bitLen returns the length of x in bits.
+// Unlike most methods, it works even if x is not normalized.
 func (x nat) bitLen() int {
        if i := len(x) - 1; i >= 0 {
                return i*_W + bits.Len(uint(x[i]))
@@ -673,6 +681,18 @@ func (x nat) trailingZeroBits() uint {
        return i*_W + uint(bits.TrailingZeros(uint(x[i])))
 }
 
+// isPow2 returns i, true when x == 2**i and 0, false otherwise.
+func (x nat) isPow2() (uint, bool) {
+       var i uint
+       for x[i] == 0 {
+               i++
+       }
+       if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 {
+               return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true
+       }
+       return 0, false
+}
+
 func same(x, y nat) bool {
        return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0]
 }
@@ -803,6 +823,20 @@ func (z nat) and(x, y nat) nat {
        return z.norm()
 }
 
+// trunc returns z = x mod 2ⁿ.
+func (z nat) trunc(x nat, n uint) nat {
+       w := (n + _W - 1) / _W
+       if uint(len(x)) < w {
+               return z.set(x)
+       }
+       z = z.make(int(w))
+       copy(z, x)
+       if n%_W != 0 {
+               z[len(z)-1] &= 1<<(n%_W) - 1
+       }
+       return z.norm()
+}
+
 func (z nat) andNot(x, y nat) nat {
        m := len(x)
        n := len(y)
@@ -896,7 +930,7 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
 
 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
 // otherwise it sets z to x**y. The result is the value of z.
-func (z nat) expNN(x, y, m nat) nat {
+func (z nat) expNN(x, y, m nat, slow bool) nat {
        if alias(z, x) || alias(z, y) {
                // We cannot allow in-place modification of x or y.
                z = nil
@@ -914,31 +948,48 @@ func (z nat) expNN(x, y, m nat) nat {
        }
        // y > 0
 
-       // x**1 mod m == x mod m
-       if len(y) == 1 && y[0] == 1 && len(m) != 0 {
-               _, z = nat(nil).div(z, x, m)
-               return z
+       // 0**y = 0
+       if len(x) == 0 {
+               return z.setWord(0)
+       }
+       // x > 0
+
+       // 1**y = 1
+       if len(x) == 1 && x[0] == 1 {
+               return z.setWord(1)
+       }
+       // x > 1
+
+       // x**1 == x
+       if len(y) == 1 && y[0] == 1 {
+               if len(m) != 0 {
+                       return z.rem(x, m)
+               }
+               return z.set(x)
        }
        // y > 1
 
        if len(m) != 0 {
                // We likely end up being as long as the modulus.
                z = z.make(len(m))
-       }
-       z = z.set(x)
 
-       // If the base is non-trivial and the exponent is large, we use
-       // 4-bit, windowed exponentiation. This involves precomputing 14 values
-       // (x^2...x^15) but then reduces the number of multiply-reduces by a
-       // third. Even for a 32-bit exponent, this reduces the number of
-       // operations. Uses Montgomery method for odd moduli.
-       if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 {
-               if m[0]&1 == 1 {
-                       return z.expNNMontgomery(x, y, m)
+               // If the exponent is large, we use the Montgomery method for odd values,
+               // and a 4-bit, windowed exponentiation for powers of two,
+               // and a CRT-decomposed Montgomery method for the remaining values
+               // (even values times non-trivial odd values, which decompose into one
+               // instance of each of the first two cases).
+               if len(y) > 1 && !slow {
+                       if m[0]&1 == 1 {
+                               return z.expNNMontgomery(x, y, m)
+                       }
+                       if logM, ok := m.isPow2(); ok {
+                               return z.expNNWindowed(x, y, logM)
+                       }
+                       return z.expNNMontgomeryEven(x, y, m)
                }
-               return z.expNNWindowed(x, y, m)
        }
 
+       z = z.set(x)
        v := y[len(y)-1] // v > 0 because y is normalized and y > 0
        shift := nlz(v) + 1
        v <<= shift
@@ -995,66 +1046,151 @@ func (z nat) expNN(x, y, m nat) nat {
        return z.norm()
 }
 
-// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
-func (z nat) expNNWindowed(x, y, m nat) nat {
-       // zz and r are used to avoid allocating in mul and div as otherwise
+// expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd.
+// It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2
+// and then uses the Chinese Remainder Theorem to combine the results.
+// The recursive call using m1 will use expNNWindowed,
+// while the recursive call using m2 will use expNNMontgomery.
+// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
+// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
+// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
+func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
+       // Split m = m₁ × m₂ where m₁ = 2ⁿ
+       n := m.trailingZeroBits()
+       m1 := nat(nil).shl(natOne, n)
+       m2 := nat(nil).shr(m, n)
+
+       // We want z = x**y mod m.
+       // z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
+       // z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
+       // (We are using the math/big convention for names here,
+       // where the computation is z = x**y mod m, so its parts are z1 and z2.
+       // The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
+       z1 := nat(nil).expNN(x, y, m1, false)
+       z2 := nat(nil).expNN(x, y, m2, false)
+
+       // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
+       // which uses only a single modInverse (and an easy one at that).
+       //      p = (z₁ - z₂) × m₂⁻¹ (mod m₁)
+       //      z = z₂ + p × m₂
+       // The final addition is in range because:
+       //      z = z₂ + p × m₂
+       //        ≤ z₂ + (m₁-1) × m₂
+       //        < m₂ + (m₁-1) × m₂
+       //        = m₁ × m₂
+       //        = m.
+       z = z.set(z2)
+
+       // Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1.
+       z1 = z1.subMod2N(z1, z2, n)
+
+       // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
+       m2inv := nat(nil).modInverse(m2, m1)
+       z2 = z2.mul(z1, m2inv)
+       z2 = z2.trunc(z2, n)
+
+       // Reuse z1 for p * m2.
+       z = z.add(z, z1.mul(z2, m2))
+
+       return z
+}
+
+// expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
+// where m = 2**logM.
+func (z nat) expNNWindowed(x, y nat, logM uint) nat {
+       if len(y) <= 1 {
+               panic("big: misuse of expNNWindowed")
+       }
+       if x[0]&1 == 0 {
+               // len(y) > 1, so y  > logM.
+               // x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM.
+               return z.setWord(0)
+       }
+       if logM == 1 {
+               return z.setWord(1)
+       }
+
+       // zz is used to avoid allocating in mul as otherwise
        // the arguments would alias.
-       var zz, r nat
+       w := int((logM + _W - 1) / _W)
+       zzp := getNat(w)
+       zz := *zzp
 
        const n = 4
        // powers[i] contains x^i.
-       var powers [1 << n]nat
-       powers[0] = natOne
-       powers[1] = x
+       var powers [1 << n]*nat
+       for i := range powers {
+               powers[i] = getNat(w)
+       }
+       *powers[0] = powers[0].set(natOne)
+       *powers[1] = powers[1].trunc(x, logM)
        for i := 2; i < 1<<n; i += 2 {
-               p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+               p2, p, p1 := powers[i/2], powers[i], powers[i+1]
                *p = p.sqr(*p2)
-               zz, r = zz.div(r, *p, m)
-               *p, r = r, *p
+               *p = p.trunc(*p, logM)
                *p1 = p1.mul(*p, x)
-               zz, r = zz.div(r, *p1, m)
-               *p1, r = r, *p1
+               *p1 = p1.trunc(*p1, logM)
        }
 
+       // Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1,
+       // so we can compute x**(y mod 2**(logM-1)) instead of x**y.
+       // That is, we can throw away all but the bottom logM-1 bits of y.
+       // Instead of allocating a new y, we start reading y at the right word
+       // and truncate it appropriately at the start of the loop.
+       i := len(y) - 1
+       mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word.
+       mmask := ^Word(0)
+       if mbits := (logM - 1) & (_W - 1); mbits != 0 {
+               mmask = (1 << mbits) - 1
+       }
+       if i > mtop {
+               i = mtop
+       }
+       advance := false
        z = z.setWord(1)
-
-       for i := len(y) - 1; i >= 0; i-- {
+       for ; i >= 0; i-- {
                yi := y[i]
+               if i == mtop {
+                       yi &= mmask
+               }
                for j := 0; j < _W; j += n {
-                       if i != len(y)-1 || j != 0 {
+                       if advance {
+                               // Account for use of 4 bits in previous iteration.
                                // Unrolled loop for significant performance
                                // gain. Use go test -bench=".*" in crypto/rsa
                                // to check performance before making changes.
                                zz = zz.sqr(z)
                                zz, z = z, zz
-                               zz, r = zz.div(r, z, m)
-                               z, r = r, z
+                               z = z.trunc(z, logM)
 
                                zz = zz.sqr(z)
                                zz, z = z, zz
-                               zz, r = zz.div(r, z, m)
-                               z, r = r, z
+                               z = z.trunc(z, logM)
 
                                zz = zz.sqr(z)
                                zz, z = z, zz
-                               zz, r = zz.div(r, z, m)
-                               z, r = r, z
+                               z = z.trunc(z, logM)
 
                                zz = zz.sqr(z)
                                zz, z = z, zz
-                               zz, r = zz.div(r, z, m)
-                               z, r = r, z
+                               z = z.trunc(z, logM)
                        }
 
-                       zz = zz.mul(z, powers[yi>>(_W-n)])
+                       zz = zz.mul(z, *powers[yi>>(_W-n)])
                        zz, z = z, zz
-                       zz, r = zz.div(r, z, m)
-                       z, r = r, z
+                       z = z.trunc(z, logM)
 
                        yi <<= n
+                       advance = true
                }
        }
 
+       *zzp = zz
+       putNat(zzp)
+       for i := range powers {
+               putNat(powers[i])
+       }
+
        return z.norm()
 }
 
@@ -1242,3 +1378,36 @@ func (z nat) sqrt(x nat) nat {
                z1, z2 = z2, z1
        }
 }
+
+// subMod2N returns z = (x - y) mod 2ⁿ.
+func (z nat) subMod2N(x, y nat, n uint) nat {
+       if uint(x.bitLen()) > n {
+               if alias(z, x) {
+                       // ok to overwrite x in place
+                       x = x.trunc(x, n)
+               } else {
+                       x = nat(nil).trunc(x, n)
+               }
+       }
+       if uint(y.bitLen()) > n {
+               if alias(z, y) {
+                       // ok to overwrite y in place
+                       y = y.trunc(y, n)
+               } else {
+                       y = nat(nil).trunc(y, n)
+               }
+       }
+       if x.cmp(y) >= 0 {
+               return z.sub(x, y)
+       }
+       // x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x).
+       z = z.sub(y, x)
+       for uint(len(z))*_W < n {
+               z = append(z, 0)
+       }
+       for i := range z {
+               z[i] = ^z[i]
+       }
+       z = z.trunc(z, n)
+       return z.add(z, natOne)
+}
index 08508189321795872257cf110963455a39c5c32c..b84a7be5bcbcac0ec81e92d12c4a9662b364ec2a 100644 (file)
@@ -523,6 +523,12 @@ var expNNTests = []struct {
                "444747819283133684179",
                "42",
        },
+       {"375", "249", "388", "175"},
+       {"375", "18446744073709551801", "388", "175"},
+       {"0", "0x40000000000000", "0x200", "0"},
+       {"0xeffffff900002f00", "0x40000000000000", "0x200", "0"},
+       {"5", "1435700818", "72", "49"},
+       {"0xffff", "0x300030003000300030003000300030003000302a3000300030003000300030003000300030003000300030003000300030003030623066307f3030783062303430383064303630343036", "0x300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", "0xa3f94c08b0b90e87af637cacc9383f7ea032352b8961fc036a52b659b6c9b33491b335ffd74c927f64ddd62cfca0001"},
 }
 
 func TestExpNN(t *testing.T) {
@@ -536,13 +542,29 @@ func TestExpNN(t *testing.T) {
                        m = natFromString(test.m)
                }
 
-               z := nat(nil).expNN(x, y, m)
+               z := nat(nil).expNN(x, y, m, false)
                if z.cmp(out) != 0 {
                        t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10))
                }
        }
 }
 
+func FuzzExpMont(f *testing.F) {
+       f.Fuzz(func(t *testing.T, x1, x2, x3, y1, y2, y3, m1, m2, m3 uint) {
+               if m1 == 0 && m2 == 0 && m3 == 0 {
+                       return
+               }
+               x := new(Int).SetBits([]Word{Word(x1), Word(x2), Word(x3)})
+               y := new(Int).SetBits([]Word{Word(y1), Word(y2), Word(y3)})
+               m := new(Int).SetBits([]Word{Word(m1), Word(m2), Word(m3)})
+               out := new(Int).Exp(x, y, m)
+               want := new(Int).expSlow(x, y, m)
+               if out.Cmp(want) != 0 {
+                       t.Errorf("x = %#x\ny=%#x\nz=%#x\nout=%#x\nwant=%#x\ndc: 16o 16i %X %X %X |p", x, y, m, out, want, x, y, m)
+               }
+       })
+}
+
 func BenchmarkExp3Power(b *testing.B) {
        const x = 3
        for _, y := range []Word{
@@ -733,6 +755,54 @@ func BenchmarkNatSqr(b *testing.B) {
        }
 }
 
+var subMod2NTests = []struct {
+       x string
+       y string
+       n uint
+       z string
+}{
+       {"1", "2", 0, "0"},
+       {"1", "0", 1, "1"},
+       {"0", "1", 1, "1"},
+       {"3", "5", 3, "6"},
+       {"5", "3", 3, "2"},
+       // 2^65, 2^66-1, 2^65 - (2^66-1) + 2^67
+       {"36893488147419103232", "73786976294838206463", 67, "110680464442257309697"},
+       // 2^66-1, 2^65, 2^65-1
+       {"73786976294838206463", "36893488147419103232", 67, "36893488147419103231"},
+}
+
+func TestNatSubMod2N(t *testing.T) {
+       for _, mode := range []string{"noalias", "aliasX", "aliasY"} {
+               t.Run(mode, func(t *testing.T) {
+                       for _, tt := range subMod2NTests {
+                               x0 := natFromString(tt.x)
+                               y0 := natFromString(tt.y)
+                               want := natFromString(tt.z)
+                               x := nat(nil).set(x0)
+                               y := nat(nil).set(y0)
+                               var z nat
+                               switch mode {
+                               case "aliasX":
+                                       z = x
+                               case "aliasY":
+                                       z = y
+                               }
+                               z = z.subMod2N(x, y, tt.n)
+                               if z.cmp(want) != 0 {
+                                       t.Fatalf("subMod2N(%d, %d, %d) = %d, want %d", x0, y0, tt.n, z, want)
+                               }
+                               if mode != "aliasX" && x.cmp(x0) != 0 {
+                                       t.Fatalf("subMod2N(%d, %d, %d) modified x", x0, y0, tt.n)
+                               }
+                               if mode != "aliasY" && y.cmp(y0) != 0 {
+                                       t.Fatalf("subMod2N(%d, %d, %d) modified y", x0, y0, tt.n)
+                               }
+                       }
+               })
+       }
+}
+
 func BenchmarkNatSetBytes(b *testing.B) {
        const maxLength = 128
        lengths := []int{
index 21fdab53fd46cc9025ddb6411726e58fc3ec8c3f..da59bd6e4c5755909f607fcf12f7f64004a78dd4 100644 (file)
@@ -452,7 +452,7 @@ var cacheBase10 struct {
 
 // expWW computes x**y
 func (z nat) expWW(x, y Word) nat {
-       return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil)
+       return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
 }
 
 // construct table of powers of bb*leafSize to use in subdivisions
index 882bb6d3bace1394e1b6a9a67a6b60ddeeb5d43f..14233a2ddb57e2406d2b762255cead3ee0b73d46 100644 (file)
@@ -500,6 +500,19 @@ package big
 
 import "math/bits"
 
+// rem returns r such that r = u%v.
+// It uses z as the storage for r.
+func (z nat) rem(u, v nat) (r nat) {
+       if alias(z, u) {
+               z = nil
+       }
+       qp := getNat(0)
+       q, r := qp.div(z, u, v)
+       *qp = q
+       putNat(qp)
+       return r
+}
+
 // div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v.
 // It uses z and z2 as the storage for q and r.
 func (z nat) div(z2, u, v nat) (q, r nat) {
index a06378956a6bfb45f43cd4a0ed6492f2c03fc0e1..26688bbd64e9f2221d2a64a696ed05d037d4b5e1 100644 (file)
@@ -103,7 +103,7 @@ NextRandom:
                        x = x.random(rand, nm3, nm3Len)
                        x = x.add(x, natTwo)
                }
-               y = y.expNN(x, q, n)
+               y = y.expNN(x, q, n, false)
                if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
                        continue
                }
index 794a51d007a70a5b03290a5045831af46973b94c..8537a6795f0b0bc0de1f67d2745cbe3cb1f60501 100644 (file)
@@ -178,7 +178,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
                if n > 1e6 {
                        return nil, false // avoid excessively large exponents
                }
-               pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil) // use underlying array of z.b.abs
+               pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
                if exp5 > 0 {
                        z.a.abs = z.a.abs.mul(z.a.abs, pow5)
                        z.b.abs = z.b.abs.setWord(1)
@@ -346,7 +346,7 @@ func (x *Rat) FloatString(prec int) string {
 
        p := natOne
        if prec > 0 {
-               p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil)
+               p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false)
        }
 
        r = r.mul(r, p)