]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: replace nat pool with Word stack
authorRuss Cox <rsc@golang.org>
Fri, 17 Jan 2025 17:28:58 +0000 (12:28 -0500)
committerGopher Robot <gobot@golang.org>
Thu, 27 Feb 2025 13:58:09 +0000 (05:58 -0800)
In the early days of math/big, algorithms that needed more space
grew the result larger than it needed to be and then used the
high words as extra space. This made results their own temporary
space caches, at the cost that saving a result in a data structure
might hold significantly more memory than necessary.
Specifically, new(big.Int).Mul(x, y) returned a big.Int with a
backing slice 3X as big as it strictly needed to be.
If you are storing many multiplication results, or even a single
large result, the 3X overhead can add up.

This approach to storage for temporaries also requires being able
to analyze the algorithms to predict the exact amount they need,
which can be difficult.

For both these reasons, the implementation of recursive long division,
which came later, introduced a “nat pool” where temporaries could be
stored and reused, or reclaimed by the GC when no longer used.
This avoids the storage and bookkeeping overheads but introduces a
per-temporary sync.Pool overhead. divRecursiveStep takes an array
of cached temporaries to remove some of that overhead.
The nat pool was better but is still not quite right.

This CL introduces something even better than the nat pool
(still probably not quite right, but the best I can see for now):
a sync.Pool holding stacks for allocating temporaries.
Now an operation can get one stack out of the pool and then
allocate as many temporaries as it needs during the operation,
eventually returning the stack back to the pool. The sync.Pool
operations are now per-exported-operation (like big.Int.Mul),
not per-temporary.

This CL converts both the pre-allocation in nat.mul and the
uses of the nat pool to use stack pools instead. This simplifies
some code and sets us up better for more complex algorithms
(such as Toom-Cook or FFT-based multiplication) that need
more temporaries. It is also a little bit faster.

goos: linux
goarch: amd64
pkg: math/big
cpu: Intel(R) Xeon(R) CPU @ 3.10GHz
                         │     old     │                 new                 │
                         │   sec/op    │   sec/op     vs base                │
Div/20/10-16               23.68n ± 0%   22.21n ± 0%   -6.21% (p=0.000 n=15)
Div/40/20-16               23.68n ± 0%   22.21n ± 0%   -6.21% (p=0.000 n=15)
Div/100/50-16              56.65n ± 0%   55.53n ± 0%   -1.98% (p=0.000 n=15)
Div/200/100-16             194.6n ± 1%   172.8n ± 0%  -11.20% (p=0.000 n=15)
Div/400/200-16             232.1n ± 0%   206.7n ± 0%  -10.94% (p=0.000 n=15)
Div/1000/500-16            405.3n ± 1%   383.8n ± 0%   -5.30% (p=0.000 n=15)
Div/2000/1000-16           810.4n ± 1%   795.2n ± 0%   -1.88% (p=0.000 n=15)
Div/20000/10000-16         25.88µ ± 0%   25.39µ ± 0%   -1.89% (p=0.000 n=15)
Div/200000/100000-16       931.5µ ± 0%   924.3µ ± 0%   -0.77% (p=0.000 n=15)
Div/2000000/1000000-16     37.77m ± 0%   37.75m ± 0%        ~ (p=0.098 n=15)
Div/20000000/10000000-16    1.367 ± 0%    1.377 ± 0%   +0.72% (p=0.003 n=15)
NatMul/10-16               168.5n ± 3%   164.0n ± 4%        ~ (p=0.751 n=15)
NatMul/100-16              6.086µ ± 3%   5.380µ ± 3%  -11.60% (p=0.000 n=15)
NatMul/1000-16             238.1µ ± 3%   228.3µ ± 1%   -4.12% (p=0.000 n=15)
NatMul/10000-16            8.721m ± 2%   8.518m ± 1%   -2.33% (p=0.000 n=15)
NatMul/100000-16           369.6m ± 0%   371.1m ± 0%   +0.42% (p=0.000 n=15)
geomean                    19.57µ        18.74µ        -4.21%

                 │     old      │                  new                   │
                 │     B/op     │     B/op      vs base                  │
NatMul/10-16         192.0 ± 0%     192.0 ± 0%        ~ (p=1.000 n=15) ¹
NatMul/100-16      4.750Ki ± 0%   1.751Ki ± 0%  -63.14% (p=0.000 n=15)
NatMul/1000-16     48.16Ki ± 0%   16.02Ki ± 0%  -66.73% (p=0.000 n=15)
NatMul/10000-16    482.9Ki ± 1%   165.4Ki ± 3%  -65.75% (p=0.000 n=15)
NatMul/100000-16   5.747Mi ± 7%   4.197Mi ± 0%  -26.97% (p=0.000 n=15)
geomean            41.42Ki        20.63Ki       -50.18%
¹ all samples are equal

                 │     old     │                 new                  │
                 │  allocs/op  │  allocs/op   vs base                 │
NatMul/10-16       1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/100-16      1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/1000-16     1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/10000-16    1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/100000-16   7.000 ± 14%   7.000 ± 14%       ~ (p=0.668 n=15)
geomean            1.476         1.476        +0.00%
¹ all samples are equal

goos: linux
goarch: amd64
pkg: math/big
cpu: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
                         │     old     │                 new                 │
                         │   sec/op    │   sec/op     vs base                │
Div/20/10-88               15.84n ± 1%   13.12n ± 0%  -17.17% (p=0.000 n=15)
Div/40/20-88               15.88n ± 1%   13.12n ± 0%  -17.38% (p=0.000 n=15)
Div/100/50-88              26.42n ± 0%   25.47n ± 0%   -3.60% (p=0.000 n=15)
Div/200/100-88             132.4n ± 0%   114.9n ± 0%  -13.22% (p=0.000 n=15)
Div/400/200-88             150.1n ± 0%   135.6n ± 0%   -9.66% (p=0.000 n=15)
Div/1000/500-88            275.5n ± 0%   264.1n ± 0%   -4.14% (p=0.000 n=15)
Div/2000/1000-88           586.5n ± 0%   581.1n ± 0%   -0.92% (p=0.000 n=15)
Div/20000/10000-88         25.87µ ± 0%   25.72µ ± 0%   -0.59% (p=0.000 n=15)
Div/200000/100000-88       772.2µ ± 0%   779.0µ ± 0%   +0.88% (p=0.000 n=15)
Div/2000000/1000000-88     33.36m ± 0%   33.63m ± 0%   +0.80% (p=0.000 n=15)
Div/20000000/10000000-88    1.307 ± 0%    1.320 ± 0%   +1.03% (p=0.000 n=15)
NatMul/10-88               140.4n ± 0%   148.8n ± 4%   +5.98% (p=0.000 n=15)
NatMul/100-88              4.663µ ± 1%   4.388µ ± 1%   -5.90% (p=0.000 n=15)
NatMul/1000-88             207.7µ ± 0%   205.8µ ± 0%   -0.89% (p=0.000 n=15)
NatMul/10000-88            8.456m ± 0%   8.468m ± 0%   +0.14% (p=0.021 n=15)
NatMul/100000-88           295.1m ± 0%   297.9m ± 0%   +0.94% (p=0.000 n=15)
geomean                    14.96µ        14.33µ        -4.23%

                 │     old      │                   new                   │
                 │     B/op     │     B/op       vs base                  │
NatMul/10-88         192.0 ± 0%     192.0 ±  0%        ~ (p=1.000 n=15) ¹
NatMul/100-88      4.750Ki ± 0%   1.758Ki ±  0%  -62.99% (p=0.000 n=15)
NatMul/1000-88     48.44Ki ± 0%   16.08Ki ±  0%  -66.80% (p=0.000 n=15)
NatMul/10000-88    489.7Ki ± 1%   166.1Ki ±  3%  -66.08% (p=0.000 n=15)
NatMul/100000-88   5.546Mi ± 0%   3.819Mi ± 60%  -31.15% (p=0.000 n=15)
geomean            41.29Ki        20.30Ki        -50.85%
¹ all samples are equal

                 │     old     │                 new                  │
                 │  allocs/op  │  allocs/op   vs base                 │
NatMul/10-88       1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/100-88      1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/1000-88     1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/10000-88    1.000 ±  0%   1.000 ±  0%       ~ (p=1.000 n=15) ¹
NatMul/100000-88   5.000 ± 20%   6.000 ± 67%       ~ (p=0.672 n=15)
geomean            1.380         1.431        +3.71%
¹ all samples are equal

goos: linux
goarch: arm64
pkg: math/big
                         │     old     │                 new                 │
                         │   sec/op    │   sec/op     vs base                │
Div/20/10-16               15.85n ± 0%   15.23n ± 0%   -3.91% (p=0.000 n=15)
Div/40/20-16               15.88n ± 0%   15.22n ± 0%   -4.16% (p=0.000 n=15)
Div/100/50-16              29.69n ± 0%   26.39n ± 0%  -11.11% (p=0.000 n=15)
Div/200/100-16             149.2n ± 0%   123.3n ± 0%  -17.36% (p=0.000 n=15)
Div/400/200-16             160.3n ± 0%   139.2n ± 0%  -13.16% (p=0.000 n=15)
Div/1000/500-16            271.0n ± 0%   256.1n ± 0%   -5.50% (p=0.000 n=15)
Div/2000/1000-16           545.3n ± 0%   527.0n ± 0%   -3.36% (p=0.000 n=15)
Div/20000/10000-16         22.60µ ± 0%   22.20µ ± 0%   -1.77% (p=0.000 n=15)
Div/200000/100000-16       889.0µ ± 0%   892.2µ ± 0%   +0.35% (p=0.000 n=15)
Div/2000000/1000000-16     38.01m ± 0%   38.12m ± 0%   +0.30% (p=0.000 n=15)
Div/20000000/10000000-16    1.437 ± 0%    1.444 ± 0%   +0.50% (p=0.000 n=15)
NatMul/10-16               166.4n ± 2%   169.5n ± 1%   +1.86% (p=0.000 n=15)
NatMul/100-16              5.733µ ± 1%   5.570µ ± 1%   -2.84% (p=0.000 n=15)
NatMul/1000-16             232.6µ ± 1%   229.8µ ± 0%   -1.22% (p=0.000 n=15)
NatMul/10000-16            9.039m ± 1%   8.969m ± 0%   -0.77% (p=0.000 n=15)
NatMul/100000-16           367.0m ± 0%   368.8m ± 0%   +0.48% (p=0.000 n=15)
geomean                    16.15µ        15.50µ        -4.01%

                 │     old      │                  new                   │
                 │     B/op     │     B/op      vs base                  │
