// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// Calibration used to determine thresholds for using
+// different algorithms. Ideally, this would be converted
+// to go generate to create thresholds.go
+
// This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get
// a clear picture.
-// Usage: go test -run=TestCalibrate -calibrate
+// Calculates lower and upper thresholds for when basicSqr
+// is faster than standard multiplication.
+
+// Usage: go test -run=TestCalibrate -v -calibrate
package big
var calibrate = flag.Bool("calibrate", false, "run calibration test")
+func TestCalibrate(t *testing.T) {
+ if *calibrate {
+ computeKaratsubaThresholds()
+
+ // compute basicSqrThreshold where overhead becomes neglible
+ minSqr := computeSqrThreshold(10, 30, 1, 3)
+ // compute karatsubaSqrThreshold where karatsuba is faster
+ maxSqr := computeSqrThreshold(300, 500, 10, 3)
+ if minSqr != 0 {
+ fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
+ } else {
+ fmt.Println("no basicSqrThreshold found")
+ }
+ if maxSqr != 0 {
+ fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
+ } else {
+ fmt.Println("no karatsubaSqrThreshold found")
+ }
+ }
+}
+
func karatsubaLoad(b *testing.B) {
BenchmarkMul(b)
}
return time.Duration(res.NsPerOp())
}
-func computeThresholds() {
+func computeKaratsubaThresholds() {
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
fmt.Printf("(run repeatedly for good results)\n")
}
}
-func TestCalibrate(t *testing.T) {
- if *calibrate {
- computeThresholds()
+func measureBasicSqr(words, nruns int, enable bool) time.Duration {
+ // more runs for better statistics
+ initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
+
+ if enable {
+ // set thresholds to use basicSqr at this number of words
+ basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
+ } else {
+ // set thresholds to disable basicSqr for any number of words
+ basicSqrThreshold, karatsubaSqrThreshold = -1, -1
+ }
+
+ var testval int64
+ for i := 0; i < nruns; i++ {
+ res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
+ testval += res.NsPerOp()
+ }
+ testval /= int64(nruns)
+
+ basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
+
+ return time.Duration(testval)
+}
+
+func computeSqrThreshold(from, to, step, nruns int) int {
+ fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
+ fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
+ var initPos bool
+ var threshold int
+ for i := from; i <= to; i += step {
+ baseline := measureBasicSqr(i, nruns, false)
+ testval := measureBasicSqr(i, nruns, true)
+ pos := baseline > testval
+ delta := baseline - testval
+ percent := delta * 100 / baseline
+ fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
+ if i == from {
+ initPos = pos
+ }
+ if threshold == 0 && pos != initPos {
+ threshold = i
+ fmt.Printf(" threshold found")
+ }
+ fmt.Println()
+
+ }
+ if threshold != 0 {
+ fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
+ } else {
+ fmt.Printf("Found NO threshold between %d - %d\n", from, to)
}
+ return threshold
}
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
-var karatsubaThreshold int = 40 // computed by calibrate.go
+var karatsubaThreshold = 40 // computed by calibrate_test.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
return z.norm()
}
+// basicSqr sets z = x*x and is asymptotically faster than basicMul
+// by about a factor of 2, but slower for small arguments due to overhead.
+// Requirements: len(x) > 0, len(z) >= 2*len(x)
+// The (non-normalized) result is placed in z[0 : 2 * len(x)].
+func basicSqr(z, x nat) {
+ n := len(x)
+ t := make(nat, 2*n) // temporary variable to hold the products
+ z[1], z[0] = mulWW(x[0], x[0]) // the initial square
+ for i := 1; i < n; i++ {
+ d := x[i]
+ // z collects the squares x[i] * x[i]
+ z[2*i+1], z[2*i] = mulWW(d, d)
+ // t collects the products x[i] * x[j] where j < i
+ t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
+ }
+ t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
+ addVV(z, z, t) // combine the result
+}
+
+// Operands that are shorter than basicSqrThreshold are squared using
+// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
+// the Karatsuba algorithm is used.
+var basicSqrThreshold = 20 // computed by calibrate_test.go
+var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
+
+// z = x*x
+func (z nat) sqr(x nat) nat {
+ n := len(x)
+ switch {
+ case n == 0:
+ return z[:0]
+ case n == 1:
+ d := x[0]
+ z = z.make(2)
+ z[1], z[0] = mulWW(d, d)
+ return z.norm()
+ }
+
+ if alias(z, x) {
+ z = nil // z is an alias for x - cannot reuse
+ }
+ z = z.make(2 * n)
+
+ if n < basicSqrThreshold {
+ basicMul(z, x, x)
+ return z.norm()
+ }
+ if n < karatsubaSqrThreshold {
+ basicSqr(z, x)
+ return z.norm()
+ }
+
+ return z.mul(x, x)
+}
+
// mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1.
func (z nat) mulRange(a, b uint64) nat {
}
}
}
+
+func testBasicSqr(t *testing.T, x nat) {
+ got := make(nat, 2*len(x))
+ want := make(nat, 2*len(x))
+ basicSqr(got, x)
+ basicMul(want, x, x)
+ if got.cmp(want) != 0 {
+ t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
+ }
+}
+
+func TestBasicSqr(t *testing.T) {
+ for _, a := range prodNN {
+ if a.x != nil {
+ testBasicSqr(t, a.x)
+ }
+ if a.y != nil {
+ testBasicSqr(t, a.y)
+ }
+ if a.z != nil {
+ testBasicSqr(t, a.z)
+ }
+ }
+}
+
+func benchmarkNatSqr(b *testing.B, nwords int) {
+ x := rndNat(nwords)
+ var z nat
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ z.sqr(x)
+ }
+}
+
+var sqrBenchSizes = []int{1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100, 200, 300, 500, 800, 1000}
+
+func BenchmarkNatSqr(b *testing.B) {
+ for _, n := range sqrBenchSizes {
+ if isRaceBuilder && n > 1e3 {
+ continue
+ }
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ benchmarkNatSqr(b, n)
+ })
+ }
+}