In the early days of math/big, algorithms that needed more space
grew the result larger than it needed to be and then used the
high words as extra space. This made results their own temporary
space caches, at the cost that saving a result in a data structure
might hold significantly more memory than necessary.
Specifically, new(big.Int).Mul(x, y) returned a big.Int with a
backing slice 3X as big as it strictly needed to be.
If you are storing many multiplication results, or even a single
large result, the 3X overhead can add up.
This approach to storage for temporaries also requires being able
to analyze the algorithms to predict the exact amount they need,
which can be difficult.
For both these reasons, the implementation of recursive long division,
which came later, introduced a “nat pool” where temporaries could be
stored and reused, or reclaimed by the GC when no longer used.
This avoids the storage and bookkeeping overheads but introduces a
per-temporary sync.Pool overhead. divRecursiveStep takes an array
of cached temporaries to remove some of that overhead.
The nat pool was better but is still not quite right.
This CL introduces something even better than the nat pool
(still probably not quite right, but the best I can see for now):
a sync.Pool holding stacks for allocating temporaries.
Now an operation can get one stack out of the pool and then
allocate as many temporaries as it needs during the operation,
eventually returning the stack back to the pool. The sync.Pool
operations are now per-exported-operation (like big.Int.Mul),
not per-temporary.
This CL converts both the pre-allocation in nat.mul and the
uses of the nat pool to use stack pools instead. This simplifies
some code and sets us up better for more complex algorithms
(such as Toom-Cook or FFT-based multiplication) that need
more temporaries. It is also a little bit faster.
goos: linux
goarch: amd64
pkg: math/big
cpu: Intel(R) Xeon(R) CPU @ 3.10GHz
│ old │ new │
│ sec/op │ sec/op vs base │
Div/20/10-16 23.68n ± 0% 22.21n ± 0% -6.21% (p=0.000 n=15)
Div/40/20-16 23.68n ± 0% 22.21n ± 0% -6.21% (p=0.000 n=15)
Div/100/50-16 56.65n ± 0% 55.53n ± 0% -1.98% (p=0.000 n=15)
Div/200/100-16 194.6n ± 1% 172.8n ± 0% -11.20% (p=0.000 n=15)
Div/400/200-16 232.1n ± 0% 206.7n ± 0% -10.94% (p=0.000 n=15)
Div/1000/500-16 405.3n ± 1% 383.8n ± 0% -5.30% (p=0.000 n=15)
Div/2000/1000-16 810.4n ± 1% 795.2n ± 0% -1.88% (p=0.000 n=15)
Div/20000/10000-16 25.88µ ± 0% 25.39µ ± 0% -1.89% (p=0.000 n=15)
Div/200000/100000-16 931.5µ ± 0% 924.3µ ± 0% -0.77% (p=0.000 n=15)
Div/
2000000/
1000000-16 37.77m ± 0% 37.75m ± 0% ~ (p=0.098 n=15)
Div/
20000000/
10000000-16 1.367 ± 0% 1.377 ± 0% +0.72% (p=0.003 n=15)
NatMul/10-16 168.5n ± 3% 164.0n ± 4% ~ (p=0.751 n=15)
NatMul/100-16 6.086µ ± 3% 5.380µ ± 3% -11.60% (p=0.000 n=15)
NatMul/1000-16 238.1µ ± 3% 228.3µ ± 1% -4.12% (p=0.000 n=15)
NatMul/10000-16 8.721m ± 2% 8.518m ± 1% -2.33% (p=0.000 n=15)
NatMul/100000-16 369.6m ± 0% 371.1m ± 0% +0.42% (p=0.000 n=15)
geomean 19.57µ 18.74µ -4.21%
│ old │ new │
│ B/op │ B/op vs base │
NatMul/10-16 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-16 4.750Ki ± 0% 1.751Ki ± 0% -63.14% (p=0.000 n=15)
NatMul/1000-16 48.16Ki ± 0% 16.02Ki ± 0% -66.73% (p=0.000 n=15)
NatMul/10000-16 482.9Ki ± 1% 165.4Ki ± 3% -65.75% (p=0.000 n=15)
NatMul/100000-16 5.747Mi ± 7% 4.197Mi ± 0% -26.97% (p=0.000 n=15)
geomean 41.42Ki 20.63Ki -50.18%
¹ all samples are equal
│ old │ new │
│ allocs/op │ allocs/op vs base │
NatMul/10-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/1000-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/10000-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100000-16 7.000 ± 14% 7.000 ± 14% ~ (p=0.668 n=15)
geomean 1.476 1.476 +0.00%
¹ all samples are equal
goos: linux
goarch: amd64
pkg: math/big
cpu: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
│ old │ new │
│ sec/op │ sec/op vs base │
Div/20/10-88 15.84n ± 1% 13.12n ± 0% -17.17% (p=0.000 n=15)
Div/40/20-88 15.88n ± 1% 13.12n ± 0% -17.38% (p=0.000 n=15)
Div/100/50-88 26.42n ± 0% 25.47n ± 0% -3.60% (p=0.000 n=15)
Div/200/100-88 132.4n ± 0% 114.9n ± 0% -13.22% (p=0.000 n=15)
Div/400/200-88 150.1n ± 0% 135.6n ± 0% -9.66% (p=0.000 n=15)
Div/1000/500-88 275.5n ± 0% 264.1n ± 0% -4.14% (p=0.000 n=15)
Div/2000/1000-88 586.5n ± 0% 581.1n ± 0% -0.92% (p=0.000 n=15)
Div/20000/10000-88 25.87µ ± 0% 25.72µ ± 0% -0.59% (p=0.000 n=15)
Div/200000/100000-88 772.2µ ± 0% 779.0µ ± 0% +0.88% (p=0.000 n=15)
Div/
2000000/
1000000-88 33.36m ± 0% 33.63m ± 0% +0.80% (p=0.000 n=15)
Div/
20000000/
10000000-88 1.307 ± 0% 1.320 ± 0% +1.03% (p=0.000 n=15)
NatMul/10-88 140.4n ± 0% 148.8n ± 4% +5.98% (p=0.000 n=15)
NatMul/100-88 4.663µ ± 1% 4.388µ ± 1% -5.90% (p=0.000 n=15)
NatMul/1000-88 207.7µ ± 0% 205.8µ ± 0% -0.89% (p=0.000 n=15)
NatMul/10000-88 8.456m ± 0% 8.468m ± 0% +0.14% (p=0.021 n=15)
NatMul/100000-88 295.1m ± 0% 297.9m ± 0% +0.94% (p=0.000 n=15)
geomean 14.96µ 14.33µ -4.23%
│ old │ new │
│ B/op │ B/op vs base │
NatMul/10-88 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-88 4.750Ki ± 0% 1.758Ki ± 0% -62.99% (p=0.000 n=15)
NatMul/1000-88 48.44Ki ± 0% 16.08Ki ± 0% -66.80% (p=0.000 n=15)
NatMul/10000-88 489.7Ki ± 1% 166.1Ki ± 3% -66.08% (p=0.000 n=15)
NatMul/100000-88 5.546Mi ± 0% 3.819Mi ± 60% -31.15% (p=0.000 n=15)
geomean 41.29Ki 20.30Ki -50.85%
¹ all samples are equal
│ old │ new │
│ allocs/op │ allocs/op vs base │
NatMul/10-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/1000-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/10000-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100000-88 5.000 ± 20% 6.000 ± 67% ~ (p=0.672 n=15)
geomean 1.380 1.431 +3.71%
¹ all samples are equal
goos: linux
goarch: arm64
pkg: math/big
│ old │ new │
│ sec/op │ sec/op vs base │
Div/20/10-16 15.85n ± 0% 15.23n ± 0% -3.91% (p=0.000 n=15)
Div/40/20-16 15.88n ± 0% 15.22n ± 0% -4.16% (p=0.000 n=15)
Div/100/50-16 29.69n ± 0% 26.39n ± 0% -11.11% (p=0.000 n=15)
Div/200/100-16 149.2n ± 0% 123.3n ± 0% -17.36% (p=0.000 n=15)
Div/400/200-16 160.3n ± 0% 139.2n ± 0% -13.16% (p=0.000 n=15)
Div/1000/500-16 271.0n ± 0% 256.1n ± 0% -5.50% (p=0.000 n=15)
Div/2000/1000-16 545.3n ± 0% 527.0n ± 0% -3.36% (p=0.000 n=15)
Div/20000/10000-16 22.60µ ± 0% 22.20µ ± 0% -1.77% (p=0.000 n=15)
Div/200000/100000-16 889.0µ ± 0% 892.2µ ± 0% +0.35% (p=0.000 n=15)
Div/
2000000/
1000000-16 38.01m ± 0% 38.12m ± 0% +0.30% (p=0.000 n=15)
Div/
20000000/
10000000-16 1.437 ± 0% 1.444 ± 0% +0.50% (p=0.000 n=15)
NatMul/10-16 166.4n ± 2% 169.5n ± 1% +1.86% (p=0.000 n=15)
NatMul/100-16 5.733µ ± 1% 5.570µ ± 1% -2.84% (p=0.000 n=15)
NatMul/1000-16 232.6µ ± 1% 229.8µ ± 0% -1.22% (p=0.000 n=15)
NatMul/10000-16 9.039m ± 1% 8.969m ± 0% -0.77% (p=0.000 n=15)
NatMul/100000-16 367.0m ± 0% 368.8m ± 0% +0.48% (p=0.000 n=15)
geomean 16.15µ 15.50µ -4.01%
│ old │ new │
│ B/op │ B/op vs base │
NatMul/10-16 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-16 4.750Ki ± 0% 1.751Ki ± 0% -63.14% (p=0.000 n=15)
NatMul/1000-16 48.33Ki ± 0% 16.02Ki ± 0% -66.85% (p=0.000 n=15)
NatMul/10000-16 536.5Ki ± 1% 165.7Ki ± 3% -69.12% (p=0.000 n=15)
NatMul/100000-16 6.078Mi ± 6% 4.197Mi ± 0% -30.94% (p=0.000 n=15)
geomean 42.81Ki 20.64Ki -51.78%
¹ all samples are equal
│ old │ new │
│ allocs/op │ allocs/op vs base │
NatMul/10-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/1000-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/10000-16 2.000 ± 50% 1.000 ± 0% -50.00% (p=0.001 n=15)
NatMul/100000-16 9.000 ± 11% 8.000 ± 12% -11.11% (p=0.001 n=15)
geomean 1.783 1.516 -14.97%
¹ all samples are equal
goos: darwin
goarch: arm64
pkg: math/big
cpu: Apple M3 Pro
│ old │ new │
│ sec/op │ sec/op vs base │
Div/20/10-12 9.850n ± 1% 9.405n ± 1% -4.52% (p=0.000 n=15)
Div/40/20-12 9.858n ± 0% 9.403n ± 1% -4.62% (p=0.000 n=15)
Div/100/50-12 16.40n ± 1% 14.81n ± 0% -9.70% (p=0.000 n=15)
Div/200/100-12 88.48n ± 2% 80.88n ± 0% -8.59% (p=0.000 n=15)
Div/400/200-12 107.90n ± 1% 99.28n ± 1% -7.99% (p=0.000 n=15)
Div/1000/500-12 188.8n ± 1% 178.6n ± 1% -5.40% (p=0.000 n=15)
Div/2000/1000-12 399.9n ± 0% 389.1n ± 0% -2.70% (p=0.000 n=15)
Div/20000/10000-12 13.94µ ± 2% 13.81µ ± 1% ~ (p=0.574 n=15)
Div/200000/100000-12 523.8µ ± 0% 521.7µ ± 0% -0.40% (p=0.000 n=15)
Div/
2000000/
1000000-12 21.46m ± 0% 21.48m ± 0% ~ (p=0.067 n=15)
Div/
20000000/
10000000-12 812.5m ± 0% 812.9m ± 0% ~ (p=0.061 n=15)
NatMul/10-12 77.14n ± 0% 78.35n ± 1% +1.57% (p=0.000 n=15)
NatMul/100-12 2.999µ ± 0% 2.871µ ± 1% -4.27% (p=0.000 n=15)
NatMul/1000-12 126.2µ ± 0% 126.8µ ± 0% +0.51% (p=0.011 n=15)
NatMul/10000-12 5.099m ± 0% 5.125m ± 0% +0.51% (p=0.000 n=15)
NatMul/100000-12 206.7m ± 0% 208.4m ± 0% +0.80% (p=0.000 n=15)
geomean 9.512µ 9.236µ -2.91%
│ old │ new │
│ B/op │ B/op vs base │
NatMul/10-12 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-12 4.750Ki ± 0% 1.750Ki ± 0% -63.16% (p=0.000 n=15)
NatMul/1000-12 48.13Ki ± 0% 16.01Ki ± 0% -66.73% (p=0.000 n=15)
NatMul/10000-12 483.5Ki ± 1% 163.2Ki ± 2% -66.24% (p=0.000 n=15)
NatMul/100000-12 5.480Mi ± 4% 1.532Mi ± 104% -72.05% (p=0.000 n=15)
geomean 41.03Ki 16.82Ki -59.01%
¹ all samples are equal
│ old │ new │
│ allocs/op │ allocs/op vs base │
NatMul/10-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/1000-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/10000-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹
NatMul/100000-12 5.000 ± 0% 1.000 ± 400% -80.00% (p=0.007 n=15)
geomean 1.380 1.000 -27.52%
¹ all samples are equal
Change-Id: I7efa6fe37971ed26ae120a32250fcb47ece0a011
Reviewed-on: https://go-review.googlesource.com/c/go/+/650638
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Russ Cox <rsc@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
}
func TestIssue31084(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
// compute 10^n via 5^n << n.
const n = 165
- p := nat(nil).expNN(nat{5}, nat{n}, nil, false)
+ p := nat(nil).expNN(stk, nat{5}, nat{n}, nil, false)
p = p.shl(p, n)
got := string(p.utoa(10))
want := "1" + strings.Repeat("0", n)
e := int64(x.exp) + int64(y.exp)
if x == y {
- z.mant = z.mant.sqr(x.mant)
+ z.mant = z.mant.sqr(nil, x.mant)
} else {
- z.mant = z.mant.mul(x.mant, y.mant)
+ z.mant = z.mant.mul(nil, x.mant, y.mant)
}
z.setExpAndRound(e-fnorm(z.mant), 0)
}
d := len(xadj) - len(y.mant)
// divide
+ stk := getStack()
+ defer stk.free()
var r nat
- z.mant, r = z.mant.div(nil, xadj, y.mant)
+ z.mant, r = z.mant.div(stk, nil, xadj, y.mant)
e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W
// The result is long enough to include (at least) the rounding bit.
// Mul sets z to the product x*y and returns z.
func (z *Int) Mul(x, y *Int) *Int {
+ return z.mul(nil, x, y)
+}
+
+func (z *Int) mul(stk *stack, x, y *Int) *Int {
// x * y == x * y
// x * (-y) == -(x * y)
// (-x) * y == -(x * y)
// (-x) * (-y) == x * y
if x == y {
- z.abs = z.abs.sqr(x.abs)
+ z.abs = z.abs.sqr(stk, x.abs)
z.neg = false
return z
}
- z.abs = z.abs.mul(x.abs, y.abs)
+ z.abs = z.abs.mul(stk, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z
}
a, b = -b, -a
}
- z.abs = z.abs.mulRange(uint64(a), uint64(b))
+ z.abs = z.abs.mulRange(nil, uint64(a), uint64(b))
z.neg = neg
return z
}
// If y == 0, a division-by-zero run-time panic occurs.
// Quo implements truncated division (like Go); see [Int.QuoRem] for more details.
func (z *Int) Quo(x, y *Int) *Int {
- z.abs, _ = z.abs.div(nil, x.abs, y.abs)
+ z.abs, _ = z.abs.div(nil, nil, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z
}
// If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see [Int.QuoRem] for more details.
func (z *Int) Rem(x, y *Int) *Int {
- _, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
+ _, z.abs = nat(nil).div(nil, z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z
}
// (See Daan Leijen, “Division and Modulus for Computer Scientists”.)
// See [Int.DivMod] for Euclidean division and modulus (unlike Go).
func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
- z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs)
+ z.abs, r.abs = z.abs.div(nil, r.abs, x.abs, y.abs)
z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign
return z, r
}
mWords = m.abs // m.abs may be nil for m == 0
}
- z.abs = z.abs.expNN(xWords, yWords, mWords, slow)
+ z.abs = z.abs.expNN(nil, xWords, yWords, mWords, slow)
z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign
if z.neg && len(mWords) > 0 {
// make modulus result positive
panic("square root of negative number")
}
z.neg = false
- z.abs = z.abs.sqrt(x.abs)
+ z.abs = z.abs.sqrt(nil, x.abs)
return z
}
"internal/byteorder"
"math/bits"
"math/rand"
+ "slices"
"sync"
)
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
-// power of 2. The result vector z must have len(z) >= 6*n.
-// The (non-normalized) result is placed in z[0 : 2*n].
-func karatsuba(z, x, y nat) {
+// power of 2. The result vector z must have len(z) == len(x)+len(y).
+// The (non-normalized) result is placed in z.
+func karatsuba(stk *stack, z, x, y nat) {
n := len(y)
// Switch to basic multiplication if numbers are odd or small.
x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
- // z is used for the result and temporary storage:
- //
- // 6*n 5*n 4*n 3*n 2*n 1*n 0*n
- // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
- //
- // For each recursive call of karatsuba, an unused slice of
- // z is passed in that has (at least) half the length of the
- // caller's z.
-
// compute z0 and z2 with the result "in place" in z
- karatsuba(z, x0, y0) // z0 = x0*y0
- karatsuba(z[n:], x1, y1) // z2 = x1*y1
+ karatsuba(stk, z, x0, y0) // z0 = x0*y0
+ karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1
- // compute xd (or the negative value if underflow occurs)
+ // compute xd, yd (or the negative value if underflow occurs)
s := 1 // sign of product xd*yd
- xd := z[2*n : 2*n+n2]
+ defer stk.restore(stk.save())
+ xd := stk.nat(n2)
+ yd := stk.nat(n2)
if subVV(xd, x1, x0) != 0 { // x1-x0
s = -s
subVV(xd, x0, x1) // x0-x1
}
-
- // compute yd (or the negative value if underflow occurs)
- yd := z[2*n+n2 : 3*n]
if subVV(yd, y0, y1) != 0 { // y0-y1
s = -s
subVV(yd, y1, y0) // y1-y0
// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
- p := z[n*3:]
- karatsuba(p, xd, yd)
+ p := stk.nat(2 * n2)
+ karatsuba(stk, p, xd, yd)
// save original z2:z0
// (ok to use upper half of z since we're done recurring)
- r := z[n*4:]
+ r := stk.nat(n * 2)
copy(r, z[:n*2])
// add up all partial products
return n << i
}
-func (z nat) mul(x, y nat) nat {
+// mul sets z = x*y, using stk for temporary storage.
+// The caller may pass stk == nil to request that mul obtain and release one itself.
+func (z nat) mul(stk *stack, x, y nat) nat {
m := len(x)
n := len(y)
switch {
case m < n:
- return z.mul(y, x)
+ return z.mul(stk, y, x)
case m == 0 || n == 0:
return z[:0]
case n == 1:
k := karatsubaLen(n, karatsubaThreshold)
// k <= n
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
// multiply x0 and y0 via Karatsuba
- x0 := x[0:k] // x0 is not normalized
- y0 := y[0:k] // y0 is not normalized
- z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
- karatsuba(z, x0, y0)
- z = z[0 : m+n] // z has final length but may be incomplete
+ x0 := x[0:k] // x0 is not normalized
+ y0 := y[0:k] // y0 is not normalized
+ z = z.make(m + n) // enough space for full result of x*y
+ karatsuba(stk, z, x0, y0)
clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
// If xh != 0 or yh != 0, add the missing terms to z. For
// be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
- tp := getNat(3 * k)
- t := *tp
+ defer stk.restore(stk.save())
+ t := stk.nat(3 * k)
// add x0*y1*b
x0 := x0.norm()
- y1 := y[k:] // y1 is normalized because y is
- t = t.mul(x0, y1) // update t so we don't lose t's underlying array
+ y1 := y[k:] // y1 is normalized because y is
+ t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array
addAt(z, t, k)
// add xi*y0<<i, xi*y1*b<<(i+k)
xi = xi[:k]
}
xi = xi.norm()
- t = t.mul(xi, y0)
+ t = t.mul(stk, xi, y0)
addAt(z, t, i)
- t = t.mul(xi, y1)
+ t = t.mul(stk, xi, y1)
addAt(z, t, i+k)
}
-
- putNat(tp)
}
return z.norm()
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) == 2*len(x)
// The (non-normalized) result is placed in z.
-func basicSqr(z, x nat) {
+func basicSqr(stk *stack, z, x nat) {
n := len(x)
- tp := getNat(2 * n)
- t := *tp // temporary variable to hold the products
+ defer stk.restore(stk.save())
+ t := stk.nat(2 * n)
clear(t)
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
for i := 1; i < n; i++ {
}
t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
addVV(z, z, t) // combine the result
- putNat(tp)
}
// karatsubaSqr squares x and leaves the result in z.
-// len(x) must be a power of 2 and len(z) >= 6*len(x).
-// The (non-normalized) result is placed in z[0 : 2*len(x)].
+// len(x) must be a power of 2 and len(z) == 2*len(x).
+// The (non-normalized) result is placed in z.
//
// The algorithm and the layout of z are the same as for karatsuba.
-func karatsubaSqr(z, x nat) {
+func karatsubaSqr(stk *stack, z, x nat) {
n := len(x)
if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
- basicSqr(z[:2*n], x)
+ basicSqr(stk, z[:2*n], x)
return
}
n2 := n >> 1
x1, x0 := x[n2:], x[0:n2]
- karatsubaSqr(z, x0)
- karatsubaSqr(z[n:], x1)
+ karatsubaSqr(stk, z, x0)
+ karatsubaSqr(stk, z[n:], x1)
// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
- xd := z[2*n : 2*n+n2]
+ defer stk.restore(stk.save())
+ p := stk.nat(2 * n2)
+ r := stk.nat(n * 2)
+ xd := r[:n2]
if subVV(xd, x1, x0) != 0 {
subVV(xd, x0, x1)
}
- p := z[n*3:]
- karatsubaSqr(p, xd)
-
- r := z[n*4:]
+ karatsubaSqr(stk, p, xd)
copy(r, z[:n*2])
karatsubaAdd(z[n2:], r, n)
var basicSqrThreshold = 20 // computed by calibrate_test.go
var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
-// z = x*x
-func (z nat) sqr(x nat) nat {
+// sqr sets z = x*x, using stk for temporary storage.
+// The caller may pass stk == nil to request that sqr obtain and release one itself.
+func (z nat) sqr(stk *stack, x nat) nat {
n := len(x)
switch {
case n == 0:
if alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
+ z = z.make(2 * n)
if n < basicSqrThreshold {
- z = z.make(2 * n)
basicMul(z, x, x)
return z.norm()
}
+
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
if n < karatsubaSqrThreshold {
- z = z.make(2 * n)
- basicSqr(z, x)
+ basicSqr(stk, z, x)
return z.norm()
}
k := karatsubaLen(n, karatsubaSqrThreshold)
x0 := x[0:k]
- z = z.make(max(6*k, 2*n))
- karatsubaSqr(z, x0) // z = x0^2
- z = z[0 : 2*n]
+ karatsubaSqr(stk, z, x0) // z = x0^2
clear(z[2*k:])
if k < n {
- tp := getNat(2 * k)
- t := *tp
+ t := stk.nat(2 * k)
x0 := x0.norm()
x1 := x[k:]
- t = t.mul(x0, x1)
+ t = t.mul(stk, x0, x1)
addAt(z, t, k)
addAt(z, t, k) // z = 2*x1*x0*b + x0^2
- t = t.sqr(x1)
+ t = t.sqr(stk, x1)
addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
- putNat(tp)
}
return z.norm()
// mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1.
-func (z nat) mulRange(a, b uint64) nat {
+// The caller may pass stk == nil to request that mulRange obtain and release one itself.
+func (z nat) mulRange(stk *stack, a, b uint64) nat {
switch {
case a == 0:
// cut long ranges short (optimization)
case a == b:
return z.setUint64(a)
case a+1 == b:
- return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
+ return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b))
+ }
+
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
}
+
m := a + (b-a)/2 // avoid overflow
- return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
+ return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b))
}
-// getNat returns a *nat of len n. The contents may not be zero.
-// The pool holds *nat to avoid allocation when converting to interface{}.
-func getNat(n int) *nat {
- var z *nat
- if v := natPool.Get(); v != nil {
- z = v.(*nat)
- }
- if z == nil {
- z = new(nat)
- }
- *z = z.make(n)
- if n > 0 {
- (*z)[0] = 0xfedcb // break code expecting zero
+// A stack provides temporary storage for complex calculations
+// such as multiplication and division.
+// The stack is a simple slice of words, extended as needed
+// to hold all the temporary storage for a calculation.
+// In general, if a function takes a *stack, it expects a non-nil *stack.
+// However, certain functions may allow passing a nil *stack instead,
+// so that they can handle trivial stack-free cases without forcing the
+// caller to obtain and free a stack that will be unused. These functions
+// document that they accept a nil *stack in their doc comments.
+type stack struct {
+ w []Word
+}
+
+var stackPool sync.Pool
+
+// getStack returns a temporary stack.
+// The caller must call [stack.free] to give up use of the stack when finished.
+func getStack() *stack {
+ s, _ := stackPool.Get().(*stack)
+ if s == nil {
+ s = new(stack)
}
- return z
+ return s
+}
+
+// free returns the stack for use by another calculation.
+func (s *stack) free() {
+ s.w = s.w[:0]
+ stackPool.Put(s)
}
-func putNat(x *nat) {
- natPool.Put(x)
+// save returns the current stack pointer.
+// A future call to restore with the same value
+// frees any temporaries allocated on the stack after the call to save.
+func (s *stack) save() int {
+ return len(s.w)
}
-var natPool sync.Pool
+// restore restores the stack pointer to n.
+// It is almost always invoked as
+//
+// defer stk.restore(stk.save())
+//
+// which makes sure to pop any temporaries allocated in the current function
+// from the stack before returning.
+func (s *stack) restore(n int) {
+ s.w = s.w[:n]
+}
+
+// nat returns a nat of n words, allocated on the stack.
+func (s *stack) nat(n int) nat {
+ nr := (n + 3) &^ 3 // round up to multiple of 4
+ off := len(s.w)
+ s.w = slices.Grow(s.w, nr)
+ s.w = s.w[:off+nr]
+ x := s.w[off : off+n : off+n]
+ if n > 0 {
+ x[0] = 0xfedcb
+ }
+ return x
+}
// bitLen returns the length of x in bits.
// Unlike most methods, it works even if x is not normalized.
// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
// otherwise it sets z to x**y. The result is the value of z.
-func (z nat) expNN(x, y, m nat, slow bool) nat {
+// The caller may pass stk == nil to request that expNN obtain and release one itself.
+func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat {
if alias(z, x) || alias(z, y) {
// We cannot allow in-place modification of x or y.
z = nil
// x > 1
// x**1 == x
- if len(y) == 1 && y[0] == 1 {
- if len(m) != 0 {
- return z.rem(x, m)
- }
+ if len(y) == 1 && y[0] == 1 && len(m) == 0 {
return z.set(x)
}
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+ if len(y) == 1 && y[0] == 1 { // len(m) > 0
+ return z.rem(stk, x, m)
+ }
+
// y > 1
if len(m) != 0 {
// instance of each of the first two cases).
if len(y) > 1 && !slow {
if m[0]&1 == 1 {
- return z.expNNMontgomery(x, y, m)
+ return z.expNNMontgomery(stk, x, y, m)
}
if logM, ok := m.isPow2(); ok {
- return z.expNNWindowed(x, y, logM)
+ return z.expNNWindowed(stk, x, y, logM)
}
- return z.expNNMontgomeryEven(x, y, m)
+ return z.expNNMontgomeryEven(stk, x, y, m)
}
}
// otherwise the arguments would alias.
var zz, r nat
for j := 0; j < w; j++ {
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
if v&mask != 0 {
- zz = zz.mul(z, x)
+ zz = zz.mul(stk, z, x)
zz, z = z, zz
}
if len(m) != 0 {
- zz, r = zz.div(r, z, m)
+ zz, r = zz.div(stk, r, z, m)
zz, r, q, z = q, z, zz, r
}
v = y[i]
for j := 0; j < _W; j++ {
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
if v&mask != 0 {
- zz = zz.mul(z, x)
+ zz = zz.mul(stk, z, x)
zz, z = z, zz
}
if len(m) != 0 {
- zz, r = zz.div(r, z, m)
+ zz, r = zz.div(stk, r, z, m)
zz, r, q, z = q, z, zz, r
}
// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
-func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
+func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat {
// Split m = m₁ × m₂ where m₁ = 2ⁿ
n := m.trailingZeroBits()
m1 := nat(nil).shl(natOne, n)
// (We are using the math/big convention for names here,
// where the computation is z = x**y mod m, so its parts are z1 and z2.
// The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
- z1 := nat(nil).expNN(x, y, m1, false)
- z2 := nat(nil).expNN(x, y, m2, false)
+ z1 := nat(nil).expNN(stk, x, y, m1, false)
+ z2 := nat(nil).expNN(stk, x, y, m2, false)
// Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
// which uses only a single modInverse (and an easy one at that).
// Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
m2inv := nat(nil).modInverse(m2, m1)
- z2 = z2.mul(z1, m2inv)
+ z2 = z2.mul(stk, z1, m2inv)
z2 = z2.trunc(z2, n)
// Reuse z1 for p * m2.
- z = z.add(z, z1.mul(z2, m2))
+ z = z.add(z, z1.mul(stk, z2, m2))
return z
}
// expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
// where m = 2**logM.
-func (z nat) expNNWindowed(x, y nat, logM uint) nat {
+func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat {
if len(y) <= 1 {
panic("big: misuse of expNNWindowed")
}
// zz is used to avoid allocating in mul as otherwise
// the arguments would alias.
+ defer stk.restore(stk.save())
w := int((logM + _W - 1) / _W)
- zzp := getNat(w)
- zz := *zzp
+ zz := stk.nat(w)
const n = 4
// powers[i] contains x^i.
- var powers [1 << n]*nat
+ var powers [1 << n]nat
for i := range powers {
- powers[i] = getNat(w)
+ powers[i] = stk.nat(w)
}
- *powers[0] = powers[0].set(natOne)
- *powers[1] = powers[1].trunc(x, logM)
+ powers[0] = powers[0].set(natOne)
+ powers[1] = powers[1].trunc(x, logM)
for i := 2; i < 1<<n; i += 2 {
- p2, p, p1 := powers[i/2], powers[i], powers[i+1]
- *p = p.sqr(*p2)
+ p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+ *p = p.sqr(stk, *p2)
*p = p.trunc(*p, logM)
- *p1 = p1.mul(*p, x)
+ *p1 = p1.mul(stk, *p, x)
*p1 = p1.trunc(*p1, logM)
}
// Unrolled loop for significant performance
// gain. Use go test -bench=".*" in crypto/rsa
// to check performance before making changes.
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
}
- zz = zz.mul(z, *powers[yi>>(_W-n)])
+ zz = zz.mul(stk, z, powers[yi>>(_W-n)])
zz, z = z, zz
z = z.trunc(z, logM)
}
}
- *zzp = zz
- putNat(zzp)
- for i := range powers {
- putNat(powers[i])
- }
-
return z.norm()
}
// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
// Uses Montgomery representation.
-func (z nat) expNNMontgomery(x, y, m nat) nat {
+func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat {
numWords := len(m)
// We want the lengths of x and m to be equal.
// It is OK if x >= m as long as len(x) == len(m).
if len(x) > numWords {
- _, x = nat(nil).div(nil, x, m)
+ _, x = nat(nil).div(stk, nil, x, m)
// Note: now len(x) <= numWords, not guaranteed ==.
}
if len(x) < numWords {
// RR = 2**(2*_W*len(m)) mod m
RR := nat(nil).setWord(1)
zz := nat(nil).shl(RR, uint(2*numWords*_W))
- _, RR = nat(nil).div(RR, zz, m)
+ _, RR = nat(nil).div(stk, RR, zz, m)
if len(RR) < numWords {
zz = zz.make(numWords)
copy(zz, RR)
// The div is not expected to be reached.
zz = zz.sub(zz, m)
if zz.cmp(m) >= 0 {
- _, zz = nat(nil).div(nil, zz, m)
+ _, zz = nat(nil).div(stk, nil, zz, m)
}
}
}
// sqrt sets z = ⌊√x⌋
-func (z nat) sqrt(x nat) nat {
+// The caller may pass stk == nil to request that sqrt obtain and release one itself.
+func (z nat) sqrt(stk *stack, x nat) nat {
if x.cmp(natOne) <= 0 {
return z.set(x)
}
z = nil
}
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
// https://members.loria.fr/PZimmermann/mca/pub226.html
z1 = z1.setUint64(1)
z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
for n := 0; ; n++ {
- z2, _ = z2.div(nil, x, z1)
+ z2, _ = z2.div(stk, nil, x, z1)
z2 = z2.add(z2, z1)
z2 = z2.shr(z2, 1)
if z2.cmp(z1) >= 0 {
}
type funNN func(z, x, y nat) nat
+type funSNN func(z nat, stk *stack, x, y nat) nat
type argNN struct {
z, x, y nat
}
}
}
+func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) {
+ stk := getStack()
+ defer stk.free()
+ z := f(nil, stk, a.x, a.y)
+ if z.cmp(a.z) != 0 {
+ t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
+ }
+}
+
func TestFunNN(t *testing.T) {
for _, a := range sumNN {
arg := a
for _, a := range prodNN {
arg := a
- testFunNN(t, "mul", nat.mul, arg)
+ testFunSNN(t, "mul", nat.mul, arg)
arg = argNN{a.z, a.y, a.x}
- testFunNN(t, "mul symmetric", nat.mul, arg)
+ testFunSNN(t, "mul symmetric", nat.mul, arg)
}
}
}
func TestMulRangeN(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
for i, r := range mulRangesN {
- prod := string(nat(nil).mulRange(r.a, r.b).utoa(10))
+ prod := string(nat(nil).mulRange(stk, r.a, r.b).utoa(10))
if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
}
// does not cause deep recursion and in turn allocate too much memory.
// Test case for issue 3807.
func TestMulUnbalanced(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
x := rndNat(50000)
y := rndNat(40)
allocSize := allocBytes(func() {
- nat(nil).mul(x, y)
+ nat(nil).mul(stk, x, y)
})
inputSize := uint64(len(x)+len(y)) * _S
if ratio := allocSize / uint64(inputSize); ratio > 10 {
}
func BenchmarkMul(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
mulx := rndNat(1e4)
muly := rndNat(1e4)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var z nat
- z.mul(mulx, muly)
+ z.mul(stk, mulx, muly)
}
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
- z.mul(x, y)
+ z.mul(nil, x, y)
}
}
}
func TestMontgomery(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
one := NewInt(1)
_B := new(Int).Lsh(one, _W)
for i, test := range montgomeryTests {
}
if x.cmp(m) > 0 {
- _, r := nat(nil).div(nil, x, m)
+ _, r := nat(nil).div(stk, nil, x, m)
t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16))
}
if y.cmp(m) > 0 {
- _, r := nat(nil).div(nil, x, m)
+ _, r := nat(nil).div(stk, nil, x, m)
t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16))
}
}
func TestExpNN(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
for i, test := range expNNTests {
x := natFromString(test.x)
y := natFromString(test.y)
m = natFromString(test.m)
}
- z := nat(nil).expNN(x, y, m, false)
+ z := nat(nil).expNN(stk, x, y, m, false)
if z.cmp(out) != 0 {
t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10))
}
}
func BenchmarkExp3Power(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
const x = 3
for _, y := range []Word{
0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000,
b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) {
var z nat
for i := 0; i < b.N; i++ {
- z.expWW(x, y)
+ z.expWW(stk, x, y)
}
})
}
}
func testSqr(t *testing.T, x nat) {
+ stk := getStack()
+ defer stk.free()
+
got := make(nat, 2*len(x))
want := make(nat, 2*len(x))
- got = got.sqr(x)
- want = want.mul(x, x)
+ got = got.sqr(stk, x)
+ want = want.mul(stk, x, x)
if got.cmp(want) != 0 {
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
- z.sqr(x)
+ z.sqr(nil, x)
}
}
}
func TestNatDiv(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
sizes := []int{
1, 2, 5, 8, 15, 25, 40, 65, 100,
200, 500, 800, 1500, 2500, 4000, 6500, 10000,
c = c.norm()
}
// compute x = a*b+c
- x := nat(nil).mul(a, b)
+ x := nat(nil).mul(stk, a, b)
x = x.add(x, c)
var q, r nat
- q, r = q.div(r, x, b)
+ q, r = q.div(stk, r, x, b)
if q.cmp(a) != 0 {
t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10))
}
// the inaccurate estimate of the first word's quotient
// happens at the very beginning of the loop.
func TestIssue37499(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
// Choose u and v such that v is slightly larger than u >> N.
// This tricks divBasic into choosing 1 as the first word
// of the quotient. This works in both 32-bit and 64-bit settings.
v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c")
q := nat(nil).make(8)
- q.divBasic(u, v)
+ q.divBasic(stk, u, v)
q = q.norm()
if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" {
t.Fatalf("incorrect quotient: %s", s)
// where the first division loop is never entered, and correcting
// the remainder takes exactly two iterations in the final loop.
func TestIssue42552(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000")
v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
q := nat(nil).make(16)
- q.div(q, u, v)
+ q.div(stk, q, u, v)
}
}
} else {
+ stk := getStack()
+ defer stk.free()
+
bb, ndigits := maxPow(b)
// construct table of successive squares of bb*leafSize to use in subdivisions
// result (table != nil) <=> (len(x) > leafSize > 0)
- table := divisors(len(x), b, ndigits, bb)
+ table := divisors(stk, len(x), b, ndigits, bb)
// preserve x, create local copy for use by convertWords
q := nat(nil).set(x)
// convert q to string s in base b
- q.convertWords(s, b, ndigits, bb, table)
+ q.convertWords(stk, s, b, ndigits, bb, table)
// strip leading zeros
// (x != 0; thus s must contain at least one non-zero digit
// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
// specific hardware.
-func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) {
+func (q nat) convertWords(stk *stack, s []byte, b Word, ndigits int, bb Word, table []divisor) {
// split larger blocks recursively
if table != nil {
// len(q) > leafSize > 0
}
// split q into the two digit number (q'*bbb + r) to form independent subblocks
- q, r = q.div(r, q, table[index].bbb)
+ q, r = q.div(stk, r, q, table[index].bbb)
// convert subblocks and collect results in s[:h] and s[h:]
h := len(s) - table[index].ndigits
- r.convertWords(s[h:], b, ndigits, bb, table[0:index])
- s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1])
+ r.convertWords(stk, s[h:], b, ndigits, bb, table[0:index])
+ s = s[:h] // == q.convertWords(stk, s, b, ndigits, bb, table[0:index+1])
}
}
}
// expWW computes x**y
-func (z nat) expWW(x, y Word) nat {
- return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
+func (z nat) expWW(stk *stack, x, y Word) nat {
+ return z.expNN(stk, nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
}
// construct table of powers of bb*leafSize to use in subdivisions.
-func divisors(m int, b Word, ndigits int, bb Word) []divisor {
+func divisors(stk *stack, m int, b Word, ndigits int, bb Word) []divisor {
// only compute table when recursive conversion is enabled and x is large
if leafSize == 0 || m <= leafSize {
return nil
for i := 0; i < k; i++ {
if table[i].ndigits == 0 {
if i == 0 {
- table[0].bbb = nat(nil).expWW(bb, Word(leafSize))
+ table[0].bbb = nat(nil).expWW(stk, bb, Word(leafSize))
table[0].ndigits = ndigits * leafSize
} else {
- table[i].bbb = nat(nil).sqr(table[i-1].bbb)
+ table[i].bbb = nat(nil).sqr(stk, table[i-1].bbb)
table[i].ndigits = 2 * table[i-1].ndigits
}
}
func BenchmarkScan(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
const x = 10
for _, base := range []int{2, 8, 10, 16} {
for _, y := range []Word{10, 100, 1000, 10000, 100000} {
b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) {
b.StopTimer()
var z nat
- z = z.expWW(x, y)
+ z = z.expWW(stk, x, y)
s := z.utoa(base)
if t := itoa(z, base); !bytes.Equal(s, t) {
}
func BenchmarkString(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
const x = 10
for _, base := range []int{2, 8, 10, 16} {
for _, y := range []Word{10, 100, 1000, 10000, 100000} {
b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) {
b.StopTimer()
var z nat
- z = z.expWW(x, y)
+ z = z.expWW(stk, x, y)
z.utoa(base) // warm divisor cache
b.StartTimer()
for d := 1; d <= 10000; d *= 10 {
b.StopTimer()
+ stk := getStack()
var z nat
- z = z.expWW(Word(base), Word(d)) // build target number
- _ = z.utoa(base) // warm divisor cache
+ z = z.expWW(stk, Word(base), Word(d)) // build target number
+ _ = z.utoa(base) // warm divisor cache
+ stk.free()
b.StartTimer()
for i := 0; i < b.N; i++ {
}
func TestStringPowers(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
var p Word
for b := 2; b <= 16; b++ {
for p = 0; p <= 512; p++ {
if testing.Short() && p > 10 {
break
}
- x := nat(nil).expWW(Word(b), p)
+ x := nat(nil).expWW(stk, Word(b), p)
xs := x.utoa(b)
xs2 := itoa(x, b)
if !bytes.Equal(xs, xs2) {
// rem returns r such that r = u%v.
// It uses z as the storage for r.
-func (z nat) rem(u, v nat) (r nat) {
+func (z nat) rem(stk *stack, u, v nat) (r nat) {
if alias(z, u) {
z = nil
}
- qp := getNat(0)
- q, r := qp.div(z, u, v)
- *qp = q
- putNat(qp)
+ defer stk.restore(stk.save())
+ q := stk.nat(len(u) - (len(v) - 1))
+ _, r = q.div(stk, z, u, v)
return r
}
// div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v.
// It uses z and z2 as the storage for q and r.
-func (z nat) div(z2, u, v nat) (q, r nat) {
+// The caller may pass stk == nil to request that div obtain and release one itself.
+func (z nat) div(stk *stack, z2, u, v nat) (q, r nat) {
if len(v) == 0 {
panic("division by zero")
}
- if u.cmp(v) < 0 {
- q = z[:0]
- r = z2.set(u)
- return
- }
-
if len(v) == 1 {
// Short division: long optimized for a single-word divisor.
// In that case, the 2-by-1 guess is all we need at each step.
return
}
- q, r = z.divLarge(z2, u, v)
+ if u.cmp(v) < 0 {
+ q = z[:0]
+ r = z2.set(u)
+ return
+ }
+
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
+ q, r = z.divLarge(stk, z2, u, v)
return
}
// It uses z and u as the storage for q and r.
// The caller must ensure that len(vIn) ≥ 2 (use divW otherwise)
// and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise).
-func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
+func (z nat) divLarge(stk *stack, u, uIn, vIn nat) (q, r nat) {
n := len(vIn)
m := len(uIn) - n
// vIn is treated as a read-only input (it may be in use by another
// goroutine), so we must make a copy.
// uIn is copied to u.
+ defer stk.restore(stk.save())
shift := nlz(vIn[n-1])
- vp := getNat(n)
- v := *vp
+ v := stk.nat(n)
shlVU(v, vIn, shift)
u = u.make(len(uIn) + 1)
u[len(uIn)] = shlVU(u[:len(uIn)], uIn, shift)
// Use basic or recursive long division depending on size.
if n < divRecursiveThreshold {
- q.divBasic(u, v)
+ q.divBasic(stk, u, v)
} else {
- q.divRecursive(u, v)
+ q.divRecursive(stk, u, v)
}
- putNat(vp)
q = q.norm()
// divBasic implements long division as described above.
// It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r.
// q must be large enough to hold ⌊u/v⌋.
-func (q nat) divBasic(u, v nat) {
+func (q nat) divBasic(stk *stack, u, v nat) {
n := len(v)
m := len(u) - n
- qhatvp := getNat(n + 1)
- qhatv := *qhatvp
+ defer stk.restore(stk.save())
+ qhatv := stk.nat(n + 1)
// Set up for divWW below, precomputing reciprocal argument.
vn1 := v[n-1]
}
q[j] = qhat
}
-
- putNat(qhatvp)
}
// greaterThan reports whether the two digit numbers x1 x2 > y1 y2.
// z must be large enough to hold ⌊u/v⌋.
// This function is just for allocating and freeing temporaries
// around divRecursiveStep, the real implementation.
-func (z nat) divRecursive(u, v nat) {
- // Recursion depth is (much) less than 2 log₂(len(v)).
- // Allocate a slice of temporaries to be reused across recursion,
- // plus one extra temporary not live across the recursion.
- recDepth := 2 * bits.Len(uint(len(v)))
- tmp := getNat(3 * len(v))
- temps := make([]*nat, recDepth)
-
+func (z nat) divRecursive(stk *stack, u, v nat) {
clear(z)
- z.divRecursiveStep(u, v, 0, tmp, temps)
-
- // Free temporaries.
- for _, n := range temps {
- if n != nil {
- putNat(n)
- }
- }
- putNat(tmp)
+ z.divRecursiveStep(stk, u, v, 0)
}
// divRecursiveStep is the actual implementation of recursive division.
// z must be large enough to hold ⌊u/v⌋.
// It uses temps[depth] (allocating if needed) as a temporary live across
// the recursive call. It also uses tmp, but not live across the recursion.
-func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
+func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
// u is a subsection of the original and may have leading zeros.
// TODO(rsc): The v = v.norm() is useless and should be removed.
// We know (and require) that v's top digit is ≥ B/2.
// Fall back to basic division if the problem is now small enough.
n := len(v)
if n < divRecursiveThreshold {
- z.divBasic(u, v)
+ z.divBasic(stk, u, v)
return
}
B := n / 2
// Allocate a nat for qhat below.
- if temps[depth] == nil {
- temps[depth] = getNat(n) // TODO(rsc): Can be just B+1.
- } else {
- *temps[depth] = temps[depth].make(B + 1)
- }
+ defer stk.restore(stk.save())
+ qhat0 := stk.nat(B + 1)
// Compute each wide digit of the quotient.
//
uu := u[j-B:]
// Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n].
- qhat := *temps[depth]
+ qhat := qhat0
clear(qhat)
- qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
+ qhat.divRecursiveStep(stk, uu[s:B+n], v[s:], depth+1)
qhat = qhat.norm()
// Extend to a 3-by-2 quotient and remainder.
// q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u.
// But we can do the subtraction directly, as in the comment above
// and in long division, because we know that q̂ is wrong by at most one.
- qhatv := tmp.make(3 * n)
+ mark := stk.save()
+ qhatv := stk.nat(3 * n)
clear(qhatv)
- qhatv = qhatv.mul(qhat, v[:s])
+ qhatv = qhatv.mul(stk, qhat, v[:s])
for i := 0; i < 2; i++ {
e := qhatv.cmp(uu.norm())
if e <= 0 {
}
addAt(z, qhat, j-B)
j -= B
+ stk.restore(mark)
}
// TODO(rsc): Rewrite loop as described above and delete all this code.
// Now u < (v<<B), compute lower bits in the same way.
// Choose shift = B-1 again.
s := B - 1
- qhat := *temps[depth]
+ qhat := qhat0
clear(qhat)
- qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
+ qhat.divRecursiveStep(stk, u[s:].norm(), v[s:], depth+1)
qhat = qhat.norm()
- qhatv := tmp.make(3 * n)
+ qhatv := stk.nat(3 * n)
clear(qhatv)
- qhatv = qhatv.mul(qhat, v[:s])
+ qhatv = qhatv.mul(stk, qhat, v[:s])
// Set the correct remainder as before.
for i := 0; i < 2; i++ {
if e := qhatv.cmp(u.norm()); e > 0 {
return false
}
- return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas()
+ stk := getStack()
+ defer stk.free()
+ return x.abs.probablyPrimeMillerRabin(stk, n+1, true) && x.abs.probablyPrimeLucas(stk)
}
// probablyPrimeMillerRabin reports whether n passes reps rounds of the
// If force2 is true, one of the rounds is forced to use base 2.
// See Handbook of Applied Cryptography, p. 139, Algorithm 4.24.
// The number n is known to be non-zero.
-func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool {
+func (n nat) probablyPrimeMillerRabin(stk *stack, reps int, force2 bool) bool {
nm1 := nat(nil).sub(n, natOne)
// determine q, k such that nm1 = q << k
k := nm1.trailingZeroBits()
x = x.random(rand, nm3, nm3Len)
x = x.add(x, natTwo)
}
- y = y.expNN(x, q, n, false)
+ y = y.expNN(stk, x, q, n, false)
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
for j := uint(1); j < k; j++ {
- y = y.sqr(y)
- quotient, y = quotient.div(y, y, n)
+ y = y.sqr(stk, y)
+ quotient, y = quotient.div(stk, y, y, n)
if y.cmp(nm1) == 0 {
continue NextRandom
}
//
// Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed.
// Springer, 2005.
-func (n nat) probablyPrimeLucas() bool {
+func (n nat) probablyPrimeLucas(stk *stack) bool {
// Discard 0, 1.
if len(n) == 0 || n.cmp(natOne) == 0 {
return false
// We'll never find (d/n) = -1 if n is a square.
// If n is a non-square we expect to find a d in just a few attempts on average.
// After 40 attempts, take a moment to check if n is indeed a square.
- t1 = t1.sqrt(n)
- t1 = t1.sqr(t1)
+ t1 = t1.sqrt(stk, n)
+ t1 = t1.sqr(stk, t1)
if t1.cmp(n) == 0 {
return false
}
if s.bit(uint(i)) != 0 {
// k' = 2k+1
// V(k') = V(2k+1) = V(k) V(k+1) - P.
- t1 = t1.mul(vk, vk1)
+ t1 = t1.mul(stk, vk, vk1)
t1 = t1.add(t1, n)
t1 = t1.sub(t1, natP)
- t2, vk = t2.div(vk, t1, n)
+ t2, vk = t2.div(stk, vk, t1, n)
// V(k'+1) = V(2k+2) = V(k+1)² - 2.
- t1 = t1.sqr(vk1)
+ t1 = t1.sqr(stk, vk1)
t1 = t1.add(t1, nm2)
- t2, vk1 = t2.div(vk1, t1, n)
+ t2, vk1 = t2.div(stk, vk1, t1, n)
} else {
// k' = 2k
// V(k'+1) = V(2k+1) = V(k) V(k+1) - P.
- t1 = t1.mul(vk, vk1)
+ t1 = t1.mul(stk, vk, vk1)
t1 = t1.add(t1, n)
t1 = t1.sub(t1, natP)
- t2, vk1 = t2.div(vk1, t1, n)
+ t2, vk1 = t2.div(stk, vk1, t1, n)
// V(k') = V(2k) = V(k)² - 2
- t1 = t1.sqr(vk)
+ t1 = t1.sqr(stk, vk)
t1 = t1.add(t1, nm2)
- t2, vk = t2.div(vk, t1, n)
+ t2, vk = t2.div(stk, vk, t1, n)
}
}
//
// Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n,
// or P V(k) - 2 V(k+1) == 0 mod n.
- t1 := t1.mul(vk, natP)
+ t1 := t1.mul(stk, vk, natP)
t2 := t2.shl(vk1, 1)
if t1.cmp(t2) < 0 {
t1, t2 = t2, t1
t3 := vk1 // steal vk1, no longer needed below
vk1 = nil
_ = vk1
- t2, t3 = t2.div(t3, t1, n)
+ t2, t3 = t2.div(stk, t3, t1, n)
if len(t3) == 0 {
return true
}
}
// k' = 2k
// V(k') = V(2k) = V(k)² - 2
- t1 = t1.sqr(vk)
+ t1 = t1.sqr(stk, vk)
t1 = t1.sub(t1, natTwo)
- t2, vk = t2.div(vk, t1, n)
+ t2, vk = t2.div(stk, vk, t1, n)
}
return false
}
}
func BenchmarkProbablyPrime(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
p, _ := new(Int).SetString("203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", 10)
for _, n := range []int{0, 1, 5, 10, 20} {
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
b.Run("Lucas", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- p.abs.probablyPrimeLucas()
+ p.abs.probablyPrimeLucas(stk)
}
})
b.Run("MillerRabinBase2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- p.abs.probablyPrimeMillerRabin(1, true)
+ p.abs.probablyPrimeMillerRabin(stk, 1, true)
}
})
}
func TestMillerRabinPseudoprimes(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
testPseudoprimes(t, "probablyPrimeMillerRabin",
- func(n nat) bool { return n.probablyPrimeMillerRabin(1, true) && !n.probablyPrimeLucas() },
+ func(n nat) bool { return n.probablyPrimeMillerRabin(stk, 1, true) && !n.probablyPrimeLucas(stk) },
// https://oeis.org/A001262
[]int{2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, 52633, 65281, 74665, 80581, 85489, 88357, 90751})
}
func TestLucasPseudoprimes(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
testPseudoprimes(t, "probablyPrimeLucas",
- func(n nat) bool { return n.probablyPrimeLucas() && !n.probablyPrimeMillerRabin(1, true) },
+ func(n nat) bool { return n.probablyPrimeLucas(stk) && !n.probablyPrimeMillerRabin(stk, 1, true) },
// https://oeis.org/A217719
[]int{989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, 72389, 73919, 75077})
}
// nearest to the quotient a/b, using round-to-even in
// halfway cases. It does not mutate its arguments.
// Preconditions: b is non-zero; a and b have no common factors.
-func quotToFloat32(a, b nat) (f float32, exact bool) {
+func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) {
const (
// float size in bits
Fsize = 32
// extra shift, the low-order bit of q is logically the
// high-order bit of r.
var q nat
- q, r := q.div(a2, a2, b2) // (recycle a2)
+ q, r := q.div(stk, a2, a2, b2) // (recycle a2)
mantissa := low32(q)
haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
// nearest to the quotient a/b, using round-to-even in
// halfway cases. It does not mutate its arguments.
// Preconditions: b is non-zero; a and b have no common factors.
-func quotToFloat64(a, b nat) (f float64, exact bool) {
+func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) {
const (
// float size in bits
Fsize = 64
// extra shift, the low-order bit of q is logically the
// high-order bit of r.
var q nat
- q, r := q.div(a2, a2, b2) // (recycle a2)
+ q, r := q.div(stk, a2, a2, b2) // (recycle a2)
mantissa := low64(q)
haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
if len(b) == 0 {
b = natOne
}
- f, exact = quotToFloat32(x.a.abs, b)
+ stk := getStack()
+ defer stk.free()
+ f, exact = quotToFloat32(stk, x.a.abs, b)
if x.a.neg {
f = -f
}
if len(b) == 0 {
b = natOne
}
- f, exact = quotToFloat64(x.a.abs, b)
+ stk := getStack()
+ defer stk.free()
+ f, exact = quotToFloat64(stk, x.a.abs, b)
if x.a.neg {
f = -f
}
z.b.abs = z.b.abs.setWord(1)
default:
// z is fraction; normalize numerator and denominator
+ stk := getStack()
+ defer stk.free()
neg := z.a.neg
z.a.neg = false
z.b.neg = false
if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 {
- z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
- z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
+ z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs)
+ z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs)
}
z.a.neg = neg
}
// mulDenom sets z to the denominator product x*y (by taking into
// account that 0 values for x or y must be interpreted as 1) and
// returns z.
-func mulDenom(z, x, y nat) nat {
+func mulDenom(stk *stack, z, x, y nat) nat {
switch {
case len(x) == 0 && len(y) == 0:
return z.setWord(1)
case len(y) == 0:
return z.set(x)
}
- return z.mul(x, y)
+ return z.mul(stk, x, y)
}
// scaleDenom sets z to the product x*f.
// If f == 0 (zero value of denominator), z is set to (a copy of) x.
-func (z *Int) scaleDenom(x *Int, f nat) {
+func (z *Int) scaleDenom(stk *stack, x *Int, f nat) {
if len(f) == 0 {
z.Set(x)
return
}
- z.abs = z.abs.mul(x.abs, f)
+ z.abs = z.abs.mul(stk, x.abs, f)
z.neg = x.neg
}
// - +1 if x > y.
func (x *Rat) Cmp(y *Rat) int {
var a, b Int
- a.scaleDenom(&x.a, y.b.abs)
- b.scaleDenom(&y.a, x.b.abs)
+ stk := getStack()
+ defer stk.free()
+ a.scaleDenom(stk, &x.a, y.b.abs)
+ b.scaleDenom(stk, &y.a, x.b.abs)
return a.Cmp(&b)
}
// Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
var a1, a2 Int
- a1.scaleDenom(&x.a, y.b.abs)
- a2.scaleDenom(&y.a, x.b.abs)
+ a1.scaleDenom(stk, &x.a, y.b.abs)
+ a2.scaleDenom(stk, &y.a, x.b.abs)
z.a.Add(&a1, &a2)
- z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+ z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
var a1, a2 Int
- a1.scaleDenom(&x.a, y.b.abs)
- a2.scaleDenom(&y.a, x.b.abs)
+ a1.scaleDenom(stk, &x.a, y.b.abs)
+ a2.scaleDenom(stk, &y.a, x.b.abs)
z.a.Sub(&a1, &a2)
- z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+ z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
if x == y {
// a squared Rat is positive and can't be reduced (no need to call norm())
z.a.neg = false
- z.a.abs = z.a.abs.sqr(x.a.abs)
+ z.a.abs = z.a.abs.sqr(stk, x.a.abs)
if len(x.b.abs) == 0 {
z.b.abs = z.b.abs.setWord(1)
} else {
- z.b.abs = z.b.abs.sqr(x.b.abs)
+ z.b.abs = z.b.abs.sqr(stk, x.b.abs)
}
return z
}
- z.a.Mul(&x.a, &y.a)
- z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+
+ z.a.mul(stk, &x.a, &y.a)
+ z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Quo sets z to the quotient x/y and returns z.
// If y == 0, Quo panics.
func (z *Rat) Quo(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
if len(y.a.abs) == 0 {
panic("division by zero")
}
var a, b Int
- a.scaleDenom(&x.a, y.b.abs)
- b.scaleDenom(&y.a, x.b.abs)
+ a.scaleDenom(stk, &x.a, y.b.abs)
+ b.scaleDenom(stk, &y.a, x.b.abs)
z.a.abs = a.abs
z.b.abs = b.abs
z.a.neg = a.neg != b.neg
}
// exp consumed - not needed anymore
+ stk := getStack()
+ defer stk.free()
+
// apply exp5 contributions
// (start with exp5 so the numbers to multiply are smaller)
if exp5 != 0 {
if n > 1e6 {
return nil, false // avoid excessively large exponents
}
- pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
+ pow5 := z.b.abs.expNN(stk, natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
if exp5 > 0 {
- z.a.abs = z.a.abs.mul(z.a.abs, pow5)
+ z.a.abs = z.a.abs.mul(stk, z.a.abs, pow5)
z.b.abs = z.b.abs.setWord(1)
} else {
z.b.abs = pow5
}
// x.b.abs != 0
- q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs)
+ stk := getStack()
+ defer stk.free()
+ q, r := nat(nil).div(stk, nat(nil), x.a.abs, x.b.abs)
p := natOne
if prec > 0 {
- p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false)
+ p = nat(nil).expNN(stk, natTen, nat(nil).setUint64(uint64(prec)), nil, false)
}
- r = r.mul(r, p)
- r, r2 := r.div(nat(nil), r, x.b.abs)
+ r = r.mul(stk, r, p)
+ r, r2 := r.div(stk, nat(nil), r, x.b.abs)
// see if we need to round up
r2 = r2.add(r2, r2)
// 1/4 2 true 0.25
// 1/6 1 false 0.2 (0.166... rounded)
func (x *Rat) FloatPrec() (n int, exact bool) {
+ stk := getStack()
+ defer stk.free()
+
// Determine q and largest p2, p5 such that d = q·2^p2·5^p5.
// The results n, exact are:
//
f := nat{1220703125} // == 5^fp (must fit into a uint32 Word)
var t, r nat // temporaries
for {
- if _, r = t.div(r, q, f); len(r) != 0 {
+ if _, r = t.div(stk, r, q, f); len(r) != 0 {
break // f doesn't divide q evenly
}
tab = append(tab, f)
- f = nat(nil).sqr(f) // nat(nil) to ensure a new f for each table entry
+ f = nat(nil).sqr(stk, f) // nat(nil) to ensure a new f for each table entry
}
// Factor q using the table entries, if any.
// The same reasoning applies to the subsequent factors.
var p5 uint
for i := len(tab) - 1; i >= 0; i-- {
- if t, r = t.div(r, q, tab[i]); len(r) == 0 {
+ if t, r = t.div(stk, r, q, tab[i]); len(r) == 0 {
p5 += fp * (1 << i) // tab[i] == 5^(fp·2^i)
q = q.set(t)
}
// If fp != 1, we may still have multiples of 5 left.
for {
- if t, r = t.div(r, q, natFive); len(r) != 0 {
+ if t, r = t.div(stk, r, q, natFive); len(r) != 0 {
break
}
p5++