]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: simplify, speed up Karatsuba multiplication
authorRuss Cox <rsc@golang.org>
Fri, 17 Jan 2025 22:26:59 +0000 (17:26 -0500)
committerGopher Robot <gobot@golang.org>
Wed, 12 Mar 2025 12:41:17 +0000 (05:41 -0700)
The old Karatsuba implementation only operated on lengths that are
a power of two times a number smaller than karatsubaThreshold.
For example, when karatsubaThreshold = 40, multiplying a pair
of 99-word numbers runs karatsuba on the low 96 (= 39<<2) words
and then has to fix up the answer to include the high 3 words of each.

I suspect this requirement was needed to make the analysis of
how many temporary words to reserve easier, back when the
answer was 3*n and depended on exactly halving the size at
each Karatsuba step.

Now that we have the more flexible temporary allocation stack,
we can change Karatsuba to accept operands of odd length.
Doing so avoids most of the fixup that the old approach required.
For example, multiplying a pair of 99-word numbers runs
karatsuba on all 99 words now.

This is simpler and about the same speed or, for large cases, faster.

goos: linux
goarch: amd64
pkg: math/big
cpu: Intel(R) Xeon(R) CPU @ 3.10GHz
                            │     old     │                 new                 │
                            │   sec/op    │   sec/op     vs base                │
GCD10x10/WithoutXY-16         99.62n ± 3%   99.10n ± 3%        ~ (p=0.009 n=15)
GCD10x10/WithXY-16            243.4n ± 1%   245.2n ± 1%        ~ (p=0.009 n=15)
GCD100x100/WithoutXY-16       921.9n ± 1%   919.2n ± 1%        ~ (p=0.076 n=15)
GCD100x100/WithXY-16          1.527µ ± 1%   1.526µ ± 0%        ~ (p=0.813 n=15)
GCD1000x1000/WithoutXY-16     9.704µ ± 1%   9.696µ ± 0%        ~ (p=0.532 n=15)
GCD1000x1000/WithXY-16        14.03µ ± 1%   13.96µ ± 0%        ~ (p=0.014 n=15)
GCD10000x10000/WithoutXY-16   206.5µ ± 2%   206.5µ ± 0%        ~ (p=0.967 n=15)
GCD10000x10000/WithXY-16      398.0µ ± 1%   397.4µ ± 0%        ~ (p=0.683 n=15)
Div/20/10-16                  22.22n ± 0%   22.23n ± 0%        ~ (p=0.105 n=15)
Div/40/20-16                  22.23n ± 0%   22.23n ± 0%        ~ (p=0.307 n=15)
Div/100/50-16                 55.47n ± 0%   55.47n ± 0%        ~ (p=0.573 n=15)
Div/200/100-16                174.9n ± 1%   174.6n ± 1%        ~ (p=0.814 n=15)
Div/400/200-16                209.5n ± 1%   210.5n ± 1%        ~ (p=0.454 n=15)
Div/1000/500-16               379.9n ± 0%   383.5n ± 2%        ~ (p=0.123 n=15)
Div/2000/1000-16              780.1n ± 0%   784.6n ± 1%   +0.58% (p=0.000 n=15)
Div/20000/10000-16            25.22µ ± 1%   25.15µ ± 0%        ~ (p=0.213 n=15)
Div/200000/100000-16          921.8µ ± 1%   926.1µ ± 0%        ~ (p=0.009 n=15)
Div/2000000/1000000-16        37.91m ± 0%   35.63m ± 0%   -6.02% (p=0.000 n=15)
Div/20000000/10000000-16       1.378 ± 0%    1.336 ± 0%   -3.03% (p=0.000 n=15)
NatMul/10-16                  166.8n ± 4%   168.9n ± 3%        ~ (p=0.008 n=15)
NatMul/100-16                 5.519µ ± 2%   5.548µ ± 4%        ~ (p=0.032 n=15)
NatMul/1000-16                230.4µ ± 1%   220.2µ ± 1%   -4.43% (p=0.000 n=15)
NatMul/10000-16               8.569m ± 1%   8.640m ± 1%        ~ (p=0.005 n=15)
NatMul/100000-16              376.5m ± 1%   334.1m ± 0%  -11.26% (p=0.000 n=15)
NatSqr/1-16                   27.85n ± 5%   28.60n ± 2%        ~ (p=0.123 n=15)
NatSqr/2-16                   47.99n ± 2%   48.84n ± 1%        ~ (p=0.008 n=15)
NatSqr/3-16                   59.41n ± 2%   60.87n ± 2%   +2.46% (p=0.001 n=15)
NatSqr/5-16                   87.27n ± 2%   89.31n ± 3%        ~ (p=0.087 n=15)
NatSqr/8-16                   124.6n ± 3%   128.9n ± 3%        ~ (p=0.006 n=15)
NatSqr/10-16                  166.3n ± 3%   172.7n ± 3%        ~ (p=0.002 n=15)
NatSqr/20-16                  385.2n ± 2%   394.7n ± 3%        ~ (p=0.036 n=15)
NatSqr/30-16                  622.7n ± 3%   642.9n ± 3%        ~ (p=0.032 n=15)
NatSqr/50-16                  1.274µ ± 3%   1.323µ ± 4%        ~ (p=0.003 n=15)
NatSqr/80-16                  2.606µ ± 4%   2.714µ ± 4%        ~ (p=0.044 n=15)
NatSqr/100-16                 3.731µ ± 4%   3.871µ ± 4%        ~ (p=0.038 n=15)
NatSqr/200-16                 12.99µ ± 2%   13.09µ ± 3%        ~ (p=0.838 n=15)
NatSqr/300-16                 22.87µ ± 2%   23.25µ ± 2%        ~ (p=0.285 n=15)
NatSqr/500-16                 58.43µ ± 1%   58.25µ ± 2%        ~ (p=0.345 n=15)
NatSqr/800-16                 115.3µ ± 3%   116.2µ ± 3%        ~ (p=0.126 n=15)
NatSqr/1000-16                173.9µ ± 1%   174.3µ ± 1%        ~ (p=0.935 n=15)
NatSqr/10000-16               6.133m ± 2%   6.034m ± 1%   -1.62% (p=0.000 n=15)
NatSqr/100000-16              253.8m ± 1%   241.5m ± 0%   -4.87% (p=0.000 n=15)
geomean                       7.745µ        7.760µ        +0.19%

