]> Cypherpunks repositories - gostls13.git/commitdiff
big: improved computation of "karatsuba length" for faster multiplies
authorRobert Griesemer <gri@golang.org>
Fri, 30 Apr 2010 18:54:27 +0000 (11:54 -0700)
committerRobert Griesemer <gri@golang.org>
Fri, 30 Apr 2010 18:54:27 +0000 (11:54 -0700)
This results in an improvement of > 35% for the existing Mul benchmark
using the same karatsuba threshold, and an improvement of > 50% with
a slightly higher threshold (32 instead of 30):

big.BenchmarkMul           500    6731846 ns/op (old alg.)
big.BenchmarkMul    500    4351122 ns/op (new alg.)
big.BenchmarkMul           500    3133782 ns/op (new alg., new theshold)

Also:
- tweaked calibrate.go, use same benchmark as for Mul benchmark

R=rsc
CC=golang-dev
https://golang.org/cl/1037041

src/pkg/big/calibrate_test.go
src/pkg/big/nat.go
src/pkg/big/nat_test.go

index 04da8af89179e6eccbf1dd30f2b4d76a412608db..c6cd2e693b742e83aa5f1c366659dcb34b299496 100644 (file)
@@ -2,7 +2,12 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// This file computes the Karatsuba threshold as a "test".
+// This file prints execution times for the Mul benchmark
+// given different Karatsuba thresholds. The result may be
+// used to manually fine-tune the threshold constant. The
+// results are somewhat fragile; use repeated runs to get
+// a clear picture.
+
 // Usage: gotest -calibrate
 
 package big
@@ -12,27 +17,13 @@ import (
        "fmt"
        "testing"
        "time"
-       "unsafe" // for Sizeof
 )
 
 
 var calibrate = flag.Bool("calibrate", false, "run calibration test")
 
 
