]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: recognize z.Mul(x, x) as squaring of x
authorBrian Kessler <brian.m.kessler@gmail.com>
Sat, 29 Jul 2017 05:29:30 +0000 (22:29 -0700)
committerRobert Griesemer <gri@golang.org>
Wed, 16 Aug 2017 10:07:47 +0000 (10:07 +0000)
updates #13745

Multiprecision squaring can be done in a straightforward manner
with about half the multiplications of a basic multiplication
due to the symmetry of the operands.  This change implements
basic squaring for nat types and uses it for Int multiplication
when the same variable is supplied to both arguments of
z.Mul(x, x). This has some overhead to allocate a temporary
variable to hold the cross products, shift them to double and
add them to the diagonal terms.  There is a speed benefit in
the intermediate range when the overhead is neglible and the
asymptotic performance of karatsuba multiplication has not been
reached.

basicSqrThreshold = 20
karatsubaSqrThreshold = 400

Were set by running calibrate_test.go to measure timing differences
between the algorithms.  Benchmarks for squaring:

name           old time/op  new time/op  delta
IntSqr/1-4     51.5ns ±25%  25.1ns ± 7%  -51.38%  (p=0.008 n=5+5)
IntSqr/2-4     79.1ns ± 4%  72.4ns ± 2%   -8.47%  (p=0.008 n=5+5)
IntSqr/3-4      102ns ± 4%    97ns ± 5%     ~     (p=0.056 n=5+5)
IntSqr/5-4      161ns ± 4%   163ns ± 7%     ~     (p=0.952 n=5+5)
IntSqr/8-4      277ns ± 5%   267ns ± 6%     ~     (p=0.087 n=5+5)
IntSqr/10-4     358ns ± 3%   360ns ± 4%     ~     (p=0.730 n=5+5)
IntSqr/20-4    1.07µs ± 3%  1.01µs ± 6%     ~     (p=0.056 n=5+5)
IntSqr/30-4    2.36µs ± 4%  1.72µs ± 2%  -27.03%  (p=0.008 n=5+5)
IntSqr/50-4    5.19µs ± 3%  3.88µs ± 4%  -25.37%  (p=0.008 n=5+5)
IntSqr/80-4    11.3µs ± 4%   8.6µs ± 3%  -23.78%  (p=0.008 n=5+5)
IntSqr/100-4   16.2µs ± 4%  12.8µs ± 3%  -21.49%  (p=0.008 n=5+5)
IntSqr/200-4   50.1µs ± 5%  44.7µs ± 3%  -10.65%  (p=0.008 n=5+5)
IntSqr/300-4    105µs ±11%    95µs ± 3%   -9.50%  (p=0.008 n=5+5)
IntSqr/500-4    231µs ± 5%   227µs ± 2%     ~     (p=0.310 n=5+5)
IntSqr/800-4    496µs ± 9%   459µs ± 3%   -7.40%  (p=0.016 n=5+5)
IntSqr/1000-4   700µs ± 3%   710µs ± 5%     ~     (p=0.841 n=5+5)

Show a speed up of 10-25% in the range where basicSqr is optimal,
improved single word squaring and no significant difference when
the fallback to standard multiplication is used.

Change-Id: Iae2c82ca91cf890823f91e5c83bbe9a2c534b72b
Reviewed-on: https://go-review.googlesource.com/53638
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/math/big/calibrate_test.go
src/math/big/int.go
src/math/big/int_test.go
src/math/big/nat.go
src/math/big/nat_test.go

index f69ffbf5cf13b8d6f0e8b0190038c3d868298872..11ce064c156527c8139982e88492e16f34c57e27 100644 (file)
@@ -2,13 +2,20 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
+// Calibration used to determine thresholds for using
+// different algorithms.  Ideally, this would be converted
+// to go generate to create thresholds.go
+
 // This file prints execution times for the Mul benchmark
 // given different Karatsuba thresholds. The result may be
 // used to manually fine-tune the threshold constant. The
 // results are somewhat fragile; use repeated runs to get
 // a clear picture.
 
-// Usage: go test -run=TestCalibrate -calibrate
+// Calculates lower and upper thresholds for when basicSqr
+// is faster than standard multiplication.
+
+// Usage: go test -run=TestCalibrate -v -calibrate
 
 package big
 