goos: linux
goarch: amd64
pkg: math/big
cpu: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
                            │     old     │                 new                  │
                            │   sec/op    │    sec/op     vs base                │
GCD10x10/WithoutXY-88         62.17n ± 4%   61.44n ±  0%   -1.17% (p=0.000 n=15)
GCD10x10/WithXY-88            173.4n ± 2%   172.4n ±  4%        ~ (p=0.615 n=15)
GCD100x100/WithoutXY-88       584.0n ± 1%   582.9n ±  0%        ~ (p=0.009 n=15)
GCD100x100/WithXY-88          1.098µ ± 1%   1.091µ ±  2%        ~ (p=0.002 n=15)
GCD1000x1000/WithoutXY-88     6.055µ ± 0%   6.049µ ±  0%        ~ (p=0.007 n=15)
GCD1000x1000/WithXY-88        9.430µ ± 0%   9.417µ ±  1%        ~ (p=0.123 n=15)
GCD10000x10000/WithoutXY-88   153.4µ ± 2%   149.0µ ±  2%   -2.85% (p=0.000 n=15)
GCD10000x10000/WithXY-88      350.6µ ± 3%   349.0µ ±  2%        ~ (p=0.126 n=15)
Div/20/10-88                  13.12n ± 0%   13.12n ±  1%    0.00% (p=0.042 n=15)
Div/40/20-88                  13.12n ± 0%   13.13n ±  0%        ~ (p=0.004 n=15)
Div/100/50-88                 25.49n ± 0%   25.49n ±  0%        ~ (p=0.452 n=15)
Div/200/100-88                115.7n ± 2%   113.8n ±  2%        ~ (p=0.212 n=15)
Div/400/200-88                135.0n ± 1%   136.1n ±  1%        ~ (p=0.005 n=15)
Div/1000/500-88               257.5n ± 1%   259.9n ±  1%        ~ (p=0.004 n=15)
Div/2000/1000-88              567.5n ± 1%   572.4n ±  2%        ~ (p=0.616 n=15)
Div/20000/10000-88            25.65µ ± 0%   25.77µ ±  1%        ~ (p=0.032 n=15)
Div/200000/100000-88          777.4µ ± 1%   754.3µ ±  1%   -2.97% (p=0.000 n=15)
Div/2000000/1000000-88        33.66m ± 0%   31.37m ±  0%   -6.81% (p=0.000 n=15)
Div/20000000/10000000-88       1.320 ± 0%    1.266 ±  0%   -4.04% (p=0.000 n=15)
NatMul/10-88                  151.9n ± 7%   143.3n ±  7%        ~ (p=0.878 n=15)
NatMul/100-88                 4.418µ ± 2%   4.337µ ±  3%        ~ (p=0.512 n=15)
NatMul/1000-88                206.8µ ± 1%   189.8µ ±  1%   -8.25% (p=0.000 n=15)
NatMul/10000-88               8.531m ± 1%   8.095m ±  0%   -5.12% (p=0.000 n=15)
NatMul/100000-88              298.9m ± 0%   260.5m ±  1%  -12.85% (p=0.000 n=15)
NatSqr/1-88                   27.55n ± 6%   28.25n ±  7%        ~ (p=0.024 n=15)
NatSqr/2-88                   44.71n ± 6%   46.21n ±  9%        ~ (p=0.024 n=15)
NatSqr/3-88                   55.44n ± 4%   58.41n ± 10%        ~ (p=0.126 n=15)
NatSqr/5-88                   80.71n ± 5%   81.41n ±  5%        ~ (p=0.032 n=15)
NatSqr/8-88                   115.7n ± 4%   115.4n ±  5%        ~ (p=0.814 n=15)
NatSqr/10-88                  147.4n ± 4%   147.3n ±  4%        ~ (p=0.505 n=15)
NatSqr/20-88                  337.8n ± 3%   337.3n ±  4%        ~ (p=0.814 n=15)
NatSqr/30-88                  556.9n ± 3%   557.6n ±  4%        ~ (p=0.814 n=15)
NatSqr/50-88                  1.208µ ± 4%   1.208µ ±  3%        ~ (p=0.910 n=15)
NatSqr/80-88                  2.591µ ± 3%   2.581µ ±  3%        ~ (p=0.705 n=15)
NatSqr/100-88                 3.870µ ± 3%   3.858µ ±  3%        ~ (p=0.846 n=15)
NatSqr/200-88                 14.43µ ± 3%   14.28µ ±  2%        ~ (p=0.383 n=15)
NatSqr/300-88                 24.68µ ± 2%   24.49µ ±  2%        ~ (p=0.624 n=15)
NatSqr/500-88                 66.27µ ± 1%   66.18µ ±  1%        ~ (p=0.735 n=15)
NatSqr/800-88                 128.7µ ± 1%   127.4µ ±  1%        ~ (p=0.050 n=15)
NatSqr/1000-88                198.7µ ± 1%   197.7µ ±  1%        ~ (p=0.229 n=15)
NatSqr/10000-88               6.582m ± 1%   6.426m ±  1%   -2.37% (p=0.000 n=15)
NatSqr/100000-88              274.3m ± 0%   267.3m ±  0%   -2.57% (p=0.000 n=15)
geomean                       6.518µ        6.438µ         -1.22%

goos: linux
goarch: arm64
pkg: math/big
                            │     old     │                 new                 │
                            │   sec/op    │   sec/op     vs base                │
