From: Russ Cox Date: Fri, 17 Jan 2025 22:26:59 +0000 (-0500) Subject: math/big: simplify, speed up Karatsuba multiplication X-Git-Tag: go1.25rc1~738 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=d037ed62bc583af358b2cc5aeb151872a6ba7c2e;p=gostls13.git math/big: simplify, speed up Karatsuba multiplication The old Karatsuba implementation only operated on lengths that are a power of two times a number smaller than karatsubaThreshold. For example, when karatsubaThreshold = 40, multiplying a pair of 99-word numbers runs karatsuba on the low 96 (= 39<<2) words and then has to fix up the answer to include the high 3 words of each. I suspect this requirement was needed to make the analysis of how many temporary words to reserve easier, back when the answer was 3*n and depended on exactly halving the size at each Karatsuba step. Now that we have the more flexible temporary allocation stack, we can change Karatsuba to accept operands of odd length. Doing so avoids most of the fixup that the old approach required. For example, multiplying a pair of 99-word numbers runs karatsuba on all 99 words now. This is simpler and about the same speed or, for large cases, faster. goos: linux goarch: amd64 pkg: math/big cpu: Intel(R) Xeon(R) CPU @ 3.10GHz │ old │ new │ │ sec/op │ sec/op vs base │ GCD10x10/WithoutXY-16 99.62n ± 3% 99.10n ± 3% ~ (p=0.009 n=15) GCD10x10/WithXY-16 243.4n ± 1% 245.2n ± 1% ~ (p=0.009 n=15) GCD100x100/WithoutXY-16 921.9n ± 1% 919.2n ± 1% ~ (p=0.076 n=15) GCD100x100/WithXY-16 1.527µ ± 1% 1.526µ ± 0% ~ (p=0.813 n=15) GCD1000x1000/WithoutXY-16 9.704µ ± 1% 9.696µ ± 0% ~ (p=0.532 n=15) GCD1000x1000/WithXY-16 14.03µ ± 1% 13.96µ ± 0% ~ (p=0.014 n=15) GCD10000x10000/WithoutXY-16 206.5µ ± 2% 206.5µ ± 0% ~ (p=0.967 n=15) GCD10000x10000/WithXY-16 398.0µ ± 1% 397.4µ ± 0% ~ (p=0.683 n=15) Div/20/10-16 22.22n ± 0% 22.23n ± 0% ~ (p=0.105 n=15) Div/40/20-16 22.23n ± 0% 22.23n ± 0% ~ (p=0.307 n=15) Div/100/50-16 55.47n ± 0% 55.47n ± 0% ~ (p=0.573 n=15) Div/200/100-16 174.9n ± 1% 174.6n ± 1% ~ (p=0.814 n=15) Div/400/200-16 209.5n ± 1% 210.5n ± 1% ~ (p=0.454 n=15) Div/1000/500-16 379.9n ± 0% 383.5n ± 2% ~ (p=0.123 n=15) Div/2000/1000-16 780.1n ± 0% 784.6n ± 1% +0.58% (p=0.000 n=15) Div/20000/10000-16 25.22µ ± 1% 25.15µ ± 0% ~ (p=0.213 n=15) Div/200000/100000-16 921.8µ ± 1% 926.1µ ± 0% ~ (p=0.009 n=15) Div/2000000/1000000-16 37.91m ± 0% 35.63m ± 0% -6.02% (p=0.000 n=15) Div/20000000/10000000-16 1.378 ± 0% 1.336 ± 0% -3.03% (p=0.000 n=15) NatMul/10-16 166.8n ± 4% 168.9n ± 3% ~ (p=0.008 n=15) NatMul/100-16 5.519µ ± 2% 5.548µ ± 4% ~ (p=0.032 n=15) NatMul/1000-16 230.4µ ± 1% 220.2µ ± 1% -4.43% (p=0.000 n=15) NatMul/10000-16 8.569m ± 1% 8.640m ± 1% ~ (p=0.005 n=15) NatMul/100000-16 376.5m ± 1% 334.1m ± 0% -11.26% (p=0.000 n=15) NatSqr/1-16 27.85n ± 5% 28.60n ± 2% ~ (p=0.123 n=15) NatSqr/2-16 47.99n ± 2% 48.84n ± 1% ~ (p=0.008 n=15) NatSqr/3-16 59.41n ± 2% 60.87n ± 2% +2.46% (p=0.001 n=15) NatSqr/5-16 87.27n ± 2% 89.31n ± 3% ~ (p=0.087 n=15) NatSqr/8-16 124.6n ± 3% 128.9n ± 3% ~ (p=0.006 n=15) NatSqr/10-16 166.3n ± 3% 172.7n ± 3% ~ (p=0.002 n=15) NatSqr/20-16 385.2n ± 2% 394.7n ± 3% ~ (p=0.036 n=15) NatSqr/30-16 622.7n ± 3% 642.9n ± 3% ~ (p=0.032 n=15) NatSqr/50-16 1.274µ ± 3% 1.323µ ± 4% ~ (p=0.003 n=15) NatSqr/80-16 2.606µ ± 4% 2.714µ ± 4% ~ (p=0.044 n=15) NatSqr/100-16 3.731µ ± 4% 3.871µ ± 4% ~ (p=0.038 n=15) NatSqr/200-16 12.99µ ± 2% 13.09µ ± 3% ~ (p=0.838 n=15) NatSqr/300-16 22.87µ ± 2% 23.25µ ± 2% ~ (p=0.285 n=15) NatSqr/500-16 58.43µ ± 1% 58.25µ ± 2% ~ (p=0.345 n=15) NatSqr/800-16 115.3µ ± 3% 116.2µ ± 3% ~ (p=0.126 n=15) NatSqr/1000-16 173.9µ ± 1% 174.3µ ± 1% ~ (p=0.935 n=15) NatSqr/10000-16 6.133m ± 2% 6.034m ± 1% -1.62% (p=0.000 n=15) NatSqr/100000-16 253.8m ± 1% 241.5m ± 0% -4.87% (p=0.000 n=15) geomean 7.745µ 7.760µ +0.19% goos: linux goarch: amd64 pkg: math/big cpu: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz │ old │ new │ │ sec/op │ sec/op vs base │ GCD10x10/WithoutXY-88 62.17n ± 4% 61.44n ± 0% -1.17% (p=0.000 n=15) GCD10x10/WithXY-88 173.4n ± 2% 172.4n ± 4% ~ (p=0.615 n=15) GCD100x100/WithoutXY-88 584.0n ± 1% 582.9n ± 0% ~ (p=0.009 n=15) GCD100x100/WithXY-88 1.098µ ± 1% 1.091µ ± 2% ~ (p=0.002 n=15) GCD1000x1000/WithoutXY-88 6.055µ ± 0% 6.049µ ± 0% ~ (p=0.007 n=15) GCD1000x1000/WithXY-88 9.430µ ± 0% 9.417µ ± 1% ~ (p=0.123 n=15) GCD10000x10000/WithoutXY-88 153.4µ ± 2% 149.0µ ± 2% -2.85% (p=0.000 n=15) GCD10000x10000/WithXY-88 350.6µ ± 3% 349.0µ ± 2% ~ (p=0.126 n=15) Div/20/10-88 13.12n ± 0% 13.12n ± 1% 0.00% (p=0.042 n=15) Div/40/20-88 13.12n ± 0% 13.13n ± 0% ~ (p=0.004 n=15) Div/100/50-88 25.49n ± 0% 25.49n ± 0% ~ (p=0.452 n=15) Div/200/100-88 115.7n ± 2% 113.8n ± 2% ~ (p=0.212 n=15) Div/400/200-88 135.0n ± 1% 136.1n ± 1% ~ (p=0.005 n=15) Div/1000/500-88 257.5n ± 1% 259.9n ± 1% ~ (p=0.004 n=15) Div/2000/1000-88 567.5n ± 1% 572.4n ± 2% ~ (p=0.616 n=15) Div/20000/10000-88 25.65µ ± 0% 25.77µ ± 1% ~ (p=0.032 n=15) Div/200000/100000-88 777.4µ ± 1% 754.3µ ± 1% -2.97% (p=0.000 n=15) Div/2000000/1000000-88 33.66m ± 0% 31.37m ± 0% -6.81% (p=0.000 n=15) Div/20000000/10000000-88 1.320 ± 0% 1.266 ± 0% -4.04% (p=0.000 n=15) NatMul/10-88 151.9n ± 7% 143.3n ± 7% ~ (p=0.878 n=15) NatMul/100-88 4.418µ ± 2% 4.337µ ± 3% ~ (p=0.512 n=15) NatMul/1000-88 206.8µ ± 1% 189.8µ ± 1% -8.25% (p=0.000 n=15) NatMul/10000-88 8.531m ± 1% 8.095m ± 0% -5.12% (p=0.000 n=15) NatMul/100000-88 298.9m ± 0% 260.5m ± 1% -12.85% (p=0.000 n=15) NatSqr/1-88 27.55n ± 6% 28.25n ± 7% ~ (p=0.024 n=15) NatSqr/2-88 44.71n ± 6% 46.21n ± 9% ~ (p=0.024 n=15) NatSqr/3-88 55.44n ± 4% 58.41n ± 10% ~ (p=0.126 n=15) NatSqr/5-88 80.71n ± 5% 81.41n ± 5% ~ (p=0.032 n=15) NatSqr/8-88 115.7n ± 4% 115.4n ± 5% ~ (p=0.814 n=15) NatSqr/10-88 147.4n ± 4% 147.3n ± 4% ~ (p=0.505 n=15) NatSqr/20-88 337.8n ± 3% 337.3n ± 4% ~ (p=0.814 n=15) NatSqr/30-88 556.9n ± 3% 557.6n ± 4% ~ (p=0.814 n=15) NatSqr/50-88 1.208µ ± 4% 1.208µ ± 3% ~ (p=0.910 n=15) NatSqr/80-88 2.591µ ± 3% 2.581µ ± 3% ~ (p=0.705 n=15) NatSqr/100-88 3.870µ ± 3% 3.858µ ± 3% ~ (p=0.846 n=15) NatSqr/200-88 14.43µ ± 3% 14.28µ ± 2% ~ (p=0.383 n=15) NatSqr/300-88 24.68µ ± 2% 24.49µ ± 2% ~ (p=0.624 n=15) NatSqr/500-88 66.27µ ± 1% 66.18µ ± 1% ~ (p=0.735 n=15) NatSqr/800-88 128.7µ ± 1% 127.4µ ± 1% ~ (p=0.050 n=15) NatSqr/1000-88 198.7µ ± 1% 197.7µ ± 1% ~ (p=0.229 n=15) NatSqr/10000-88 6.582m ± 1% 6.426m ± 1% -2.37% (p=0.000 n=15) NatSqr/100000-88 274.3m ± 0% 267.3m ± 0% -2.57% (p=0.000 n=15) geomean 6.518µ 6.438µ -1.22% goos: linux goarch: arm64 pkg: math/big │ old │ new │ │ sec/op │ sec/op vs base │ GCD10x10/WithoutXY-16 61.70n ± 1% 61.32n ± 1% ~ (p=0.361 n=15) GCD10x10/WithXY-16 217.3n ± 1% 217.0n ± 1% ~ (p=0.395 n=15) GCD100x100/WithoutXY-16 569.7n ± 0% 572.6n ± 2% ~ (p=0.213 n=15) GCD100x100/WithXY-16 1.241µ ± 1% 1.236µ ± 1% ~ (p=0.157 n=15) GCD1000x1000/WithoutXY-16 5.558µ ± 0% 5.566µ ± 0% ~ (p=0.228 n=15) GCD1000x1000/WithXY-16 9.319µ ± 0% 9.326µ ± 0% ~ (p=0.233 n=15) GCD10000x10000/WithoutXY-16 126.4µ ± 2% 128.7µ ± 3% ~ (p=0.081 n=15) GCD10000x10000/WithXY-16 279.3µ ± 0% 278.3µ ± 5% ~ (p=0.187 n=15) Div/20/10-16 15.12n ± 1% 15.21n ± 1% ~ (p=0.490 n=15) Div/40/20-16 15.11n ± 0% 15.23n ± 1% ~ (p=0.107 n=15) Div/100/50-16 26.53n ± 0% 26.50n ± 0% ~ (p=0.299 n=15) Div/200/100-16 123.7n ± 0% 124.0n ± 0% ~ (p=0.086 n=15) Div/400/200-16 142.5n ± 0% 142.4n ± 0% ~ (p=0.039 n=15) Div/1000/500-16 259.9n ± 1% 261.2n ± 1% ~ (p=0.044 n=15) Div/2000/1000-16 539.4n ± 1% 532.3n ± 1% -1.32% (p=0.001 n=15) Div/20000/10000-16 22.43µ ± 0% 22.32µ ± 0% -0.49% (p=0.000 n=15) Div/200000/100000-16 898.3µ ± 0% 889.6µ ± 0% -0.96% (p=0.000 n=15) Div/2000000/1000000-16 38.37m ± 0% 35.11m ± 0% -8.49% (p=0.000 n=15) Div/20000000/10000000-16 1.449 ± 0% 1.384 ± 0% -4.48% (p=0.000 n=15) NatMul/10-16 182.0n ± 1% 177.8n ± 1% -2.31% (p=0.000 n=15) NatMul/100-16 5.537µ ± 0% 5.693µ ± 0% +2.82% (p=0.000 n=15) NatMul/1000-16 229.9µ ± 0% 224.8µ ± 0% -2.24% (p=0.000 n=15) NatMul/10000-16 8.985m ± 0% 8.751m ± 0% -2.61% (p=0.000 n=15) NatMul/100000-16 371.1m ± 0% 331.5m ± 0% -10.66% (p=0.000 n=15) NatSqr/1-16 46.77n ± 6% 42.76n ± 1% -8.57% (p=0.000 n=15) NatSqr/2-16 66.99n ± 4% 63.62n ± 1% -5.03% (p=0.000 n=15) NatSqr/3-16 76.79n ± 4% 73.42n ± 1% ~ (p=0.007 n=15) NatSqr/5-16 99.00n ± 3% 95.35n ± 1% -3.69% (p=0.000 n=15) NatSqr/8-16 160.0n ± 3% 155.1n ± 1% -3.06% (p=0.001 n=15) NatSqr/10-16 178.4n ± 2% 175.9n ± 0% -1.40% (p=0.001 n=15) NatSqr/20-16 361.9n ± 2% 361.3n ± 0% ~ (p=0.083 n=15) NatSqr/30-16 584.7n ± 0% 586.8n ± 0% +0.36% (p=0.000 n=15) NatSqr/50-16 1.327µ ± 0% 1.329µ ± 0% ~ (p=0.349 n=15) NatSqr/80-16 2.893µ ± 1% 2.925µ ± 0% +1.11% (p=0.000 n=15) NatSqr/100-16 4.330µ ± 1% 4.381µ ± 0% +1.18% (p=0.000 n=15) NatSqr/200-16 16.25µ ± 1% 16.43µ ± 0% +1.07% (p=0.000 n=15) NatSqr/300-16 27.85µ ± 1% 28.06µ ± 0% +0.77% (p=0.000 n=15) NatSqr/500-16 76.01µ ± 0% 76.34µ ± 0% ~ (p=0.002 n=15) NatSqr/800-16 146.8µ ± 0% 148.1µ ± 0% +0.83% (p=0.000 n=15) NatSqr/1000-16 228.2µ ± 0% 228.6µ ± 0% ~ (p=0.123 n=15) NatSqr/10000-16 7.524m ± 0% 7.426m ± 0% -1.31% (p=0.000 n=15) NatSqr/100000-16 316.7m ± 0% 309.2m ± 0% -2.36% (p=0.000 n=15) geomean 7.264µ 7.172µ -1.27% goos: darwin goarch: arm64 pkg: math/big cpu: Apple M3 Pro │ old │ new │ │ sec/op │ sec/op vs base │ GCD10x10/WithoutXY-12 32.61n ± 1% 32.42n ± 1% ~ (p=0.021 n=15) GCD10x10/WithXY-12 87.70n ± 1% 88.42n ± 1% ~ (p=0.010 n=15) GCD100x100/WithoutXY-12 305.9n ± 0% 306.4n ± 0% ~ (p=0.003 n=15) GCD100x100/WithXY-12 560.3n ± 2% 556.6n ± 1% ~ (p=0.018 n=15) GCD1000x1000/WithoutXY-12 3.509µ ± 2% 3.464µ ± 1% ~ (p=0.145 n=15) GCD1000x1000/WithXY-12 5.347µ ± 2% 5.372µ ± 1% ~ (p=0.046 n=15) GCD10000x10000/WithoutXY-12 73.75µ ± 1% 73.99µ ± 1% ~ (p=0.004 n=15) GCD10000x10000/WithXY-12 148.4µ ± 0% 147.8µ ± 1% ~ (p=0.076 n=15) Div/20/10-12 9.481n ± 0% 9.462n ± 1% ~ (p=0.631 n=15) Div/40/20-12 9.457n ± 0% 9.462n ± 1% ~ (p=0.798 n=15) Div/100/50-12 14.91n ± 0% 14.79n ± 1% -0.80% (p=0.000 n=15) Div/200/100-12 84.56n ± 1% 84.60n ± 1% ~ (p=0.271 n=15) Div/400/200-12 103.8n ± 0% 102.8n ± 0% -0.96% (p=0.000 n=15) Div/1000/500-12 181.3n ± 1% 184.2n ± 2% ~ (p=0.091 n=15) Div/2000/1000-12 397.5n ± 0% 397.4n ± 0% ~ (p=0.299 n=15) Div/20000/10000-12 14.04µ ± 1% 13.99µ ± 0% ~ (p=0.221 n=15) Div/200000/100000-12 523.1µ ± 0% 514.0µ ± 3% ~ (p=0.775 n=15) Div/2000000/1000000-12 21.58m ± 0% 20.01m ± 1% -7.29% (p=0.000 n=15) Div/20000000/10000000-12 813.5m ± 0% 796.2m ± 1% -2.13% (p=0.000 n=15) NatMul/10-12 80.46n ± 1% 80.02n ± 1% ~ (p=0.063 n=15) NatMul/100-12 2.904µ ± 0% 2.979µ ± 1% +2.58% (p=0.000 n=15) NatMul/1000-12 127.8µ ± 0% 122.3µ ± 0% -4.28% (p=0.000 n=15) NatMul/10000-12 5.141m ± 0% 4.975m ± 1% -3.23% (p=0.000 n=15) NatMul/100000-12 208.8m ± 0% 189.6m ± 3% -9.21% (p=0.000 n=15) NatSqr/1-12 11.90n ± 1% 11.76n ± 1% ~ (p=0.059 n=15) NatSqr/2-12 21.33n ± 1% 21.12n ± 0% ~ (p=0.063 n=15) NatSqr/3-12 26.05n ± 1% 25.79n ± 0% ~ (p=0.002 n=15) NatSqr/5-12 37.31n ± 0% 36.98n ± 1% ~ (p=0.008 n=15) NatSqr/8-12 63.07n ± 0% 62.75n ± 1% ~ (p=0.061 n=15) NatSqr/10-12 79.48n ± 0% 79.59n ± 0% ~ (p=0.455 n=15) NatSqr/20-12 173.1n ± 0% 173.2n ± 1% ~ (p=0.518 n=15) NatSqr/30-12 288.6n ± 1% 289.2n ± 0% ~ (p=0.030 n=15) NatSqr/50-12 653.3n ± 0% 653.3n ± 0% ~ (p=0.361 n=15) NatSqr/80-12 1.492µ ± 0% 1.496µ ± 0% ~ (p=0.018 n=15) NatSqr/100-12 2.270µ ± 1% 2.270µ ± 0% ~ (p=0.326 n=15) NatSqr/200-12 8.776µ ± 1% 8.784µ ± 1% ~ (p=0.083 n=15) NatSqr/300-12 15.07µ ± 0% 15.09µ ± 0% ~ (p=0.455 n=15) NatSqr/500-12 41.71µ ± 0% 41.77µ ± 1% ~ (p=0.305 n=15) NatSqr/800-12 80.77µ ± 1% 80.59µ ± 0% ~ (p=0.113 n=15) NatSqr/1000-12 126.4µ ± 1% 126.5µ ± 0% ~ (p=0.683 n=15) NatSqr/10000-12 4.204m ± 0% 4.119m ± 0% -2.02% (p=0.000 n=15) NatSqr/100000-12 177.0m ± 0% 172.9m ± 0% -2.31% (p=0.000 n=15) geomean 3.790µ 3.757µ -0.87% Change-Id: Ifc7a9b61f678df216690511ac8bb9143189a795e Reviewed-on: https://go-review.googlesource.com/c/go/+/652057 Auto-Submit: Russ Cox LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Griesemer --- diff --git a/src/math/big/int.go b/src/math/big/int.go index cb7221250d..4abfd19278 100644 --- a/src/math/big/int.go +++ b/src/math/big/int.go @@ -181,10 +181,14 @@ func (z *Int) Sub(x, y *Int) *Int { // Mul sets z to the product x*y and returns z. func (z *Int) Mul(x, y *Int) *Int { - return z.mul(nil, x, y) + z.mul(nil, x, y) + return z } -func (z *Int) mul(stk *stack, x, y *Int) *Int { +// mul is like Mul but takes an explicit stack to use, for internal use. +// It does not return a *Int because doing so makes the stack-allocated Ints +// used in natmul.go escape to the heap (even though the result is unused). +func (z *Int) mul(stk *stack, x, y *Int) { // x * y == x * y // x * (-y) == -(x * y) // (-x) * y == -(x * y) @@ -192,11 +196,10 @@ func (z *Int) mul(stk *stack, x, y *Int) *Int { if x == y { z.abs = z.abs.sqr(stk, x.abs) z.neg = false - return z + return } z.abs = z.abs.mul(stk, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign - return z } // MulRange sets z to the product of all integers diff --git a/src/math/big/nat.go b/src/math/big/nat.go index a091608f3e..922cdb4306 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -227,15 +227,14 @@ func alias(x, y nat) bool { return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] } -// addAt implements z += x<<(_W*i); z must be long enough. +// addTo implements z += x; z must be long enough. // (we don't use nat.add because we need z to stay the same // slice, and we don't need to normalize z after each addition) -func addAt(z, x nat, i int) { +func addTo(z, x nat) { if n := len(x); n > 0 { - if c := addVV(z[i:i+n], z[i:], x); c != 0 { - j := i + n - if j < len(z) { - addVW(z[j:], z[j:], c) + if c := addVV(z[:n], z, x); c != 0 { + if n < len(z) { + addVW(z[n:], z[n:], c) } } } diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index 1811dccfe3..251877b506 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -7,6 +7,8 @@ package big import ( "fmt" "math" + "math/bits" + "math/rand/v2" "runtime" "strings" "testing" @@ -56,12 +58,84 @@ var sumNN = []argNN{ {nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}}, } -var prodNN = []argNN{ - {}, - {nil, nil, nil}, +var prodNN = append(prodTests(), prodNNExtra...) + +func permute[E any](x []E) { + out := make([]E, len(x)) + for i, j := range rand.Perm(len(x)) { + out[i] = x[j] + } + copy(x, out) +} + +// testMul returns the product of x and y using the grade-school algorithm, +// as a reference implementation. +func testMul(x, y nat) nat { + z := make(nat, len(x)+len(y)) + for i, xi := range x { + for j, yj := range y { + hi, lo := bits.Mul(uint(xi), uint(yj)) + k := i + j + s, c := bits.Add(uint(z[k]), lo, 0) + z[k] = Word(s) + k++ + for hi != 0 || c != 0 { + s, c = bits.Add(uint(z[k]), hi, c) + hi = 0 + z[k] = Word(s) + k++ + } + } + } + return z.norm() +} + +func prodTests() []argNN { + var tests []argNN + for size := range 10 { + var x, y nat + for i := range size { + x = append(x, Word(i+1)) + y = append(y, Word(i+1+size)) + } + permute(x) + permute(y) + x = x.norm() + y = y.norm() + tests = append(tests, argNN{testMul(x, y), x, y}) + } + + words := []Word{0, 1, 2, 3, 4, ^Word(0), ^Word(1), ^Word(2), ^Word(3)} + for size := range 10 { + if size == 0 { + continue // already tested the only 0-length possibility above + } + for range 10 { + x := make(nat, size) + y := make(nat, size) + for i := range size { + x[i] = words[rand.N(len(words))] + y[i] = words[rand.N(len(words))] + } + x = x.norm() + y = y.norm() + tests = append(tests, argNN{testMul(x, y), x, y}) + } + } + return tests +} + +var prodNNExtra = []argNN{ {nil, nat{991}, nil}, {nat{991}, nat{991}, nat{1}}, {nat{991 * 991}, nat{991}, nat{991}}, + {nat{8, 22, 15}, nat{2, 3}, nat{4, 5}}, + {nat{10, 27, 52, 45, 28}, nat{2, 3, 4}, nat{5, 6, 7}}, + {nat{12, 32, 61, 100, 94, 76, 45}, nat{2, 3, 4, 5}, nat{6, 7, 8, 9}}, + {nat{12, 32, 61, 100, 94, 76, 45}, nat{2, 3, 4, 5}, nat{6, 7, 8, 9}}, + {nat{14, 37, 70, 114, 170, 166, 148, 115, 66}, nat{2, 3, 4, 5, 6}, nat{7, 8, 9, 10, 11}}, + {nat{991 * 991, 991 * 2, 1}, nat{991, 1}, nat{991, 1}}, + {nat{991 * 991, 991 * 777 * 2, 777 * 777}, nat{991, 777}, nat{991, 777}}, {nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}}, {nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}}, {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}}, @@ -114,38 +188,113 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) { } func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) { + t.Helper() stk := getStack() defer stk.free() z := f(nil, stk, a.x, a.y) if z.cmp(a.z) != 0 { - t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z) + t.Fatalf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z) } } -func TestFunNN(t *testing.T) { - for _, a := range sumNN { - arg := a - testFunNN(t, "add", nat.add, arg) +func setDuringTest[V any](t *testing.T, p *V, v V) { + old := *p + *p = v + t.Cleanup(func() { *p = old }) +} - arg = argNN{a.z, a.y, a.x} - testFunNN(t, "add symmetric", nat.add, arg) +func TestAdd(t *testing.T) { + for _, a := range sumNN { + testFunNN(t, "add", nat.add, a) + a.x, a.y = a.y, a.x + testFunNN(t, "add", nat.add, a) + } +} - arg = argNN{a.x, a.z, a.y} - testFunNN(t, "sub", nat.sub, arg) +func TestSub(t *testing.T) { + for _, a := range sumNN { + a.x, a.z = a.z, a.x + testFunNN(t, "sub", nat.sub, a) - arg = argNN{a.y, a.z, a.x} - testFunNN(t, "sub symmetric", nat.sub, arg) + a.y, a.z = a.z, a.y + testFunNN(t, "sub", nat.sub, a) } +} - for _, a := range prodNN { - arg := a - testFunSNN(t, "mul", nat.mul, arg) +func TestNatMul(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + setDuringTest(t, &karatsubaThreshold, 1e9) + for _, a := range prodNN { + if len(a.z) >= 100 { + continue + } + testFunSNN(t, "mul", nat.mul, a) + a.x, a.y = a.y, a.x + testFunSNN(t, "mul", nat.mul, a) + } + }) + t.Run("Karatsuba", func(t *testing.T) { + setDuringTest(t, &karatsubaThreshold, 2) + for _, a := range prodNN { + testFunSNN(t, "mul", nat.mul, a) + a.x, a.y = a.y, a.x + testFunSNN(t, "mul", nat.mul, a) + } + }) - arg = argNN{a.z, a.y, a.x} - testFunSNN(t, "mul symmetric", nat.mul, arg) + t.Run("Mul", func(t *testing.T) { + for _, a := range prodNN { + testFunSNN(t, "mul", nat.mul, a) + a.x, a.y = a.y, a.x + testFunSNN(t, "mul", nat.mul, a) + } + }) +} + +func testSqr(t *testing.T, x nat) { + stk := getStack() + defer stk.free() + + got := make(nat, 2*len(x)) + want := make(nat, 2*len(x)) + got = got.sqr(stk, x) + want = want.mul(stk, x, x) + if got.cmp(want) != 0 { + t.Errorf("basicSqr(%v), got %v, want %v", x, got, want) } } +func TestNatSqr(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + setDuringTest(t, &basicSqrThreshold, 0) + setDuringTest(t, &karatsubaSqrThreshold, 1e9) + for _, a := range prodNN { + if len(a.z) >= 100 { + continue + } + testSqr(t, a.x) + testSqr(t, a.y) + testSqr(t, a.z) + } + }) + t.Run("Karatsuba", func(t *testing.T) { + setDuringTest(t, &basicSqrThreshold, 2) + setDuringTest(t, &karatsubaSqrThreshold, 2) + for _, a := range prodNN { + testSqr(t, a.x) + testSqr(t, a.y) + testSqr(t, a.z) + } + }) + t.Run("Sqr", func(t *testing.T) { + for _, a := range prodNN { + testSqr(t, a.x) + testSqr(t, a.y) + testSqr(t, a.z) + } + }) +} + var mulRangesN = []struct { a, b uint64 prod string @@ -739,33 +888,6 @@ func TestSticky(t *testing.T) { } } -func testSqr(t *testing.T, x nat) { - stk := getStack() - defer stk.free() - - got := make(nat, 2*len(x)) - want := make(nat, 2*len(x)) - got = got.sqr(stk, x) - want = want.mul(stk, x, x) - if got.cmp(want) != 0 { - t.Errorf("basicSqr(%v), got %v, want %v", x, got, want) - } -} - -func TestSqr(t *testing.T) { - for _, a := range prodNN { - if a.x != nil { - testSqr(t, a.x) - } - if a.y != nil { - testSqr(t, a.y) - } - if a.z != nil { - testSqr(t, a.z) - } - } -} - func benchmarkNatSqr(b *testing.B, nwords int) { x := rndNat(nwords) var z nat diff --git a/src/math/big/natdiv.go b/src/math/big/natdiv.go index 7f6a1bbb07..b67d6afeda 100644 --- a/src/math/big/natdiv.go +++ b/src/math/big/natdiv.go @@ -831,7 +831,7 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { if len(qhatv) > s { subVW(qhatv[s:], qhatv[s:], c) } - addAt(uu[s:], v[s:], 0) + addTo(uu[s:], v[s:]) } if qhatv.cmp(uu.norm()) > 0 { panic("impossible") @@ -840,7 +840,7 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { if c > 0 { subVW(uu[len(qhatv):], uu[len(qhatv):], c) } - addAt(z, qhat, j-B) + addTo(z[j-B:], qhat) j -= B stk.restore(mark) } @@ -865,7 +865,7 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { if len(qhatv) > s { subVW(qhatv[s:], qhatv[s:], c) } - addAt(u[s:], v[s:], 0) + addTo(u[s:], v[s:]) } } if qhatv.cmp(u.norm()) > 0 { @@ -880,5 +880,5 @@ func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { } // Done! - addAt(z, qhat.norm(), 0) + addTo(z, qhat.norm()) } diff --git a/src/math/big/natmul.go b/src/math/big/natmul.go index 001e78d2c1..8ab4d13cba 100644 --- a/src/math/big/natmul.go +++ b/src/math/big/natmul.go @@ -31,71 +31,31 @@ func (z nat) mul(stk *stack, x, y nat) nat { if alias(z, x) || alias(z, y) { z = nil // z is an alias for x or y - cannot reuse } + z = z.make(m + n) // use basic multiplication if the numbers are small if n < karatsubaThreshold { - z = z.make(m + n) basicMul(z, x, y) return z.norm() } - // m >= n && n >= karatsubaThreshold && n >= 2 - - // determine Karatsuba length k such that - // - // x = xh*b + x0 (0 <= x0 < b) - // y = yh*b + y0 (0 <= y0 < b) - // b = 1<<(_W*k) ("base" of digits xi, yi) - // - k := karatsubaLen(n, karatsubaThreshold) - // k <= n if stk == nil { stk = getStack() defer stk.free() } - // multiply x0 and y0 via Karatsuba - x0 := x[0:k] // x0 is not normalized - y0 := y[0:k] // y0 is not normalized - z = z.make(m + n) // enough space for full result of x*y - karatsuba(stk, z, x0, y0) - clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) + // Let x = x1:x0 where x0 is the same length as y. + // Compute z = x0*y and then add in x1*y in sections + // if needed. + karatsuba(stk, z[:2*n], x[:n], y) - // If xh != 0 or yh != 0, add the missing terms to z. For - // - // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) - // yh = y1*b (0 <= y1 < b) - // - // the missing terms are - // - // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 - // - // since all the yi for i > 1 are 0 by choice of k: If any of them - // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would - // be a larger valid threshold contradicting the assumption about k. - // - if k < n || m != n { + if n < m { + clear(z[2*n:]) defer stk.restore(stk.save()) - t := stk.nat(3 * k) - - // add x0*y1*b - x0 := x0.norm() - y1 := y[k:] // y1 is normalized because y is - t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array - addAt(z, t, k) - - // add xi*y0< k { - xi = xi[:k] - } - xi = xi.norm() - t = t.mul(stk, xi, y0) - addAt(z, t, i) - t = t.mul(stk, xi, y1) - addAt(z, t, i+k) + t := stk.nat(2 * n) + for i := n; i < m; i += n { + t = t.mul(stk, x[i:min(i+n, len(x))], y) + addTo(z[i:], t) } } @@ -142,28 +102,7 @@ func (z nat) sqr(stk *stack, x nat) nat { return z.norm() } - // Use Karatsuba multiplication optimized for x == y. - // The algorithm and layout of z are the same as for mul. - - // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2 - - k := karatsubaLen(n, karatsubaSqrThreshold) - - x0 := x[0:k] - karatsubaSqr(stk, z, x0) // z = x0^2 - clear(z[2*k:]) - - if k < n { - t := stk.nat(2 * k) - x0 := x0.norm() - x1 := x[k:] - t = t.mul(stk, x0, x1) - addAt(z, t, k) - addAt(z, t, k) // z = 2*x1*x0*b + x0^2 - t = t.sqr(stk, x1) - addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 - } - + karatsubaSqr(stk, z, x) return z.norm() } @@ -188,6 +127,7 @@ func basicSqr(stk *stack, z, x nat) { addVV(z, z, t) // combine the result } +// mulAddWW returns z = x*y + r. func (z nat) mulAddWW(x nat, y, r Word) nat { m := len(x) if m == 0 || y == 0 { @@ -212,155 +152,199 @@ func basicMul(z, x, y nat) { } } -// karatsubaLen computes an approximation to the maximum k <= n such that -// k = p<= 0. Thus, the -// result is the largest number that can be divided repeatedly by 2 before -// becoming about the value of threshold. -func karatsubaLen(n, threshold int) int { - i := uint(0) - for n > threshold { - n >>= 1 - i++ - } - return n << i -} - -// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. -// Factored out for readability - do not use outside karatsuba. -func karatsubaAdd(z, x nat, n int) { - if c := addVV(z[0:n], z, x); c != 0 { - addVW(z[n:n+n>>1], z[n:], c) - } -} - -// Like karatsubaAdd, but does subtract. -func karatsubaSub(z, x nat, n int) { - if c := subVV(z[0:n], z, x); c != 0 { - subVW(z[n:n+n>>1], z[n:], c) - } -} - -// karatsuba multiplies x and y and leaves the result in z. -// Both x and y must have the same length n and n must be a -// power of 2. The result vector z must have len(z) == len(x)+len(y). -// The (non-normalized) result is placed in z. +// karatsuba multiplies x and y, +// writing the (non-normalized) result to z. +// x and y must have the same length n, +// and z must have length twice that. func karatsuba(stk *stack, z, x, y nat) { n := len(y) + if len(x) != n || len(z) != 2*n { + panic("bad karatsuba length") + } - // Switch to basic multiplication if numbers are odd or small. - // (n is always even if karatsubaThreshold is even, but be - // conservative) - if n&1 != 0 || n < karatsubaThreshold || n < 2 { + // Fall back to basic algorithm if small enough. + if n < karatsubaThreshold || n < 2 { basicMul(z, x, y) return } - // n&1 == 0 && n >= karatsubaThreshold && n >= 2 - // Karatsuba multiplication is based on the observation that - // for two numbers x and y with: + // Let the notation x1:x0 denote the nat (x1<> 1 // n2 >= 1 - x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 - y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 - - // compute z0 and z2 with the result "in place" in z - karatsuba(stk, z, x0, y0) // z0 = x0*y0 - karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1 - - // compute xd, yd (or the negative value if underflow occurs) - s := 1 // sign of product xd*yd - defer stk.restore(stk.save()) - xd := stk.nat(n2) - yd := stk.nat(n2) - if subVV(xd, x1, x0) != 0 { // x1-x0 - s = -s - subVV(xd, x0, x1) // x0-x1 - } - if subVV(yd, y0, y1) != 0 { // y0-y1 - s = -s - subVV(yd, y1, y0) // y1-y0 - } - - // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 - // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 - p := stk.nat(2 * n2) - karatsuba(stk, p, xd, yd) + // z0 = x0*y0 + // z2 = x1*y1 + // z1 = (x0-x1)*(y1-y0) + z0 + z2 - // save original z2:z0 - // (ok to use upper half of z since we're done recurring) - r := stk.nat(n * 2) - copy(r, z[:n*2]) + n2 := (n + 1) / 2 + x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()} + y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()} + z0 := &Int{abs: z[0 : 2*n2]} + z2 := &Int{abs: z[2*n2:]} - // add up all partial products - // - // 2*n n 0 - // z = [ z2 | z0 ] - // + [ z0 ] - // + [ z2 ] - // + [ p ] - // - karatsubaAdd(z[n2:], r, n) - karatsubaAdd(z[n2:], r[n:], n) - if s > 0 { - karatsubaAdd(z[n2:], p, n) - } else { - karatsubaSub(z[n2:], p, n) + // Allocate temporary storage for z1; repurpose z0 to hold tx and ty. + defer stk.restore(stk.save()) + z1 := &Int{abs: stk.nat(2*n2 + 1)} + tx := &Int{abs: z[0:n2]} + ty := &Int{abs: z[n2 : 2*n2]} + + tx.Sub(x0, x1) + ty.Sub(y1, y0) + z1.mul(stk, tx, ty) + + clear(z) + z0.mul(stk, x0, y0) + z2.mul(stk, x1, y1) + z1.Add(z1, z0) + z1.Add(z1, z2) + addTo(z[n2:], z1.abs) + + // Debug mode: double-check answer and print trace on failure. + const debug = false + if debug { + zz := make(nat, len(z)) + basicMul(zz, x, y) + if z.cmp(zz) != 0 { + // All the temps were aliased to z and gone. Recompute. + z0 = new(Int) + z0.mul(stk, x0, y0) + tx = new(Int).Sub(x1, x0) + ty = new(Int).Sub(y0, y1) + z2 = new(Int) + z2.mul(stk, x1, y1) + print("karatsuba wrong\n") + trace("x ", &Int{abs: x}) + trace("y ", &Int{abs: y}) + trace("z ", &Int{abs: z}) + trace("zz", &Int{abs: zz}) + trace("x0", x0) + trace("x1", x1) + trace("y0", y0) + trace("y1", y1) + trace("tx", tx) + trace("ty", ty) + trace("z0", z0) + trace("z1", z1) + trace("z2", z2) + panic("karatsuba") + } } + } -// karatsubaSqr squares x and leaves the result in z. -// len(x) must be a power of 2 and len(z) == 2*len(x). -// The (non-normalized) result is placed in z. -// -// The algorithm and the layout of z are the same as for karatsuba. +// karatsubaSqr squares x, +// writing the (non-normalized) result to z. +// z must have length 2*len(x). +// It is analogous to [karatsuba] but can run faster +// knowing both multiplicands are the same value. func karatsubaSqr(stk *stack, z, x nat) { n := len(x) + if len(z) != 2*n { + panic("bad karatsubaSqr length") + } - if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { - basicSqr(stk, z[:2*n], x) + if n < karatsubaSqrThreshold || n < 2 { + basicSqr(stk, z, x) return } - n2 := n >> 1 - x1, x0 := x[n2:], x[0:n2] + // Recall that for karatsuba we want to compute: + // + // x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0 + // = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0 + // = z2:z1:z0 + // where: + // + // z0 = x0y0 + // z2 = x1y1 + // z1 = (x0-x1)*(y1-y0) + z0 + z2 + // + // When x = y, these simplify to: + // + // z0 = x0² + // z2 = x1² + // z1 = z0 + z2 - (x0-x1)² - karatsubaSqr(stk, z, x0) - karatsubaSqr(stk, z[n:], x1) + n2 := (n + 1) / 2 + x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()} + z0 := &Int{abs: z[0 : 2*n2]} + z2 := &Int{abs: z[2*n2:]} - // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 + // Allocate temporary storage for z1; repurpose z0 to hold tx. defer stk.restore(stk.save()) - p := stk.nat(2 * n2) - r := stk.nat(n * 2) - xd := r[:n2] - if subVV(xd, x1, x0) != 0 { - subVV(xd, x0, x1) + z1 := &Int{abs: stk.nat(2*n2 + 1)} + tx := &Int{abs: z[0:n2]} + + tx.Sub(x0, x1) + z1.abs = z1.abs.sqr(stk, tx.abs) + z1.neg = true + + clear(z) + z0.abs = z0.abs.sqr(stk, x0.abs) + z2.abs = z2.abs.sqr(stk, x1.abs) + z1.Add(z1, z0) + z1.Add(z1, z2) + addTo(z[n2:], z1.abs) + + // Debug mode: double-check answer and print trace on failure. + const debug = false + if debug { + zz := make(nat, len(z)) + basicSqr(stk, zz, x) + if z.cmp(zz) != 0 { + // All the temps were aliased to z and gone. Recompute. + tx = new(Int).Sub(x0, x1) + z0 = new(Int).Mul(x0, x0) + z2 = new(Int).Mul(x1, x1) + z1 = new(Int).Mul(tx, tx) + z1.Neg(z1) + z1.Add(z1, z0) + z1.Add(z1, z2) + print("karatsubaSqr wrong\n") + trace("x ", &Int{abs: x}) + trace("z ", &Int{abs: z}) + trace("zz", &Int{abs: zz}) + trace("x0", x0) + trace("x1", x1) + trace("z0", z0) + trace("z1", z1) + trace("z2", z2) + panic("karatsubaSqr") + } + } +} + +// ifmt returns the debug formatting of the Int x: 0xHEX. +func ifmt(x *Int) string { + neg, s, t := "", x.Text(16), "" + if s == "" { // happens for denormalized zero + s = "0x0" + } + if s[0] == '-' { + neg, s = "-", s[1:] } - karatsubaSqr(stk, p, xd) - copy(r, z[:n*2]) + // Add _ between words. + const D = _W / 4 // digits per chunk + for len(s) > D { + s, t = s[:len(s)-D], s[len(s)-D:]+"_"+t + } + return neg + s + t +} - karatsubaAdd(z[n2:], r, n) - karatsubaAdd(z[n2:], r[n:], n) - karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0 +// trace prints a single debug value. +func trace(name string, x *Int) { + print(name, "=", ifmt(x), "\n") }