var calibrate = flag.Bool("calibrate", false, "run calibration test")
+const (
+ sqrModeMul = "mul(x, x)"
+ sqrModeBasic = "basicSqr(x)"
+ sqrModeKaratsuba = "karatsubaSqr(x)"
+)
+
func TestCalibrate(t *testing.T) {
- if *calibrate {
- computeKaratsubaThresholds()
-
- // compute basicSqrThreshold where overhead becomes negligible
- minSqr := computeSqrThreshold(10, 30, 1, 3)
- // compute karatsubaSqrThreshold where karatsuba is faster
- maxSqr := computeSqrThreshold(300, 500, 10, 3)
- if minSqr != 0 {
- fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
- } else {
- fmt.Println("no basicSqrThreshold found")
- }
- if maxSqr != 0 {
- fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
- } else {
- fmt.Println("no karatsubaSqrThreshold found")
- }
+ if !*calibrate {
+ return
+ }
+
+ computeKaratsubaThresholds()
+
+ // compute basicSqrThreshold where overhead becomes negligible
+ minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
+ // compute karatsubaSqrThreshold where karatsuba is faster
+ maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
+ if minSqr != 0 {
+ fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
+ } else {
+ fmt.Println("no basicSqrThreshold found")
+ }
+ if maxSqr != 0 {
+ fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
+ } else {
+ fmt.Println("no karatsubaSqrThreshold found")
}
}
}
}
-func measureBasicSqr(words, nruns int, enable bool) time.Duration {
+func measureSqr(words, nruns int, mode string) time.Duration {
// more runs for better statistics
initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
- if enable {
- // set thresholds to use basicSqr at this number of words
+ switch mode {
+ case sqrModeMul:
+ basicSqrThreshold = words + 1
+ case sqrModeBasic:
basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
- } else {
- // set thresholds to disable basicSqr for any number of words
- basicSqrThreshold, karatsubaSqrThreshold = -1, -1
+ case sqrModeKaratsuba:
+ karatsubaSqrThreshold = words - 1
}
var testval int64
return time.Duration(testval)
}
-func computeSqrThreshold(from, to, step, nruns int) int {
- fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
+func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
+ fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
var initPos bool
var threshold int
for i := from; i <= to; i += step {
- baseline := measureBasicSqr(i, nruns, false)
- testval := measureBasicSqr(i, nruns, true)
+ baseline := measureSqr(i, nruns, lower)
+ testval := measureSqr(i, nruns, upper)
pos := baseline > testval
delta := baseline - testval
percent := delta * 100 / baseline
- fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
+ fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
if i == from {
initPos = pos
}
}
// karatsubaLen computes an approximation to the maximum k <= n such that
-// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
+// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
// result is the largest number that can be divided repeatedly by 2 before
-// becoming about the value of karatsubaThreshold.
-func karatsubaLen(n int) int {
+// becoming about the value of threshold.
+func karatsubaLen(n, threshold int) int {
i := uint(0)
- for n > karatsubaThreshold {
+ for n > threshold {
n >>= 1
i++
}
// y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
- k := karatsubaLen(n)
+ k := karatsubaLen(n, karatsubaThreshold)
// k <= n
// multiply x0 and y0 via Karatsuba
// basicSqr sets z = x*x and is asymptotically faster than basicMul
// by about a factor of 2, but slower for small arguments due to overhead.
-// Requirements: len(x) > 0, len(z) >= 2*len(x)
-// The (non-normalized) result is placed in z[0 : 2 * len(x)].
+// Requirements: len(x) > 0, len(z) == 2*len(x)
+// The (non-normalized) result is placed in z.
func basicSqr(z, x nat) {
n := len(x)
t := make(nat, 2*n) // temporary variable to hold the products
addVV(z, z, t) // combine the result
}
+// karatsubaSqr squares x and leaves the result in z.
+// len(x) must be a power of 2 and len(z) >= 6*len(x).
+// The (non-normalized) result is placed in z[0 : 2*len(x)].
+//
+// The algorithm and the layout of z are the same as for karatsuba.
+func karatsubaSqr(z, x nat) {
+ n := len(x)
+
+ if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
+ z = z[:2*n]
+ basicSqr(z, x)
+ return
+ }
+
+ n2 := n >> 1
+ x1, x0 := x[n2:], x[0:n2]
+
+ karatsubaSqr(z, x0)
+ karatsubaSqr(z[n:], x1)
+
+ // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
+ xd := z[2*n : 2*n+n2]
+ if subVV(xd, x1, x0) != 0 {
+ subVV(xd, x0, x1)
+ }
+
+ p := z[n*3:]
+ karatsubaSqr(p, xd)
+
+ r := z[n*4:]
+ copy(r, z[:n*2])
+
+ karatsubaAdd(z[n2:], r, n)
+ karatsubaAdd(z[n2:], r[n:], n)
+ karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
+}
+
// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
-// the Karatsuba algorithm is used.
+// we use the Karatsuba algorithm optimized for x == y.
var basicSqrThreshold = 20 // computed by calibrate_test.go
-var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
+var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
// z = x*x
func (z nat) sqr(x nat) nat {
return z.norm()
}
- return z.mul(x, x)
+ // Use Karatsuba multiplication optimized for x == y.
+ // The algorithm and layout of z are the same as for mul.
+
+ // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
+
+ k := karatsubaLen(n, karatsubaSqrThreshold)
+
+ x0 := x[0:k]
+ z = z.make(max(6*k, 2*n))
+ karatsubaSqr(z, x0) // z = x0^2
+ z = z[0 : 2*n]
+ z[2*k:].clear()
+
+ if k < n {
+ var t nat
+ x0 := x0.norm()
+ x1 := x[k:]
+ t = t.mul(x0, x1)
+ addAt(z, t, k)
+ addAt(z, t, k) // z = 2*x1*x0*b + x0^2
+ t = t.sqr(x1)
+ addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
+ }
+
+ return z.norm()
}
// mulRange computes the product of all the unsigned integers in the
}
}
-func testBasicSqr(t *testing.T, x nat) {
+func testSqr(t *testing.T, x nat) {
got := make(nat, 2*len(x))
want := make(nat, 2*len(x))
- basicSqr(got, x)
- basicMul(want, x, x)
+ got = got.sqr(x)
+ want = want.mul(x, x)
if got.cmp(want) != 0 {
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
}
}
-func TestBasicSqr(t *testing.T) {
+func TestSqr(t *testing.T) {
for _, a := range prodNN {
if a.x != nil {
- testBasicSqr(t, a.x)
+ testSqr(t, a.x)
}
if a.y != nil {
- testBasicSqr(t, a.y)
+ testSqr(t, a.y)
}
if a.z != nil {
- testBasicSqr(t, a.z)
+ testSqr(t, a.z)
}
}
}