From: Rémy Oudompheng Date: Thu, 12 Jul 2012 17:18:24 +0000 (-0700) Subject: math/big: correct quadratic space complexity in Mul. X-Git-Tag: go1.1rc2~2816 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=ac12131649391c6303f96514aee4424cb7a0b7d7;p=gostls13.git math/big: correct quadratic space complexity in Mul. The previous implementation used to have a O(n) recursion depth for unbalanced inputs. A test is added to check that a reasonable amount of bytes is allocated in this case. Fixes #3807. R=golang-dev, dsymonds, gri CC=golang-dev, remy https://golang.org/cl/6345075 --- diff --git a/src/pkg/math/big/nat.go b/src/pkg/math/big/nat.go index 66f14b4ee7..43d53d17a6 100644 --- a/src/pkg/math/big/nat.go +++ b/src/pkg/math/big/nat.go @@ -342,7 +342,7 @@ func alias(x, y nat) bool { return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] } -// addAt implements z += x*(1<<(_W*i)); z must be long enough. +// addAt implements z += x<<(_W*i); z must be long enough. // (we don't use nat.add because we need z to stay the same // slice, and we don't need to normalize z after each addition) func addAt(z, x nat, i int) { @@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat { // determine Karatsuba length k such that // - // x = x1*b + x0 - // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) // b = 1<<(_W*k) ("base" of digits xi, yi) // k := karatsubaLen(n) @@ -417,27 +417,41 @@ func (z nat) mul(x, y nat) nat { y0 := y[0:k] // y0 is not normalized z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y karatsuba(z, x0, y0) - z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage - - // If x1 and/or y1 are not 0, add missing terms to z explicitly: - // - // m+n 2*k 0 - // z = [ ... | x0*y0 ] - // + [ x1*y1 ] - // + [ x1*y0 ] - // + [ x0*y1 ] + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) + + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 + // + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. // if k < n || m != n { - x1 := x[k:] // x1 is normalized because x is - y1 := y[k:] // y1 is normalized because y is var t nat - t = t.mul(x1, y1) - copy(z[2*k:], t) - z[2*k+len(t):].clear() // upper portion of z is garbage - t = t.mul(x1, y0.norm()) - addAt(z, t, k) - t = t.mul(x0.norm(), y1) - addAt(z, t, k) + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + addAt(z, t.mul(x0, y1), k) + + // add xi*y0< k { + xi = xi[:k] + } + xi = xi.norm() + addAt(z, t.mul(xi, y0), i) + addAt(z, t.mul(xi, y1), i+k) + } } return z.norm() diff --git a/src/pkg/math/big/nat_test.go b/src/pkg/math/big/nat_test.go index dee64174a1..e4ea1ca441 100644 --- a/src/pkg/math/big/nat_test.go +++ b/src/pkg/math/big/nat_test.go @@ -7,6 +7,7 @@ package big import ( "io" "math/rand" + "runtime" "strings" "testing" ) @@ -63,6 +64,36 @@ var prodNN = []argNN{ {nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}}, {nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}}, {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}}, + // 3^100 * 3^28 = 3^128 + { + natFromString("11790184577738583171520872861412518665678211592275841109096961"), + natFromString("515377520732011331036461129765621272702107522001"), + natFromString("22876792454961"), + }, + // z = 111....1 (70000 digits) + // x = 10^(99*700) + ... + 10^1400 + 10^700 + 1 + // y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit) + { + natFromString(strings.Repeat("1", 70000)), + natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)), + natFromString(strings.Repeat("1", 700)), + }, + // z = 111....1 (20000 digits) + // x = 10^10000 + 1 + // y = 111....1 (10000 digits) + { + natFromString(strings.Repeat("1", 20000)), + natFromString("1" + strings.Repeat("0", 9999) + "1"), + natFromString(strings.Repeat("1", 10000)), + }, +} + +func natFromString(s string) nat { + x, _, err := nat(nil).scan(strings.NewReader(s), 0) + if err != nil { + panic(err) + } + return x } func TestSet(t *testing.T) { @@ -136,6 +167,31 @@ func TestMulRangeN(t *testing.T) { } } +// allocBytes returns the number of bytes allocated by invoking f. +func allocBytes(f func()) uint64 { + var stats runtime.MemStats + runtime.ReadMemStats(&stats) + t := stats.TotalAlloc + f() + runtime.ReadMemStats(&stats) + return stats.TotalAlloc - t +} + +// TestMulUnbalanced tests that multiplying numbers of different lengths +// does not cause deep recursion and in turn allocate too much memory. +// test case for issue 3807 +func TestMulUnbalanced(t *testing.T) { + x := rndNat(50000) + y := rndNat(40) + allocSize := allocBytes(func() { + nat(nil).mul(x, y) + }) + inputSize := uint64(len(x)+len(y)) * _S + if ratio := allocSize / uint64(inputSize); ratio > 10 { + t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio) + } +} + var rnd = rand.New(rand.NewSource(0x43de683f473542af)) var mulx = rndNat(1e4) var muly = rndNat(1e4)