@@ -21,6 +28,27 @@ import (
 
 var calibrate = flag.Bool("calibrate", false, "run calibration test")
 
+func TestCalibrate(t *testing.T) {
+       if *calibrate {
+               computeKaratsubaThresholds()
+
+               // compute basicSqrThreshold where overhead becomes neglible
+               minSqr := computeSqrThreshold(10, 30, 1, 3)
+               // compute karatsubaSqrThreshold where karatsuba is faster
+               maxSqr := computeSqrThreshold(300, 500, 10, 3)
+               if minSqr != 0 {
+                       fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
+               } else {
+                       fmt.Println("no basicSqrThreshold found")
+               }
+               if maxSqr != 0 {
+                       fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
+               } else {
+                       fmt.Println("no karatsubaSqrThreshold found")
+               }
+       }
+}
+
 func karatsubaLoad(b *testing.B) {
        BenchmarkMul(b)
 }
@@ -34,7 +62,7 @@ func measureKaratsuba(th int) time.Duration {
        return time.Duration(res.NsPerOp())
 }
 
-func computeThresholds() {
+func computeKaratsubaThresholds() {
        fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
        fmt.Printf("(run repeatedly for good results)\n")
 
@@ -81,8 +109,56 @@ func computeThresholds() {
        }
 }
 
-func TestCalibrate(t *testing.T) {
-       if *calibrate {
-               computeThresholds()
+func measureBasicSqr(words, nruns int, enable bool) time.Duration {
+       // more runs for better statistics
+       initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
+
+       if enable {
+               // set thresholds to use basicSqr at this number of words
+               basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
+       } else {
+               // set thresholds to disable basicSqr for any number of words
+               basicSqrThreshold, karatsubaSqrThreshold = -1, -1
+       }
+
+       var testval int64
+       for i := 0; i < nruns; i++ {
+               res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
+               testval += res.NsPerOp()
+       }
+       testval /= int64(nruns)
+
+       basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
+
+       return time.Duration(testval)
+}
+
+func computeSqrThreshold(from, to, step, nruns int) int {
+       fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
+       fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
+       var initPos bool
+       var threshold int
+       for i := from; i <= to; i += step {
+               baseline := measureBasicSqr(i, nruns, false)
+               testval := measureBasicSqr(i, nruns, true)
+               pos := baseline > testval
+               delta := baseline - testval
+               percent := delta * 100 / baseline
+               fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
+               if i == from {
+                       initPos = pos
+               }
+               if threshold == 0 && pos != initPos {
+                       threshold = i
+                       fmt.Printf("  threshold  found")
+               }
+               fmt.Println()
+
+       }
+       if threshold != 0 {
+               fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
+       } else {
+               fmt.Printf("Found NO threshold between %d - %d\n", from, to)
        }
+       return threshold
 }
index 52b6423dfe25ee57fec4e5cc80009b42d5121292..63a750cb96907eee9fd8dfe12d9616718964a390 100644 (file)
@@ -153,6 +153,11 @@ func (z *Int) Mul(x, y *Int) *Int {
        // x * (-y) == -(x * y)
        // (-x) * y == -(x * y)
        // (-x) * (-y) == x * y
+       if x == y {
+               z.abs = z.abs.sqr(x.abs)
+               z.neg = false
+               return z
+       }
        z.abs = z.abs.mul(x.abs, y.abs)
        z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
        return z
index 42e810b3b81c6ceb2701b1e928a5ac3c6375ce3c..65e24f1e4bfb99f1d8e84bea32f99317d9bb1703 100644 (file)
@@ -7,6 +7,7 @@ package big
 import (
        "bytes"
        "encoding/hex"
+       "fmt"
        "math/rand"
        "strconv"
        "strings"
@@ -1544,3 +1545,24 @@ func BenchmarkSqrt(b *testing.B) {
                t.Sqrt(n)
        }
 }
+
+func benchmarkIntSqr(b *testing.B, nwords int) {
+       x := new(Int)
+       x.abs = rndNat(nwords)
+       t := new(Int)
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               t.Mul(x, x)
+       }
+}
+
+func BenchmarkIntSqr(b *testing.B) {
+       for _, n := range sqrBenchSizes {
+               if isRaceBuilder && n > 1e3 {
+                       continue
+               }
+               b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+                       benchmarkIntSqr(b, n)
+               })
+       }
+}
index 889eacb90f03c12911187d22d0f705beb12c81e5..3b5c0f6e72fe11d73516fa3d68f134873f8332ea 100644 (file)
@@ -249,7 +249,7 @@ func karatsubaSub(z, x nat, n int) {
 // Operands that are shorter than karatsubaThreshold are multiplied using
 // "grade school" multiplication; for longer operands the Karatsuba algorithm
 // is used.
-var karatsubaThreshold int = 40 // computed by calibrate.go
+var karatsubaThreshold = 40 // computed by calibrate_test.go
 
 // karatsuba multiplies x and y and leaves the result in z.
 // Both x and y must have the same length n and n must be a
@@ -473,6 +473,61 @@ func (z nat) mul(x, y nat) nat {
        return z.norm()
 }
 
+// basicSqr sets z = x*x and is asymptotically faster than basicMul
+// by about a factor of 2, but slower for small arguments due to overhead.
+// Requirements: len(x) > 0, len(z) >= 2*len(x)
+// The (non-normalized) result is placed in z[0 : 2 * len(x)].
+func basicSqr(z, x nat) {
+       n := len(x)
+       t := make(nat, 2*n)            // temporary variable to hold the products
+       z[1], z[0] = mulWW(x[0], x[0]) // the initial square
+       for i := 1; i < n; i++ {
+               d := x[i]
+               // z collects the squares x[i] * x[i]
+               z[2*i+1], z[2*i] = mulWW(d, d)
+               // t collects the products x[i] * x[j] where j < i
+               t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
+       }
+       t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
+       addVV(z, z, t)                              // combine the result
+}
+
+// Operands that are shorter than basicSqrThreshold are squared using
+// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
+// the Karatsuba algorithm is used.
+var basicSqrThreshold = 20      // computed by calibrate_test.go
+var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
+
+// z = x*x
+func (z nat) sqr(x nat) nat {
+       n := len(x)
+       switch {
+       case n == 0:
+               return z[:0]
+       case n == 1:
+               d := x[0]
+               z = z.make(2)
+               z[1], z[0] = mulWW(d, d)
+               return z.norm()
+       }
+
+       if alias(z, x) {
+               z = nil // z is an alias for x - cannot reuse
+       }
+       z = z.make(2 * n)
+
+       if n < basicSqrThreshold {
+               basicMul(z, x, x)
+               return z.norm()
+       }
+       if n < karatsubaSqrThreshold {
+               basicSqr(z, x)
+               return z.norm()
+       }
+
+       return z.mul(x, x)
+}
+
 // mulRange computes the product of all the unsigned integers in the
 // range [a, b] inclusively. If a > b (empty range), the result is 1.
 func (z nat) mulRange(a, b uint64) nat {
index 200a247f513aaaf1372d045cdb9273f412a776dc..c25cdf00a3df5f66531e83f44b8ed36380ef4aae 100644 (file)
@@ -619,3 +619,49 @@ func TestSticky(t *testing.T) {
                }
        }
 }
+
+func testBasicSqr(t *testing.T, x nat) {
+       got := make(nat, 2*len(x))
+       want := make(nat, 2*len(x))
+       basicSqr(got, x)
+       basicMul(want, x, x)
+       if got.cmp(want) != 0 {
+               t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
+       }
+}
+
+func TestBasicSqr(t *testing.T) {
+       for _, a := range prodNN {
+               if a.x != nil {
+                       testBasicSqr(t, a.x)
+               }
+               if a.y != nil {
+                       testBasicSqr(t, a.y)
+               }
+               if a.z != nil {
+                       testBasicSqr(t, a.z)
+               }
+       }
+}
+
+func benchmarkNatSqr(b *testing.B, nwords int) {
+       x := rndNat(nwords)
+       var z nat
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               z.sqr(x)
+       }
+}
+
+var sqrBenchSizes = []int{1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100, 200, 300, 500, 800, 1000}
+
+func BenchmarkNatSqr(b *testing.B) {
+       for _, n := range sqrBenchSizes {
+               if isRaceBuilder && n > 1e3 {
+                       continue
+               }
+               b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+                       benchmarkNatSqr(b, n)
+               })
+       }
+}