]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: specialize Karatsuba implementation for squaring
authorAlexander Döring <email@alexd.ch>
Fri, 11 May 2018 18:06:53 +0000 (20:06 +0200)
committerRobert Griesemer <gri@golang.org>
Wed, 23 May 2018 20:46:02 +0000 (20:46 +0000)
Currently we use three different algorithms for squaring:
1. basic multiplication for small numbers
2. basic squaring for medium numbers
3. Karatsuba multiplication for large numbers

Change 3. to a version of Karatsuba multiplication specialized
for x == y.

Increasing the performance of 3. lets us lower the threshold
between 2. and 3.

Adapt TestCalibrate to the change that 3. isn't independent
of the threshold between 1. and 2. any more.

Fixes #23221.

benchstat old.txt new.txt
name           old time/op  new time/op  delta
NatSqr/1-4     29.6ns ± 7%  29.5ns ± 5%     ~     (p=0.103 n=50+50)
NatSqr/2-4     51.9ns ± 1%  51.9ns ± 1%     ~     (p=0.693 n=42+49)
NatSqr/3-4     64.3ns ± 1%  64.1ns ± 0%   -0.26%  (p=0.000 n=46+43)
NatSqr/5-4     93.5ns ± 2%  93.1ns ± 1%   -0.39%  (p=0.000 n=48+49)
NatSqr/8-4      131ns ± 1%   131ns ± 1%     ~     (p=0.870 n=46+49)
NatSqr/10-4     175ns ± 1%   175ns ± 1%   +0.38%  (p=0.000 n=49+47)
NatSqr/20-4     426ns ± 1%   429ns ± 1%   +0.84%  (p=0.000 n=46+48)
NatSqr/30-4     702ns ± 2%   699ns ± 1%   -0.38%  (p=0.011 n=46+44)
NatSqr/50-4    1.44µs ± 2%  1.43µs ± 1%   -0.54%  (p=0.010 n=48+48)
NatSqr/80-4    2.85µs ± 1%  2.87µs ± 1%   +0.68%  (p=0.000 n=47+47)
NatSqr/100-4   4.06µs ± 1%  4.07µs ± 1%   +0.29%  (p=0.000 n=46+45)
NatSqr/200-4   13.4µs ± 1%  13.5µs ± 1%   +0.73%  (p=0.000 n=48+48)
NatSqr/300-4   28.5µs ± 1%  28.2µs ± 1%   -1.22%  (p=0.000 n=46+48)
NatSqr/500-4   81.9µs ± 1%  67.0µs ± 1%  -18.25%  (p=0.000 n=48+48)
NatSqr/800-4    161µs ± 1%   140µs ± 1%  -13.29%  (p=0.000 n=47+48)
NatSqr/1000-4   245µs ± 1%   207µs ± 1%  -15.17%  (p=0.000 n=49+49)

go test -v -calibrate --run TestCalibrate
...
Calibrating threshold between basicSqr(x) and karatsubaSqr(x)
Looking for a timing difference for x between 200 - 500 words by 10 step
words = 200 deltaT =     -980ns (  -7%) is karatsubaSqr(x) better: false
words = 210 deltaT =     -773ns (  -5%) is karatsubaSqr(x) better: false
words = 220 deltaT =     -695ns (  -4%) is karatsubaSqr(x) better: false
words = 230 deltaT =     -570ns (  -3%) is karatsubaSqr(x) better: false
words = 240 deltaT =     -458ns (  -2%) is karatsubaSqr(x) better: false
words = 250 deltaT =      -63ns (   0%) is karatsubaSqr(x) better: false
words = 260 deltaT =      118ns (   0%) is karatsubaSqr(x) better: true  threshold  found
words = 270 deltaT =      377ns (   1%) is karatsubaSqr(x) better: true
words = 280 deltaT =      765ns (   3%) is karatsubaSqr(x) better: true
words = 290 deltaT =      673ns (   2%) is karatsubaSqr(x) better: true
words = 300 deltaT =      502ns (   1%) is karatsubaSqr(x) better: true
words = 310 deltaT =      629ns (   2%) is karatsubaSqr(x) better: true
words = 320 deltaT =    1.011µs (   3%) is karatsubaSqr(x) better: true
words = 330 deltaT =     1.36µs (   4%) is karatsubaSqr(x) better: true
words = 340 deltaT =    3.001µs (   8%) is karatsubaSqr(x) better: true
words = 350 deltaT =    3.178µs (   8%) is karatsubaSqr(x) better: true
...