GCD10x10/WithoutXY-16         61.70n ± 1%   61.32n ± 1%        ~ (p=0.361 n=15)
GCD10x10/WithXY-16            217.3n ± 1%   217.0n ± 1%        ~ (p=0.395 n=15)
GCD100x100/WithoutXY-16       569.7n ± 0%   572.6n ± 2%        ~ (p=0.213 n=15)
GCD100x100/WithXY-16          1.241µ ± 1%   1.236µ ± 1%        ~ (p=0.157 n=15)
GCD1000x1000/WithoutXY-16     5.558µ ± 0%   5.566µ ± 0%        ~ (p=0.228 n=15)
GCD1000x1000/WithXY-16        9.319µ ± 0%   9.326µ ± 0%        ~ (p=0.233 n=15)
GCD10000x10000/WithoutXY-16   126.4µ ± 2%   128.7µ ± 3%        ~ (p=0.081 n=15)
GCD10000x10000/WithXY-16      279.3µ ± 0%   278.3µ ± 5%        ~ (p=0.187 n=15)
Div/20/10-16                  15.12n ± 1%   15.21n ± 1%        ~ (p=0.490 n=15)
Div/40/20-16                  15.11n ± 0%   15.23n ± 1%        ~ (p=0.107 n=15)
Div/100/50-16                 26.53n ± 0%   26.50n ± 0%        ~ (p=0.299 n=15)
Div/200/100-16                123.7n ± 0%   124.0n ± 0%        ~ (p=0.086 n=15)
Div/400/200-16                142.5n ± 0%   142.4n ± 0%        ~ (p=0.039 n=15)
Div/1000/500-16               259.9n ± 1%   261.2n ± 1%        ~ (p=0.044 n=15)
Div/2000/1000-16              539.4n ± 1%   532.3n ± 1%   -1.32% (p=0.001 n=15)
Div/20000/10000-16            22.43µ ± 0%   22.32µ ± 0%   -0.49% (p=0.000 n=15)
Div/200000/100000-16          898.3µ ± 0%   889.6µ ± 0%   -0.96% (p=0.000 n=15)
Div/2000000/1000000-16        38.37m ± 0%   35.11m ± 0%   -8.49% (p=0.000 n=15)
Div/20000000/10000000-16       1.449 ± 0%    1.384 ± 0%   -4.48% (p=0.000 n=15)
NatMul/10-16                  182.0n ± 1%   177.8n ± 1%   -2.31% (p=0.000 n=15)
NatMul/100-16                 5.537µ ± 0%   5.693µ ± 0%   +2.82% (p=0.000 n=15)
NatMul/1000-16                229.9µ ± 0%   224.8µ ± 0%   -2.24% (p=0.000 n=15)
NatMul/10000-16               8.985m ± 0%   8.751m ± 0%   -2.61% (p=0.000 n=15)
NatMul/100000-16              371.1m ± 0%   331.5m ± 0%  -10.66% (p=0.000 n=15)
NatSqr/1-16                   46.77n ± 6%   42.76n ± 1%   -8.57% (p=0.000 n=15)
NatSqr/2-16                   66.99n ± 4%   63.62n ± 1%   -5.03% (p=0.000 n=15)
NatSqr/3-16                   76.79n ± 4%   73.42n ± 1%        ~ (p=0.007 n=15)
NatSqr/5-16                   99.00n ± 3%   95.35n ± 1%   -3.69% (p=0.000 n=15)
NatSqr/8-16                   160.0n ± 3%   155.1n ± 1%   -3.06% (p=0.001 n=15)
NatSqr/10-16                  178.4n ± 2%   175.9n ± 0%   -1.40% (p=0.001 n=15)
NatSqr/20-16                  361.9n ± 2%   361.3n ± 0%        ~ (p=0.083 n=15)
NatSqr/30-16                  584.7n ± 0%   586.8n ± 0%   +0.36% (p=0.000 n=15)
NatSqr/50-16                  1.327µ ± 0%   1.329µ ± 0%        ~ (p=0.349 n=15)
NatSqr/80-16                  2.893µ ± 1%   2.925µ ± 0%   +1.11% (p=0.000 n=15)
NatSqr/100-16                 4.330µ ± 1%   4.381µ ± 0%   +1.18% (p=0.000 n=15)
NatSqr/200-16                 16.25µ ± 1%   16.43µ ± 0%   +1.07% (p=0.000 n=15)
NatSqr/300-16                 27.85µ ± 1%   28.06µ ± 0%   +0.77% (p=0.000 n=15)
NatSqr/500-16                 76.01µ ± 0%   76.34µ ± 0%        ~ (p=0.002 n=15)
NatSqr/800-16                 146.8µ ± 0%   148.1µ ± 0%   +0.83% (p=0.000 n=15)
NatSqr/1000-16                228.2µ ± 0%   228.6µ ± 0%        ~ (p=0.123 n=15)
NatSqr/10000-16               7.524m ± 0%   7.426m ± 0%   -1.31% (p=0.000 n=15)
NatSqr/100000-16              316.7m ± 0%   309.2m ± 0%   -2.36% (p=0.000 n=15)
geomean                       7.264µ        7.172µ        -1.27%

goos: darwin
goarch: arm64
pkg: math/big
cpu: Apple M3 Pro
                            │     old     │                new                 │
                            │   sec/op    │   sec/op     vs base               │
