From 1dc20e91247d6013dd5c299ffe2fde45524decba Mon Sep 17 00:00:00 2001 From: =?utf8?q?Alexander=20D=C3=B6ring?= Date: Fri, 11 May 2018 20:06:53 +0200 Subject: [PATCH] math/big: specialize Karatsuba implementation for squaring MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Currently we use three different algorithms for squaring: 1. basic multiplication for small numbers 2. basic squaring for medium numbers 3. Karatsuba multiplication for large numbers Change 3. to a version of Karatsuba multiplication specialized for x == y. Increasing the performance of 3. lets us lower the threshold between 2. and 3. Adapt TestCalibrate to the change that 3. isn't independent of the threshold between 1. and 2. any more. Fixes #23221. benchstat old.txt new.txt name old time/op new time/op delta NatSqr/1-4 29.6ns ± 7% 29.5ns ± 5% ~ (p=0.103 n=50+50) NatSqr/2-4 51.9ns ± 1% 51.9ns ± 1% ~ (p=0.693 n=42+49) NatSqr/3-4 64.3ns ± 1% 64.1ns ± 0% -0.26% (p=0.000 n=46+43) NatSqr/5-4 93.5ns ± 2% 93.1ns ± 1% -0.39% (p=0.000 n=48+49) NatSqr/8-4 131ns ± 1% 131ns ± 1% ~ (p=0.870 n=46+49) NatSqr/10-4 175ns ± 1% 175ns ± 1% +0.38% (p=0.000 n=49+47) NatSqr/20-4 426ns ± 1% 429ns ± 1% +0.84% (p=0.000 n=46+48) NatSqr/30-4 702ns ± 2% 699ns ± 1% -0.38% (p=0.011 n=46+44) NatSqr/50-4 1.44µs ± 2% 1.43µs ± 1% -0.54% (p=0.010 n=48+48) NatSqr/80-4 2.85µs ± 1% 2.87µs ± 1% +0.68% (p=0.000 n=47+47) NatSqr/100-4 4.06µs ± 1% 4.07µs ± 1% +0.29% (p=0.000 n=46+45) NatSqr/200-4 13.4µs ± 1% 13.5µs ± 1% +0.73% (p=0.000 n=48+48) NatSqr/300-4 28.5µs ± 1% 28.2µs ± 1% -1.22% (p=0.000 n=46+48) NatSqr/500-4 81.9µs ± 1% 67.0µs ± 1% -18.25% (p=0.000 n=48+48) NatSqr/800-4 161µs ± 1% 140µs ± 1% -13.29% (p=0.000 n=47+48) NatSqr/1000-4 245µs ± 1% 207µs ± 1% -15.17% (p=0.000 n=49+49) go test -v -calibrate --run TestCalibrate ... Calibrating threshold between basicSqr(x) and karatsubaSqr(x) Looking for a timing difference for x between 200 - 500 words by 10 step words = 200 deltaT = -980ns ( -7%) is karatsubaSqr(x) better: false words = 210 deltaT = -773ns ( -5%) is karatsubaSqr(x) better: false words = 220 deltaT = -695ns ( -4%) is karatsubaSqr(x) better: false words = 230 deltaT = -570ns ( -3%) is karatsubaSqr(x) better: false words = 240 deltaT = -458ns ( -2%) is karatsubaSqr(x) better: false words = 250 deltaT = -63ns ( 0%) is karatsubaSqr(x) better: false words = 260 deltaT = 118ns ( 0%) is karatsubaSqr(x) better: true threshold found words = 270 deltaT = 377ns ( 1%) is karatsubaSqr(x) better: true words = 280 deltaT = 765ns ( 3%) is karatsubaSqr(x) better: true words = 290 deltaT = 673ns ( 2%) is karatsubaSqr(x) better: true words = 300 deltaT = 502ns ( 1%) is karatsubaSqr(x) better: true words = 310 deltaT = 629ns ( 2%) is karatsubaSqr(x) better: true words = 320 deltaT = 1.011µs ( 3%) is karatsubaSqr(x) better: true words = 330 deltaT = 1.36µs ( 4%) is karatsubaSqr(x) better: true words = 340 deltaT = 3.001µs ( 8%) is karatsubaSqr(x) better: true words = 350 deltaT = 3.178µs ( 8%) is karatsubaSqr(x) better: true ... Change-Id: I6f13c23d94d042539ac28e77fd2618cdc37a429e Reviewed-on: https://go-review.googlesource.com/105075 Run-TryBot: Robert Griesemer TryBot-Result: Gobot Gobot Reviewed-by: Robert Griesemer --- src/math/big/calibrate_test.go | 65 +++++++++++++++------------ src/math/big/nat.go | 81 +++++++++++++++++++++++++++++----- src/math/big/nat_test.go | 14 +++--- 3 files changed, 115 insertions(+), 45 deletions(-) diff --git a/src/math/big/calibrate_test.go b/src/math/big/calibrate_test.go index 2b96e74a65..4fa663ff08 100644 --- a/src/math/big/calibrate_test.go +++ b/src/math/big/calibrate_test.go @@ -28,24 +28,32 @@ import ( var calibrate = flag.Bool("calibrate", false, "run calibration test") +const ( + sqrModeMul = "mul(x, x)" + sqrModeBasic = "basicSqr(x)" + sqrModeKaratsuba = "karatsubaSqr(x)" +) + func TestCalibrate(t *testing.T) { - if *calibrate { - computeKaratsubaThresholds() - - // compute basicSqrThreshold where overhead becomes negligible - minSqr := computeSqrThreshold(10, 30, 1, 3) - // compute karatsubaSqrThreshold where karatsuba is faster - maxSqr := computeSqrThreshold(300, 500, 10, 3) - if minSqr != 0 { - fmt.Printf("found basicSqrThreshold = %d\n", minSqr) - } else { - fmt.Println("no basicSqrThreshold found") - } - if maxSqr != 0 { - fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr) - } else { - fmt.Println("no karatsubaSqrThreshold found") - } + if !*calibrate { + return + } + + computeKaratsubaThresholds() + + // compute basicSqrThreshold where overhead becomes negligible + minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic) + // compute karatsubaSqrThreshold where karatsuba is faster + maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba) + if minSqr != 0 { + fmt.Printf("found basicSqrThreshold = %d\n", minSqr) + } else { + fmt.Println("no basicSqrThreshold found") + } + if maxSqr != 0 { + fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr) + } else { + fmt.Println("no karatsubaSqrThreshold found") } } @@ -109,16 +117,17 @@ func computeKaratsubaThresholds() { } } -func measureBasicSqr(words, nruns int, enable bool) time.Duration { +func measureSqr(words, nruns int, mode string) time.Duration { // more runs for better statistics initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold - if enable { - // set thresholds to use basicSqr at this number of words + switch mode { + case sqrModeMul: + basicSqrThreshold = words + 1 + case sqrModeBasic: basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1 - } else { - // set thresholds to disable basicSqr for any number of words - basicSqrThreshold, karatsubaSqrThreshold = -1, -1 + case sqrModeKaratsuba: + karatsubaSqrThreshold = words - 1 } var testval int64 @@ -133,18 +142,18 @@ func measureBasicSqr(words, nruns int, enable bool) time.Duration { return time.Duration(testval) } -func computeSqrThreshold(from, to, step, nruns int) int { - fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)") +func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int { + fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper) fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step) var initPos bool var threshold int for i := from; i <= to; i += step { - baseline := measureBasicSqr(i, nruns, false) - testval := measureBasicSqr(i, nruns, true) + baseline := measureSqr(i, nruns, lower) + testval := measureSqr(i, nruns, upper) pos := baseline > testval delta := baseline - testval percent := delta * 100 / baseline - fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos) + fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos) if i == from { initPos = pos } diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 9ec81270a3..dc292b4e7c 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -388,12 +388,12 @@ func max(x, y int) int { } // karatsubaLen computes an approximation to the maximum k <= n such that -// k = p<= 0. Thus, the +// k = p<= 0. Thus, the // result is the largest number that can be divided repeatedly by 2 before -// becoming about the value of karatsubaThreshold. -func karatsubaLen(n int) int { +// becoming about the value of threshold. +func karatsubaLen(n, threshold int) int { i := uint(0) - for n > karatsubaThreshold { + for n > threshold { n >>= 1 i++ } @@ -433,7 +433,7 @@ func (z nat) mul(x, y nat) nat { // y = yh*b + y0 (0 <= y0 < b) // b = 1<<(_W*k) ("base" of digits xi, yi) // - k := karatsubaLen(n) + k := karatsubaLen(n, karatsubaThreshold) // k <= n // multiply x0 and y0 via Karatsuba @@ -486,8 +486,8 @@ func (z nat) mul(x, y nat) nat { // basicSqr sets z = x*x and is asymptotically faster than basicMul // by about a factor of 2, but slower for small arguments due to overhead. -// Requirements: len(x) > 0, len(z) >= 2*len(x) -// The (non-normalized) result is placed in z[0 : 2 * len(x)]. +// Requirements: len(x) > 0, len(z) == 2*len(x) +// The (non-normalized) result is placed in z. func basicSqr(z, x nat) { n := len(x) t := make(nat, 2*n) // temporary variable to hold the products @@ -503,11 +503,48 @@ func basicSqr(z, x nat) { addVV(z, z, t) // combine the result } +// karatsubaSqr squares x and leaves the result in z. +// len(x) must be a power of 2 and len(z) >= 6*len(x). +// The (non-normalized) result is placed in z[0 : 2*len(x)]. +// +// The algorithm and the layout of z are the same as for karatsuba. +func karatsubaSqr(z, x nat) { + n := len(x) + + if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { + z = z[:2*n] + basicSqr(z, x) + return + } + + n2 := n >> 1 + x1, x0 := x[n2:], x[0:n2] + + karatsubaSqr(z, x0) + karatsubaSqr(z[n:], x1) + + // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 + xd := z[2*n : 2*n+n2] + if subVV(xd, x1, x0) != 0 { + subVV(xd, x0, x1) + } + + p := z[n*3:] + karatsubaSqr(p, xd) + + r := z[n*4:] + copy(r, z[:n*2]) + + karatsubaAdd(z[n2:], r, n) + karatsubaAdd(z[n2:], r[n:], n) + karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0 +} + // Operands that are shorter than basicSqrThreshold are squared using // "grade school" multiplication; for operands longer than karatsubaSqrThreshold -// the Karatsuba algorithm is used. +// we use the Karatsuba algorithm optimized for x == y. var basicSqrThreshold = 20 // computed by calibrate_test.go -var karatsubaSqrThreshold = 400 // computed by calibrate_test.go +var karatsubaSqrThreshold = 260 // computed by calibrate_test.go // z = x*x func (z nat) sqr(x nat) nat { @@ -536,7 +573,31 @@ func (z nat) sqr(x nat) nat { return z.norm() } - return z.mul(x, x) + // Use Karatsuba multiplication optimized for x == y. + // The algorithm and layout of z are the same as for mul. + + // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2 + + k := karatsubaLen(n, karatsubaSqrThreshold) + + x0 := x[0:k] + z = z.make(max(6*k, 2*n)) + karatsubaSqr(z, x0) // z = x0^2 + z = z[0 : 2*n] + z[2*k:].clear() + + if k < n { + var t nat + x0 := x0.norm() + x1 := x[k:] + t = t.mul(x0, x1) + addAt(z, t, k) + addAt(z, t, k) // z = 2*x1*x0*b + x0^2 + t = t.sqr(x1) + addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 + } + + return z.norm() } // mulRange computes the product of all the unsigned integers in the diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index 0b94db3476..3c794954dc 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -648,26 +648,26 @@ func TestSticky(t *testing.T) { } } -func testBasicSqr(t *testing.T, x nat) { +func testSqr(t *testing.T, x nat) { got := make(nat, 2*len(x)) want := make(nat, 2*len(x)) - basicSqr(got, x) - basicMul(want, x, x) + got = got.sqr(x) + want = want.mul(x, x) if got.cmp(want) != 0 { t.Errorf("basicSqr(%v), got %v, want %v", x, got, want) } } -func TestBasicSqr(t *testing.T) { +func TestSqr(t *testing.T) { for _, a := range prodNN { if a.x != nil { - testBasicSqr(t, a.x) + testSqr(t, a.x) } if a.y != nil { - testBasicSqr(t, a.y) + testSqr(t, a.y) } if a.z != nil { - testBasicSqr(t, a.z) + testSqr(t, a.z) } } } -- 2.50.0