NatMul/10-16         192.0 ± 0%     192.0 ± 0%        ~ (p=1.000 n=15) ¹
NatMul/100-16      4.750Ki ± 0%   1.751Ki ± 0%  -63.14% (p=0.000 n=15)
NatMul/1000-16     48.33Ki ± 0%   16.02Ki ± 0%  -66.85% (p=0.000 n=15)
NatMul/10000-16    536.5Ki ± 1%   165.7Ki ± 3%  -69.12% (p=0.000 n=15)
NatMul/100000-16   6.078Mi ± 6%   4.197Mi ± 0%  -30.94% (p=0.000 n=15)
geomean            42.81Ki        20.64Ki       -51.78%
¹ all samples are equal

                 │     old     │                  new                  │
                 │  allocs/op  │  allocs/op   vs base                  │
NatMul/10-16       1.000 ±  0%   1.000 ±  0%        ~ (p=1.000 n=15) ¹
NatMul/100-16      1.000 ±  0%   1.000 ±  0%        ~ (p=1.000 n=15) ¹
NatMul/1000-16     1.000 ±  0%   1.000 ±  0%        ~ (p=1.000 n=15) ¹
NatMul/10000-16    2.000 ± 50%   1.000 ±  0%  -50.00% (p=0.001 n=15)
NatMul/100000-16   9.000 ± 11%   8.000 ± 12%  -11.11% (p=0.001 n=15)
geomean            1.783         1.516        -14.97%
¹ all samples are equal

goos: darwin
goarch: arm64
pkg: math/big
cpu: Apple M3 Pro
                         │     old      │                new                 │
                         │    sec/op    │   sec/op     vs base               │
Div/20/10-12                9.850n ± 1%   9.405n ± 1%  -4.52% (p=0.000 n=15)
Div/40/20-12                9.858n ± 0%   9.403n ± 1%  -4.62% (p=0.000 n=15)
Div/100/50-12               16.40n ± 1%   14.81n ± 0%  -9.70% (p=0.000 n=15)
Div/200/100-12              88.48n ± 2%   80.88n ± 0%  -8.59% (p=0.000 n=15)
Div/400/200-12             107.90n ± 1%   99.28n ± 1%  -7.99% (p=0.000 n=15)
Div/1000/500-12             188.8n ± 1%   178.6n ± 1%  -5.40% (p=0.000 n=15)
Div/2000/1000-12            399.9n ± 0%   389.1n ± 0%  -2.70% (p=0.000 n=15)
Div/20000/10000-12          13.94µ ± 2%   13.81µ ± 1%       ~ (p=0.574 n=15)
Div/200000/100000-12        523.8µ ± 0%   521.7µ ± 0%  -0.40% (p=0.000 n=15)
Div/2000000/1000000-12      21.46m ± 0%   21.48m ± 0%       ~ (p=0.067 n=15)
Div/20000000/10000000-12    812.5m ± 0%   812.9m ± 0%       ~ (p=0.061 n=15)
NatMul/10-12                77.14n ± 0%   78.35n ± 1%  +1.57% (p=0.000 n=15)
NatMul/100-12               2.999µ ± 0%   2.871µ ± 1%  -4.27% (p=0.000 n=15)
NatMul/1000-12              126.2µ ± 0%   126.8µ ± 0%  +0.51% (p=0.011 n=15)
NatMul/10000-12             5.099m ± 0%   5.125m ± 0%  +0.51% (p=0.000 n=15)
NatMul/100000-12            206.7m ± 0%   208.4m ± 0%  +0.80% (p=0.000 n=15)
geomean                     9.512µ        9.236µ       -2.91%

                 │     old      │                   new                    │
                 │     B/op     │      B/op       vs base                  │
NatMul/10-12         192.0 ± 0%     192.0 ±   0%        ~ (p=1.000 n=15) ¹
NatMul/100-12      4.750Ki ± 0%   1.750Ki ±   0%  -63.16% (p=0.000 n=15)
NatMul/1000-12     48.13Ki ± 0%   16.01Ki ±   0%  -66.73% (p=0.000 n=15)
NatMul/10000-12    483.5Ki ± 1%   163.2Ki ±   2%  -66.24% (p=0.000 n=15)
NatMul/100000-12   5.480Mi ± 4%   1.532Mi ± 104%  -72.05% (p=0.000 n=15)
geomean            41.03Ki        16.82Ki         -59.01%
¹ all samples are equal

                 │    old     │                  new                   │
                 │ allocs/op  │  allocs/op    vs base                  │
NatMul/10-12       1.000 ± 0%   1.000 ±   0%        ~ (p=1.000 n=15) ¹
NatMul/100-12      1.000 ± 0%   1.000 ±   0%        ~ (p=1.000 n=15) ¹
NatMul/1000-12     1.000 ± 0%   1.000 ±   0%        ~ (p=1.000 n=15) ¹
NatMul/10000-12    1.000 ± 0%   1.000 ±   0%        ~ (p=1.000 n=15) ¹
NatMul/100000-12   5.000 ± 0%   1.000 ± 400%  -80.00% (p=0.007 n=15)
geomean            1.380        1.000         -27.52%
¹ all samples are equal

Change-Id: I7efa6fe37971ed26ae120a32250fcb47ece0a011
Reviewed-on: https://go-review.googlesource.com/c/go/+/650638
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Russ Cox <rsc@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
12 files changed:
src/math/big/arith_test.go
src/math/big/float.go
src/math/big/int.go
src/math/big/nat.go
src/math/big/nat_test.go
src/math/big/natconv.go
src/math/big/natconv_test.go
src/math/big/natdiv.go
src/math/big/prime.go
src/math/big/prime_test.go
src/math/big/rat.go
src/math/big/ratconv.go

