]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: implement recursive algorithm for division
authorRémy Oudompheng <remyoudompheng@gmail.com>
Sun, 14 Apr 2019 06:16:13 +0000 (08:16 +0200)
committerRobert Griesemer <gri@golang.org>
Tue, 12 Nov 2019 05:18:25 +0000 (05:18 +0000)
The current division algorithm produces one word of result at a time,
using 2-word division to compute the top word and mulAddVWW to compute
the remainder. The top word may need to be adjusted by 1 or 2 units.

The recursive version, based on Burnikel, Ziegler, "Fast Recursive Division",
uses the same principles, but in a multi-word setting, so that
multiplication benefits from the Karatsuba algorithm (and possibly later
improvements).

benchmark                             old ns/op        new ns/op      delta
BenchmarkDiv/20/10-4                  38.2             38.3           +0.26%
BenchmarkDiv/40/20-4                  38.7             38.5           -0.52%
BenchmarkDiv/100/50-4                 62.5             62.6           +0.16%
BenchmarkDiv/200/100-4                238              259            +8.82%
BenchmarkDiv/400/200-4                311              338            +8.68%
BenchmarkDiv/1000/500-4               604              649            +7.45%
BenchmarkDiv/2000/1000-4              1214             1278           +5.27%
BenchmarkDiv/20000/10000-4            38279            36510          -4.62%
BenchmarkDiv/200000/100000-4          3022057          1359615        -55.01%
BenchmarkDiv/2000000/1000000-4        310827664        54012939       -82.62%
BenchmarkDiv/20000000/10000000-4      33272829421      1965401359     -94.09%
BenchmarkString/10/Base10-4           158              156            -1.27%
BenchmarkString/100/Base10-4          797              792            -0.63%
BenchmarkString/1000/Base10-4         3677             3814           +3.73%
BenchmarkString/10000/Base10-4        16633            17116          +2.90%
BenchmarkString/100000/Base10-4       5779029          1793808        -68.96%
BenchmarkString/1000000/Base10-4      889840820        85524031       -90.39%
BenchmarkString/10000000/Base10-4     134338236860     4935657026     -96.33%

Fixes #21960
Updates #30943

Change-Id: I134c6f81a47870c688ca95b6081eb9211def15a2
Reviewed-on: https://go-review.googlesource.com/c/go/+/172018
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/math/big/int_test.go
src/math/big/nat.go
src/math/big/nat_test.go

