return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
-// addAt implements z += x*(1<<(_W*i)); z must be long enough.
+// addAt implements z += x<<(_W*i); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func addAt(z, x nat, i int) {
// determine Karatsuba length k such that
//
- // x = x1*b + x0
- // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
+ // x = xh*b + x0 (0 <= x0 < b)
+ // y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n)
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
- z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
-
- // If x1 and/or y1 are not 0, add missing terms to z explicitly:
- //
- // m+n 2*k 0
- // z = [ ... | x0*y0 ]
- // + [ x1*y1 ]
- // + [ x1*y0 ]
- // + [ x0*y1 ]
+ z = z[0 : m+n] // z has final length but may be incomplete
+ z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
+
+ // If xh != 0 or yh != 0, add the missing terms to z. For
+ //
+ // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
+ // yh = y1*b (0 <= y1 < b)
+ //
+ // the missing terms are
+ //
+ // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
+ //
+ // since all the yi for i > 1 are 0 by choice of k: If any of them
+ // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
+ // be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
- x1 := x[k:] // x1 is normalized because x is
- y1 := y[k:] // y1 is normalized because y is
var t nat
- t = t.mul(x1, y1)
- copy(z[2*k:], t)
- z[2*k+len(t):].clear() // upper portion of z is garbage
- t = t.mul(x1, y0.norm())
- addAt(z, t, k)
- t = t.mul(x0.norm(), y1)
- addAt(z, t, k)
+
+ // add x0*y1*b
+ x0 := x0.norm()
+ y1 := y[k:] // y1 is normalized because y is
+ addAt(z, t.mul(x0, y1), k)
+
+ // add xi*y0<<i, xi*y1*b<<(i+k)
+ y0 := y0.norm()
+ for i := k; i < len(x); i += k {
+ xi := x[i:]
+ if len(xi) > k {
+ xi = xi[:k]
+ }
+ xi = xi.norm()
+ addAt(z, t.mul(xi, y0), i)
+ addAt(z, t.mul(xi, y1), i+k)
+ }
}
return z.norm()
import (
"io"
"math/rand"
+ "runtime"
"strings"
"testing"
)
{nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
{nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
+ // 3^100 * 3^28 = 3^128
+ {
+ natFromString("11790184577738583171520872861412518665678211592275841109096961"),
+ natFromString("515377520732011331036461129765621272702107522001"),
+ natFromString("22876792454961"),
+ },
+ // z = 111....1 (70000 digits)
+ // x = 10^(99*700) + ... + 10^1400 + 10^700 + 1
+ // y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit)
+ {
+ natFromString(strings.Repeat("1", 70000)),
+ natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)),
+ natFromString(strings.Repeat("1", 700)),
+ },
+ // z = 111....1 (20000 digits)
+ // x = 10^10000 + 1
+ // y = 111....1 (10000 digits)
+ {
+ natFromString(strings.Repeat("1", 20000)),
+ natFromString("1" + strings.Repeat("0", 9999) + "1"),
+ natFromString(strings.Repeat("1", 10000)),
+ },
+}
+
+func natFromString(s string) nat {
+ x, _, err := nat(nil).scan(strings.NewReader(s), 0)
+ if err != nil {
+ panic(err)
+ }
+ return x
}
func TestSet(t *testing.T) {
}
}
+// allocBytes returns the number of bytes allocated by invoking f.
+func allocBytes(f func()) uint64 {
+ var stats runtime.MemStats
+ runtime.ReadMemStats(&stats)
+ t := stats.TotalAlloc
+ f()
+ runtime.ReadMemStats(&stats)
+ return stats.TotalAlloc - t
+}
+
+// TestMulUnbalanced tests that multiplying numbers of different lengths
+// does not cause deep recursion and in turn allocate too much memory.
+// test case for issue 3807
+func TestMulUnbalanced(t *testing.T) {
+ x := rndNat(50000)
+ y := rndNat(40)
+ allocSize := allocBytes(func() {
+ nat(nil).mul(x, y)
+ })
+ inputSize := uint64(len(x)+len(y)) * _S
+ if ratio := allocSize / uint64(inputSize); ratio > 10 {
+ t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio)
+ }
+}
+
var rnd = rand.New(rand.NewSource(0x43de683f473542af))
var mulx = rndNat(1e4)
var muly = rndNat(1e4)