GCD10x10/WithoutXY-12         32.61n ± 1%   32.42n ± 1%       ~ (p=0.021 n=15)
GCD10x10/WithXY-12            87.70n ± 1%   88.42n ± 1%       ~ (p=0.010 n=15)
GCD100x100/WithoutXY-12       305.9n ± 0%   306.4n ± 0%       ~ (p=0.003 n=15)
GCD100x100/WithXY-12          560.3n ± 2%   556.6n ± 1%       ~ (p=0.018 n=15)
GCD1000x1000/WithoutXY-12     3.509µ ± 2%   3.464µ ± 1%       ~ (p=0.145 n=15)
GCD1000x1000/WithXY-12        5.347µ ± 2%   5.372µ ± 1%       ~ (p=0.046 n=15)
GCD10000x10000/WithoutXY-12   73.75µ ± 1%   73.99µ ± 1%       ~ (p=0.004 n=15)
GCD10000x10000/WithXY-12      148.4µ ± 0%   147.8µ ± 1%       ~ (p=0.076 n=15)
Div/20/10-12                  9.481n ± 0%   9.462n ± 1%       ~ (p=0.631 n=15)
Div/40/20-12                  9.457n ± 0%   9.462n ± 1%       ~ (p=0.798 n=15)
Div/100/50-12                 14.91n ± 0%   14.79n ± 1%  -0.80% (p=0.000 n=15)
Div/200/100-12                84.56n ± 1%   84.60n ± 1%       ~ (p=0.271 n=15)
Div/400/200-12                103.8n ± 0%   102.8n ± 0%  -0.96% (p=0.000 n=15)
Div/1000/500-12               181.3n ± 1%   184.2n ± 2%       ~ (p=0.091 n=15)
Div/2000/1000-12              397.5n ± 0%   397.4n ± 0%       ~ (p=0.299 n=15)
Div/20000/10000-12            14.04µ ± 1%   13.99µ ± 0%       ~ (p=0.221 n=15)
Div/200000/100000-12          523.1µ ± 0%   514.0µ ± 3%       ~ (p=0.775 n=15)
Div/2000000/1000000-12        21.58m ± 0%   20.01m ± 1%  -7.29% (p=0.000 n=15)
Div/20000000/10000000-12      813.5m ± 0%   796.2m ± 1%  -2.13% (p=0.000 n=15)
NatMul/10-12                  80.46n ± 1%   80.02n ± 1%       ~ (p=0.063 n=15)
NatMul/100-12                 2.904µ ± 0%   2.979µ ± 1%  +2.58% (p=0.000 n=15)
NatMul/1000-12                127.8µ ± 0%   122.3µ ± 0%  -4.28% (p=0.000 n=15)
NatMul/10000-12               5.141m ± 0%   4.975m ± 1%  -3.23% (p=0.000 n=15)
NatMul/100000-12              208.8m ± 0%   189.6m ± 3%  -9.21% (p=0.000 n=15)
NatSqr/1-12                   11.90n ± 1%   11.76n ± 1%       ~ (p=0.059 n=15)
NatSqr/2-12                   21.33n ± 1%   21.12n ± 0%       ~ (p=0.063 n=15)
NatSqr/3-12                   26.05n ± 1%   25.79n ± 0%       ~ (p=0.002 n=15)
NatSqr/5-12                   37.31n ± 0%   36.98n ± 1%       ~ (p=0.008 n=15)
NatSqr/8-12                   63.07n ± 0%   62.75n ± 1%       ~ (p=0.061 n=15)
NatSqr/10-12                  79.48n ± 0%   79.59n ± 0%       ~ (p=0.455 n=15)
NatSqr/20-12                  173.1n ± 0%   173.2n ± 1%       ~ (p=0.518 n=15)
NatSqr/30-12                  288.6n ± 1%   289.2n ± 0%       ~ (p=0.030 n=15)
NatSqr/50-12                  653.3n ± 0%   653.3n ± 0%       ~ (p=0.361 n=15)
NatSqr/80-12                  1.492µ ± 0%   1.496µ ± 0%       ~ (p=0.018 n=15)
NatSqr/100-12                 2.270µ ± 1%   2.270µ ± 0%       ~ (p=0.326 n=15)
NatSqr/200-12                 8.776µ ± 1%   8.784µ ± 1%       ~ (p=0.083 n=15)
NatSqr/300-12                 15.07µ ± 0%   15.09µ ± 0%       ~ (p=0.455 n=15)
NatSqr/500-12                 41.71µ ± 0%   41.77µ ± 1%       ~ (p=0.305 n=15)
NatSqr/800-12                 80.77µ ± 1%   80.59µ ± 0%       ~ (p=0.113 n=15)
NatSqr/1000-12                126.4µ ± 1%   126.5µ ± 0%       ~ (p=0.683 n=15)
NatSqr/10000-12               4.204m ± 0%   4.119m ± 0%  -2.02% (p=0.000 n=15)
NatSqr/100000-12              177.0m ± 0%   172.9m ± 0%  -2.31% (p=0.000 n=15)
geomean                       3.790µ        3.757µ       -0.87%

Change-Id: Ifc7a9b61f678df216690511ac8bb9143189a795e
Reviewed-on: https://go-review.googlesource.com/c/go/+/652057
Auto-Submit: Russ Cox <rsc@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Robert Griesemer <gri@google.com>
src/math/big/int.go
src/math/big/nat.go
src/math/big/nat_test.go
src/math/big/natdiv.go
src/math/big/natmul.go