index a4285f3239c7c2aa18a231a414ac5b9b188f821d..e3a1587b3f0ad7acb0cb56091c9fa69902c89eb8 100644 (file)
@@ -1829,8 +1829,11 @@ func benchmarkDiv(b *testing.B, aSize, bSize int) {
 }
 
 func BenchmarkDiv(b *testing.B) {
-       min, max, step := 10, 100000, 10
-       for i := min; i <= max; i *= step {
+       sizes := []int{
+               10, 20, 50, 100, 200, 500, 1000,
+               1e4, 1e5, 1e6, 1e7,
+       }
+       for _, i := range sizes {
                j := 2 * i
                b.Run(fmt.Sprintf("%d/%d", j, i), func(b *testing.B) {
                        benchmarkDiv(b, j, i)
index 3b602320759f4b68673f2f3fe333dff2d909c1cd..6667319100f2af3d055bb76cd6f57c650efeceb2 100644 (file)
@@ -693,7 +693,7 @@ func putNat(x *nat) {
 
 var natPool sync.Pool
 
-// q = (uIn-r)/vIn, with 0 <= r < y
+// q = (uIn-r)/vIn, with 0 <= r < vIn
 // Uses z as storage for q, and u as storage for r if possible.
 // See Knuth, Volume 2, section 4.3.1, Algorithm D.
 // Preconditions:
@@ -721,6 +721,30 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
        }
        q = z.make(m + 1)
 
+       if n < divRecursiveThreshold {
+               q.divBasic(u, v)
+       } else {
+               q.divRecursive(u, v)
+       }
+       putNat(vp)
+
+       q = q.norm()
+       shrVU(u, u, shift)
+       r = u.norm()
+
+       return q, r
+}
+
+// divBasic performs word-by-word division of u by v.
+// The quotient is written in pre-allocated q.
+// The remainder overwrites input u.
+//
+// Precondition:
+// - len(q) >= len(u)-len(v)
+func (q nat) divBasic(u, v nat) {
+       n := len(v)
+       m := len(u) - n
+
        qhatvp := getNat(n + 1)
        qhatv := *qhatvp
 
@@ -729,7 +753,11 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
        for j := m; j >= 0; j-- {
                // D3.
                qhat := Word(_M)
-               if ujn := u[j+n]; ujn != vn1 {
+               var ujn Word
+               if j+n < len(u) {
+                       ujn = u[j+n]
+               }
+               if ujn != vn1 {
                        var rhat Word
                        qhat, rhat = divWW(ujn, u[j+n-1], vn1)
 
@@ -752,25 +780,175 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
 
                // D4.
                qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
-
-               c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
+               qhl := len(qhatv)
+               if j+qhl > len(u) && qhatv[n] == 0 {
+                       qhl--
+               }
+               c := subVV(u[j:j+qhl], u[j:], qhatv)
                if c != 0 {
                        c := addVV(u[j:j+n], u[j:], v)
                        u[j+n] += c
                        qhat--
                }
 
+               if j == m && m == len(q) && qhat == 0 {
+                       continue
+               }
                q[j] = qhat
        }
 
-       putNat(vp)
        putNat(qhatvp)
+}
 
-       q = q.norm()
-       shrVU(u, u, shift)
-       r = u.norm()
+const divRecursiveThreshold = 100
 
-       return q, r
+// divRecursive performs word-by-word division of u by v.
+// The quotient is written in pre-allocated z.
+// The remainder overwrites input u.
+//
+// Precondition:
+// - len(z) >= len(u)-len(v)
+//
+// See Burnikel, Ziegler, "Fast Recursive Division", Algorithm 1 and 2.
+func (z nat) divRecursive(u, v nat) {
+       // Recursion depth is less than 2 log2(len(v))
+       // Allocate a slice of temporaries to be reused across recursion.
+       recDepth := 2 * bits.Len(uint(len(v)))
+       // large enough to perform Karatsuba on operands as large as v
+       tmp := getNat(3 * len(v))
+       temps := make([]*nat, recDepth)
+       z.clear()
+       z.divRecursiveStep(u, v, 0, tmp, temps)
+       for _, n := range temps {
+               if n != nil {
+                       putNat(n)
+               }
+       }
+       putNat(tmp)
+}
+
+func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
+       u = u.norm()
+       v = v.norm()
+
+       if len(u) == 0 {
+               z.clear()
+               return
+       }
+       n := len(v)
+       if n < divRecursiveThreshold {
+               z.divBasic(u, v)
+               return
+       }
+       m := len(u) - n
+       if m < 0 {
+               return
+       }
+
+       // Produce the quotient by blocks of B words.
+       // Division by v (length n) is done using a length n/2 division
+       // and a length n/2 multiplication for each block. The final
+       // complexity is driven by multiplication complexity.
+       B := n / 2
+
+       // Allocate a nat for qhat below.
+       if temps[depth] == nil {
+               temps[depth] = getNat(n)
+       } else {
+               *temps[depth] = temps[depth].make(B + 1)
+       }
+
+       j := m
+       for j > B {
+               // Divide u[j-B:j+n] by vIn. Keep remainder in u
+               // for next block.
+               //
+               // The following property will be used (Lemma 2):
+               // if u = u1 << s + u0
+               //    v = v1 << s + v0
+               // then floor(u1/v1) >= floor(u/v)
+               //
+               // Moreover, the difference is at most 2 if len(v1) >= len(u/v)
+               // We choose s = B-1 since len(v)-B >= B+1 >= len(u/v)
+               s := (B - 1)
+               // Except for the first step, the top bits are always
+               // a division remainder, so the quotient length is <= n.
+               uu := u[j-B:]
+
+               qhat := *temps[depth]
+               qhat.clear()
+               qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
+               qhat = qhat.norm()
+               // Adjust the quotient:
+               //    u = u_h << s + u_l
+               //    v = v_h << s + v_l
+               //  u_h = q̂ v_h + rh
+               //    u = q̂ (v - v_l) + rh << s + u_l
+               // After the above step, u contains a remainder:
+               //    u = rh << s + u_l
+               // and we need to substract q̂ v_l
+               //
+               // But it may be a bit too large, in which case q̂ needs to be smaller.
+               qhatv := tmp.make(3 * n)
+               qhatv.clear()
+               qhatv = qhatv.mul(qhat, v[:s])
+               for i := 0; i < 2; i++ {
+                       e := qhatv.cmp(uu.norm())
+                       if e <= 0 {
+                               break
+                       }
+                       subVW(qhat, qhat, 1)
+                       c := subVV(qhatv[:s], qhatv[:s], v[:s])
+                       if len(qhatv) > s {
+                               subVW(qhatv[s:], qhatv[s:], c)
+                       }
+                       addAt(uu[s:], v[s:], 0)
+               }
+               if qhatv.cmp(uu.norm()) > 0 {
+                       panic("impossible")
+               }
+               c := subVV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv)
+               if c > 0 {
+                       subVW(uu[len(qhatv):], uu[len(qhatv):], c)
+               }
+               addAt(z, qhat, j-B)
+               j -= B
+       }
+
+       // Now u < (v<<B), compute lower bits in the same way.
+       // Choose shift = B-1 again.
+       s := B
+       qhat := *temps[depth]
+       qhat.clear()
+       qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
+       qhat = qhat.norm()
+       qhatv := tmp.make(3 * n)
+       qhatv.clear()
+       qhatv = qhatv.mul(qhat, v[:s])
+       // Set the correct remainder as before.
+       for i := 0; i < 2; i++ {
+               if e := qhatv.cmp(u.norm()); e > 0 {
+                       subVW(qhat, qhat, 1)
+                       c := subVV(qhatv[:s], qhatv[:s], v[:s])
+                       if len(qhatv) > s {
+                               subVW(qhatv[s:], qhatv[s:], c)
+                       }
+                       addAt(u[s:], v[s:], 0)
+               }
+       }
+       if qhatv.cmp(u.norm()) > 0 {
+               panic("impossible")
+       }
+       c := subVV(u[0:len(qhatv)], u[0:len(qhatv)], qhatv)
+       if c > 0 {
+               c = subVW(u[len(qhatv):B+n], u[len(qhatv):B+n], c)
+       }
+       if c > 0 {
+               panic("impossible")
+       }
+
+       // Done!
+       addAt(z, qhat.norm(), 0)
 }
 
 // Length of x in bits. x must be normalized.
index bb5e14b5fa4eef70cdcc2800aa0406eb71243315..da34e95c1fac86488ccb4dc138b0223effa72e5c 100644 (file)
@@ -739,3 +739,27 @@ func BenchmarkNatSetBytes(b *testing.B) {
                })
        }
 }
+
+func TestNatDiv(t *testing.T) {
+       sizes := []int{
+               1, 2, 5, 8, 15, 25, 40, 65, 100,
+               200, 500, 800, 1500, 2500, 4000, 6500, 10000,
+       }
+       for _, i := range sizes {
+               for _, j := range sizes {
+                       a := rndNat(i)
+                       b := rndNat(j)
+                       x := nat(nil).mul(a, b)
+                       addVW(x, x, 1)
+
+                       var q, r nat
+                       q, r = q.div(r, x, b)
+                       if q.cmp(a) != 0 {
+                               t.Fatal("wrong quotient", i, j)
+                       }
+                       if len(r) != 1 || r[0] != 1 {
+                               t.Fatal("wrong remainder")
+                       }
+               }
+       }
+}