From 872496da2bbda59ef598e8c16f0ebfa41c98a462 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Fri, 17 Jan 2025 12:28:58 -0500 Subject: [PATCH] math/big: replace nat pool with Word stack MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit In the early days of math/big, algorithms that needed more space grew the result larger than it needed to be and then used the high words as extra space. This made results their own temporary space caches, at the cost that saving a result in a data structure might hold significantly more memory than necessary. Specifically, new(big.Int).Mul(x, y) returned a big.Int with a backing slice 3X as big as it strictly needed to be. If you are storing many multiplication results, or even a single large result, the 3X overhead can add up. This approach to storage for temporaries also requires being able to analyze the algorithms to predict the exact amount they need, which can be difficult. For both these reasons, the implementation of recursive long division, which came later, introduced a “nat pool” where temporaries could be stored and reused, or reclaimed by the GC when no longer used. This avoids the storage and bookkeeping overheads but introduces a per-temporary sync.Pool overhead. divRecursiveStep takes an array of cached temporaries to remove some of that overhead. The nat pool was better but is still not quite right. This CL introduces something even better than the nat pool (still probably not quite right, but the best I can see for now): a sync.Pool holding stacks for allocating temporaries. Now an operation can get one stack out of the pool and then allocate as many temporaries as it needs during the operation, eventually returning the stack back to the pool. The sync.Pool operations are now per-exported-operation (like big.Int.Mul), not per-temporary. This CL converts both the pre-allocation in nat.mul and the uses of the nat pool to use stack pools instead. This simplifies some code and sets us up better for more complex algorithms (such as Toom-Cook or FFT-based multiplication) that need more temporaries. It is also a little bit faster. goos: linux goarch: amd64 pkg: math/big cpu: Intel(R) Xeon(R) CPU @ 3.10GHz │ old │ new │ │ sec/op │ sec/op vs base │ Div/20/10-16 23.68n ± 0% 22.21n ± 0% -6.21% (p=0.000 n=15) Div/40/20-16 23.68n ± 0% 22.21n ± 0% -6.21% (p=0.000 n=15) Div/100/50-16 56.65n ± 0% 55.53n ± 0% -1.98% (p=0.000 n=15) Div/200/100-16 194.6n ± 1% 172.8n ± 0% -11.20% (p=0.000 n=15) Div/400/200-16 232.1n ± 0% 206.7n ± 0% -10.94% (p=0.000 n=15) Div/1000/500-16 405.3n ± 1% 383.8n ± 0% -5.30% (p=0.000 n=15) Div/2000/1000-16 810.4n ± 1% 795.2n ± 0% -1.88% (p=0.000 n=15) Div/20000/10000-16 25.88µ ± 0% 25.39µ ± 0% -1.89% (p=0.000 n=15) Div/200000/100000-16 931.5µ ± 0% 924.3µ ± 0% -0.77% (p=0.000 n=15) Div/2000000/1000000-16 37.77m ± 0% 37.75m ± 0% ~ (p=0.098 n=15) Div/20000000/10000000-16 1.367 ± 0% 1.377 ± 0% +0.72% (p=0.003 n=15) NatMul/10-16 168.5n ± 3% 164.0n ± 4% ~ (p=0.751 n=15) NatMul/100-16 6.086µ ± 3% 5.380µ ± 3% -11.60% (p=0.000 n=15) NatMul/1000-16 238.1µ ± 3% 228.3µ ± 1% -4.12% (p=0.000 n=15) NatMul/10000-16 8.721m ± 2% 8.518m ± 1% -2.33% (p=0.000 n=15) NatMul/100000-16 369.6m ± 0% 371.1m ± 0% +0.42% (p=0.000 n=15) geomean 19.57µ 18.74µ -4.21% │ old │ new │ │ B/op │ B/op vs base │ NatMul/10-16 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-16 4.750Ki ± 0% 1.751Ki ± 0% -63.14% (p=0.000 n=15) NatMul/1000-16 48.16Ki ± 0% 16.02Ki ± 0% -66.73% (p=0.000 n=15) NatMul/10000-16 482.9Ki ± 1% 165.4Ki ± 3% -65.75% (p=0.000 n=15) NatMul/100000-16 5.747Mi ± 7% 4.197Mi ± 0% -26.97% (p=0.000 n=15) geomean 41.42Ki 20.63Ki -50.18% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ NatMul/10-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/1000-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/10000-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100000-16 7.000 ± 14% 7.000 ± 14% ~ (p=0.668 n=15) geomean 1.476 1.476 +0.00% ¹ all samples are equal goos: linux goarch: amd64 pkg: math/big cpu: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz │ old │ new │ │ sec/op │ sec/op vs base │ Div/20/10-88 15.84n ± 1% 13.12n ± 0% -17.17% (p=0.000 n=15) Div/40/20-88 15.88n ± 1% 13.12n ± 0% -17.38% (p=0.000 n=15) Div/100/50-88 26.42n ± 0% 25.47n ± 0% -3.60% (p=0.000 n=15) Div/200/100-88 132.4n ± 0% 114.9n ± 0% -13.22% (p=0.000 n=15) Div/400/200-88 150.1n ± 0% 135.6n ± 0% -9.66% (p=0.000 n=15) Div/1000/500-88 275.5n ± 0% 264.1n ± 0% -4.14% (p=0.000 n=15) Div/2000/1000-88 586.5n ± 0% 581.1n ± 0% -0.92% (p=0.000 n=15) Div/20000/10000-88 25.87µ ± 0% 25.72µ ± 0% -0.59% (p=0.000 n=15) Div/200000/100000-88 772.2µ ± 0% 779.0µ ± 0% +0.88% (p=0.000 n=15) Div/2000000/1000000-88 33.36m ± 0% 33.63m ± 0% +0.80% (p=0.000 n=15) Div/20000000/10000000-88 1.307 ± 0% 1.320 ± 0% +1.03% (p=0.000 n=15) NatMul/10-88 140.4n ± 0% 148.8n ± 4% +5.98% (p=0.000 n=15) NatMul/100-88 4.663µ ± 1% 4.388µ ± 1% -5.90% (p=0.000 n=15) NatMul/1000-88 207.7µ ± 0% 205.8µ ± 0% -0.89% (p=0.000 n=15) NatMul/10000-88 8.456m ± 0% 8.468m ± 0% +0.14% (p=0.021 n=15) NatMul/100000-88 295.1m ± 0% 297.9m ± 0% +0.94% (p=0.000 n=15) geomean 14.96µ 14.33µ -4.23% │ old │ new │ │ B/op │ B/op vs base │ NatMul/10-88 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-88 4.750Ki ± 0% 1.758Ki ± 0% -62.99% (p=0.000 n=15) NatMul/1000-88 48.44Ki ± 0% 16.08Ki ± 0% -66.80% (p=0.000 n=15) NatMul/10000-88 489.7Ki ± 1% 166.1Ki ± 3% -66.08% (p=0.000 n=15) NatMul/100000-88 5.546Mi ± 0% 3.819Mi ± 60% -31.15% (p=0.000 n=15) geomean 41.29Ki 20.30Ki -50.85% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ NatMul/10-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/1000-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/10000-88 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100000-88 5.000 ± 20% 6.000 ± 67% ~ (p=0.672 n=15) geomean 1.380 1.431 +3.71% ¹ all samples are equal goos: linux goarch: arm64 pkg: math/big │ old │ new │ │ sec/op │ sec/op vs base │ Div/20/10-16 15.85n ± 0% 15.23n ± 0% -3.91% (p=0.000 n=15) Div/40/20-16 15.88n ± 0% 15.22n ± 0% -4.16% (p=0.000 n=15) Div/100/50-16 29.69n ± 0% 26.39n ± 0% -11.11% (p=0.000 n=15) Div/200/100-16 149.2n ± 0% 123.3n ± 0% -17.36% (p=0.000 n=15) Div/400/200-16 160.3n ± 0% 139.2n ± 0% -13.16% (p=0.000 n=15) Div/1000/500-16 271.0n ± 0% 256.1n ± 0% -5.50% (p=0.000 n=15) Div/2000/1000-16 545.3n ± 0% 527.0n ± 0% -3.36% (p=0.000 n=15) Div/20000/10000-16 22.60µ ± 0% 22.20µ ± 0% -1.77% (p=0.000 n=15) Div/200000/100000-16 889.0µ ± 0% 892.2µ ± 0% +0.35% (p=0.000 n=15) Div/2000000/1000000-16 38.01m ± 0% 38.12m ± 0% +0.30% (p=0.000 n=15) Div/20000000/10000000-16 1.437 ± 0% 1.444 ± 0% +0.50% (p=0.000 n=15) NatMul/10-16 166.4n ± 2% 169.5n ± 1% +1.86% (p=0.000 n=15) NatMul/100-16 5.733µ ± 1% 5.570µ ± 1% -2.84% (p=0.000 n=15) NatMul/1000-16 232.6µ ± 1% 229.8µ ± 0% -1.22% (p=0.000 n=15) NatMul/10000-16 9.039m ± 1% 8.969m ± 0% -0.77% (p=0.000 n=15) NatMul/100000-16 367.0m ± 0% 368.8m ± 0% +0.48% (p=0.000 n=15) geomean 16.15µ 15.50µ -4.01% │ old │ new │ │ B/op │ B/op vs base │ NatMul/10-16 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-16 4.750Ki ± 0% 1.751Ki ± 0% -63.14% (p=0.000 n=15) NatMul/1000-16 48.33Ki ± 0% 16.02Ki ± 0% -66.85% (p=0.000 n=15) NatMul/10000-16 536.5Ki ± 1% 165.7Ki ± 3% -69.12% (p=0.000 n=15) NatMul/100000-16 6.078Mi ± 6% 4.197Mi ± 0% -30.94% (p=0.000 n=15) geomean 42.81Ki 20.64Ki -51.78% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ NatMul/10-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/1000-16 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/10000-16 2.000 ± 50% 1.000 ± 0% -50.00% (p=0.001 n=15) NatMul/100000-16 9.000 ± 11% 8.000 ± 12% -11.11% (p=0.001 n=15) geomean 1.783 1.516 -14.97% ¹ all samples are equal goos: darwin goarch: arm64 pkg: math/big cpu: Apple M3 Pro │ old │ new │ │ sec/op │ sec/op vs base │ Div/20/10-12 9.850n ± 1% 9.405n ± 1% -4.52% (p=0.000 n=15) Div/40/20-12 9.858n ± 0% 9.403n ± 1% -4.62% (p=0.000 n=15) Div/100/50-12 16.40n ± 1% 14.81n ± 0% -9.70% (p=0.000 n=15) Div/200/100-12 88.48n ± 2% 80.88n ± 0% -8.59% (p=0.000 n=15) Div/400/200-12 107.90n ± 1% 99.28n ± 1% -7.99% (p=0.000 n=15) Div/1000/500-12 188.8n ± 1% 178.6n ± 1% -5.40% (p=0.000 n=15) Div/2000/1000-12 399.9n ± 0% 389.1n ± 0% -2.70% (p=0.000 n=15) Div/20000/10000-12 13.94µ ± 2% 13.81µ ± 1% ~ (p=0.574 n=15) Div/200000/100000-12 523.8µ ± 0% 521.7µ ± 0% -0.40% (p=0.000 n=15) Div/2000000/1000000-12 21.46m ± 0% 21.48m ± 0% ~ (p=0.067 n=15) Div/20000000/10000000-12 812.5m ± 0% 812.9m ± 0% ~ (p=0.061 n=15) NatMul/10-12 77.14n ± 0% 78.35n ± 1% +1.57% (p=0.000 n=15) NatMul/100-12 2.999µ ± 0% 2.871µ ± 1% -4.27% (p=0.000 n=15) NatMul/1000-12 126.2µ ± 0% 126.8µ ± 0% +0.51% (p=0.011 n=15) NatMul/10000-12 5.099m ± 0% 5.125m ± 0% +0.51% (p=0.000 n=15) NatMul/100000-12 206.7m ± 0% 208.4m ± 0% +0.80% (p=0.000 n=15) geomean 9.512µ 9.236µ -2.91% │ old │ new │ │ B/op │ B/op vs base │ NatMul/10-12 192.0 ± 0% 192.0 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-12 4.750Ki ± 0% 1.750Ki ± 0% -63.16% (p=0.000 n=15) NatMul/1000-12 48.13Ki ± 0% 16.01Ki ± 0% -66.73% (p=0.000 n=15) NatMul/10000-12 483.5Ki ± 1% 163.2Ki ± 2% -66.24% (p=0.000 n=15) NatMul/100000-12 5.480Mi ± 4% 1.532Mi ± 104% -72.05% (p=0.000 n=15) geomean 41.03Ki 16.82Ki -59.01% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ NatMul/10-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/1000-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/10000-12 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=15) ¹ NatMul/100000-12 5.000 ± 0% 1.000 ± 400% -80.00% (p=0.007 n=15) geomean 1.380 1.000 -27.52% ¹ all samples are equal Change-Id: I7efa6fe37971ed26ae120a32250fcb47ece0a011 Reviewed-on: https://go-review.googlesource.com/c/go/+/650638 LUCI-TryBot-Result: Go LUCI Auto-Submit: Russ Cox Reviewed-by: Ian Lance Taylor Reviewed-by: Alan Donovan --- src/math/big/arith_test.go | 5 +- src/math/big/float.go | 8 +- src/math/big/int.go | 20 ++- src/math/big/nat.go | 310 ++++++++++++++++++++--------------- src/math/big/nat_test.go | 74 +++++++-- src/math/big/natconv.go | 25 +-- src/math/big/natconv_test.go | 21 ++- src/math/big/natdiv.go | 98 +++++------ src/math/big/prime.go | 42 ++--- src/math/big/prime_test.go | 17 +- src/math/big/rat.go | 73 ++++++--- src/math/big/ratconv.go | 28 ++-- 12 files changed, 429 insertions(+), 292 deletions(-) diff --git a/src/math/big/arith_test.go b/src/math/big/arith_test.go index 64225bbd53..feffa1bc95 100644 --- a/src/math/big/arith_test.go +++ b/src/math/big/arith_test.go @@ -368,9 +368,12 @@ func TestShiftOverlap(t *testing.T) { } func TestIssue31084(t *testing.T) { + stk := getStack() + defer stk.free() + // compute 10^n via 5^n << n. const n = 165 - p := nat(nil).expNN(nat{5}, nat{n}, nil, false) + p := nat(nil).expNN(stk, nat{5}, nat{n}, nil, false) p = p.shl(p, n) got := string(p.utoa(10)) want := "1" + strings.Repeat("0", n) diff --git a/src/math/big/float.go b/src/math/big/float.go index e1d20d8bb4..2c5234a4ce 100644 --- a/src/math/big/float.go +++ b/src/math/big/float.go @@ -1327,9 +1327,9 @@ func (z *Float) umul(x, y *Float) { e := int64(x.exp) + int64(y.exp) if x == y { - z.mant = z.mant.sqr(x.mant) + z.mant = z.mant.sqr(nil, x.mant) } else { - z.mant = z.mant.mul(x.mant, y.mant) + z.mant = z.mant.mul(nil, x.mant, y.mant) } z.setExpAndRound(e-fnorm(z.mant), 0) } @@ -1363,8 +1363,10 @@ func (z *Float) uquo(x, y *Float) { d := len(xadj) - len(y.mant) // divide + stk := getStack() + defer stk.free() var r nat - z.mant, r = z.mant.div(nil, xadj, y.mant) + z.mant, r = z.mant.div(stk, nil, xadj, y.mant) e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W // The result is long enough to include (at least) the rounding bit. diff --git a/src/math/big/int.go b/src/math/big/int.go index 0b710c6968..cb7221250d 100644 --- a/src/math/big/int.go +++ b/src/math/big/int.go @@ -181,16 +181,20 @@ func (z *Int) Sub(x, y *Int) *Int { // Mul sets z to the product x*y and returns z. func (z *Int) Mul(x, y *Int) *Int { + return z.mul(nil, x, y) +} + +func (z *Int) mul(stk *stack, x, y *Int) *Int { // x * y == x * y // x * (-y) == -(x * y) // (-x) * y == -(x * y) // (-x) * (-y) == x * y if x == y { - z.abs = z.abs.sqr(x.abs) + z.abs = z.abs.sqr(stk, x.abs) z.neg = false return z } - z.abs = z.abs.mul(x.abs, y.abs) + z.abs = z.abs.mul(stk, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign return z } @@ -213,7 +217,7 @@ func (z *Int) MulRange(a, b int64) *Int { a, b = -b, -a } - z.abs = z.abs.mulRange(uint64(a), uint64(b)) + z.abs = z.abs.mulRange(nil, uint64(a), uint64(b)) z.neg = neg return z } @@ -264,7 +268,7 @@ func (z *Int) Binomial(n, k int64) *Int { // If y == 0, a division-by-zero run-time panic occurs. // Quo implements truncated division (like Go); see [Int.QuoRem] for more details. func (z *Int) Quo(x, y *Int) *Int { - z.abs, _ = z.abs.div(nil, x.abs, y.abs) + z.abs, _ = z.abs.div(nil, nil, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign return z } @@ -273,7 +277,7 @@ func (z *Int) Quo(x, y *Int) *Int { // If y == 0, a division-by-zero run-time panic occurs. // Rem implements truncated modulus (like Go); see [Int.QuoRem] for more details. func (z *Int) Rem(x, y *Int) *Int { - _, z.abs = nat(nil).div(z.abs, x.abs, y.abs) + _, z.abs = nat(nil).div(nil, z.abs, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg // 0 has no sign return z } @@ -290,7 +294,7 @@ func (z *Int) Rem(x, y *Int) *Int { // (See Daan Leijen, “Division and Modulus for Computer Scientists”.) // See [Int.DivMod] for Euclidean division and modulus (unlike Go). func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) { - z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs) + z.abs, r.abs = z.abs.div(nil, r.abs, x.abs, y.abs) z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign return z, r } @@ -589,7 +593,7 @@ func (z *Int) exp(x, y, m *Int, slow bool) *Int { mWords = m.abs // m.abs may be nil for m == 0 } - z.abs = z.abs.expNN(xWords, yWords, mWords, slow) + z.abs = z.abs.expNN(nil, xWords, yWords, mWords, slow) z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign if z.neg && len(mWords) > 0 { // make modulus result positive @@ -1298,6 +1302,6 @@ func (z *Int) Sqrt(x *Int) *Int { panic("square root of negative number") } z.neg = false - z.abs = z.abs.sqrt(x.abs) + z.abs = z.abs.sqrt(nil, x.abs) return z } diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 541da229d6..ec75c8f6fd 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -17,6 +17,7 @@ import ( "internal/byteorder" "math/bits" "math/rand" + "slices" "sync" ) @@ -262,9 +263,9 @@ var karatsubaThreshold = 40 // computed by calibrate_test.go // karatsuba multiplies x and y and leaves the result in z. // Both x and y must have the same length n and n must be a -// power of 2. The result vector z must have len(z) >= 6*n. -// The (non-normalized) result is placed in z[0 : 2*n]. -func karatsuba(z, x, y nat) { +// power of 2. The result vector z must have len(z) == len(x)+len(y). +// The (non-normalized) result is placed in z. +func karatsuba(stk *stack, z, x, y nat) { n := len(y) // Switch to basic multiplication if numbers are odd or small. @@ -304,29 +305,19 @@ func karatsuba(z, x, y nat) { x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 - // z is used for the result and temporary storage: - // - // 6*n 5*n 4*n 3*n 2*n 1*n 0*n - // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] - // - // For each recursive call of karatsuba, an unused slice of - // z is passed in that has (at least) half the length of the - // caller's z. - // compute z0 and z2 with the result "in place" in z - karatsuba(z, x0, y0) // z0 = x0*y0 - karatsuba(z[n:], x1, y1) // z2 = x1*y1 + karatsuba(stk, z, x0, y0) // z0 = x0*y0 + karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1 - // compute xd (or the negative value if underflow occurs) + // compute xd, yd (or the negative value if underflow occurs) s := 1 // sign of product xd*yd - xd := z[2*n : 2*n+n2] + defer stk.restore(stk.save()) + xd := stk.nat(n2) + yd := stk.nat(n2) if subVV(xd, x1, x0) != 0 { // x1-x0 s = -s subVV(xd, x0, x1) // x0-x1 } - - // compute yd (or the negative value if underflow occurs) - yd := z[2*n+n2 : 3*n] if subVV(yd, y0, y1) != 0 { // y0-y1 s = -s subVV(yd, y1, y0) // y1-y0 @@ -334,12 +325,12 @@ func karatsuba(z, x, y nat) { // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 - p := z[n*3:] - karatsuba(p, xd, yd) + p := stk.nat(2 * n2) + karatsuba(stk, p, xd, yd) // save original z2:z0 // (ok to use upper half of z since we're done recurring) - r := z[n*4:] + r := stk.nat(n * 2) copy(r, z[:n*2]) // add up all partial products @@ -396,13 +387,15 @@ func karatsubaLen(n, threshold int) int { return n << i } -func (z nat) mul(x, y nat) nat { +// mul sets z = x*y, using stk for temporary storage. +// The caller may pass stk == nil to request that mul obtain and release one itself. +func (z nat) mul(stk *stack, x, y nat) nat { m := len(x) n := len(y) switch { case m < n: - return z.mul(y, x) + return z.mul(stk, y, x) case m == 0 || n == 0: return z[:0] case n == 1: @@ -432,12 +425,16 @@ func (z nat) mul(x, y nat) nat { k := karatsubaLen(n, karatsubaThreshold) // k <= n + if stk == nil { + stk = getStack() + defer stk.free() + } + // multiply x0 and y0 via Karatsuba - x0 := x[0:k] // x0 is not normalized - y0 := y[0:k] // y0 is not normalized - z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y - karatsuba(z, x0, y0) - z = z[0 : m+n] // z has final length but may be incomplete + x0 := x[0:k] // x0 is not normalized + y0 := y[0:k] // y0 is not normalized + z = z.make(m + n) // enough space for full result of x*y + karatsuba(stk, z, x0, y0) clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) // If xh != 0 or yh != 0, add the missing terms to z. For @@ -454,13 +451,13 @@ func (z nat) mul(x, y nat) nat { // be a larger valid threshold contradicting the assumption about k. // if k < n || m != n { - tp := getNat(3 * k) - t := *tp + defer stk.restore(stk.save()) + t := stk.nat(3 * k) // add x0*y1*b x0 := x0.norm() - y1 := y[k:] // y1 is normalized because y is - t = t.mul(x0, y1) // update t so we don't lose t's underlying array + y1 := y[k:] // y1 is normalized because y is + t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array addAt(z, t, k) // add xi*y0< 0, len(z) == 2*len(x) // The (non-normalized) result is placed in z. -func basicSqr(z, x nat) { +func basicSqr(stk *stack, z, x nat) { n := len(x) - tp := getNat(2 * n) - t := *tp // temporary variable to hold the products + defer stk.restore(stk.save()) + t := stk.nat(2 * n) clear(t) z[1], z[0] = mulWW(x[0], x[0]) // the initial square for i := 1; i < n; i++ { @@ -502,38 +497,37 @@ func basicSqr(z, x nat) { } t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products addVV(z, z, t) // combine the result - putNat(tp) } // karatsubaSqr squares x and leaves the result in z. -// len(x) must be a power of 2 and len(z) >= 6*len(x). -// The (non-normalized) result is placed in z[0 : 2*len(x)]. +// len(x) must be a power of 2 and len(z) == 2*len(x). +// The (non-normalized) result is placed in z. // // The algorithm and the layout of z are the same as for karatsuba. -func karatsubaSqr(z, x nat) { +func karatsubaSqr(stk *stack, z, x nat) { n := len(x) if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { - basicSqr(z[:2*n], x) + basicSqr(stk, z[:2*n], x) return } n2 := n >> 1 x1, x0 := x[n2:], x[0:n2] - karatsubaSqr(z, x0) - karatsubaSqr(z[n:], x1) + karatsubaSqr(stk, z, x0) + karatsubaSqr(stk, z[n:], x1) // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 - xd := z[2*n : 2*n+n2] + defer stk.restore(stk.save()) + p := stk.nat(2 * n2) + r := stk.nat(n * 2) + xd := r[:n2] if subVV(xd, x1, x0) != 0 { subVV(xd, x0, x1) } - p := z[n*3:] - karatsubaSqr(p, xd) - - r := z[n*4:] + karatsubaSqr(stk, p, xd) copy(r, z[:n*2]) karatsubaAdd(z[n2:], r, n) @@ -547,8 +541,9 @@ func karatsubaSqr(z, x nat) { var basicSqrThreshold = 20 // computed by calibrate_test.go var karatsubaSqrThreshold = 260 // computed by calibrate_test.go -// z = x*x -func (z nat) sqr(x nat) nat { +// sqr sets z = x*x, using stk for temporary storage. +// The caller may pass stk == nil to request that sqr obtain and release one itself. +func (z nat) sqr(stk *stack, x nat) nat { n := len(x) switch { case n == 0: @@ -563,15 +558,20 @@ func (z nat) sqr(x nat) nat { if alias(z, x) { z = nil // z is an alias for x - cannot reuse } + z = z.make(2 * n) if n < basicSqrThreshold { - z = z.make(2 * n) basicMul(z, x, x) return z.norm() } + + if stk == nil { + stk = getStack() + defer stk.free() + } + if n < karatsubaSqrThreshold { - z = z.make(2 * n) - basicSqr(z, x) + basicSqr(stk, z, x) return z.norm() } @@ -583,22 +583,18 @@ func (z nat) sqr(x nat) nat { k := karatsubaLen(n, karatsubaSqrThreshold) x0 := x[0:k] - z = z.make(max(6*k, 2*n)) - karatsubaSqr(z, x0) // z = x0^2 - z = z[0 : 2*n] + karatsubaSqr(stk, z, x0) // z = x0^2 clear(z[2*k:]) if k < n { - tp := getNat(2 * k) - t := *tp + t := stk.nat(2 * k) x0 := x0.norm() x1 := x[k:] - t = t.mul(x0, x1) + t = t.mul(stk, x0, x1) addAt(z, t, k) addAt(z, t, k) // z = 2*x1*x0*b + x0^2 - t = t.sqr(x1) + t = t.sqr(stk, x1) addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 - putNat(tp) } return z.norm() @@ -606,7 +602,8 @@ func (z nat) sqr(x nat) nat { // mulRange computes the product of all the unsigned integers in the // range [a, b] inclusively. If a > b (empty range), the result is 1. -func (z nat) mulRange(a, b uint64) nat { +// The caller may pass stk == nil to request that mulRange obtain and release one itself. +func (z nat) mulRange(stk *stack, a, b uint64) nat { switch { case a == 0: // cut long ranges short (optimization) @@ -616,34 +613,79 @@ func (z nat) mulRange(a, b uint64) nat { case a == b: return z.setUint64(a) case a+1 == b: - return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) + return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b)) + } + + if stk == nil { + stk = getStack() + defer stk.free() } + m := a + (b-a)/2 // avoid overflow - return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) + return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b)) } -// getNat returns a *nat of len n. The contents may not be zero. -// The pool holds *nat to avoid allocation when converting to interface{}. -func getNat(n int) *nat { - var z *nat - if v := natPool.Get(); v != nil { - z = v.(*nat) - } - if z == nil { - z = new(nat) - } - *z = z.make(n) - if n > 0 { - (*z)[0] = 0xfedcb // break code expecting zero +// A stack provides temporary storage for complex calculations +// such as multiplication and division. +// The stack is a simple slice of words, extended as needed +// to hold all the temporary storage for a calculation. +// In general, if a function takes a *stack, it expects a non-nil *stack. +// However, certain functions may allow passing a nil *stack instead, +// so that they can handle trivial stack-free cases without forcing the +// caller to obtain and free a stack that will be unused. These functions +// document that they accept a nil *stack in their doc comments. +type stack struct { + w []Word +} + +var stackPool sync.Pool + +// getStack returns a temporary stack. +// The caller must call [stack.free] to give up use of the stack when finished. +func getStack() *stack { + s, _ := stackPool.Get().(*stack) + if s == nil { + s = new(stack) } - return z + return s +} + +// free returns the stack for use by another calculation. +func (s *stack) free() { + s.w = s.w[:0] + stackPool.Put(s) } -func putNat(x *nat) { - natPool.Put(x) +// save returns the current stack pointer. +// A future call to restore with the same value +// frees any temporaries allocated on the stack after the call to save. +func (s *stack) save() int { + return len(s.w) } -var natPool sync.Pool +// restore restores the stack pointer to n. +// It is almost always invoked as +// +// defer stk.restore(stk.save()) +// +// which makes sure to pop any temporaries allocated in the current function +// from the stack before returning. +func (s *stack) restore(n int) { + s.w = s.w[:n] +} + +// nat returns a nat of n words, allocated on the stack. +func (s *stack) nat(n int) nat { + nr := (n + 3) &^ 3 // round up to multiple of 4 + off := len(s.w) + s.w = slices.Grow(s.w, nr) + s.w = s.w[:off+nr] + x := s.w[off : off+n : off+n] + if n > 0 { + x[0] = 0xfedcb + } + return x +} // bitLen returns the length of x in bits. // Unlike most methods, it works even if x is not normalized. @@ -930,7 +972,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; // otherwise it sets z to x**y. The result is the value of z. -func (z nat) expNN(x, y, m nat, slow bool) nat { +// The caller may pass stk == nil to request that expNN obtain and release one itself. +func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat { if alias(z, x) || alias(z, y) { // We cannot allow in-place modification of x or y. z = nil @@ -961,12 +1004,17 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // x > 1 // x**1 == x - if len(y) == 1 && y[0] == 1 { - if len(m) != 0 { - return z.rem(x, m) - } + if len(y) == 1 && y[0] == 1 && len(m) == 0 { return z.set(x) } + if stk == nil { + stk = getStack() + defer stk.free() + } + if len(y) == 1 && y[0] == 1 { // len(m) > 0 + return z.rem(stk, x, m) + } + // y > 1 if len(m) != 0 { @@ -980,12 +1028,12 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // instance of each of the first two cases). if len(y) > 1 && !slow { if m[0]&1 == 1 { - return z.expNNMontgomery(x, y, m) + return z.expNNMontgomery(stk, x, y, m) } if logM, ok := m.isPow2(); ok { - return z.expNNWindowed(x, y, logM) + return z.expNNWindowed(stk, x, y, logM) } - return z.expNNMontgomeryEven(x, y, m) + return z.expNNMontgomeryEven(stk, x, y, m) } } @@ -1006,16 +1054,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // otherwise the arguments would alias. var zz, r nat for j := 0; j < w; j++ { - zz = zz.sqr(z) + zz = zz.sqr(stk, z) zz, z = z, zz if v&mask != 0 { - zz = zz.mul(z, x) + zz = zz.mul(stk, z, x) zz, z = z, zz } if len(m) != 0 { - zz, r = zz.div(r, z, m) + zz, r = zz.div(stk, r, z, m) zz, r, q, z = q, z, zz, r } @@ -1026,16 +1074,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { v = y[i] for j := 0; j < _W; j++ { - zz = zz.sqr(z) + zz = zz.sqr(stk, z) zz, z = z, zz if v&mask != 0 { - zz = zz.mul(z, x) + zz = zz.mul(stk, z, x) zz, z = z, zz } if len(m) != 0 { - zz, r = zz.div(r, z, m) + zz, r = zz.div(stk, r, z, m) zz, r, q, z = q, z, zz, r } @@ -1054,7 +1102,7 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”, // IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994. // http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf -func (z nat) expNNMontgomeryEven(x, y, m nat) nat { +func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat { // Split m = m₁ × m₂ where m₁ = 2ⁿ n := m.trailingZeroBits() m1 := nat(nil).shl(natOne, n) @@ -1066,8 +1114,8 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat { // (We are using the math/big convention for names here, // where the computation is z = x**y mod m, so its parts are z1 and z2. // The paper is computing x = a**e mod n; it refers to these as x2 and z1.) - z1 := nat(nil).expNN(x, y, m1, false) - z2 := nat(nil).expNN(x, y, m2, false) + z1 := nat(nil).expNN(stk, x, y, m1, false) + z2 := nat(nil).expNN(stk, x, y, m2, false) // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper, // which uses only a single modInverse (and an easy one at that). @@ -1086,18 +1134,18 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat { // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]). m2inv := nat(nil).modInverse(m2, m1) - z2 = z2.mul(z1, m2inv) + z2 = z2.mul(stk, z1, m2inv) z2 = z2.trunc(z2, n) // Reuse z1 for p * m2. - z = z.add(z, z1.mul(z2, m2)) + z = z.add(z, z1.mul(stk, z2, m2)) return z } // expNNWindowed calculates x**y mod m using a fixed, 4-bit window, // where m = 2**logM. -func (z nat) expNNWindowed(x, y nat, logM uint) nat { +func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat { if len(y) <= 1 { panic("big: misuse of expNNWindowed") } @@ -1112,23 +1160,23 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat { // zz is used to avoid allocating in mul as otherwise // the arguments would alias. + defer stk.restore(stk.save()) w := int((logM + _W - 1) / _W) - zzp := getNat(w) - zz := *zzp + zz := stk.nat(w) const n = 4 // powers[i] contains x^i. - var powers [1 << n]*nat + var powers [1 << n]nat for i := range powers { - powers[i] = getNat(w) + powers[i] = stk.nat(w) } - *powers[0] = powers[0].set(natOne) - *powers[1] = powers[1].trunc(x, logM) + powers[0] = powers[0].set(natOne) + powers[1] = powers[1].trunc(x, logM) for i := 2; i < 1<>(_W-n)]) + zz = zz.mul(stk, z, powers[yi>>(_W-n)]) zz, z = z, zz z = z.trunc(z, logM) @@ -1185,24 +1233,18 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat { } } - *zzp = zz - putNat(zzp) - for i := range powers { - putNat(powers[i]) - } - return z.norm() } // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. // Uses Montgomery representation. -func (z nat) expNNMontgomery(x, y, m nat) nat { +func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat { numWords := len(m) // We want the lengths of x and m to be equal. // It is OK if x >= m as long as len(x) == len(m). if len(x) > numWords { - _, x = nat(nil).div(nil, x, m) + _, x = nat(nil).div(stk, nil, x, m) // Note: now len(x) <= numWords, not guaranteed ==. } if len(x) < numWords { @@ -1225,7 +1267,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat { // RR = 2**(2*_W*len(m)) mod m RR := nat(nil).setWord(1) zz := nat(nil).shl(RR, uint(2*numWords*_W)) - _, RR = nat(nil).div(RR, zz, m) + _, RR = nat(nil).div(stk, RR, zz, m) if len(RR) < numWords { zz = zz.make(numWords) copy(zz, RR) @@ -1280,7 +1322,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat { // The div is not expected to be reached. zz = zz.sub(zz, m) if zz.cmp(m) >= 0 { - _, zz = nat(nil).div(nil, zz, m) + _, zz = nat(nil).div(stk, nil, zz, m) } } @@ -1349,7 +1391,8 @@ func (z nat) setBytes(buf []byte) nat { } // sqrt sets z = ⌊√x⌋ -func (z nat) sqrt(x nat) nat { +// The caller may pass stk == nil to request that sqrt obtain and release one itself. +func (z nat) sqrt(stk *stack, x nat) nat { if x.cmp(natOne) <= 0 { return z.set(x) } @@ -1357,6 +1400,11 @@ func (z nat) sqrt(x nat) nat { z = nil } + if stk == nil { + stk = getStack() + defer stk.free() + } + // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). // https://members.loria.fr/PZimmermann/mca/pub226.html @@ -1367,7 +1415,7 @@ func (z nat) sqrt(x nat) nat { z1 = z1.setUint64(1) z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x for n := 0; ; n++ { - z2, _ = z2.div(nil, x, z1) + z2, _ = z2.div(stk, nil, x, z1) z2 = z2.add(z2, z1) z2 = z2.shr(z2, 1) if z2.cmp(z1) >= 0 { diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index 46231f7976..1811dccfe3 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -42,6 +42,7 @@ func TestCmp(t *testing.T) { } type funNN func(z, x, y nat) nat +type funSNN func(z nat, stk *stack, x, y nat) nat type argNN struct { z, x, y nat } @@ -112,6 +113,15 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) { } } +func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) { + stk := getStack() + defer stk.free() + z := f(nil, stk, a.x, a.y) + if z.cmp(a.z) != 0 { + t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z) + } +} + func TestFunNN(t *testing.T) { for _, a := range sumNN { arg := a @@ -129,10 +139,10 @@ func TestFunNN(t *testing.T) { for _, a := range prodNN { arg := a - testFunNN(t, "mul", nat.mul, arg) + testFunSNN(t, "mul", nat.mul, arg) arg = argNN{a.z, a.y, a.x} - testFunNN(t, "mul symmetric", nat.mul, arg) + testFunSNN(t, "mul symmetric", nat.mul, arg) } } @@ -163,8 +173,11 @@ var mulRangesN = []struct { } func TestMulRangeN(t *testing.T) { + stk := getStack() + defer stk.free() + for i, r := range mulRangesN { - prod := string(nat(nil).mulRange(r.a, r.b).utoa(10)) + prod := string(nat(nil).mulRange(stk, r.a, r.b).utoa(10)) if prod != r.prod { t.Errorf("#%d: got %s; want %s", i, prod, r.prod) } @@ -185,11 +198,14 @@ func allocBytes(f func()) uint64 { // does not cause deep recursion and in turn allocate too much memory. // Test case for issue 3807. func TestMulUnbalanced(t *testing.T) { + stk := getStack() + defer stk.free() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) x := rndNat(50000) y := rndNat(40) allocSize := allocBytes(func() { - nat(nil).mul(x, y) + nat(nil).mul(stk, x, y) }) inputSize := uint64(len(x)+len(y)) * _S if ratio := allocSize / uint64(inputSize); ratio > 10 { @@ -214,12 +230,15 @@ func rndNat1(n int) nat { } func BenchmarkMul(b *testing.B) { + stk := getStack() + defer stk.free() + mulx := rndNat(1e4) muly := rndNat(1e4) b.ResetTimer() for i := 0; i < b.N; i++ { var z nat - z.mul(mulx, muly) + z.mul(stk, mulx, muly) } } @@ -230,7 +249,7 @@ func benchmarkNatMul(b *testing.B, nwords int) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - z.mul(x, y) + z.mul(nil, x, y) } } @@ -444,6 +463,9 @@ var montgomeryTests = []struct { } func TestMontgomery(t *testing.T) { + stk := getStack() + defer stk.free() + one := NewInt(1) _B := new(Int).Lsh(one, _W) for i, test := range montgomeryTests { @@ -458,11 +480,11 @@ func TestMontgomery(t *testing.T) { } if x.cmp(m) > 0 { - _, r := nat(nil).div(nil, x, m) + _, r := nat(nil).div(stk, nil, x, m) t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16)) } if y.cmp(m) > 0 { - _, r := nat(nil).div(nil, x, m) + _, r := nat(nil).div(stk, nil, x, m) t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16)) } @@ -538,6 +560,9 @@ var expNNTests = []struct { } func TestExpNN(t *testing.T) { + stk := getStack() + defer stk.free() + for i, test := range expNNTests { x := natFromString(test.x) y := natFromString(test.y) @@ -548,7 +573,7 @@ func TestExpNN(t *testing.T) { m = natFromString(test.m) } - z := nat(nil).expNN(x, y, m, false) + z := nat(nil).expNN(stk, x, y, m, false) if z.cmp(out) != 0 { t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10)) } @@ -572,6 +597,9 @@ func FuzzExpMont(f *testing.F) { } func BenchmarkExp3Power(b *testing.B) { + stk := getStack() + defer stk.free() + const x = 3 for _, y := range []Word{ 0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000, @@ -579,7 +607,7 @@ func BenchmarkExp3Power(b *testing.B) { b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) { var z nat for i := 0; i < b.N; i++ { - z.expWW(x, y) + z.expWW(stk, x, y) } }) } @@ -712,10 +740,13 @@ func TestSticky(t *testing.T) { } func testSqr(t *testing.T, x nat) { + stk := getStack() + defer stk.free() + got := make(nat, 2*len(x)) want := make(nat, 2*len(x)) - got = got.sqr(x) - want = want.mul(x, x) + got = got.sqr(stk, x) + want = want.mul(stk, x, x) if got.cmp(want) != 0 { t.Errorf("basicSqr(%v), got %v, want %v", x, got, want) } @@ -741,7 +772,7 @@ func benchmarkNatSqr(b *testing.B, nwords int) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - z.sqr(x) + z.sqr(nil, x) } } @@ -830,6 +861,9 @@ func BenchmarkNatSetBytes(b *testing.B) { } func TestNatDiv(t *testing.T) { + stk := getStack() + defer stk.free() + sizes := []int{ 1, 2, 5, 8, 15, 25, 40, 65, 100, 200, 500, 800, 1500, 2500, 4000, 6500, 10000, @@ -849,11 +883,11 @@ func TestNatDiv(t *testing.T) { c = c.norm() } // compute x = a*b+c - x := nat(nil).mul(a, b) + x := nat(nil).mul(stk, a, b) x = x.add(x, c) var q, r nat - q, r = q.div(r, x, b) + q, r = q.div(stk, r, x, b) if q.cmp(a) != 0 { t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10)) } @@ -868,6 +902,9 @@ func TestNatDiv(t *testing.T) { // the inaccurate estimate of the first word's quotient // happens at the very beginning of the loop. func TestIssue37499(t *testing.T) { + stk := getStack() + defer stk.free() + // Choose u and v such that v is slightly larger than u >> N. // This tricks divBasic into choosing 1 as the first word // of the quotient. This works in both 32-bit and 64-bit settings. @@ -875,7 +912,7 @@ func TestIssue37499(t *testing.T) { v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c") q := nat(nil).make(8) - q.divBasic(u, v) + q.divBasic(stk, u, v) q = q.norm() if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" { t.Fatalf("incorrect quotient: %s", s) @@ -886,8 +923,11 @@ func TestIssue37499(t *testing.T) { // where the first division loop is never entered, and correcting // the remainder takes exactly two iterations in the final loop. func TestIssue42552(t *testing.T) { + stk := getStack() + defer stk.free() + u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000") v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") q := nat(nil).make(16) - q.div(q, u, v) + q.div(stk, q, u, v) } diff --git a/src/math/big/natconv.go b/src/math/big/natconv.go index ce94f2cf72..8a47ec9f9c 100644 --- a/src/math/big/natconv.go +++ b/src/math/big/natconv.go @@ -321,17 +321,20 @@ func (x nat) itoa(neg bool, base int) []byte { } } else { + stk := getStack() + defer stk.free() + bb, ndigits := maxPow(b) // construct table of successive squares of bb*leafSize to use in subdivisions // result (table != nil) <=> (len(x) > leafSize > 0) - table := divisors(len(x), b, ndigits, bb) + table := divisors(stk, len(x), b, ndigits, bb) // preserve x, create local copy for use by convertWords q := nat(nil).set(x) // convert q to string s in base b - q.convertWords(s, b, ndigits, bb, table) + q.convertWords(stk, s, b, ndigits, bb, table) // strip leading zeros // (x != 0; thus s must contain at least one non-zero digit @@ -365,7 +368,7 @@ func (x nat) itoa(neg bool, base int) []byte { // range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and // ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for // specific hardware. -func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) { +func (q nat) convertWords(stk *stack, s []byte, b Word, ndigits int, bb Word, table []divisor) { // split larger blocks recursively if table != nil { // len(q) > leafSize > 0 @@ -386,12 +389,12 @@ func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []diviso } // split q into the two digit number (q'*bbb + r) to form independent subblocks - q, r = q.div(r, q, table[index].bbb) + q, r = q.div(stk, r, q, table[index].bbb) // convert subblocks and collect results in s[:h] and s[h:] h := len(s) - table[index].ndigits - r.convertWords(s[h:], b, ndigits, bb, table[0:index]) - s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1]) + r.convertWords(stk, s[h:], b, ndigits, bb, table[0:index]) + s = s[:h] // == q.convertWords(stk, s, b, ndigits, bb, table[0:index+1]) } } @@ -451,12 +454,12 @@ var cacheBase10 struct { } // expWW computes x**y -func (z nat) expWW(x, y Word) nat { - return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false) +func (z nat) expWW(stk *stack, x, y Word) nat { + return z.expNN(stk, nat(nil).setWord(x), nat(nil).setWord(y), nil, false) } // construct table of powers of bb*leafSize to use in subdivisions. -func divisors(m int, b Word, ndigits int, bb Word) []divisor { +func divisors(stk *stack, m int, b Word, ndigits int, bb Word) []divisor { // only compute table when recursive conversion is enabled and x is large if leafSize == 0 || m <= leafSize { return nil @@ -484,10 +487,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor { for i := 0; i < k; i++ { if table[i].ndigits == 0 { if i == 0 { - table[0].bbb = nat(nil).expWW(bb, Word(leafSize)) + table[0].bbb = nat(nil).expWW(stk, bb, Word(leafSize)) table[0].ndigits = ndigits * leafSize } else { - table[i].bbb = nat(nil).sqr(table[i-1].bbb) + table[i].bbb = nat(nil).sqr(stk, table[i-1].bbb) table[i].ndigits = 2 * table[i-1].ndigits } diff --git a/src/math/big/natconv_test.go b/src/math/big/natconv_test.go index d390272108..66300e412b 100644 --- a/src/math/big/natconv_test.go +++ b/src/math/big/natconv_test.go @@ -350,6 +350,9 @@ func BenchmarkStringPiParallel(b *testing.B) { } func BenchmarkScan(b *testing.B) { + stk := getStack() + defer stk.free() + const x = 10 for _, base := range []int{2, 8, 10, 16} { for _, y := range []Word{10, 100, 1000, 10000, 100000} { @@ -359,7 +362,7 @@ func BenchmarkScan(b *testing.B) { b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) { b.StopTimer() var z nat - z = z.expWW(x, y) + z = z.expWW(stk, x, y) s := z.utoa(base) if t := itoa(z, base); !bytes.Equal(s, t) { @@ -376,6 +379,9 @@ func BenchmarkScan(b *testing.B) { } func BenchmarkString(b *testing.B) { + stk := getStack() + defer stk.free() + const x = 10 for _, base := range []int{2, 8, 10, 16} { for _, y := range []Word{10, 100, 1000, 10000, 100000} { @@ -385,7 +391,7 @@ func BenchmarkString(b *testing.B) { b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) { b.StopTimer() var z nat - z = z.expWW(x, y) + z = z.expWW(stk, x, y) z.utoa(base) // warm divisor cache b.StartTimer() @@ -416,9 +422,11 @@ func LeafSizeHelper(b *testing.B, base, size int) { for d := 1; d <= 10000; d *= 10 { b.StopTimer() + stk := getStack() var z nat - z = z.expWW(Word(base), Word(d)) // build target number - _ = z.utoa(base) // warm divisor cache + z = z.expWW(stk, Word(base), Word(d)) // build target number + _ = z.utoa(base) // warm divisor cache + stk.free() b.StartTimer() for i := 0; i < b.N; i++ { @@ -443,13 +451,16 @@ func resetTable(table []divisor) { } func TestStringPowers(t *testing.T) { + stk := getStack() + defer stk.free() + var p Word for b := 2; b <= 16; b++ { for p = 0; p <= 512; p++ { if testing.Short() && p > 10 { break } - x := nat(nil).expWW(Word(b), p) + x := nat(nil).expWW(stk, Word(b), p) xs := x.utoa(b) xs2 := itoa(x, b) if !bytes.Equal(xs, xs2) { diff --git a/src/math/big/natdiv.go b/src/math/big/natdiv.go index 2e66e3425c..b514e2ce21 100644 --- a/src/math/big/natdiv.go +++ b/src/math/big/natdiv.go @@ -502,30 +502,24 @@ import "math/bits" // rem returns r such that r = u%v. // It uses z as the storage for r. -func (z nat) rem(u, v nat) (r nat) { +func (z nat) rem(stk *stack, u, v nat) (r nat) { if alias(z, u) { z = nil } - qp := getNat(0) - q, r := qp.div(z, u, v) - *qp = q - putNat(qp) + defer stk.restore(stk.save()) + q := stk.nat(len(u) - (len(v) - 1)) + _, r = q.div(stk, z, u, v) return r } // div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v. // It uses z and z2 as the storage for q and r. -func (z nat) div(z2, u, v nat) (q, r nat) { +// The caller may pass stk == nil to request that div obtain and release one itself. +func (z nat) div(stk *stack, z2, u, v nat) (q, r nat) { if len(v) == 0 { panic("division by zero") } - if u.cmp(v) < 0 { - q = z[:0] - r = z2.set(u) - return - } - if len(v) == 1 { // Short division: long optimized for a single-word divisor. // In that case, the 2-by-1 guess is all we need at each step. @@ -535,7 +529,18 @@ func (z nat) div(z2, u, v nat) (q, r nat) { return } - q, r = z.divLarge(z2, u, v) + if u.cmp(v) < 0 { + q = z[:0] + r = z2.set(u) + return + } + + if stk == nil { + stk = getStack() + defer stk.free() + } + + q, r = z.divLarge(stk, z2, u, v) return } @@ -589,7 +594,7 @@ func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { // It uses z and u as the storage for q and r. // The caller must ensure that len(vIn) ≥ 2 (use divW otherwise) // and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise). -func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { +func (z nat) divLarge(stk *stack, u, uIn, vIn nat) (q, r nat) { n := len(vIn) m := len(uIn) - n @@ -597,9 +602,9 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // vIn is treated as a read-only input (it may be in use by another // goroutine), so we must make a copy. // uIn is copied to u. + defer stk.restore(stk.save()) shift := nlz(vIn[n-1]) - vp := getNat(n) - v := *vp + v := stk.nat(n) shlVU(v, vIn, shift) u = u.make(len(uIn) + 1) u[len(uIn)] = shlVU(u[:len(uIn)], uIn, shift) @@ -613,11 +618,10 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // Use basic or recursive long division depending on size. if n < divRecursiveThreshold { - q.divBasic(u, v) + q.divBasic(stk, u, v) } else { - q.divRecursive(u, v) + q.divRecursive(stk, u, v) } - putNat(vp) q = q.norm() @@ -631,12 +635,12 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // divBasic implements long division as described above. // It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r. // q must be large enough to hold ⌊u/v⌋. -func (q nat) divBasic(u, v nat) { +func (q nat) divBasic(stk *stack, u, v nat) { n := len(v) m := len(u) - n - qhatvp := getNat(n + 1) - qhatv := *qhatvp + defer stk.restore(stk.save()) + qhatv := stk.nat(n + 1) // Set up for divWW below, precomputing reciprocal argument. vn1 := v[n-1] @@ -707,8 +711,6 @@ func (q nat) divBasic(u, v nat) { } q[j] = qhat } - - putNat(qhatvp) } // greaterThan reports whether the two digit numbers x1 x2 > y1 y2. @@ -727,24 +729,9 @@ const divRecursiveThreshold = 100 // z must be large enough to hold ⌊u/v⌋. // This function is just for allocating and freeing temporaries // around divRecursiveStep, the real implementation. -func (z nat) divRecursive(u, v nat) { - // Recursion depth is (much) less than 2 log₂(len(v)). - // Allocate a slice of temporaries to be reused across recursion, - // plus one extra temporary not live across the recursion. - recDepth := 2 * bits.Len(uint(len(v))) - tmp := getNat(3 * len(v)) - temps := make([]*nat, recDepth) - +func (z nat) divRecursive(stk *stack, u, v nat) { clear(z) - z.divRecursiveStep(u, v, 0, tmp, temps) - - // Free temporaries. - for _, n := range temps { - if n != nil { - putNat(n) - } - } - putNat(tmp) + z.divRecursiveStep(stk, u, v, 0) } // divRecursiveStep is the actual implementation of recursive division. @@ -752,7 +739,7 @@ func (z nat) divRecursive(u, v nat) { // z must be large enough to hold ⌊u/v⌋. // It uses temps[depth] (allocating if needed) as a temporary live across // the recursive call. It also uses tmp, but not live across the recursion. -func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { +func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { // u is a subsection of the original and may have leading zeros. // TODO(rsc): The v = v.norm() is useless and should be removed. // We know (and require) that v's top digit is ≥ B/2. @@ -766,7 +753,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { // Fall back to basic division if the problem is now small enough. n := len(v) if n < divRecursiveThreshold { - z.divBasic(u, v) + z.divBasic(stk, u, v) return } @@ -785,11 +772,8 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { B := n / 2 // Allocate a nat for qhat below. - if temps[depth] == nil { - temps[depth] = getNat(n) // TODO(rsc): Can be just B+1. - } else { - *temps[depth] = temps[depth].make(B + 1) - } + defer stk.restore(stk.save()) + qhat0 := stk.nat(B + 1) // Compute each wide digit of the quotient. // @@ -816,9 +800,9 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { uu := u[j-B:] // Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n]. - qhat := *temps[depth] + qhat := qhat0 clear(qhat) - qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps) + qhat.divRecursiveStep(stk, uu[s:B+n], v[s:], depth+1) qhat = qhat.norm() // Extend to a 3-by-2 quotient and remainder. @@ -833,9 +817,10 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { // q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u. // But we can do the subtraction directly, as in the comment above // and in long division, because we know that q̂ is wrong by at most one. - qhatv := tmp.make(3 * n) + mark := stk.save() + qhatv := stk.nat(3 * n) clear(qhatv) - qhatv = qhatv.mul(qhat, v[:s]) + qhatv = qhatv.mul(stk, qhat, v[:s]) for i := 0; i < 2; i++ { e := qhatv.cmp(uu.norm()) if e <= 0 { @@ -857,6 +842,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { } addAt(z, qhat, j-B) j -= B + stk.restore(mark) } // TODO(rsc): Rewrite loop as described above and delete all this code. @@ -864,13 +850,13 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { // Now u < (v< 0 { diff --git a/src/math/big/prime.go b/src/math/big/prime.go index 26688bbd64..bba5a07685 100644 --- a/src/math/big/prime.go +++ b/src/math/big/prime.go @@ -75,7 +75,9 @@ func (x *Int) ProbablyPrime(n int) bool { return false } - return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas() + stk := getStack() + defer stk.free() + return x.abs.probablyPrimeMillerRabin(stk, n+1, true) && x.abs.probablyPrimeLucas(stk) } // probablyPrimeMillerRabin reports whether n passes reps rounds of the @@ -83,7 +85,7 @@ func (x *Int) ProbablyPrime(n int) bool { // If force2 is true, one of the rounds is forced to use base 2. // See Handbook of Applied Cryptography, p. 139, Algorithm 4.24. // The number n is known to be non-zero. -func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool { +func (n nat) probablyPrimeMillerRabin(stk *stack, reps int, force2 bool) bool { nm1 := nat(nil).sub(n, natOne) // determine q, k such that nm1 = q << k k := nm1.trailingZeroBits() @@ -103,13 +105,13 @@ NextRandom: x = x.random(rand, nm3, nm3Len) x = x.add(x, natTwo) } - y = y.expNN(x, q, n, false) + y = y.expNN(stk, x, q, n, false) if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { continue } for j := uint(1); j < k; j++ { - y = y.sqr(y) - quotient, y = quotient.div(y, y, n) + y = y.sqr(stk, y) + quotient, y = quotient.div(stk, y, y, n) if y.cmp(nm1) == 0 { continue NextRandom } @@ -147,7 +149,7 @@ NextRandom: // // Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed. // Springer, 2005. -func (n nat) probablyPrimeLucas() bool { +func (n nat) probablyPrimeLucas(stk *stack) bool { // Discard 0, 1. if len(n) == 0 || n.cmp(natOne) == 0 { return false @@ -193,8 +195,8 @@ func (n nat) probablyPrimeLucas() bool { // We'll never find (d/n) = -1 if n is a square. // If n is a non-square we expect to find a d in just a few attempts on average. // After 40 attempts, take a moment to check if n is indeed a square. - t1 = t1.sqrt(n) - t1 = t1.sqr(t1) + t1 = t1.sqrt(stk, n) + t1 = t1.sqr(stk, t1) if t1.cmp(n) == 0 { return false } @@ -254,25 +256,25 @@ func (n nat) probablyPrimeLucas() bool { if s.bit(uint(i)) != 0 { // k' = 2k+1 // V(k') = V(2k+1) = V(k) V(k+1) - P. - t1 = t1.mul(vk, vk1) + t1 = t1.mul(stk, vk, vk1) t1 = t1.add(t1, n) t1 = t1.sub(t1, natP) - t2, vk = t2.div(vk, t1, n) + t2, vk = t2.div(stk, vk, t1, n) // V(k'+1) = V(2k+2) = V(k+1)² - 2. - t1 = t1.sqr(vk1) + t1 = t1.sqr(stk, vk1) t1 = t1.add(t1, nm2) - t2, vk1 = t2.div(vk1, t1, n) + t2, vk1 = t2.div(stk, vk1, t1, n) } else { // k' = 2k // V(k'+1) = V(2k+1) = V(k) V(k+1) - P. - t1 = t1.mul(vk, vk1) + t1 = t1.mul(stk, vk, vk1) t1 = t1.add(t1, n) t1 = t1.sub(t1, natP) - t2, vk1 = t2.div(vk1, t1, n) + t2, vk1 = t2.div(stk, vk1, t1, n) // V(k') = V(2k) = V(k)² - 2 - t1 = t1.sqr(vk) + t1 = t1.sqr(stk, vk) t1 = t1.add(t1, nm2) - t2, vk = t2.div(vk, t1, n) + t2, vk = t2.div(stk, vk, t1, n) } } @@ -285,7 +287,7 @@ func (n nat) probablyPrimeLucas() bool { // // Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n, // or P V(k) - 2 V(k+1) == 0 mod n. - t1 := t1.mul(vk, natP) + t1 := t1.mul(stk, vk, natP) t2 := t2.shl(vk1, 1) if t1.cmp(t2) < 0 { t1, t2 = t2, t1 @@ -294,7 +296,7 @@ func (n nat) probablyPrimeLucas() bool { t3 := vk1 // steal vk1, no longer needed below vk1 = nil _ = vk1 - t2, t3 = t2.div(t3, t1, n) + t2, t3 = t2.div(stk, t3, t1, n) if len(t3) == 0 { return true } @@ -312,9 +314,9 @@ func (n nat) probablyPrimeLucas() bool { } // k' = 2k // V(k') = V(2k) = V(k)² - 2 - t1 = t1.sqr(vk) + t1 = t1.sqr(stk, vk) t1 = t1.sub(t1, natTwo) - t2, vk = t2.div(vk, t1, n) + t2, vk = t2.div(stk, vk, t1, n) } return false } diff --git a/src/math/big/prime_test.go b/src/math/big/prime_test.go index 8596e33a13..2b1995bcb2 100644 --- a/src/math/big/prime_test.go +++ b/src/math/big/prime_test.go @@ -159,6 +159,9 @@ func TestProbablyPrime(t *testing.T) { } func BenchmarkProbablyPrime(b *testing.B) { + stk := getStack() + defer stk.free() + p, _ := new(Int).SetString("203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", 10) for _, n := range []int{0, 1, 5, 10, 20} { b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { @@ -170,26 +173,32 @@ func BenchmarkProbablyPrime(b *testing.B) { b.Run("Lucas", func(b *testing.B) { for i := 0; i < b.N; i++ { - p.abs.probablyPrimeLucas() + p.abs.probablyPrimeLucas(stk) } }) b.Run("MillerRabinBase2", func(b *testing.B) { for i := 0; i < b.N; i++ { - p.abs.probablyPrimeMillerRabin(1, true) + p.abs.probablyPrimeMillerRabin(stk, 1, true) } }) } func TestMillerRabinPseudoprimes(t *testing.T) { + stk := getStack() + defer stk.free() + testPseudoprimes(t, "probablyPrimeMillerRabin", - func(n nat) bool { return n.probablyPrimeMillerRabin(1, true) && !n.probablyPrimeLucas() }, + func(n nat) bool { return n.probablyPrimeMillerRabin(stk, 1, true) && !n.probablyPrimeLucas(stk) }, // https://oeis.org/A001262 []int{2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, 52633, 65281, 74665, 80581, 85489, 88357, 90751}) } func TestLucasPseudoprimes(t *testing.T) { + stk := getStack() + defer stk.free() + testPseudoprimes(t, "probablyPrimeLucas", - func(n nat) bool { return n.probablyPrimeLucas() && !n.probablyPrimeMillerRabin(1, true) }, + func(n nat) bool { return n.probablyPrimeLucas(stk) && !n.probablyPrimeMillerRabin(stk, 1, true) }, // https://oeis.org/A217719 []int{989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, 72389, 73919, 75077}) } diff --git a/src/math/big/rat.go b/src/math/big/rat.go index e58433ecea..ac94056a83 100644 --- a/src/math/big/rat.go +++ b/src/math/big/rat.go @@ -74,7 +74,7 @@ func (z *Rat) SetFloat64(f float64) *Rat { // nearest to the quotient a/b, using round-to-even in // halfway cases. It does not mutate its arguments. // Preconditions: b is non-zero; a and b have no common factors. -func quotToFloat32(a, b nat) (f float32, exact bool) { +func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) { const ( // float size in bits Fsize = 32 @@ -121,7 +121,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) { // extra shift, the low-order bit of q is logically the // high-order bit of r. var q nat - q, r := q.div(a2, a2, b2) // (recycle a2) + q, r := q.div(stk, a2, a2, b2) // (recycle a2) mantissa := low32(q) haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half @@ -172,7 +172,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) { // nearest to the quotient a/b, using round-to-even in // halfway cases. It does not mutate its arguments. // Preconditions: b is non-zero; a and b have no common factors. -func quotToFloat64(a, b nat) (f float64, exact bool) { +func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) { const ( // float size in bits Fsize = 64 @@ -219,7 +219,7 @@ func quotToFloat64(a, b nat) (f float64, exact bool) { // extra shift, the low-order bit of q is logically the // high-order bit of r. var q nat - q, r := q.div(a2, a2, b2) // (recycle a2) + q, r := q.div(stk, a2, a2, b2) // (recycle a2) mantissa := low64(q) haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half @@ -275,7 +275,9 @@ func (x *Rat) Float32() (f float32, exact bool) { if len(b) == 0 { b = natOne } - f, exact = quotToFloat32(x.a.abs, b) + stk := getStack() + defer stk.free() + f, exact = quotToFloat32(stk, x.a.abs, b) if x.a.neg { f = -f } @@ -291,7 +293,9 @@ func (x *Rat) Float64() (f float64, exact bool) { if len(b) == 0 { b = natOne } - f, exact = quotToFloat64(x.a.abs, b) + stk := getStack() + defer stk.free() + f, exact = quotToFloat64(stk, x.a.abs, b) if x.a.neg { f = -f } @@ -437,12 +441,14 @@ func (z *Rat) norm() *Rat { z.b.abs = z.b.abs.setWord(1) default: // z is fraction; normalize numerator and denominator + stk := getStack() + defer stk.free() neg := z.a.neg z.a.neg = false z.b.neg = false if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 { - z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs) - z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs) + z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs) + z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs) } z.a.neg = neg } @@ -452,7 +458,7 @@ func (z *Rat) norm() *Rat { // mulDenom sets z to the denominator product x*y (by taking into // account that 0 values for x or y must be interpreted as 1) and // returns z. -func mulDenom(z, x, y nat) nat { +func mulDenom(stk *stack, z, x, y nat) nat { switch { case len(x) == 0 && len(y) == 0: return z.setWord(1) @@ -461,17 +467,17 @@ func mulDenom(z, x, y nat) nat { case len(y) == 0: return z.set(x) } - return z.mul(x, y) + return z.mul(stk, x, y) } // scaleDenom sets z to the product x*f. // If f == 0 (zero value of denominator), z is set to (a copy of) x. -func (z *Int) scaleDenom(x *Int, f nat) { +func (z *Int) scaleDenom(stk *stack, x *Int, f nat) { if len(f) == 0 { z.Set(x) return } - z.abs = z.abs.mul(x.abs, f) + z.abs = z.abs.mul(stk, x.abs, f) z.neg = x.neg } @@ -481,58 +487,73 @@ func (z *Int) scaleDenom(x *Int, f nat) { // - +1 if x > y. func (x *Rat) Cmp(y *Rat) int { var a, b Int - a.scaleDenom(&x.a, y.b.abs) - b.scaleDenom(&y.a, x.b.abs) + stk := getStack() + defer stk.free() + a.scaleDenom(stk, &x.a, y.b.abs) + b.scaleDenom(stk, &y.a, x.b.abs) return a.Cmp(&b) } // Add sets z to the sum x+y and returns z. func (z *Rat) Add(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + var a1, a2 Int - a1.scaleDenom(&x.a, y.b.abs) - a2.scaleDenom(&y.a, x.b.abs) + a1.scaleDenom(stk, &x.a, y.b.abs) + a2.scaleDenom(stk, &y.a, x.b.abs) z.a.Add(&a1, &a2) - z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Sub sets z to the difference x-y and returns z. func (z *Rat) Sub(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + var a1, a2 Int - a1.scaleDenom(&x.a, y.b.abs) - a2.scaleDenom(&y.a, x.b.abs) + a1.scaleDenom(stk, &x.a, y.b.abs) + a2.scaleDenom(stk, &y.a, x.b.abs) z.a.Sub(&a1, &a2) - z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Mul sets z to the product x*y and returns z. func (z *Rat) Mul(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + if x == y { // a squared Rat is positive and can't be reduced (no need to call norm()) z.a.neg = false - z.a.abs = z.a.abs.sqr(x.a.abs) + z.a.abs = z.a.abs.sqr(stk, x.a.abs) if len(x.b.abs) == 0 { z.b.abs = z.b.abs.setWord(1) } else { - z.b.abs = z.b.abs.sqr(x.b.abs) + z.b.abs = z.b.abs.sqr(stk, x.b.abs) } return z } - z.a.Mul(&x.a, &y.a) - z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + + z.a.mul(stk, &x.a, &y.a) + z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Quo sets z to the quotient x/y and returns z. // If y == 0, Quo panics. func (z *Rat) Quo(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + if len(y.a.abs) == 0 { panic("division by zero") } var a, b Int - a.scaleDenom(&x.a, y.b.abs) - b.scaleDenom(&y.a, x.b.abs) + a.scaleDenom(stk, &x.a, y.b.abs) + b.scaleDenom(stk, &y.a, x.b.abs) z.a.abs = a.abs z.b.abs = b.abs z.a.neg = a.neg != b.neg diff --git a/src/math/big/ratconv.go b/src/math/big/ratconv.go index 12f9888c37..84602ff455 100644 --- a/src/math/big/ratconv.go +++ b/src/math/big/ratconv.go @@ -163,6 +163,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) { } // exp consumed - not needed anymore + stk := getStack() + defer stk.free() + // apply exp5 contributions // (start with exp5 so the numbers to multiply are smaller) if exp5 != 0 { @@ -178,9 +181,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) { if n > 1e6 { return nil, false // avoid excessively large exponents } - pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs + pow5 := z.b.abs.expNN(stk, natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs if exp5 > 0 { - z.a.abs = z.a.abs.mul(z.a.abs, pow5) + z.a.abs = z.a.abs.mul(stk, z.a.abs, pow5) z.b.abs = z.b.abs.setWord(1) } else { z.b.abs = pow5 @@ -343,15 +346,17 @@ func (x *Rat) FloatString(prec int) string { } // x.b.abs != 0 - q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs) + stk := getStack() + defer stk.free() + q, r := nat(nil).div(stk, nat(nil), x.a.abs, x.b.abs) p := natOne if prec > 0 { - p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false) + p = nat(nil).expNN(stk, natTen, nat(nil).setUint64(uint64(prec)), nil, false) } - r = r.mul(r, p) - r, r2 := r.div(nat(nil), r, x.b.abs) + r = r.mul(stk, r, p) + r, r2 := r.div(stk, nat(nil), r, x.b.abs) // see if we need to round up r2 = r2.add(r2, r2) @@ -398,6 +403,9 @@ func (x *Rat) FloatString(prec int) string { // 1/4 2 true 0.25 // 1/6 1 false 0.2 (0.166... rounded) func (x *Rat) FloatPrec() (n int, exact bool) { + stk := getStack() + defer stk.free() + // Determine q and largest p2, p5 such that d = q·2^p2·5^p5. // The results n, exact are: // @@ -425,11 +433,11 @@ func (x *Rat) FloatPrec() (n int, exact bool) { f := nat{1220703125} // == 5^fp (must fit into a uint32 Word) var t, r nat // temporaries for { - if _, r = t.div(r, q, f); len(r) != 0 { + if _, r = t.div(stk, r, q, f); len(r) != 0 { break // f doesn't divide q evenly } tab = append(tab, f) - f = nat(nil).sqr(f) // nat(nil) to ensure a new f for each table entry + f = nat(nil).sqr(stk, f) // nat(nil) to ensure a new f for each table entry } // Factor q using the table entries, if any. @@ -441,7 +449,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) { // The same reasoning applies to the subsequent factors. var p5 uint for i := len(tab) - 1; i >= 0; i-- { - if t, r = t.div(r, q, tab[i]); len(r) == 0 { + if t, r = t.div(stk, r, q, tab[i]); len(r) == 0 { p5 += fp * (1 << i) // tab[i] == 5^(fp·2^i) q = q.set(t) } @@ -449,7 +457,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) { // If fp != 1, we may still have multiples of 5 left. for { - if t, r = t.div(r, q, natFive); len(r) != 0 { + if t, r = t.div(stk, r, q, natFive); len(r) != 0 { break } p5++ -- 2.50.0