Change-Id: I6f13c23d94d042539ac28e77fd2618cdc37a429e
Reviewed-on: https://go-review.googlesource.com/105075
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/math/big/calibrate_test.go
src/math/big/nat.go
src/math/big/nat_test.go

index 2b96e74a6586fe1ca2afce1f6b1cedaebf80c81a..4fa663ff08331e850b0412edd336b388c7f35809 100644 (file)
@@ -28,24 +28,32 @@ import (
 
 var calibrate = flag.Bool("calibrate", false, "run calibration test")
 
+const (
+       sqrModeMul       = "mul(x, x)"
+       sqrModeBasic     = "basicSqr(x)"
+       sqrModeKaratsuba = "karatsubaSqr(x)"
+)
+
 func TestCalibrate(t *testing.T) {
-       if *calibrate {
-               computeKaratsubaThresholds()
-
-               // compute basicSqrThreshold where overhead becomes negligible
-               minSqr := computeSqrThreshold(10, 30, 1, 3)
-               // compute karatsubaSqrThreshold where karatsuba is faster
-               maxSqr := computeSqrThreshold(300, 500, 10, 3)
-               if minSqr != 0 {
-                       fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
-               } else {
-                       fmt.Println("no basicSqrThreshold found")
-               }
-               if maxSqr != 0 {
-                       fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
-               } else {
-                       fmt.Println("no karatsubaSqrThreshold found")
-               }
+       if !*calibrate {
+               return
+       }
+
+       computeKaratsubaThresholds()
+
+       // compute basicSqrThreshold where overhead becomes negligible
+       minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
+       // compute karatsubaSqrThreshold where karatsuba is faster
+       maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
+       if minSqr != 0 {
+               fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
+       } else {
+               fmt.Println("no basicSqrThreshold found")
+       }
+       if maxSqr != 0 {
+               fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
+       } else {
+               fmt.Println("no karatsubaSqrThreshold found")
        }
 }
 
@@ -109,16 +117,17 @@ func computeKaratsubaThresholds() {
        }
 }
 
-func measureBasicSqr(words, nruns int, enable bool) time.Duration {
+func measureSqr(words, nruns int, mode string) time.Duration {
        // more runs for better statistics
        initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
 
-       if enable {
-               // set thresholds to use basicSqr at this number of words
+       switch mode {
+       case sqrModeMul:
+               basicSqrThreshold = words + 1
+       case sqrModeBasic:
                basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
-       } else {
-               // set thresholds to disable basicSqr for any number of words
-               basicSqrThreshold, karatsubaSqrThreshold = -1, -1
+       case sqrModeKaratsuba:
+               karatsubaSqrThreshold = words - 1
        }
 
        var testval int64
@@ -133,18 +142,18 @@ func measureBasicSqr(words, nruns int, enable bool) time.Duration {
        return time.Duration(testval)
 }
 
-func computeSqrThreshold(from, to, step, nruns int) int {
-       fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
+func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
+       fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
        fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
        var initPos bool
        var threshold int
        for i := from; i <= to; i += step {
-               baseline := measureBasicSqr(i, nruns, false)
-               testval := measureBasicSqr(i, nruns, true)
+               baseline := measureSqr(i, nruns, lower)
+               testval := measureSqr(i, nruns, upper)
                pos := baseline > testval
                delta := baseline - testval
                percent := delta * 100 / baseline
-               fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
+               fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
                if i == from {
                        initPos = pos
                }
index 9ec81270a3f214bd084df6b8005a9c5b21c82bd4..dc292b4e7c96abb2e823259f170ede328af8030a 100644 (file)
@@ -388,12 +388,12 @@ func max(x, y int) int {
 }
 
 // karatsubaLen computes an approximation to the maximum k <= n such that
-// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
+// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
 // result is the largest number that can be divided repeatedly by 2 before
-// becoming about the value of karatsubaThreshold.
-func karatsubaLen(n int) int {
+// becoming about the value of threshold.
+func karatsubaLen(n, threshold int) int {
        i := uint(0)
-       for n > karatsubaThreshold {
+       for n > threshold {
                n >>= 1
                i++
        }
@@ -433,7 +433,7 @@ func (z nat) mul(x, y nat) nat {
        //   y = yh*b + y0  (0 <= y0 < b)
        //   b = 1<<(_W*k)  ("base" of digits xi, yi)
        //
-       k := karatsubaLen(n)
+       k := karatsubaLen(n, karatsubaThreshold)
        // k <= n
 
        // multiply x0 and y0 via Karatsuba
@@ -486,8 +486,8 @@ func (z nat) mul(x, y nat) nat {
 
 // basicSqr sets z = x*x and is asymptotically faster than basicMul
 // by about a factor of 2, but slower for small arguments due to overhead.
-// Requirements: len(x) > 0, len(z) >= 2*len(x)
-// The (non-normalized) result is placed in z[0 : 2 * len(x)].
+// Requirements: len(x) > 0, len(z) == 2*len(x)
+// The (non-normalized) result is placed in z.
 func basicSqr(z, x nat) {
        n := len(x)
        t := make(nat, 2*n)            // temporary variable to hold the products
@@ -503,11 +503,48 @@ func basicSqr(z, x nat) {
        addVV(z, z, t)                              // combine the result
 }
 
+// karatsubaSqr squares x and leaves the result in z.
+// len(x) must be a power of 2 and len(z) >= 6*len(x).
+// The (non-normalized) result is placed in z[0 : 2*len(x)].
+//
+// The algorithm and the layout of z are the same as for karatsuba.
+func karatsubaSqr(z, x nat) {
+       n := len(x)
+
+       if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
+               z = z[:2*n]
+               basicSqr(z, x)
+               return
+       }
+
+       n2 := n >> 1
+       x1, x0 := x[n2:], x[0:n2]
+
+       karatsubaSqr(z, x0)
+       karatsubaSqr(z[n:], x1)
+
+       // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
+       xd := z[2*n : 2*n+n2]
+       if subVV(xd, x1, x0) != 0 {
+               subVV(xd, x0, x1)
+       }
+
+       p := z[n*3:]
+       karatsubaSqr(p, xd)
+
+       r := z[n*4:]
+       copy(r, z[:n*2])
+
+       karatsubaAdd(z[n2:], r, n)
+       karatsubaAdd(z[n2:], r[n:], n)
+       karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
+}
+
 // Operands that are shorter than basicSqrThreshold are squared using
 // "grade school" multiplication; for operands longer than karatsubaSqrThreshold
-// the Karatsuba algorithm is used.
+// we use the Karatsuba algorithm optimized for x == y.
 var basicSqrThreshold = 20      // computed by calibrate_test.go
-var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
+var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
 
 // z = x*x
 func (z nat) sqr(x nat) nat {
@@ -536,7 +573,31 @@ func (z nat) sqr(x nat) nat {
                return z.norm()
        }
 
-       return z.mul(x, x)
+       // Use Karatsuba multiplication optimized for x == y.
+       // The algorithm and layout of z are the same as for mul.
+
+       // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
+
+       k := karatsubaLen(n, karatsubaSqrThreshold)
+
+       x0 := x[0:k]
+       z = z.make(max(6*k, 2*n))
+       karatsubaSqr(z, x0) // z = x0^2
+       z = z[0 : 2*n]
+       z[2*k:].clear()
+
+       if k < n {
+               var t nat
+               x0 := x0.norm()
+               x1 := x[k:]
+               t = t.mul(x0, x1)
+               addAt(z, t, k)
+               addAt(z, t, k) // z = 2*x1*x0*b + x0^2
+               t = t.sqr(x1)
+               addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
+       }
+
+       return z.norm()
 }
 
 // mulRange computes the product of all the unsigned integers in the
index 0b94db3476b73192ea649e79fbdd7fdf8c25387b..3c794954dc388486e0dbca08178b27b22d5b1f03 100644 (file)
@@ -648,26 +648,26 @@ func TestSticky(t *testing.T) {
        }
 }
 
-func testBasicSqr(t *testing.T, x nat) {
+func testSqr(t *testing.T, x nat) {
        got := make(nat, 2*len(x))
        want := make(nat, 2*len(x))
-       basicSqr(got, x)
-       basicMul(want, x, x)
+       got = got.sqr(x)
+       want = want.mul(x, x)
        if got.cmp(want) != 0 {
                t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
        }
 }
 
-func TestBasicSqr(t *testing.T) {
+func TestSqr(t *testing.T) {
        for _, a := range prodNN {
                if a.x != nil {
-                       testBasicSqr(t, a.x)
+                       testSqr(t, a.x)
                }
                if a.y != nil {
-                       testBasicSqr(t, a.y)
+                       testSqr(t, a.y)
                }
                if a.z != nil {
-                       testBasicSqr(t, a.z)
+                       testSqr(t, a.z)
                }
        }
 }