]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: correct quadratic space complexity in Mul.
authorRémy Oudompheng <oudomphe@phare.normalesup.org>
Thu, 12 Jul 2012 17:18:24 +0000 (10:18 -0700)
committerRobert Griesemer <gri@golang.org>
Thu, 12 Jul 2012 17:18:24 +0000 (10:18 -0700)
The previous implementation used to have a O(n) recursion
depth for unbalanced inputs. A test is added to check that a
reasonable amount of bytes is allocated in this case.

Fixes #3807.

R=golang-dev, dsymonds, gri
CC=golang-dev, remy
https://golang.org/cl/6345075

src/pkg/math/big/nat.go
src/pkg/math/big/nat_test.go

index 66f14b4ee7526b7d966e686ce7f062478ab06447..43d53d17a6e1bf01bab10f0613d72df706cf2c70 100644 (file)
@@ -342,7 +342,7 @@ func alias(x, y nat) bool {
        return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
 }
 
-// addAt implements z += x*(1<<(_W*i)); z must be long enough.
+// addAt implements z += x<<(_W*i); z must be long enough.
 // (we don't use nat.add because we need z to stay the same
 // slice, and we don't need to normalize z after each addition)
 func addAt(z, x nat, i int) {
@@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat {
 
        // determine Karatsuba length k such that
        //
-       //   x = x1*b + x0
-       //   y = y1*b + y0  (and k <= len(y), which implies k <= len(x))
+       //   x = xh*b + x0  (0 <= x0 < b)
+       //   y = yh*b + y0  (0 <= y0 < b)
        //   b = 1<<(_W*k)  ("base" of digits xi, yi)
        //
        k := karatsubaLen(n)
@@ -417,27 +417,41 @@ func (z nat) mul(x, y nat) nat {
        y0 := y[0:k]              // y0 is not normalized
        z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
        karatsuba(z, x0, y0)
-       z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
-
-       // If x1 and/or y1 are not 0, add missing terms to z explicitly:
-       //
-       //     m+n       2*k       0
-       //   z = [   ...   | x0*y0 ]
-       //     +   [ x1*y1 ]
-       //     +   [ x1*y0 ]
-       //     +   [ x0*y1 ]
+       z = z[0 : m+n]  // z has final length but may be incomplete
+       z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
+
+       // If xh != 0 or yh != 0, add the missing terms to z. For
+       // 
+       //   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 
+       //   yh =                         y1*b (0 <= y1 < b) 
+       // 
+       // the missing terms are 
+       // 
+       //   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 
+       // 
+       // since all the yi for i > 1 are 0 by choice of k: If any of them 
+       // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 
+       // be a larger valid threshold contradicting the assumption about k. 
        //
        if k < n || m != n {
-               x1 := x[k:] // x1 is normalized because x is
-               y1 := y[k:] // y1 is normalized because y is
                var t nat
-               t = t.mul(x1, y1)
-               copy(z[2*k:], t)
-               z[2*k+len(t):].clear() // upper portion of z is garbage
-               t = t.mul(x1, y0.norm())
-               addAt(z, t, k)
-               t = t.mul(x0.norm(), y1)
-               addAt(z, t, k)
+
+               // add x0*y1*b
+               x0 := x0.norm()
+               y1 := y[k:] // y1 is normalized because y is
+               addAt(z, t.mul(x0, y1), k)
+
+               // add xi*y0<<i, xi*y1*b<<(i+k)
+               y0 := y0.norm()
+               for i := k; i < len(x); i += k {
+                       xi := x[i:]
+                       if len(xi) > k {
+                               xi = xi[:k]
+                       }
+                       xi = xi.norm()
+                       addAt(z, t.mul(xi, y0), i)
+                       addAt(z, t.mul(xi, y1), i+k)
+               }
        }
 
        return z.norm()
index dee64174a1ebf38774df0facf8463393e465dea5..e4ea1ca441f7f589e9e0a90bb8bd551de041c391 100644 (file)
@@ -7,6 +7,7 @@ package big
 import (
        "io"
        "math/rand"
+       "runtime"
        "strings"
        "testing"
 )
@@ -63,6 +64,36 @@ var prodNN = []argNN{
        {nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
        {nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
        {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
+       // 3^100 * 3^28 = 3^128
+       {
+               natFromString("11790184577738583171520872861412518665678211592275841109096961"),
+               natFromString("515377520732011331036461129765621272702107522001"),
+               natFromString("22876792454961"),
+       },
+       // z = 111....1 (70000 digits)
+       // x = 10^(99*700) + ... + 10^1400 + 10^700 + 1
+       // y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit)
+       {
+               natFromString(strings.Repeat("1", 70000)),
+               natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)),
+               natFromString(strings.Repeat("1", 700)),
+       },
+       // z = 111....1 (20000 digits)
+       // x = 10^10000 + 1
+       // y = 111....1 (10000 digits)
+       {
+               natFromString(strings.Repeat("1", 20000)),
+               natFromString("1" + strings.Repeat("0", 9999) + "1"),
+               natFromString(strings.Repeat("1", 10000)),
+       },
+}
+
+func natFromString(s string) nat {
+       x, _, err := nat(nil).scan(strings.NewReader(s), 0)
+       if err != nil {
+               panic(err)
+       }
+       return x
 }
 
 func TestSet(t *testing.T) {
@@ -136,6 +167,31 @@ func TestMulRangeN(t *testing.T) {
        }
 }
 
+// allocBytes returns the number of bytes allocated by invoking f. 
+func allocBytes(f func()) uint64 {
+       var stats runtime.MemStats
+       runtime.ReadMemStats(&stats)
+       t := stats.TotalAlloc
+       f()
+       runtime.ReadMemStats(&stats)
+       return stats.TotalAlloc - t
+}
+
+// TestMulUnbalanced tests that multiplying numbers of different lengths
+// does not cause deep recursion and in turn allocate too much memory.
+// test case for issue 3807
+func TestMulUnbalanced(t *testing.T) {
+       x := rndNat(50000)
+       y := rndNat(40)
+       allocSize := allocBytes(func() {
+               nat(nil).mul(x, y)
+       })
+       inputSize := uint64(len(x)+len(y)) * _S
+       if ratio := allocSize / uint64(inputSize); ratio > 10 {
+               t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio)
+       }
+}
+
 var rnd = rand.New(rand.NewSource(0x43de683f473542af))
 var mulx = rndNat(1e4)
 var muly = rndNat(1e4)