index 64225bbd53661ce9ad1fd034c5d135bf6c4caafa..feffa1bc95d80c4bb82ca76e6c9508159dd97582 100644 (file)
@@ -368,9 +368,12 @@ func TestShiftOverlap(t *testing.T) {
 }
 
 func TestIssue31084(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        // compute 10^n via 5^n << n.
        const n = 165
-       p := nat(nil).expNN(nat{5}, nat{n}, nil, false)
+       p := nat(nil).expNN(stk, nat{5}, nat{n}, nil, false)
        p = p.shl(p, n)
        got := string(p.utoa(10))
        want := "1" + strings.Repeat("0", n)
index e1d20d8bb4c0088920e3069d73bdac84968dbcc9..2c5234a4ceac77cc8e2628532b066577c3541abd 100644 (file)
@@ -1327,9 +1327,9 @@ func (z *Float) umul(x, y *Float) {
 
        e := int64(x.exp) + int64(y.exp)
        if x == y {
-               z.mant = z.mant.sqr(x.mant)
+               z.mant = z.mant.sqr(nil, x.mant)
        } else {
-               z.mant = z.mant.mul(x.mant, y.mant)
+               z.mant = z.mant.mul(nil, x.mant, y.mant)
        }
        z.setExpAndRound(e-fnorm(z.mant), 0)
 }
@@ -1363,8 +1363,10 @@ func (z *Float) uquo(x, y *Float) {
        d := len(xadj) - len(y.mant)
 
        // divide
+       stk := getStack()
+       defer stk.free()
        var r nat
-       z.mant, r = z.mant.div(nil, xadj, y.mant)
+       z.mant, r = z.mant.div(stk, nil, xadj, y.mant)
        e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W
 
        // The result is long enough to include (at least) the rounding bit.
index 0b710c69681539fe4493813e62dcba70042763bf..cb7221250da10756eef5adca5823f50b5ebebde9 100644 (file)
@@ -181,16 +181,20 @@ func (z *Int) Sub(x, y *Int) *Int {
 
 // Mul sets z to the product x*y and returns z.
 func (z *Int) Mul(x, y *Int) *Int {
+       return z.mul(nil, x, y)
+}
+
+func (z *Int) mul(stk *stack, x, y *Int) *Int {
        // x * y == x * y
        // x * (-y) == -(x * y)
        // (-x) * y == -(x * y)
        // (-x) * (-y) == x * y
        if x == y {
-               z.abs = z.abs.sqr(x.abs)
+               z.abs = z.abs.sqr(stk, x.abs)
                z.neg = false
                return z
        }
-       z.abs = z.abs.mul(x.abs, y.abs)
+       z.abs = z.abs.mul(stk, x.abs, y.abs)
        z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
        return z
 }
@@ -213,7 +217,7 @@ func (z *Int) MulRange(a, b int64) *Int {
                a, b = -b, -a
        }
 
-       z.abs = z.abs.mulRange(uint64(a), uint64(b))
+       z.abs = z.abs.mulRange(nil, uint64(a), uint64(b))
        z.neg = neg
        return z
 }
@@ -264,7 +268,7 @@ func (z *Int) Binomial(n, k int64) *Int {
 // If y == 0, a division-by-zero run-time panic occurs.
 // Quo implements truncated division (like Go); see [Int.QuoRem] for more details.
 func (z *Int) Quo(x, y *Int) *Int {
-       z.abs, _ = z.abs.div(nil, x.abs, y.abs)
+       z.abs, _ = z.abs.div(nil, nil, x.abs, y.abs)
        z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
        return z
 }
@@ -273,7 +277,7 @@ func (z *Int) Quo(x, y *Int) *Int {
 // If y == 0, a division-by-zero run-time panic occurs.
 // Rem implements truncated modulus (like Go); see [Int.QuoRem] for more details.
 func (z *Int) Rem(x, y *Int) *Int {
-       _, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
+       _, z.abs = nat(nil).div(nil, z.abs, x.abs, y.abs)
        z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
        return z
 }
@@ -290,7 +294,7 @@ func (z *Int) Rem(x, y *Int) *Int {
 // (See Daan Leijen, “Division and Modulus for Computer Scientists”.)
 // See [Int.DivMod] for Euclidean division and modulus (unlike Go).
 func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
-       z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs)
+       z.abs, r.abs = z.abs.div(nil, r.abs, x.abs, y.abs)
        z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign
        return z, r
 }
@@ -589,7 +593,7 @@ func (z *Int) exp(x, y, m *Int, slow bool) *Int {
                mWords = m.abs // m.abs may be nil for m == 0
        }
 
-       z.abs = z.abs.expNN(xWords, yWords, mWords, slow)
+       z.abs = z.abs.expNN(nil, xWords, yWords, mWords, slow)
        z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign
        if z.neg && len(mWords) > 0 {
                // make modulus result positive
@@ -1298,6 +1302,6 @@ func (z *Int) Sqrt(x *Int) *Int {
                panic("square root of negative number")
        }
        z.neg = false
-       z.abs = z.abs.sqrt(x.abs)
+       z.abs = z.abs.sqrt(nil, x.abs)
        return z
 }
index 541da229d6154ab68c44d02e41d40233967beb17..ec75c8f6fdafe8ed79280ba55fa2108a13a2c38e 100644 (file)
@@ -17,6 +17,7 @@ import (
        "internal/byteorder"
        "math/bits"
        "math/rand"
+       "slices"
        "sync"
 )
 
@@ -262,9 +263,9 @@ var karatsubaThreshold = 40 // computed by calibrate_test.go
 
 // karatsuba multiplies x and y and leaves the result in z.
 // Both x and y must have the same length n and n must be a
-// power of 2. The result vector z must have len(z) >= 6*n.
-// The (non-normalized) result is placed in z[0 : 2*n].
-func karatsuba(z, x, y nat) {
+// power of 2. The result vector z must have len(z) == len(x)+len(y).
+// The (non-normalized) result is placed in z.
+func karatsuba(stk *stack, z, x, y nat) {
        n := len(y)
 
        // Switch to basic multiplication if numbers are odd or small.
@@ -304,29 +305,19 @@ func karatsuba(z, x, y nat) {
        x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
        y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
 
-       // z is used for the result and temporary storage:
-       //
-       //   6*n     5*n     4*n     3*n     2*n     1*n     0*n
-       // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
-       //
-       // For each recursive call of karatsuba, an unused slice of
-       // z is passed in that has (at least) half the length of the
-       // caller's z.
-
        // compute z0 and z2 with the result "in place" in z
-       karatsuba(z, x0, y0)     // z0 = x0*y0
-       karatsuba(z[n:], x1, y1) // z2 = x1*y1
+       karatsuba(stk, z, x0, y0)     // z0 = x0*y0
+       karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1
 
-       // compute xd (or the negative value if underflow occurs)
+       // compute xd, yd (or the negative value if underflow occurs)
        s := 1 // sign of product xd*yd
-       xd := z[2*n : 2*n+n2]
+       defer stk.restore(stk.save())
+       xd := stk.nat(n2)
+       yd := stk.nat(n2)
        if subVV(xd, x1, x0) != 0 { // x1-x0
                s = -s
                subVV(xd, x0, x1) // x0-x1
        }
-
-       // compute yd (or the negative value if underflow occurs)
-       yd := z[2*n+n2 : 3*n]
        if subVV(yd, y0, y1) != 0 { // y0-y1
                s = -s
                subVV(yd, y1, y0) // y1-y0
@@ -334,12 +325,12 @@ func karatsuba(z, x, y nat) {
 
        // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
        // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
-       p := z[n*3:]
-       karatsuba(p, xd, yd)
+       p := stk.nat(2 * n2)
+       karatsuba(stk, p, xd, yd)
 
        // save original z2:z0
        // (ok to use upper half of z since we're done recurring)
-       r := z[n*4:]
+       r := stk.nat(n * 2)
        copy(r, z[:n*2])
 
        // add up all partial products
@@ -396,13 +387,15 @@ func karatsubaLen(n, threshold int) int {
        return n << i
 }
 
-func (z nat) mul(x, y nat) nat {
+// mul sets z = x*y, using stk for temporary storage.
+// The caller may pass stk == nil to request that mul obtain and release one itself.
+func (z nat) mul(stk *stack, x, y nat) nat {
        m := len(x)
        n := len(y)
 
        switch {
        case m < n:
-               return z.mul(y, x)
+               return z.mul(stk, y, x)
        case m == 0 || n == 0:
                return z[:0]
        case n == 1:
@@ -432,12 +425,16 @@ func (z nat) mul(x, y nat) nat {
        k := karatsubaLen(n, karatsubaThreshold)
        // k <= n
 
+       if stk == nil {
+               stk = getStack()
+               defer stk.free()
+       }
+
        // multiply x0 and y0 via Karatsuba
-       x0 := x[0:k]              // x0 is not normalized
-       y0 := y[0:k]              // y0 is not normalized
-       z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
-       karatsuba(z, x0, y0)
-       z = z[0 : m+n] // z has final length but may be incomplete
+       x0 := x[0:k]      // x0 is not normalized
+       y0 := y[0:k]      // y0 is not normalized
+       z = z.make(m + n) // enough space for full result of x*y
+       karatsuba(stk, z, x0, y0)
        clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
 
        // If xh != 0 or yh != 0, add the missing terms to z. For
@@ -454,13 +451,13 @@ func (z nat) mul(x, y nat) nat {
        // be a larger valid threshold contradicting the assumption about k.
        //
        if k < n || m != n {
-               tp := getNat(3 * k)
-               t := *tp
+               defer stk.restore(stk.save())
+               t := stk.nat(3 * k)
 
                // add x0*y1*b
                x0 := x0.norm()
-               y1 := y[k:]       // y1 is normalized because y is
-               t = t.mul(x0, y1) // update t so we don't lose t's underlying array
+               y1 := y[k:]            // y1 is normalized because y is
+               t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array
                addAt(z, t, k)
 
                // add xi*y0<<i, xi*y1*b<<(i+k)
@@ -471,13 +468,11 @@ func (z nat) mul(x, y nat) nat {
                                xi = xi[:k]
                        }
                        xi = xi.norm()
-                       t = t.mul(xi, y0)
+                       t = t.mul(stk, xi, y0)
                        addAt(z, t, i)
-                       t = t.mul(xi, y1)
+                       t = t.mul(stk, xi, y1)
                        addAt(z, t, i+k)
                }
-
-               putNat(tp)
        }
 
        return z.norm()
@@ -487,10 +482,10 @@ func (z nat) mul(x, y nat) nat {
 // by about a factor of 2, but slower for small arguments due to overhead.
 // Requirements: len(x) > 0, len(z) == 2*len(x)
 // The (non-normalized) result is placed in z.
-func basicSqr(z, x nat) {
+func basicSqr(stk *stack, z, x nat) {
        n := len(x)
-       tp := getNat(2 * n)
-       t := *tp // temporary variable to hold the products
+       defer stk.restore(stk.save())
+       t := stk.nat(2 * n)
        clear(t)
        z[1], z[0] = mulWW(x[0], x[0]) // the initial square
        for i := 1; i < n; i++ {
@@ -502,38 +497,37 @@ func basicSqr(z, x nat) {
        }
        t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
        addVV(z, z, t)                              // combine the result
-       putNat(tp)
 }
 
 // karatsubaSqr squares x and leaves the result in z.
-// len(x) must be a power of 2 and len(z) >= 6*len(x).
-// The (non-normalized) result is placed in z[0 : 2*len(x)].
+// len(x) must be a power of 2 and len(z) == 2*len(x).
+// The (non-normalized) result is placed in z.
 //
 // The algorithm and the layout of z are the same as for karatsuba.
-func karatsubaSqr(z, x nat) {
+func karatsubaSqr(stk *stack, z, x nat) {
        n := len(x)
 
        if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
-               basicSqr(z[:2*n], x)
+               basicSqr(stk, z[:2*n], x)
                return
        }
 
        n2 := n >> 1
        x1, x0 := x[n2:], x[0:n2]
 
-       karatsubaSqr(z, x0)
-       karatsubaSqr(z[n:], x1)
+       karatsubaSqr(stk, z, x0)
+       karatsubaSqr(stk, z[n:], x1)
 
        // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
-       xd := z[2*n : 2*n+n2]
+       defer stk.restore(stk.save())
+       p := stk.nat(2 * n2)
+       r := stk.nat(n * 2)
+       xd := r[:n2]
        if subVV(xd, x1, x0) != 0 {
                subVV(xd, x0, x1)
        }
 
-       p := z[n*3:]
-       karatsubaSqr(p, xd)
-
-       r := z[n*4:]
+       karatsubaSqr(stk, p, xd)
        copy(r, z[:n*2])
 
        karatsubaAdd(z[n2:], r, n)
@@ -547,8 +541,9 @@ func karatsubaSqr(z, x nat) {
 var basicSqrThreshold = 20      // computed by calibrate_test.go
 var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
 
-// z = x*x
-func (z nat) sqr(x nat) nat {
+// sqr sets z = x*x, using stk for temporary storage.
+// The caller may pass stk == nil to request that sqr obtain and release one itself.
+func (z nat) sqr(stk *stack, x nat) nat {
        n := len(x)
        switch {
        case n == 0:
@@ -563,15 +558,20 @@ func (z nat) sqr(x nat) nat {
        if alias(z, x) {
                z = nil // z is an alias for x - cannot reuse
        }
+       z = z.make(2 * n)
 
        if n < basicSqrThreshold {
-               z = z.make(2 * n)
                basicMul(z, x, x)
                return z.norm()
        }
+
+       if stk == nil {
+               stk = getStack()
+               defer stk.free()
+       }
+
        if n < karatsubaSqrThreshold {
-               z = z.make(2 * n)
-               basicSqr(z, x)
+               basicSqr(stk, z, x)
                return z.norm()
        }
 
@@ -583,22 +583,18 @@ func (z nat) sqr(x nat) nat {
        k := karatsubaLen(n, karatsubaSqrThreshold)
 
        x0 := x[0:k]
-       z = z.make(max(6*k, 2*n))
-       karatsubaSqr(z, x0) // z = x0^2
-       z = z[0 : 2*n]
+       karatsubaSqr(stk, z, x0) // z = x0^2
        clear(z[2*k:])
 
        if k < n {
-               tp := getNat(2 * k)
-               t := *tp
+               t := stk.nat(2 * k)
                x0 := x0.norm()
                x1 := x[k:]
-               t = t.mul(x0, x1)
+               t = t.mul(stk, x0, x1)
                addAt(z, t, k)
                addAt(z, t, k) // z = 2*x1*x0*b + x0^2
-               t = t.sqr(x1)
+               t = t.sqr(stk, x1)
                addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
-               putNat(tp)
        }
 
        return z.norm()
@@ -606,7 +602,8 @@ func (z nat) sqr(x nat) nat {
 
 // mulRange computes the product of all the unsigned integers in the
 // range [a, b] inclusively. If a > b (empty range), the result is 1.
-func (z nat) mulRange(a, b uint64) nat {
+// The caller may pass stk == nil to request that mulRange obtain and release one itself.
+func (z nat) mulRange(stk *stack, a, b uint64) nat {
        switch {
        case a == 0:
                // cut long ranges short (optimization)
@@ -616,34 +613,79 @@ func (z nat) mulRange(a, b uint64) nat {
        case a == b:
                return z.setUint64(a)
        case a+1 == b:
-               return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
+               return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b))
+       }
+
+       if stk == nil {
+               stk = getStack()
+               defer stk.free()
        }
+
        m := a + (b-a)/2 // avoid overflow
-       return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
+       return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b))
 }
 
-// getNat returns a *nat of len n. The contents may not be zero.
-// The pool holds *nat to avoid allocation when converting to interface{}.
-func getNat(n int) *nat {
-       var z *nat
-       if v := natPool.Get(); v != nil {
-               z = v.(*nat)
-       }
-       if z == nil {
-               z = new(nat)
-       }
-       *z = z.make(n)
-       if n > 0 {
-               (*z)[0] = 0xfedcb // break code expecting zero
+// A stack provides temporary storage for complex calculations
+// such as multiplication and division.
+// The stack is a simple slice of words, extended as needed
+// to hold all the temporary storage for a calculation.
+// In general, if a function takes a *stack, it expects a non-nil *stack.
+// However, certain functions may allow passing a nil *stack instead,
+// so that they can handle trivial stack-free cases without forcing the
+// caller to obtain and free a stack that will be unused. These functions
+// document that they accept a nil *stack in their doc comments.
+type stack struct {
+       w []Word
+}
+
+var stackPool sync.Pool
+
+// getStack returns a temporary stack.
+// The caller must call [stack.free] to give up use of the stack when finished.
+func getStack() *stack {
+       s, _ := stackPool.Get().(*stack)
+       if s == nil {
+               s = new(stack)
        }
-       return z
+       return s
+}
+
+// free returns the stack for use by another calculation.
+func (s *stack) free() {
+       s.w = s.w[:0]
+       stackPool.Put(s)
 }
 
-func putNat(x *nat) {
-       natPool.Put(x)
+// save returns the current stack pointer.
+// A future call to restore with the same value
+// frees any temporaries allocated on the stack after the call to save.
+func (s *stack) save() int {
+       return len(s.w)
 }
 
-var natPool sync.Pool
+// restore restores the stack pointer to n.
+// It is almost always invoked as
+//
+//     defer stk.restore(stk.save())
+//
+// which makes sure to pop any temporaries allocated in the current function
+// from the stack before returning.
+func (s *stack) restore(n int) {
+       s.w = s.w[:n]
+}
+
+// nat returns a nat of n words, allocated on the stack.
+func (s *stack) nat(n int) nat {
+       nr := (n + 3) &^ 3 // round up to multiple of 4
+       off := len(s.w)
+       s.w = slices.Grow(s.w, nr)
+       s.w = s.w[:off+nr]
+       x := s.w[off : off+n : off+n]
+       if n > 0 {
+               x[0] = 0xfedcb
+       }
+       return x
+}
 
 // bitLen returns the length of x in bits.
 // Unlike most methods, it works even if x is not normalized.
@@ -930,7 +972,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
 
 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
 // otherwise it sets z to x**y. The result is the value of z.
-func (z nat) expNN(x, y, m nat, slow bool) nat {
+// The caller may pass stk == nil to request that expNN obtain and release one itself.
+func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat {
        if alias(z, x) || alias(z, y) {
                // We cannot allow in-place modification of x or y.
                z = nil
@@ -961,12 +1004,17 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
        // x > 1
 
        // x**1 == x
-       if len(y) == 1 && y[0] == 1 {
-               if len(m) != 0 {
-                       return z.rem(x, m)
-               }
+       if len(y) == 1 && y[0] == 1 && len(m) == 0 {
                return z.set(x)
        }
+       if stk == nil {
+               stk = getStack()
+               defer stk.free()
+       }
+       if len(y) == 1 && y[0] == 1 { // len(m) > 0
+               return z.rem(stk, x, m)
+       }
+
        // y > 1
 
        if len(m) != 0 {
@@ -980,12 +1028,12 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
                // instance of each of the first two cases).
                if len(y) > 1 && !slow {
                        if m[0]&1 == 1 {
-                               return z.expNNMontgomery(x, y, m)
+                               return z.expNNMontgomery(stk, x, y, m)
                        }
                        if logM, ok := m.isPow2(); ok {
-                               return z.expNNWindowed(x, y, logM)
+                               return z.expNNWindowed(stk, x, y, logM)
                        }
-                       return z.expNNMontgomeryEven(x, y, m)
+                       return z.expNNMontgomeryEven(stk, x, y, m)
                }
        }
 
@@ -1006,16 +1054,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
        // otherwise the arguments would alias.
        var zz, r nat
        for j := 0; j < w; j++ {
-               zz = zz.sqr(z)
+               zz = zz.sqr(stk, z)
                zz, z = z, zz
 
                if v&mask != 0 {
-                       zz = zz.mul(z, x)
+                       zz = zz.mul(stk, z, x)
                        zz, z = z, zz
                }
 
                if len(m) != 0 {
-                       zz, r = zz.div(r, z, m)
+                       zz, r = zz.div(stk, r, z, m)
                        zz, r, q, z = q, z, zz, r
                }
 
@@ -1026,16 +1074,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
                v = y[i]
 
                for j := 0; j < _W; j++ {
-                       zz = zz.sqr(z)
+                       zz = zz.sqr(stk, z)
                        zz, z = z, zz
 
                        if v&mask != 0 {
-                               zz = zz.mul(z, x)
+                               zz = zz.mul(stk, z, x)
                                zz, z = z, zz
                        }
 
                        if len(m) != 0 {
-                               zz, r = zz.div(r, z, m)
+                               zz, r = zz.div(stk, r, z, m)
                                zz, r, q, z = q, z, zz, r
                        }
 
@@ -1054,7 +1102,7 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
 // For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
 // IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
 // http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
-func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
+func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat {
        // Split m = m₁ × m₂ where m₁ = 2ⁿ
        n := m.trailingZeroBits()
        m1 := nat(nil).shl(natOne, n)
@@ -1066,8 +1114,8 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
        // (We are using the math/big convention for names here,
        // where the computation is z = x**y mod m, so its parts are z1 and z2.
        // The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
-       z1 := nat(nil).expNN(x, y, m1, false)
-       z2 := nat(nil).expNN(x, y, m2, false)
+       z1 := nat(nil).expNN(stk, x, y, m1, false)
+       z2 := nat(nil).expNN(stk, x, y, m2, false)
 
        // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
        // which uses only a single modInverse (and an easy one at that).
@@ -1086,18 +1134,18 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
 
        // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
        m2inv := nat(nil).modInverse(m2, m1)
-       z2 = z2.mul(z1, m2inv)
+       z2 = z2.mul(stk, z1, m2inv)
        z2 = z2.trunc(z2, n)
 
        // Reuse z1 for p * m2.
-       z = z.add(z, z1.mul(z2, m2))
+       z = z.add(z, z1.mul(stk, z2, m2))
 
        return z
 }
 
 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
 // where m = 2**logM.
-func (z nat) expNNWindowed(x, y nat, logM uint) nat {
+func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat {
        if len(y) <= 1 {
                panic("big: misuse of expNNWindowed")
        }
@@ -1112,23 +1160,23 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat {
 
        // zz is used to avoid allocating in mul as otherwise
        // the arguments would alias.
+       defer stk.restore(stk.save())
        w := int((logM + _W - 1) / _W)
-       zzp := getNat(w)
-       zz := *zzp
+       zz := stk.nat(w)
 
        const n = 4
        // powers[i] contains x^i.
-       var powers [1 << n]*nat
+       var powers [1 << n]nat
        for i := range powers {
-               powers[i] = getNat(w)
+               powers[i] = stk.nat(w)
        }
-       *powers[0] = powers[0].set(natOne)
-       *powers[1] = powers[1].trunc(x, logM)
+       powers[0] = powers[0].set(natOne)
+       powers[1] = powers[1].trunc(x, logM)
        for i := 2; i < 1<<n; i += 2 {
-               p2, p, p1 := powers[i/2], powers[i], powers[i+1]
-               *p = p.sqr(*p2)
+               p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+               *p = p.sqr(stk, *p2)
                *p = p.trunc(*p, logM)
-               *p1 = p1.mul(*p, x)
+               *p1 = p1.mul(stk, *p, x)
                *p1 = p1.trunc(*p1, logM)
        }
 
@@ -1159,24 +1207,24 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat {
                                // Unrolled loop for significant performance
                                // gain. Use go test -bench=".*" in crypto/rsa
                                // to check performance before making changes.
-                               zz = zz.sqr(z)
+                               zz = zz.sqr(stk, z)
                                zz, z = z, zz
                                z = z.trunc(z, logM)
 
-                               zz = zz.sqr(z)
+                               zz = zz.sqr(stk, z)
                                zz, z = z, zz
                                z = z.trunc(z, logM)
 
-                               zz = zz.sqr(z)
+                               zz = zz.sqr(stk, z)
                                zz, z = z, zz
                                z = z.trunc(z, logM)
 
-                               zz = zz.sqr(z)
+                               zz = zz.sqr(stk, z)
                                zz, z = z, zz
                                z = z.trunc(z, logM)
                        }
 
-                       zz = zz.mul(z, *powers[yi>>(_W-n)])
+                       zz = zz.mul(stk, z, powers[yi>>(_W-n)])
                        zz, z = z, zz
                        z = z.trunc(z, logM)
 
@@ -1185,24 +1233,18 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat {
                }
        }
 
-       *zzp = zz
-       putNat(zzp)
-       for i := range powers {
-               putNat(powers[i])
-       }
-
        return z.norm()
 }
 
 // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
 // Uses Montgomery representation.
-func (z nat) expNNMontgomery(x, y, m nat) nat {
+func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat {
        numWords := len(m)
 
        // We want the lengths of x and m to be equal.
        // It is OK if x >= m as long as len(x) == len(m).
        if len(x) > numWords {
-               _, x = nat(nil).div(nil, x, m)
+               _, x = nat(nil).div(stk, nil, x, m)
                // Note: now len(x) <= numWords, not guaranteed ==.
        }
        if len(x) < numWords {
@@ -1225,7 +1267,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
        // RR = 2**(2*_W*len(m)) mod m
        RR := nat(nil).setWord(1)
        zz := nat(nil).shl(RR, uint(2*numWords*_W))
-       _, RR = nat(nil).div(RR, zz, m)
+       _, RR = nat(nil).div(stk, RR, zz, m)
        if len(RR) < numWords {
                zz = zz.make(numWords)
                copy(zz, RR)
@@ -1280,7 +1322,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
                // The div is not expected to be reached.
                zz = zz.sub(zz, m)
                if zz.cmp(m) >= 0 {
-                       _, zz = nat(nil).div(nil, zz, m)
+                       _, zz = nat(nil).div(stk, nil, zz, m)
                }
        }
 
@@ -1349,7 +1391,8 @@ func (z nat) setBytes(buf []byte) nat {
 }
 
 // sqrt sets z = ⌊√x⌋
-func (z nat) sqrt(x nat) nat {
+// The caller may pass stk == nil to request that sqrt obtain and release one itself.
+func (z nat) sqrt(stk *stack, x nat) nat {
        if x.cmp(natOne) <= 0 {
                return z.set(x)
        }
@@ -1357,6 +1400,11 @@ func (z nat) sqrt(x nat) nat {
                z = nil
        }
 
+       if stk == nil {
+               stk = getStack()
+               defer stk.free()
+       }
+
        // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
        // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
        // https://members.loria.fr/PZimmermann/mca/pub226.html
@@ -1367,7 +1415,7 @@ func (z nat) sqrt(x nat) nat {
        z1 = z1.setUint64(1)
        z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
        for n := 0; ; n++ {
-               z2, _ = z2.div(nil, x, z1)
+               z2, _ = z2.div(stk, nil, x, z1)
                z2 = z2.add(z2, z1)
                z2 = z2.shr(z2, 1)
                if z2.cmp(z1) >= 0 {
index 46231f79769a1810950f8f1624f3736748e08673..1811dccfe33d6188ca2b22b17fd812bacb301659 100644 (file)
@@ -42,6 +42,7 @@ func TestCmp(t *testing.T) {
 }
 
 type funNN func(z, x, y nat) nat
+type funSNN func(z nat, stk *stack, x, y nat) nat
 type argNN struct {
        z, x, y nat
 }
@@ -112,6 +113,15 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
        }
 }
 
+func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) {
+       stk := getStack()
+       defer stk.free()
+       z := f(nil, stk, a.x, a.y)
+       if z.cmp(a.z) != 0 {
+               t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
+       }
+}
+
 func TestFunNN(t *testing.T) {
        for _, a := range sumNN {
                arg := a
@@ -129,10 +139,10 @@ func TestFunNN(t *testing.T) {
 
        for _, a := range prodNN {
                arg := a
-               testFunNN(t, "mul", nat.mul, arg)
+               testFunSNN(t, "mul", nat.mul, arg)
 
                arg = argNN{a.z, a.y, a.x}
-               testFunNN(t, "mul symmetric", nat.mul, arg)
+               testFunSNN(t, "mul symmetric", nat.mul, arg)
        }
 }
 
@@ -163,8 +173,11 @@ var mulRangesN = []struct {
 }
 
 func TestMulRangeN(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        for i, r := range mulRangesN {
-               prod := string(nat(nil).mulRange(r.a, r.b).utoa(10))
+               prod := string(nat(nil).mulRange(stk, r.a, r.b).utoa(10))
                if prod != r.prod {
                        t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
                }
@@ -185,11 +198,14 @@ func allocBytes(f func()) uint64 {
 // does not cause deep recursion and in turn allocate too much memory.
 // Test case for issue 3807.
 func TestMulUnbalanced(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
        x := rndNat(50000)
        y := rndNat(40)
        allocSize := allocBytes(func() {
-               nat(nil).mul(x, y)
+               nat(nil).mul(stk, x, y)
        })
        inputSize := uint64(len(x)+len(y)) * _S
        if ratio := allocSize / uint64(inputSize); ratio > 10 {
@@ -214,12 +230,15 @@ func rndNat1(n int) nat {
 }
 
 func BenchmarkMul(b *testing.B) {
+       stk := getStack()
+       defer stk.free()
+
        mulx := rndNat(1e4)
        muly := rndNat(1e4)
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
                var z nat
-               z.mul(mulx, muly)
+               z.mul(stk, mulx, muly)
        }
 }
 
@@ -230,7 +249,7 @@ func benchmarkNatMul(b *testing.B, nwords int) {
        b.ResetTimer()
        b.ReportAllocs()
        for i := 0; i < b.N; i++ {
-               z.mul(x, y)
+               z.mul(nil, x, y)
        }
 }
 
@@ -444,6 +463,9 @@ var montgomeryTests = []struct {
 }
 
 func TestMontgomery(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        one := NewInt(1)
        _B := new(Int).Lsh(one, _W)
        for i, test := range montgomeryTests {
@@ -458,11 +480,11 @@ func TestMontgomery(t *testing.T) {
                }
 
                if x.cmp(m) > 0 {
-                       _, r := nat(nil).div(nil, x, m)
+                       _, r := nat(nil).div(stk, nil, x, m)
                        t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16))
                }
                if y.cmp(m) > 0 {
-                       _, r := nat(nil).div(nil, x, m)
+                       _, r := nat(nil).div(stk, nil, x, m)
                        t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16))
                }
 
@@ -538,6 +560,9 @@ var expNNTests = []struct {
 }
 
 func TestExpNN(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        for i, test := range expNNTests {
                x := natFromString(test.x)
                y := natFromString(test.y)
@@ -548,7 +573,7 @@ func TestExpNN(t *testing.T) {
                        m = natFromString(test.m)
                }
 
-               z := nat(nil).expNN(x, y, m, false)
+               z := nat(nil).expNN(stk, x, y, m, false)
                if z.cmp(out) != 0 {
                        t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10))
                }
@@ -572,6 +597,9 @@ func FuzzExpMont(f *testing.F) {
 }
 
 func BenchmarkExp3Power(b *testing.B) {
+       stk := getStack()
+       defer stk.free()
+
        const x = 3
        for _, y := range []Word{
                0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000,
@@ -579,7 +607,7 @@ func BenchmarkExp3Power(b *testing.B) {
                b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) {
                        var z nat
                        for i := 0; i < b.N; i++ {
-                               z.expWW(x, y)
+                               z.expWW(stk, x, y)
                        }
                })
        }
@@ -712,10 +740,13 @@ func TestSticky(t *testing.T) {
 }
 
 func testSqr(t *testing.T, x nat) {
+       stk := getStack()
+       defer stk.free()
+
        got := make(nat, 2*len(x))
        want := make(nat, 2*len(x))
-       got = got.sqr(x)
-       want = want.mul(x, x)
+       got = got.sqr(stk, x)
+       want = want.mul(stk, x, x)
        if got.cmp(want) != 0 {
                t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
        }
@@ -741,7 +772,7 @@ func benchmarkNatSqr(b *testing.B, nwords int) {
        b.ResetTimer()
        b.ReportAllocs()
        for i := 0; i < b.N; i++ {
-               z.sqr(x)
+               z.sqr(nil, x)
        }
 }
 
@@ -830,6 +861,9 @@ func BenchmarkNatSetBytes(b *testing.B) {
 }
 
 func TestNatDiv(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        sizes := []int{
                1, 2, 5, 8, 15, 25, 40, 65, 100,
                200, 500, 800, 1500, 2500, 4000, 6500, 10000,
@@ -849,11 +883,11 @@ func TestNatDiv(t *testing.T) {
                                c = c.norm()
                        }
                        // compute x = a*b+c
-                       x := nat(nil).mul(a, b)
+                       x := nat(nil).mul(stk, a, b)
                        x = x.add(x, c)
 
                        var q, r nat
-                       q, r = q.div(r, x, b)
+                       q, r = q.div(stk, r, x, b)
                        if q.cmp(a) != 0 {
                                t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10))
                        }
@@ -868,6 +902,9 @@ func TestNatDiv(t *testing.T) {
 // the inaccurate estimate of the first word's quotient
 // happens at the very beginning of the loop.
 func TestIssue37499(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        // Choose u and v such that v is slightly larger than u >> N.
        // This tricks divBasic into choosing 1 as the first word
        // of the quotient. This works in both 32-bit and 64-bit settings.
@@ -875,7 +912,7 @@ func TestIssue37499(t *testing.T) {
        v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c")
 
        q := nat(nil).make(8)
-       q.divBasic(u, v)
+       q.divBasic(stk, u, v)
        q = q.norm()
        if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" {
                t.Fatalf("incorrect quotient: %s", s)
@@ -886,8 +923,11 @@ func TestIssue37499(t *testing.T) {
 // where the first division loop is never entered, and correcting
 // the remainder takes exactly two iterations in the final loop.
 func TestIssue42552(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000")
        v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
        q := nat(nil).make(16)
-       q.div(q, u, v)
+       q.div(stk, q, u, v)
 }
index ce94f2cf72e9a1b3fb385c01ac4dd10694f3fb8c..8a47ec9f9c5aa0fb08bba28a74fc72c63c57c8c8 100644 (file)
@@ -321,17 +321,20 @@ func (x nat) itoa(neg bool, base int) []byte {
                }
 
        } else {
+               stk := getStack()
+               defer stk.free()
+
                bb, ndigits := maxPow(b)
 
                // construct table of successive squares of bb*leafSize to use in subdivisions
                // result (table != nil) <=> (len(x) > leafSize > 0)
-               table := divisors(len(x), b, ndigits, bb)
+               table := divisors(stk, len(x), b, ndigits, bb)
 
                // preserve x, create local copy for use by convertWords
                q := nat(nil).set(x)
 
                // convert q to string s in base b
-               q.convertWords(s, b, ndigits, bb, table)
+               q.convertWords(stk, s, b, ndigits, bb, table)
 
                // strip leading zeros
                // (x != 0; thus s must contain at least one non-zero digit
@@ -365,7 +368,7 @@ func (x nat) itoa(neg bool, base int) []byte {
 // range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
 // ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
 // specific hardware.
-func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) {
+func (q nat) convertWords(stk *stack, s []byte, b Word, ndigits int, bb Word, table []divisor) {
        // split larger blocks recursively
        if table != nil {
                // len(q) > leafSize > 0
@@ -386,12 +389,12 @@ func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []diviso
                        }
 
                        // split q into the two digit number (q'*bbb + r) to form independent subblocks
-                       q, r = q.div(r, q, table[index].bbb)
+                       q, r = q.div(stk, r, q, table[index].bbb)
 
                        // convert subblocks and collect results in s[:h] and s[h:]
                        h := len(s) - table[index].ndigits
-                       r.convertWords(s[h:], b, ndigits, bb, table[0:index])
-                       s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1])
+                       r.convertWords(stk, s[h:], b, ndigits, bb, table[0:index])
+                       s = s[:h] // == q.convertWords(stk, s, b, ndigits, bb, table[0:index+1])
                }
        }
 
@@ -451,12 +454,12 @@ var cacheBase10 struct {
 }
 
 // expWW computes x**y
-func (z nat) expWW(x, y Word) nat {
-       return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
+func (z nat) expWW(stk *stack, x, y Word) nat {
+       return z.expNN(stk, nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
 }
 
 // construct table of powers of bb*leafSize to use in subdivisions.
-func divisors(m int, b Word, ndigits int, bb Word) []divisor {
+func divisors(stk *stack, m int, b Word, ndigits int, bb Word) []divisor {
        // only compute table when recursive conversion is enabled and x is large
        if leafSize == 0 || m <= leafSize {
                return nil
@@ -484,10 +487,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
                for i := 0; i < k; i++ {
                        if table[i].ndigits == 0 {
                                if i == 0 {
-                                       table[0].bbb = nat(nil).expWW(bb, Word(leafSize))
+                                       table[0].bbb = nat(nil).expWW(stk, bb, Word(leafSize))
                                        table[0].ndigits = ndigits * leafSize
                                } else {
-                                       table[i].bbb = nat(nil).sqr(table[i-1].bbb)
+                                       table[i].bbb = nat(nil).sqr(stk, table[i-1].bbb)
                                        table[i].ndigits = 2 * table[i-1].ndigits
                                }
 
index d39027210851529a124f9dcc1b827124243693b8..66300e412b736ae79707f15cc744b0ae113d5e58 100644 (file)
@@ -350,6 +350,9 @@ func BenchmarkStringPiParallel(b *testing.B) {
 }
 
 func BenchmarkScan(b *testing.B) {
+       stk := getStack()
+       defer stk.free()
+
        const x = 10
        for _, base := range []int{2, 8, 10, 16} {
                for _, y := range []Word{10, 100, 1000, 10000, 100000} {
@@ -359,7 +362,7 @@ func BenchmarkScan(b *testing.B) {
                        b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) {
                                b.StopTimer()
                                var z nat
-                               z = z.expWW(x, y)
+                               z = z.expWW(stk, x, y)
 
                                s := z.utoa(base)
                                if t := itoa(z, base); !bytes.Equal(s, t) {
@@ -376,6 +379,9 @@ func BenchmarkScan(b *testing.B) {
 }
 
 func BenchmarkString(b *testing.B) {
+       stk := getStack()
+       defer stk.free()
+
        const x = 10
        for _, base := range []int{2, 8, 10, 16} {
                for _, y := range []Word{10, 100, 1000, 10000, 100000} {
@@ -385,7 +391,7 @@ func BenchmarkString(b *testing.B) {
                        b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) {
                                b.StopTimer()
                                var z nat
-                               z = z.expWW(x, y)
+                               z = z.expWW(stk, x, y)
                                z.utoa(base) // warm divisor cache
                                b.StartTimer()
 
@@ -416,9 +422,11 @@ func LeafSizeHelper(b *testing.B, base, size int) {
 
        for d := 1; d <= 10000; d *= 10 {
                b.StopTimer()
+               stk := getStack()
                var z nat
-               z = z.expWW(Word(base), Word(d)) // build target number
-               _ = z.utoa(base)                 // warm divisor cache
+               z = z.expWW(stk, Word(base), Word(d)) // build target number
+               _ = z.utoa(base)                      // warm divisor cache
+               stk.free()
                b.StartTimer()
 
                for i := 0; i < b.N; i++ {
@@ -443,13 +451,16 @@ func resetTable(table []divisor) {
 }
 
 func TestStringPowers(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        var p Word
        for b := 2; b <= 16; b++ {
                for p = 0; p <= 512; p++ {
                        if testing.Short() && p > 10 {
                                break
                        }
-                       x := nat(nil).expWW(Word(b), p)
+                       x := nat(nil).expWW(stk, Word(b), p)
                        xs := x.utoa(b)
                        xs2 := itoa(x, b)
                        if !bytes.Equal(xs, xs2) {
index 2e66e3425c61cf2701fa32ad6d96766e64a40087..b514e2ce217dd30b18becc5dfbe1a7ba8dbebe8c 100644 (file)
@@ -502,30 +502,24 @@ import "math/bits"
 
 // rem returns r such that r = u%v.
 // It uses z as the storage for r.
-func (z nat) rem(u, v nat) (r nat) {
+func (z nat) rem(stk *stack, u, v nat) (r nat) {
        if alias(z, u) {
                z = nil
        }
-       qp := getNat(0)
-       q, r := qp.div(z, u, v)
-       *qp = q
-       putNat(qp)
+       defer stk.restore(stk.save())
+       q := stk.nat(len(u) - (len(v) - 1))
+       _, r = q.div(stk, z, u, v)
        return r
 }
 
 // div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v.
 // It uses z and z2 as the storage for q and r.
-func (z nat) div(z2, u, v nat) (q, r nat) {
+// The caller may pass stk == nil to request that div obtain and release one itself.
+func (z nat) div(stk *stack, z2, u, v nat) (q, r nat) {
        if len(v) == 0 {
                panic("division by zero")
        }
 
-       if u.cmp(v) < 0 {
-               q = z[:0]
-               r = z2.set(u)
-               return
-       }
-
        if len(v) == 1 {
                // Short division: long optimized for a single-word divisor.
                // In that case, the 2-by-1 guess is all we need at each step.
@@ -535,7 +529,18 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
                return
        }
 
-       q, r = z.divLarge(z2, u, v)
+       if u.cmp(v) < 0 {
+               q = z[:0]
+               r = z2.set(u)
+               return
+       }
+
+       if stk == nil {
+               stk = getStack()
+               defer stk.free()
+       }
+
+       q, r = z.divLarge(stk, z2, u, v)
        return
 }
 
@@ -589,7 +594,7 @@ func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) {
 // It uses z and u as the storage for q and r.
 // The caller must ensure that len(vIn) ≥ 2 (use divW otherwise)
 // and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise).
-func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
+func (z nat) divLarge(stk *stack, u, uIn, vIn nat) (q, r nat) {
        n := len(vIn)
        m := len(uIn) - n
 
@@ -597,9 +602,9 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
        // vIn is treated as a read-only input (it may be in use by another
        // goroutine), so we must make a copy.
        // uIn is copied to u.
+       defer stk.restore(stk.save())
        shift := nlz(vIn[n-1])
-       vp := getNat(n)
-       v := *vp
+       v := stk.nat(n)
        shlVU(v, vIn, shift)
        u = u.make(len(uIn) + 1)
        u[len(uIn)] = shlVU(u[:len(uIn)], uIn, shift)
@@ -613,11 +618,10 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
 
        // Use basic or recursive long division depending on size.
        if n < divRecursiveThreshold {
-               q.divBasic(u, v)
+               q.divBasic(stk, u, v)
        } else {
-               q.divRecursive(u, v)
+               q.divRecursive(stk, u, v)
        }
-       putNat(vp)
 
        q = q.norm()
 
@@ -631,12 +635,12 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
 // divBasic implements long division as described above.
 // It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r.
 // q must be large enough to hold ⌊u/v⌋.
-func (q nat) divBasic(u, v nat) {
+func (q nat) divBasic(stk *stack, u, v nat) {
        n := len(v)
        m := len(u) - n
 
-       qhatvp := getNat(n + 1)
-       qhatv := *qhatvp
+       defer stk.restore(stk.save())
+       qhatv := stk.nat(n + 1)
 
        // Set up for divWW below, precomputing reciprocal argument.
        vn1 := v[n-1]
@@ -707,8 +711,6 @@ func (q nat) divBasic(u, v nat) {
                }
                q[j] = qhat
        }
-
-       putNat(qhatvp)
 }
 
 // greaterThan reports whether the two digit numbers x1 x2 > y1 y2.
@@ -727,24 +729,9 @@ const divRecursiveThreshold = 100
 // z must be large enough to hold ⌊u/v⌋.
 // This function is just for allocating and freeing temporaries
 // around divRecursiveStep, the real implementation.
-func (z nat) divRecursive(u, v nat) {
-       // Recursion depth is (much) less than 2 log₂(len(v)).
-       // Allocate a slice of temporaries to be reused across recursion,
-       // plus one extra temporary not live across the recursion.
-       recDepth := 2 * bits.Len(uint(len(v)))
-       tmp := getNat(3 * len(v))
-       temps := make([]*nat, recDepth)
-
+func (z nat) divRecursive(stk *stack, u, v nat) {
        clear(z)
-       z.divRecursiveStep(u, v, 0, tmp, temps)
-
-       // Free temporaries.
-       for _, n := range temps {
-               if n != nil {
-                       putNat(n)
-               }
-       }
-       putNat(tmp)
+       z.divRecursiveStep(stk, u, v, 0)
 }
 
 // divRecursiveStep is the actual implementation of recursive division.
@@ -752,7 +739,7 @@ func (z nat) divRecursive(u, v nat) {
 // z must be large enough to hold ⌊u/v⌋.
 // It uses temps[depth] (allocating if needed) as a temporary live across
 // the recursive call. It also uses tmp, but not live across the recursion.
-func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
+func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
        // u is a subsection of the original and may have leading zeros.
        // TODO(rsc): The v = v.norm() is useless and should be removed.
        // We know (and require) that v's top digit is ≥ B/2.
@@ -766,7 +753,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
        // Fall back to basic division if the problem is now small enough.
        n := len(v)
        if n < divRecursiveThreshold {
-               z.divBasic(u, v)
+               z.divBasic(stk, u, v)
                return
        }
 
@@ -785,11 +772,8 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
        B := n / 2
 
        // Allocate a nat for qhat below.
-       if temps[depth] == nil {
-               temps[depth] = getNat(n) // TODO(rsc): Can be just B+1.
-       } else {
-               *temps[depth] = temps[depth].make(B + 1)
-       }
+       defer stk.restore(stk.save())
+       qhat0 := stk.nat(B + 1)
 
        // Compute each wide digit of the quotient.
        //
@@ -816,9 +800,9 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
                uu := u[j-B:]
 
                // Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n].
-               qhat := *temps[depth]
+               qhat := qhat0
                clear(qhat)
-               qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
+               qhat.divRecursiveStep(stk, uu[s:B+n], v[s:], depth+1)
                qhat = qhat.norm()
 
                // Extend to a 3-by-2 quotient and remainder.
@@ -833,9 +817,10 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
                // q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u.
                // But we can do the subtraction directly, as in the comment above
                // and in long division, because we know that q̂ is wrong by at most one.
-               qhatv := tmp.make(3 * n)
+               mark := stk.save()
+               qhatv := stk.nat(3 * n)
                clear(qhatv)
-               qhatv = qhatv.mul(qhat, v[:s])
+               qhatv = qhatv.mul(stk, qhat, v[:s])
                for i := 0; i < 2; i++ {
                        e := qhatv.cmp(uu.norm())
                        if e <= 0 {
@@ -857,6 +842,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
                }
                addAt(z, qhat, j-B)
                j -= B
+               stk.restore(mark)
        }
 
        // TODO(rsc): Rewrite loop as described above and delete all this code.
@@ -864,13 +850,13 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
        // Now u < (v<<B), compute lower bits in the same way.
        // Choose shift = B-1 again.
        s := B - 1
-       qhat := *temps[depth]
+       qhat := qhat0
        clear(qhat)
-       qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
+       qhat.divRecursiveStep(stk, u[s:].norm(), v[s:], depth+1)
        qhat = qhat.norm()
-       qhatv := tmp.make(3 * n)
+       qhatv := stk.nat(3 * n)
        clear(qhatv)
-       qhatv = qhatv.mul(qhat, v[:s])
+       qhatv = qhatv.mul(stk, qhat, v[:s])
        // Set the correct remainder as before.
        for i := 0; i < 2; i++ {
                if e := qhatv.cmp(u.norm()); e > 0 {
index 26688bbd64e9f2221d2a64a696ed05d037d4b5e1..bba5a0768566d3c8944d43247258663ec00823b7 100644 (file)
@@ -75,7 +75,9 @@ func (x *Int) ProbablyPrime(n int) bool {
                return false
        }
 
-       return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas()
+       stk := getStack()
+       defer stk.free()
+       return x.abs.probablyPrimeMillerRabin(stk, n+1, true) && x.abs.probablyPrimeLucas(stk)
 }
 
 // probablyPrimeMillerRabin reports whether n passes reps rounds of the
@@ -83,7 +85,7 @@ func (x *Int) ProbablyPrime(n int) bool {
 // If force2 is true, one of the rounds is forced to use base 2.
 // See Handbook of Applied Cryptography, p. 139, Algorithm 4.24.
 // The number n is known to be non-zero.
-func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool {
+func (n nat) probablyPrimeMillerRabin(stk *stack, reps int, force2 bool) bool {
        nm1 := nat(nil).sub(n, natOne)
        // determine q, k such that nm1 = q << k
        k := nm1.trailingZeroBits()
@@ -103,13 +105,13 @@ NextRandom:
                        x = x.random(rand, nm3, nm3Len)
                        x = x.add(x, natTwo)
                }
-               y = y.expNN(x, q, n, false)
+               y = y.expNN(stk, x, q, n, false)
                if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
                        continue
                }
                for j := uint(1); j < k; j++ {
-                       y = y.sqr(y)
-                       quotient, y = quotient.div(y, y, n)
+                       y = y.sqr(stk, y)
+                       quotient, y = quotient.div(stk, y, y, n)
                        if y.cmp(nm1) == 0 {
                                continue NextRandom
                        }
@@ -147,7 +149,7 @@ NextRandom:
 //
 // Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed.
 // Springer, 2005.
-func (n nat) probablyPrimeLucas() bool {
+func (n nat) probablyPrimeLucas(stk *stack) bool {
        // Discard 0, 1.
        if len(n) == 0 || n.cmp(natOne) == 0 {
                return false
@@ -193,8 +195,8 @@ func (n nat) probablyPrimeLucas() bool {
                        // We'll never find (d/n) = -1 if n is a square.
                        // If n is a non-square we expect to find a d in just a few attempts on average.
                        // After 40 attempts, take a moment to check if n is indeed a square.
-                       t1 = t1.sqrt(n)
-                       t1 = t1.sqr(t1)
+                       t1 = t1.sqrt(stk, n)
+                       t1 = t1.sqr(stk, t1)
                        if t1.cmp(n) == 0 {
                                return false
                        }
@@ -254,25 +256,25 @@ func (n nat) probablyPrimeLucas() bool {
                if s.bit(uint(i)) != 0 {
                        // k' = 2k+1
                        // V(k') = V(2k+1) = V(k) V(k+1) - P.
-                       t1 = t1.mul(vk, vk1)
+                       t1 = t1.mul(stk, vk, vk1)
                        t1 = t1.add(t1, n)
                        t1 = t1.sub(t1, natP)
-                       t2, vk = t2.div(vk, t1, n)
+                       t2, vk = t2.div(stk, vk, t1, n)
                        // V(k'+1) = V(2k+2) = V(k+1)² - 2.
-                       t1 = t1.sqr(vk1)
+                       t1 = t1.sqr(stk, vk1)
                        t1 = t1.add(t1, nm2)
-                       t2, vk1 = t2.div(vk1, t1, n)
+                       t2, vk1 = t2.div(stk, vk1, t1, n)
                } else {
                        // k' = 2k
                        // V(k'+1) = V(2k+1) = V(k) V(k+1) - P.
-                       t1 = t1.mul(vk, vk1)
+                       t1 = t1.mul(stk, vk, vk1)
                        t1 = t1.add(t1, n)
                        t1 = t1.sub(t1, natP)
-                       t2, vk1 = t2.div(vk1, t1, n)
+                       t2, vk1 = t2.div(stk, vk1, t1, n)
                        // V(k') = V(2k) = V(k)² - 2
-                       t1 = t1.sqr(vk)
+                       t1 = t1.sqr(stk, vk)
                        t1 = t1.add(t1, nm2)
-                       t2, vk = t2.div(vk, t1, n)
+                       t2, vk = t2.div(stk, vk, t1, n)
                }
        }
 
@@ -285,7 +287,7 @@ func (n nat) probablyPrimeLucas() bool {
                //
                // Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n,
                // or P V(k) - 2 V(k+1) == 0 mod n.
-               t1 := t1.mul(vk, natP)
+               t1 := t1.mul(stk, vk, natP)
                t2 := t2.shl(vk1, 1)
                if t1.cmp(t2) < 0 {
                        t1, t2 = t2, t1
@@ -294,7 +296,7 @@ func (n nat) probablyPrimeLucas() bool {
                t3 := vk1 // steal vk1, no longer needed below
                vk1 = nil
                _ = vk1
-               t2, t3 = t2.div(t3, t1, n)
+               t2, t3 = t2.div(stk, t3, t1, n)
                if len(t3) == 0 {
                        return true
                }
@@ -312,9 +314,9 @@ func (n nat) probablyPrimeLucas() bool {
                }
                // k' = 2k
                // V(k') = V(2k) = V(k)² - 2
-               t1 = t1.sqr(vk)
+               t1 = t1.sqr(stk, vk)
                t1 = t1.sub(t1, natTwo)
-               t2, vk = t2.div(vk, t1, n)
+               t2, vk = t2.div(stk, vk, t1, n)
        }
        return false
 }
index 8596e33a13b86480df855ec7626a7891edb8b76b..2b1995bcb224ec33c9025724a7a003908604b126 100644 (file)
@@ -159,6 +159,9 @@ func TestProbablyPrime(t *testing.T) {
 }
 
 func BenchmarkProbablyPrime(b *testing.B) {
+       stk := getStack()
+       defer stk.free()
+
        p, _ := new(Int).SetString("203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", 10)
        for _, n := range []int{0, 1, 5, 10, 20} {
                b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
@@ -170,26 +173,32 @@ func BenchmarkProbablyPrime(b *testing.B) {
 
        b.Run("Lucas", func(b *testing.B) {
                for i := 0; i < b.N; i++ {
-                       p.abs.probablyPrimeLucas()
+                       p.abs.probablyPrimeLucas(stk)
                }
        })
        b.Run("MillerRabinBase2", func(b *testing.B) {
                for i := 0; i < b.N; i++ {
-                       p.abs.probablyPrimeMillerRabin(1, true)
+                       p.abs.probablyPrimeMillerRabin(stk, 1, true)
                }
        })
 }
 
 func TestMillerRabinPseudoprimes(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        testPseudoprimes(t, "probablyPrimeMillerRabin",
-               func(n nat) bool { return n.probablyPrimeMillerRabin(1, true) && !n.probablyPrimeLucas() },
+               func(n nat) bool { return n.probablyPrimeMillerRabin(stk, 1, true) && !n.probablyPrimeLucas(stk) },
                // https://oeis.org/A001262
                []int{2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, 52633, 65281, 74665, 80581, 85489, 88357, 90751})
 }
 
 func TestLucasPseudoprimes(t *testing.T) {
+       stk := getStack()
+       defer stk.free()
+
        testPseudoprimes(t, "probablyPrimeLucas",
-               func(n nat) bool { return n.probablyPrimeLucas() && !n.probablyPrimeMillerRabin(1, true) },
+               func(n nat) bool { return n.probablyPrimeLucas(stk) && !n.probablyPrimeMillerRabin(stk, 1, true) },
                // https://oeis.org/A217719
                []int{989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, 72389, 73919, 75077})
 }
index e58433ecea333c5fc1b7fb24f9eabb1f68768f58..ac94056a833f84f71cb1ecb7e5234043c69bd410 100644 (file)
@@ -74,7 +74,7 @@ func (z *Rat) SetFloat64(f float64) *Rat {
 // nearest to the quotient a/b, using round-to-even in
 // halfway cases. It does not mutate its arguments.
 // Preconditions: b is non-zero; a and b have no common factors.
-func quotToFloat32(a, b nat) (f float32, exact bool) {
+func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) {
        const (
                // float size in bits
                Fsize = 32
@@ -121,7 +121,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) {
        // extra shift, the low-order bit of q is logically the
        // high-order bit of r.
        var q nat
-       q, r := q.div(a2, a2, b2) // (recycle a2)
+       q, r := q.div(stk, a2, a2, b2) // (recycle a2)
        mantissa := low32(q)
        haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
 
@@ -172,7 +172,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) {
 // nearest to the quotient a/b, using round-to-even in
 // halfway cases. It does not mutate its arguments.
 // Preconditions: b is non-zero; a and b have no common factors.
-func quotToFloat64(a, b nat) (f float64, exact bool) {
+func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) {
        const (
                // float size in bits
                Fsize = 64
@@ -219,7 +219,7 @@ func quotToFloat64(a, b nat) (f float64, exact bool) {
        // extra shift, the low-order bit of q is logically the
        // high-order bit of r.
        var q nat
-       q, r := q.div(a2, a2, b2) // (recycle a2)
+       q, r := q.div(stk, a2, a2, b2) // (recycle a2)
        mantissa := low64(q)
        haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
 
@@ -275,7 +275,9 @@ func (x *Rat) Float32() (f float32, exact bool) {
        if len(b) == 0 {
                b = natOne
        }
-       f, exact = quotToFloat32(x.a.abs, b)
+       stk := getStack()
+       defer stk.free()
+       f, exact = quotToFloat32(stk, x.a.abs, b)
        if x.a.neg {
                f = -f
        }
@@ -291,7 +293,9 @@ func (x *Rat) Float64() (f float64, exact bool) {
        if len(b) == 0 {
                b = natOne
        }
-       f, exact = quotToFloat64(x.a.abs, b)
+       stk := getStack()
+       defer stk.free()
+       f, exact = quotToFloat64(stk, x.a.abs, b)
        if x.a.neg {
                f = -f
        }
@@ -437,12 +441,14 @@ func (z *Rat) norm() *Rat {
                z.b.abs = z.b.abs.setWord(1)
        default:
                // z is fraction; normalize numerator and denominator
+               stk := getStack()
+               defer stk.free()
                neg := z.a.neg
                z.a.neg = false
                z.b.neg = false
                if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 {
-                       z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
-                       z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
+                       z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs)
+                       z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs)
                }
                z.a.neg = neg
        }
@@ -452,7 +458,7 @@ func (z *Rat) norm() *Rat {
 // mulDenom sets z to the denominator product x*y (by taking into
 // account that 0 values for x or y must be interpreted as 1) and
 // returns z.
-func mulDenom(z, x, y nat) nat {
+func mulDenom(stk *stack, z, x, y nat) nat {
        switch {
        case len(x) == 0 && len(y) == 0:
                return z.setWord(1)
@@ -461,17 +467,17 @@ func mulDenom(z, x, y nat) nat {
        case len(y) == 0:
                return z.set(x)
        }
-       return z.mul(x, y)
+       return z.mul(stk, x, y)
 }
 
 // scaleDenom sets z to the product x*f.
 // If f == 0 (zero value of denominator), z is set to (a copy of) x.
-func (z *Int) scaleDenom(x *Int, f nat) {
+func (z *Int) scaleDenom(stk *stack, x *Int, f nat) {
        if len(f) == 0 {
                z.Set(x)
                return
        }
-       z.abs = z.abs.mul(x.abs, f)
+       z.abs = z.abs.mul(stk, x.abs, f)
        z.neg = x.neg
 }
 
@@ -481,58 +487,73 @@ func (z *Int) scaleDenom(x *Int, f nat) {
 //   - +1 if x > y.
 func (x *Rat) Cmp(y *Rat) int {
        var a, b Int
-       a.scaleDenom(&x.a, y.b.abs)
-       b.scaleDenom(&y.a, x.b.abs)
+       stk := getStack()
+       defer stk.free()
+       a.scaleDenom(stk, &x.a, y.b.abs)
+       b.scaleDenom(stk, &y.a, x.b.abs)
        return a.Cmp(&b)
 }
 
 // Add sets z to the sum x+y and returns z.
 func (z *Rat) Add(x, y *Rat) *Rat {
+       stk := getStack()
+       defer stk.free()
+
        var a1, a2 Int
-       a1.scaleDenom(&x.a, y.b.abs)
-       a2.scaleDenom(&y.a, x.b.abs)
+       a1.scaleDenom(stk, &x.a, y.b.abs)
+       a2.scaleDenom(stk, &y.a, x.b.abs)
        z.a.Add(&a1, &a2)
-       z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+       z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
        return z.norm()
 }
 
 // Sub sets z to the difference x-y and returns z.
 func (z *Rat) Sub(x, y *Rat) *Rat {
+       stk := getStack()
+       defer stk.free()
+
        var a1, a2 Int
-       a1.scaleDenom(&x.a, y.b.abs)
-       a2.scaleDenom(&y.a, x.b.abs)
+       a1.scaleDenom(stk, &x.a, y.b.abs)
+       a2.scaleDenom(stk, &y.a, x.b.abs)
        z.a.Sub(&a1, &a2)
-       z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+       z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
        return z.norm()
 }
 
 // Mul sets z to the product x*y and returns z.
 func (z *Rat) Mul(x, y *Rat) *Rat {
+       stk := getStack()
+       defer stk.free()
+
        if x == y {
                // a squared Rat is positive and can't be reduced (no need to call norm())
                z.a.neg = false
-               z.a.abs = z.a.abs.sqr(x.a.abs)
+               z.a.abs = z.a.abs.sqr(stk, x.a.abs)
                if len(x.b.abs) == 0 {
                        z.b.abs = z.b.abs.setWord(1)
                } else {
-                       z.b.abs = z.b.abs.sqr(x.b.abs)
+                       z.b.abs = z.b.abs.sqr(stk, x.b.abs)
                }
                return z
        }
-       z.a.Mul(&x.a, &y.a)
-       z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+
+       z.a.mul(stk, &x.a, &y.a)
+       z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
        return z.norm()
 }
 
 // Quo sets z to the quotient x/y and returns z.
 // If y == 0, Quo panics.
 func (z *Rat) Quo(x, y *Rat) *Rat {
+       stk := getStack()
+       defer stk.free()
+
        if len(y.a.abs) == 0 {
                panic("division by zero")
        }
        var a, b Int
-       a.scaleDenom(&x.a, y.b.abs)
-       b.scaleDenom(&y.a, x.b.abs)
+       a.scaleDenom(stk, &x.a, y.b.abs)
+       b.scaleDenom(stk, &y.a, x.b.abs)
        z.a.abs = a.abs
        z.b.abs = b.abs
        z.a.neg = a.neg != b.neg
index 12f9888c3701ce446a99a6aa1c9b425d8966e06d..84602ff45543e01074c1d07ecafc9fdfef5db243 100644 (file)
@@ -163,6 +163,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
        }
        // exp consumed - not needed anymore
 
+       stk := getStack()
+       defer stk.free()
+
        // apply exp5 contributions
        // (start with exp5 so the numbers to multiply are smaller)
        if exp5 != 0 {
@@ -178,9 +181,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
                if n > 1e6 {
                        return nil, false // avoid excessively large exponents
                }
-               pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
+               pow5 := z.b.abs.expNN(stk, natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
                if exp5 > 0 {
-                       z.a.abs = z.a.abs.mul(z.a.abs, pow5)
+                       z.a.abs = z.a.abs.mul(stk, z.a.abs, pow5)
                        z.b.abs = z.b.abs.setWord(1)
                } else {
                        z.b.abs = pow5
@@ -343,15 +346,17 @@ func (x *Rat) FloatString(prec int) string {
        }
        // x.b.abs != 0
 
-       q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs)
+       stk := getStack()
+       defer stk.free()
+       q, r := nat(nil).div(stk, nat(nil), x.a.abs, x.b.abs)
 
        p := natOne
        if prec > 0 {
-               p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false)
+               p = nat(nil).expNN(stk, natTen, nat(nil).setUint64(uint64(prec)), nil, false)
        }
 
-       r = r.mul(r, p)
-       r, r2 := r.div(nat(nil), r, x.b.abs)
+       r = r.mul(stk, r, p)
+       r, r2 := r.div(stk, nat(nil), r, x.b.abs)
 
        // see if we need to round up
        r2 = r2.add(r2, r2)
@@ -398,6 +403,9 @@ func (x *Rat) FloatString(prec int) string {
 //     1/4    2    true     0.25
 //     1/6    1    false    0.2     (0.166... rounded)
 func (x *Rat) FloatPrec() (n int, exact bool) {
+       stk := getStack()
+       defer stk.free()
+
        // Determine q and largest p2, p5 such that d = q·2^p2·5^p5.
        // The results n, exact are:
        //
@@ -425,11 +433,11 @@ func (x *Rat) FloatPrec() (n int, exact bool) {
        f := nat{1220703125} // == 5^fp (must fit into a uint32 Word)
        var t, r nat         // temporaries
        for {
-               if _, r = t.div(r, q, f); len(r) != 0 {
+               if _, r = t.div(stk, r, q, f); len(r) != 0 {
                        break // f doesn't divide q evenly
                }
                tab = append(tab, f)
-               f = nat(nil).sqr(f) // nat(nil) to ensure a new f for each table entry
+               f = nat(nil).sqr(stk, f) // nat(nil) to ensure a new f for each table entry
        }
 
        // Factor q using the table entries, if any.
@@ -441,7 +449,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) {
        // The same reasoning applies to the subsequent factors.
        var p5 uint
        for i := len(tab) - 1; i >= 0; i-- {
-               if t, r = t.div(r, q, tab[i]); len(r) == 0 {
+               if t, r = t.div(stk, r, q, tab[i]); len(r) == 0 {
                        p5 += fp * (1 << i) // tab[i] == 5^(fp·2^i)
                        q = q.set(t)
                }
@@ -449,7 +457,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) {
 
        // If fp != 1, we may still have multiples of 5 left.
        for {
-               if t, r = t.div(r, q, natFive); len(r) != 0 {
+               if t, r = t.div(stk, r, q, natFive); len(r) != 0 {
                        break
                }
                p5++