index cb7221250da10756eef5adca5823f50b5ebebde9..4abfd192781c1a56223ca8162b40cc62715a8427 100644 (file)
@@ -181,10 +181,14 @@ func (z *Int) Sub(x, y *Int) *Int {
 
 // Mul sets z to the product x*y and returns z.
 func (z *Int) Mul(x, y *Int) *Int {
-       return z.mul(nil, x, y)
+       z.mul(nil, x, y)
+       return z
 }
 
-func (z *Int) mul(stk *stack, x, y *Int) *Int {
+// mul is like Mul but takes an explicit stack to use, for internal use.
+// It does not return a *Int because doing so makes the stack-allocated Ints
+// used in natmul.go escape to the heap (even though the result is unused).
+func (z *Int) mul(stk *stack, x, y *Int) {
        // x * y == x * y
        // x * (-y) == -(x * y)
        // (-x) * y == -(x * y)
@@ -192,11 +196,10 @@ func (z *Int) mul(stk *stack, x, y *Int) *Int {
        if x == y {
                z.abs = z.abs.sqr(stk, x.abs)
                z.neg = false
-               return z
+               return
        }
        z.abs = z.abs.mul(stk, x.abs, y.abs)
        z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
-       return z
 }
 
 // MulRange sets z to the product of all integers
index a091608f3e28cbc5e15e7915078d533300a917ab..922cdb43067738922cd2995350202c9cfb831a2f 100644 (file)
@@ -227,15 +227,14 @@ func alias(x, y nat) bool {
        return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
 }
 
-// addAt implements z += x<<(_W*i); z must be long enough.
+// addTo implements z += x; z must be long enough.
 // (we don't use nat.add because we need z to stay the same
 // slice, and we don't need to normalize z after each addition)
-func addAt(z, x nat, i int) {
+func addTo(z, x nat) {
        if n := len(x); n > 0 {
-               if c := addVV(z[i:i+n], z[i:], x); c != 0 {
-                       j := i + n
-                       if j < len(z) {
-                               addVW(z[j:], z[j:], c)
+               if c := addVV(z[:n], z, x); c != 0 {
+                       if n < len(z) {
+                               addVW(z[n:], z[n:], c)
                        }
                }
        }
index 1811dccfe33d6188ca2b22b17fd812bacb301659..251877b50674781cb5c8734b2955bd4e0b8477db 100644 (file)
@@ -7,6 +7,8 @@ package big
 import (
        "fmt"
        "math"
+       "math/bits"
+       "math/rand/v2"
        "runtime"
        "strings"
        "testing"
@@ -56,12 +58,84 @@ var sumNN = []argNN{
        {nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}},
 }
 
-var prodNN = []argNN{
-       {},
-       {nil, nil, nil},
+var prodNN = append(prodTests(), prodNNExtra...)
+
+func permute[E any](x []E) {
+       out := make([]E, len(x))
+       for i, j := range rand.Perm(len(x)) {
+               out[i] = x[j]
+       }
+       copy(x, out)
+}
+
+// testMul returns the product of x and y using the grade-school algorithm,
+// as a reference implementation.
+func testMul(x, y nat) nat {
+       z := make(nat, len(x)+len(y))
+       for i, xi := range x {
+               for j, yj := range y {
+                       hi, lo := bits.Mul(uint(xi), uint(yj))
+                       k := i + j
+                       s, c := bits.Add(uint(z[k]), lo, 0)
+                       z[k] = Word(s)
+                       k++
+                       for hi != 0 || c != 0 {
+                               s, c = bits.Add(uint(z[k]), hi, c)
+                               hi = 0
+                               z[k] = Word(s)
+                               k++
+                       }
+               }
+       }
+       return z.norm()
+}
+
+func prodTests() []argNN {
+       var tests []argNN
+       for size := range 10 {
+               var x, y nat
+               for i := range size {
+                       x = append(x, Word(i+1))
+                       y = append(y, Word(i+1+size))
+               }
+               permute(x)
+               permute(y)
+               x = x.norm()
+               y = y.norm()
+               tests = append(tests, argNN{testMul(x, y), x, y})
+       }
+
+       words := []Word{0, 1, 2, 3, 4, ^Word(0), ^Word(1), ^Word(2), ^Word(3)}
+       for size := range 10 {
+               if size == 0 {
+                       continue // already tested the only 0-length possibility above
+               }
+               for range 10 {
+                       x := make(nat, size)
+                       y := make(nat, size)
+                       for i := range size {
+                               x[i] = words[rand.N(len(words))]
+                               y[i] = words[rand.N(len(words))]
+                       }
+                       x = x.norm()
+                       y = y.norm()
+                       tests = append(tests, argNN{testMul(x, y), x, y})
+               }
+       }
+       return tests
+}
+
+var prodNNExtra = []argNN{
        {nil, nat{991}, nil},
        {nat{991}, nat{991}, nat{1}},
        {nat{991 * 991}, nat{991}, nat{991}},
+       {nat{8, 22, 15}, nat{2, 3}, nat{4, 5}},
+       {nat{10, 27, 52, 45, 28}, nat{2, 3, 4}, nat{5, 6, 7}},
+       {nat{12, 32, 61, 100, 94, 76, 45}, nat{2, 3, 4, 5}, nat{6, 7, 8, 9}},
+       {nat{12, 32, 61, 100, 94, 76, 45}, nat{2, 3, 4, 5}, nat{6, 7, 8, 9}},
+       {nat{14, 37, 70, 114, 170, 166, 148, 115, 66}, nat{2, 3, 4, 5, 6}, nat{7, 8, 9, 10, 11}},
+       {nat{991 * 991, 991 * 2, 1}, nat{991, 1}, nat{991, 1}},
+       {nat{991 * 991, 991 * 777 * 2, 777 * 777}, nat{991, 777}, nat{991, 777}},
        {nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
        {nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
        {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
@@ -114,38 +188,113 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
 }
 
 func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) {
+       t.Helper()
        stk := getStack()
        defer stk.free()
        z := f(nil, stk, a.x, a.y)
        if z.cmp(a.z) != 0 {
-               t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
+               t.Fatalf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
        }
 }
 
-func TestFunNN(t *testing.T) {
-       for _, a := range sumNN {
-               arg := a
-               testFunNN(t, "add", nat.add, arg)
+func setDuringTest[V any](t *testing.T, p *V, v V) {
+       old := *p
+       *p = v
+       t.Cleanup(func() { *p = old })
+}
 
-               arg = argNN{a.z, a.y, a.x}
-               testFunNN(t, "add symmetric", nat.add, arg)
+func TestAdd(t *testing.T) {
+       for _, a := range sumNN {
+               testFunNN(t, "add", nat.add, a)
+               a.x, a.y = a.y, a.x
+               testFunNN(t, "add", nat.add, a)
+       }
+}
 
-               arg = argNN{a.x, a.z, a.y}
-               testFunNN(t, "sub", nat.sub, arg)
+func TestSub(t *testing.T) {
+       for _, a := range sumNN {
+               a.x, a.z = a.z, a.x
+               testFunNN(t, "sub", nat.sub, a)
 
-               arg = argNN{a.y, a.z, a.x}
-               testFunNN(t, "sub symmetric", nat.sub, arg)
+               a.y, a.z = a.z, a.y
+               testFunNN(t, "sub", nat.sub, a)
        }
+}
 
-       for _, a := range prodNN {
-               arg := a
-               testFunSNN(t, "mul", nat.mul, arg)
+func TestNatMul(t *testing.T) {
+       t.Run("Basic", func(t *testing.T) {
+               setDuringTest(t, &karatsubaThreshold, 1e9)
+               for _, a := range prodNN {
+                       if len(a.z) >= 100 {
+                               continue
+                       }
+                       testFunSNN(t, "mul", nat.mul, a)
+                       a.x, a.y = a.y, a.x
+                       testFunSNN(t, "mul", nat.mul, a)
+               }
+       })
+       t.Run("Karatsuba", func(t *testing.T) {
+               setDuringTest(t, &karatsubaThreshold, 2)
+               for _, a := range prodNN {
+                       testFunSNN(t, "mul", nat.mul, a)
+                       a.x, a.y = a.y, a.x
+                       testFunSNN(t, "mul", nat.mul, a)
+               }
+       })
 
-               arg = argNN{a.z, a.y, a.x}
-               testFunSNN(t, "mul symmetric", nat.mul, arg)
+       t.Run("Mul", func(t *testing.T) {
+               for _, a := range prodNN {
+                       testFunSNN(t, "mul", nat.mul, a)
+                       a.x, a.y = a.y, a.x
+                       testFunSNN(t, "mul", nat.mul, a)
+               }
+       })
+}
+
+func testSqr(t *testing.T, x nat) {
+       stk := getStack()
+       defer stk.free()
+
+       got := make(nat, 2*len(x))
+       want := make(nat, 2*len(x))
+       got = got.sqr(stk, x)
+       want = want.mul(stk, x, x)
+       if got.cmp(want) != 0 {
+               t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
        }
 }
 
+func TestNatSqr(t *testing.T) {
+       t.Run("Basic", func(t *testing.T) {
+               setDuringTest(t, &basicSqrThreshold, 0)
+               setDuringTest(t, &karatsubaSqrThreshold, 1e9)
+               for _, a := range prodNN {
+                       if len(a.z) >= 100 {
+                               continue
+                       }
+                       testSqr(t, a.x)
+                       testSqr(t, a.y)
+                       testSqr(t, a.z)
+               }
+       })
+       t.Run("Karatsuba", func(t *testing.T) {
+               setDuringTest(t, &basicSqrThreshold, 2)
+               setDuringTest(t, &karatsubaSqrThreshold, 2)
+               for _, a := range prodNN {
+                       testSqr(t, a.x)
+                       testSqr(t, a.y)
+                       testSqr(t, a.z)
+               }
+       })
+       t.Run("Sqr", func(t *testing.T) {
+               for _, a := range prodNN {
+                       testSqr(t, a.x)
+                       testSqr(t, a.y)
+                       testSqr(t, a.z)
+               }
+       })
+}
+
 var mulRangesN = []struct {
        a, b uint64
        prod string
@@ -739,33 +888,6 @@ func TestSticky(t *testing.T) {
        }
 }
 
-func testSqr(t *testing.T, x nat) {
-       stk := getStack()
-       defer stk.free()
-
-       got := make(nat, 2*len(x))
-       want := make(nat, 2*len(x))
-       got = got.sqr(stk, x)
-       want = want.mul(stk, x, x)
-       if got.cmp(want) != 0 {
-               t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
-       }
-}
-
-func TestSqr(t *testing.T) {
-       for _, a := range prodNN {
-               if a.x != nil {
-                       testSqr(t, a.x)
-               }
-               if a.y != nil {
-                       testSqr(t, a.y)
-               }
-               if a.z != nil {
-                       testSqr(t, a.z)
-               }
-       }
-}
-
 func benchmarkNatSqr(b *testing.B, nwords int) {
        x := rndNat(nwords)
        var z nat
index 7f6a1bbb075a36aa1357a9ea376525200530aca6..b67d6afeda8577eb71fadf6181e77bd6d25e234e 100644 (file)
@@ -831,7 +831,7 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
                        if len(qhatv) > s {
                                subVW(qhatv[s:], qhatv[s:], c)
                        }
-                       addAt(uu[s:], v[s:], 0)
+                       addTo(uu[s:], v[s:])
                }
                if qhatv.cmp(uu.norm()) > 0 {
                        panic("impossible")
@@ -840,7 +840,7 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
                if c > 0 {
                        subVW(uu[len(qhatv):], uu[len(qhatv):], c)
                }
-               addAt(z, qhat, j-B)
+               addTo(z[j-B:], qhat)
                j -= B
                stk.restore(mark)
        }
@@ -865,7 +865,7 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
                        if len(qhatv) > s {
                                subVW(qhatv[s:], qhatv[s:], c)
                        }
-                       addAt(u[s:], v[s:], 0)
+                       addTo(u[s:], v[s:])
                }
        }
        if qhatv.cmp(u.norm()) > 0 {
@@ -880,5 +880,5 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
        }
 
        // Done!
-       addAt(z, qhat.norm(), 0)
+       addTo(z, qhat.norm())
 }
index 001e78d2c13b235f2d486228607e069a26246268..8ab4d13cbaba70dd40677d55ed291d63fd5f76b9 100644 (file)
@@ -31,71 +31,31 @@ func (z nat) mul(stk *stack, x, y nat) nat {
        if alias(z, x) || alias(z, y) {
                z = nil // z is an alias for x or y - cannot reuse
        }
+       z = z.make(m + n)
 
        // use basic multiplication if the numbers are small
        if n < karatsubaThreshold {
-               z = z.make(m + n)
                basicMul(z, x, y)
                return z.norm()
        }
-       // m >= n && n >= karatsubaThreshold && n >= 2
-
-       // determine Karatsuba length k such that
-       //
-       //   x = xh*b + x0  (0 <= x0 < b)
-       //   y = yh*b + y0  (0 <= y0 < b)
-       //   b = 1<<(_W*k)  ("base" of digits xi, yi)
-       //
-       k := karatsubaLen(n, karatsubaThreshold)
-       // k <= n
 
        if stk == nil {
                stk = getStack()
                defer stk.free()
        }
 
-       // multiply x0 and y0 via Karatsuba
-       x0 := x[0:k]      // x0 is not normalized
-       y0 := y[0:k]      // y0 is not normalized
-       z = z.make(m + n) // enough space for full result of x*y
-       karatsuba(stk, z, x0, y0)
-       clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
+       // Let x = x1:x0 where x0 is the same length as y.
+       // Compute z = x0*y and then add in x1*y in sections
+       // if needed.
+       karatsuba(stk, z[:2*n], x[:n], y)
 
-       // If xh != 0 or yh != 0, add the missing terms to z. For
-       //
-       //   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
-       //   yh =                         y1*b (0 <= y1 < b)
-       //
-       // the missing terms are
-       //
-       //   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
-       //
-       // since all the yi for i > 1 are 0 by choice of k: If any of them
-       // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
-       // be a larger valid threshold contradicting the assumption about k.
-       //
-       if k < n || m != n {
+       if n < m {
+               clear(z[2*n:])
                defer stk.restore(stk.save())
-               t := stk.nat(3 * k)
-
-               // add x0*y1*b
-               x0 := x0.norm()
-               y1 := y[k:]            // y1 is normalized because y is
-               t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array
-               addAt(z, t, k)
-
-               // add xi*y0<<i, xi*y1*b<<(i+k)
-               y0 := y0.norm()
-               for i := k; i < len(x); i += k {
-                       xi := x[i:]
-                       if len(xi) > k {
-                               xi = xi[:k]
-                       }
-                       xi = xi.norm()
-                       t = t.mul(stk, xi, y0)
-                       addAt(z, t, i)
-                       t = t.mul(stk, xi, y1)
-                       addAt(z, t, i+k)
+               t := stk.nat(2 * n)
+               for i := n; i < m; i += n {
+                       t = t.mul(stk, x[i:min(i+n, len(x))], y)
+                       addTo(z[i:], t)
                }
        }
 
@@ -142,28 +102,7 @@ func (z nat) sqr(stk *stack, x nat) nat {
                return z.norm()
        }
 
-       // Use Karatsuba multiplication optimized for x == y.
-       // The algorithm and layout of z are the same as for mul.
-
-       // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
-
-       k := karatsubaLen(n, karatsubaSqrThreshold)
-
-       x0 := x[0:k]
-       karatsubaSqr(stk, z, x0) // z = x0^2
-       clear(z[2*k:])
-
-       if k < n {
-               t := stk.nat(2 * k)
-               x0 := x0.norm()
-               x1 := x[k:]
-               t = t.mul(stk, x0, x1)
-               addAt(z, t, k)
-               addAt(z, t, k) // z = 2*x1*x0*b + x0^2
-               t = t.sqr(stk, x1)
-               addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
-       }
-
+       karatsubaSqr(stk, z, x)
        return z.norm()
 }
 
@@ -188,6 +127,7 @@ func basicSqr(stk *stack, z, x nat) {
        addVV(z, z, t)                              // combine the result
 }
 
+// mulAddWW returns z = x*y + r.
 func (z nat) mulAddWW(x nat, y, r Word) nat {
        m := len(x)
        if m == 0 || y == 0 {
@@ -212,155 +152,199 @@ func basicMul(z, x, y nat) {
        }
 }
 
-// karatsubaLen computes an approximation to the maximum k <= n such that
-// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
-// result is the largest number that can be divided repeatedly by 2 before
-// becoming about the value of threshold.
-func karatsubaLen(n, threshold int) int {
-       i := uint(0)
-       for n > threshold {
-               n >>= 1
-               i++
-       }
-       return n << i
-}
-
-// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
-// Factored out for readability - do not use outside karatsuba.
-func karatsubaAdd(z, x nat, n int) {
-       if c := addVV(z[0:n], z, x); c != 0 {
-               addVW(z[n:n+n>>1], z[n:], c)
-       }
-}
-
-// Like karatsubaAdd, but does subtract.
-func karatsubaSub(z, x nat, n int) {
-       if c := subVV(z[0:n], z, x); c != 0 {
-               subVW(z[n:n+n>>1], z[n:], c)
-       }
-}
-
-// karatsuba multiplies x and y and leaves the result in z.
-// Both x and y must have the same length n and n must be a
-// power of 2. The result vector z must have len(z) == len(x)+len(y).
-// The (non-normalized) result is placed in z.
+// karatsuba multiplies x and y,
+// writing the (non-normalized) result to z.
+// x and y must have the same length n,
+// and z must have length twice that.
 func karatsuba(stk *stack, z, x, y nat) {
        n := len(y)
+       if len(x) != n || len(z) != 2*n {
+               panic("bad karatsuba length")
+       }
 
-       // Switch to basic multiplication if numbers are odd or small.
-       // (n is always even if karatsubaThreshold is even, but be
-       // conservative)
-       if n&1 != 0 || n < karatsubaThreshold || n < 2 {
+       // Fall back to basic algorithm if small enough.
+       if n < karatsubaThreshold || n < 2 {
                basicMul(z, x, y)
                return
        }
-       // n&1 == 0 && n >= karatsubaThreshold && n >= 2
 
-       // Karatsuba multiplication is based on the observation that
-       // for two numbers x and y with:
+       // Let the notation x1:x0 denote the nat (x1<<N)+x0 for some N,
+       // and similarly z2:z1:z0 = (z2<<2N)+(z1<<N)+z0.
        //
-       //   x = x1*b + x0
-       //   y = y1*b + y0
+       // (Note that z0, z1, z2 might be ≥ 2**N, in which case the high
+       // bits of, say, z0 are being added to the low bits of z1 in this notation.)
        //
-       // the product x*y can be obtained with 3 products z2, z1, z0
-       // instead of 4:
+       // Karatsuba multiplication is based on the observation that
        //
-       //   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
-       //       =    z2*b*b +              z1*b +    z0
+       //      x1:x0 * y1:y0 = x1*y1:(x0*y1+y0*x1):x0*y0
+       //                    = x1*y1:((x0-x1)*(y1-y0)+x1*y1+x0*y0):x0*y0
        //
-       // with:
+       // The second form uses only three half-width multiplications
+       // instead of the four that the straightforward first form does.
        //
-       //   xd = x1 - x0
-       //   yd = y0 - y1
+       // We call the three pieces z0, z1, z2:
        //
-       //   z1 =      xd*yd                    + z2 + z0
-       //      = (x1-x0)*(y0 - y1)             + z2 + z0
-       //      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
-       //      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
-       //      = x1*y0                 + x0*y1
-
-       // split x, y into "digits"
-       n2 := n >> 1              // n2 >= 1
-       x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
-       y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
-
-       // compute z0 and z2 with the result "in place" in z
-       karatsuba(stk, z, x0, y0)     // z0 = x0*y0
-       karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1
-
-       // compute xd, yd (or the negative value if underflow occurs)
-       s := 1 // sign of product xd*yd
-       defer stk.restore(stk.save())
-       xd := stk.nat(n2)
-       yd := stk.nat(n2)
-       if subVV(xd, x1, x0) != 0 { // x1-x0
-               s = -s
-               subVV(xd, x0, x1) // x0-x1
-       }
-       if subVV(yd, y0, y1) != 0 { // y0-y1
-               s = -s
-               subVV(yd, y1, y0) // y1-y0
-       }
-
-       // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
-       // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
-       p := stk.nat(2 * n2)
-       karatsuba(stk, p, xd, yd)
+       //      z0 = x0*y0
+       //      z2 = x1*y1
+       //      z1 = (x0-x1)*(y1-y0) + z0 + z2
 
-       // save original z2:z0
-       // (ok to use upper half of z since we're done recurring)
-       r := stk.nat(n * 2)
-       copy(r, z[:n*2])
+       n2 := (n + 1) / 2
+       x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
+       y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()}
+       z0 := &Int{abs: z[0 : 2*n2]}
+       z2 := &Int{abs: z[2*n2:]}
 
-       // add up all partial products
-       //
-       //   2*n     n     0
-       // z = [ z2  | z0  ]
-       //   +    [ z0  ]
-       //   +    [ z2  ]
-       //   +    [  p  ]
-       //
-       karatsubaAdd(z[n2:], r, n)
-       karatsubaAdd(z[n2:], r[n:], n)
-       if s > 0 {
-               karatsubaAdd(z[n2:], p, n)
-       } else {
-               karatsubaSub(z[n2:], p, n)
+       // Allocate temporary storage for z1; repurpose z0 to hold tx and ty.
+       defer stk.restore(stk.save())
+       z1 := &Int{abs: stk.nat(2*n2 + 1)}
+       tx := &Int{abs: z[0:n2]}
+       ty := &Int{abs: z[n2 : 2*n2]}
+
+       tx.Sub(x0, x1)
+       ty.Sub(y1, y0)
+       z1.mul(stk, tx, ty)
+
+       clear(z)
+       z0.mul(stk, x0, y0)
+       z2.mul(stk, x1, y1)
+       z1.Add(z1, z0)
+       z1.Add(z1, z2)
+       addTo(z[n2:], z1.abs)
+
+       // Debug mode: double-check answer and print trace on failure.
+       const debug = false
+       if debug {
+               zz := make(nat, len(z))
+               basicMul(zz, x, y)
+               if z.cmp(zz) != 0 {
+                       // All the temps were aliased to z and gone. Recompute.
+                       z0 = new(Int)
+                       z0.mul(stk, x0, y0)
+                       tx = new(Int).Sub(x1, x0)
+                       ty = new(Int).Sub(y0, y1)
+                       z2 = new(Int)
+                       z2.mul(stk, x1, y1)
+                       print("karatsuba wrong\n")
+                       trace("x ", &Int{abs: x})
+                       trace("y ", &Int{abs: y})
+                       trace("z ", &Int{abs: z})
+                       trace("zz", &Int{abs: zz})
+                       trace("x0", x0)
+                       trace("x1", x1)
+                       trace("y0", y0)
+                       trace("y1", y1)
+                       trace("tx", tx)
+                       trace("ty", ty)
+                       trace("z0", z0)
+                       trace("z1", z1)
+                       trace("z2", z2)
+                       panic("karatsuba")
+               }
        }
+
 }
 
-// karatsubaSqr squares x and leaves the result in z.
-// len(x) must be a power of 2 and len(z) == 2*len(x).
-// The (non-normalized) result is placed in z.
-//
-// The algorithm and the layout of z are the same as for karatsuba.
+// karatsubaSqr squares x,
+// writing the (non-normalized) result to z.
+// z must have length 2*len(x).
+// It is analogous to [karatsuba] but can run faster
+// knowing both multiplicands are the same value.
 func karatsubaSqr(stk *stack, z, x nat) {
        n := len(x)
+       if len(z) != 2*n {
+               panic("bad karatsubaSqr length")
+       }
 
-       if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
-               basicSqr(stk, z[:2*n], x)
+       if n < karatsubaSqrThreshold || n < 2 {
+               basicSqr(stk, z, x)
                return
        }
 
-       n2 := n >> 1
-       x1, x0 := x[n2:], x[0:n2]
+       // Recall that for karatsuba we want to compute:
+       //
+       //      x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0
+       //                = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0
+       //                    = z2:z1:z0
+       // where:
+       //
+       //      z0 = x0y0
+       //      z2 = x1y1
+       //      z1 = (x0-x1)*(y1-y0) + z0 + z2
+       //
+       // When x = y, these simplify to:
+       //
+       //      z0 = x0²
+       //      z2 = x1²
+       //      z1 = z0 + z2 - (x0-x1)²
 
-       karatsubaSqr(stk, z, x0)
-       karatsubaSqr(stk, z[n:], x1)
+       n2 := (n + 1) / 2
+       x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
+       z0 := &Int{abs: z[0 : 2*n2]}
+       z2 := &Int{abs: z[2*n2:]}
 
-       // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
+       // Allocate temporary storage for z1; repurpose z0 to hold tx.
        defer stk.restore(stk.save())
-       p := stk.nat(2 * n2)
-       r := stk.nat(n * 2)
-       xd := r[:n2]
-       if subVV(xd, x1, x0) != 0 {
-               subVV(xd, x0, x1)
+       z1 := &Int{abs: stk.nat(2*n2 + 1)}
+       tx := &Int{abs: z[0:n2]}
+
+       tx.Sub(x0, x1)
+       z1.abs = z1.abs.sqr(stk, tx.abs)
+       z1.neg = true
+
+       clear(z)
+       z0.abs = z0.abs.sqr(stk, x0.abs)
+       z2.abs = z2.abs.sqr(stk, x1.abs)
+       z1.Add(z1, z0)
+       z1.Add(z1, z2)
+       addTo(z[n2:], z1.abs)
+
+       // Debug mode: double-check answer and print trace on failure.
+       const debug = false
+       if debug {
+               zz := make(nat, len(z))
+               basicSqr(stk, zz, x)
+               if z.cmp(zz) != 0 {
+                       // All the temps were aliased to z and gone. Recompute.
+                       tx = new(Int).Sub(x0, x1)
+                       z0 = new(Int).Mul(x0, x0)
+                       z2 = new(Int).Mul(x1, x1)
+                       z1 = new(Int).Mul(tx, tx)
+                       z1.Neg(z1)
+                       z1.Add(z1, z0)
+                       z1.Add(z1, z2)
+                       print("karatsubaSqr wrong\n")
+                       trace("x ", &Int{abs: x})
+                       trace("z ", &Int{abs: z})
+                       trace("zz", &Int{abs: zz})
+                       trace("x0", x0)
+                       trace("x1", x1)
+                       trace("z0", z0)
+                       trace("z1", z1)
+                       trace("z2", z2)
+                       panic("karatsubaSqr")
+               }
+       }
+}
+
+// ifmt returns the debug formatting of the Int x: 0xHEX.
+func ifmt(x *Int) string {
+       neg, s, t := "", x.Text(16), ""
+       if s == "" { // happens for denormalized zero
+               s = "0x0"
+       }
+       if s[0] == '-' {
+               neg, s = "-", s[1:]
        }
 
-       karatsubaSqr(stk, p, xd)
-       copy(r, z[:n*2])
+       // Add _ between words.
+       const D = _W / 4 // digits per chunk
+       for len(s) > D {
+               s, t = s[:len(s)-D], s[len(s)-D:]+"_"+t
+       }
+       return neg + s + t
+}
 
-       karatsubaAdd(z[n2:], r, n)
-       karatsubaAdd(z[n2:], r[n:], n)
-       karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
+// trace prints a single debug value.
+func trace(name string, x *Int) {
+       print(name, "=", ifmt(x), "\n")
 }