-// makeNumber creates an n-word number 0xffff...ffff
-func makeNumber(n int) *Int {
-       var w Word
-       b := make([]byte, n*unsafe.Sizeof(w))
-       for i := range b {
-               b[i] = 0xff
-       }
-       var x Int
-       x.SetBytes(b)
-       return &x
-}
-
-
-// measure returns the time to compute x*x in nanoseconds
+// measure returns the time to run f
 func measure(f func()) int64 {
        const N = 100
        start := time.Nanoseconds()
@@ -44,48 +35,58 @@ func measure(f func()) int64 {
 }
 
 
-func computeThreshold(t *testing.T) int {
-       // use a mix of numbers as work load
-       x := make([]*Int, 20)
-       for i := range x {
-               x[i] = makeNumber(10 * (i + 1))
-       }
+func computeThresholds() {
+       fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
+       fmt.Printf("(run repeatedly for good results)\n")
 
-       threshold := -1
-       for n := 8; threshold < 0 || n <= threshold+20; n += 2 {
-               // set work load
-               f := func() {
-                       var t Int
-                       for _, x := range x {
-                               t.Mul(x, x)
-                       }
-               }
+       // determine Tk, the work load execution time using basic multiplication
+       karatsubaThreshold = 1e9 // disable karatsuba
+       Tb := measure(benchmarkMulLoad)
+       fmt.Printf("Tb = %dns\n", Tb)
 
-               karatsubaThreshold = 1e9 // disable karatsuba
-               t1 := measure(f)
+       // thresholds
+       n := 8 // any lower values for the threshold lead to very slow multiplies
+       th1 := -1
+       th2 := -1
 
+       var deltaOld int64
+       for count := -1; count != 0; count-- {
+               // determine Tk, the work load execution time using Karatsuba multiplication
                karatsubaThreshold = n // enable karatsuba
-               t2 := measure(f)
-
-               c := '<'
-               mark := ""
-               if t1 > t2 {
-                       c = '>'
-                       if threshold < 0 {
-                               threshold = n
-                               mark = " *"
-                       }
+               Tk := measure(benchmarkMulLoad)
+
+               // improvement over Tb
+               delta := (Tb - Tk) * 100 / Tb
+
+               fmt.Printf("n = %3d  Tk = %8dns  %4d%%", n, Tk, delta)
+
+               // determine break-even point
+               if Tk < Tb && th1 < 0 {
+                       th1 = n
+                       fmt.Print("  break-even point")
+               }
+
+               // determine diminishing return
+               if 0 < delta && delta < deltaOld && th2 < 0 {
+                       th2 = n
+                       fmt.Print("  diminishing return")
+               }
+               deltaOld = delta
+
+               fmt.Println()
+
+               // trigger counter
+               if th1 >= 0 && th2 >= 0 && count < 0 {
+                       count = 20 // this many extra measurements after we got both thresholds
                }
 
-               fmt.Printf("%4d: %8d %c %8d%s\n", n, t1, c, t2, mark)
+               n++
        }
-       return threshold
 }
 
 
 func TestCalibrate(t *testing.T) {
        if *calibrate {
-               fmt.Printf("Computing Karatsuba threshold\n")
-               fmt.Printf("threshold = %d\n", computeThreshold(t))
+               computeThresholds()
        }
 }
index 0675416e58bd07779047de5183415779b39581fa..2db9e59f8eb02b2de8fbfd9221eb77141f26ecb9 100644 (file)
@@ -253,7 +253,7 @@ func karatsubaSub(z, x nat, n int) {
 // Operands that are shorter than karatsubaThreshold are multiplied using
 // "grade school" multiplication; for longer operands the Karatsuba algorithm
 // is used.
-var karatsubaThreshold int = 30 // modified by calibrate.go
+var karatsubaThreshold int = 32 // computed by calibrate.go
 
 // karatsuba multiplies x and y and leaves the result in z.
 // Both x and y must have the same length n and n must be a
@@ -384,6 +384,20 @@ func max(x, y int) int {
 }
 
 
+// karatsubaLen computes an approximation to the maximum k <= n such that
+// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
+// result is the largest number that can be divided repeatedly by 2 before
+// becoming about the value of karatsubaThreshold.
+func karatsubaLen(n int) int {
+       i := uint(0)
+       for n > karatsubaThreshold {
+               n >>= 1
+               i++
+       }
+       return n << i
+}
+
+
 func (z nat) mul(x, y nat) nat {
        m := len(x)
        n := len(y)
@@ -411,17 +425,13 @@ func (z nat) mul(x, y nat) nat {
        }
        // m >= n && n >= karatsubaThreshold && n >= 2
 
-       // determine largest k such that
+       // determine Karatsuba length k such that
        //
        //   x = x1*b + x0
        //   y = y1*b + y0  (and k <= len(y), which implies k <= len(x))
        //   b = 1<<(_W*k)  ("base" of digits xi, yi)
        //
-       // and k is karatsubaThreshold multiplied by a power of 2
-       k := max(karatsubaThreshold, 2)
-       for k*2 <= n {
-               k *= 2
-       }
+       k := karatsubaLen(n)
        // k <= n
 
        // multiply x0 and y0 via Karatsuba
@@ -972,10 +982,8 @@ func (n nat) probablyPrime(reps int) bool {
 
                // We have to exclude these cases because we reject all
                // multiples of these numbers below.
-               if n[0] == 3 || n[0] == 5 || n[0] == 7 || n[0] == 11 ||
-                       n[0] == 13 || n[0] == 17 || n[0] == 19 || n[0] == 23 ||
-                       n[0] == 29 || n[0] == 31 || n[0] == 37 || n[0] == 41 ||
-                       n[0] == 43 || n[0] == 47 || n[0] == 53 {
+               switch n[0] {
+               case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
                        return true
                }
        }
index 52f712f66c84f4a94f01d3213aed216fcff27f4c..e1039c48a110641c80815202ae606c9e32d4d02e 100644 (file)
@@ -147,7 +147,7 @@ func TestMulRange(t *testing.T) {
 }
 
 
-var mulArg nat
+var mulArg, mulTmp nat
 
 func init() {
        const n = 1000
@@ -158,13 +158,17 @@ func init() {
 }
 
 
+func benchmarkMulLoad() {
+       for j := 1; j <= 10; j++ {
+               x := mulArg[0 : j*100]
+               mulTmp.mul(x, x)
+       }
+}
+
+
 func BenchmarkMul(b *testing.B) {
        for i := 0; i < b.N; i++ {
-               var t nat
-               for j := 1; j <= 10; j++ {
-                       x := mulArg[0 : j*100]
-                       t.mul(x, x)
-               }
+               benchmarkMulLoad()
        }
 }