]> Cypherpunks repositories - gostls13.git/commitdiff
big: use fast shift routines
authorRobert Griesemer <gri@golang.org>
Sat, 1 May 2010 04:25:48 +0000 (21:25 -0700)
committerRobert Griesemer <gri@golang.org>
Sat, 1 May 2010 04:25:48 +0000 (21:25 -0700)
- fixed a couple of bugs in the process
  (shift right was incorrect for negative numbers)
- added more tests and made some tests more robust
- changed pidigits back to using shifts to multiply
  by 2 instead of add

  This improves pidigit -s -n 10000 by approx. 5%:

  user 0m6.496s (old)
  user 0m6.156s (new)

R=rsc
CC=golang-dev
https://golang.org/cl/963044

src/pkg/big/int.go
src/pkg/big/int_test.go
src/pkg/big/nat.go
src/pkg/big/nat_test.go
test/bench/pidigits.go

index e5e589a8528a62994cce0d6aa7c1bff8e913e9e0..2b7a628052dddb72c235d8e2a7580910938800ba 100644 (file)
@@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
        if scanned != len(s) {
                goto Error
        }
+       if len(z.abs) == 0 {
+               z.neg = false // 0 has no sign
+       }
 
        return z, true
 
 Error:
        z.neg = false
        z.abs = nil
-       return nil, false
+       return z, false
 }
 
 
@@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n)
 
 // Lsh sets z = x << n and returns z.
 func (z *Int) Lsh(x *Int, n uint) *Int {
-       addedWords := int(n) / _W
-       // Don't assign z.abs yet, in case z == x
-       znew := z.abs.make(len(x.abs) + addedWords + 1)
        z.neg = x.neg
-       znew[addedWords:].shiftLeft(x.abs, n%_W)
-       for i := range znew[0:addedWords] {
-               znew[i] = 0
-       }
-       z.abs = znew.norm()
+       z.abs = z.abs.shl(x.abs, n)
        return z
 }
 
 
 // Rsh sets z = x >> n and returns z.
 func (z *Int) Rsh(x *Int, n uint) *Int {
-       removedWords := int(n) / _W
-       // Don't assign z.abs yet, in case z == x
-       znew := z.abs.make(len(x.abs) - removedWords)
-       z.neg = x.neg
-       znew.shiftRight(x.abs[removedWords:], n%_W)
-       z.abs = znew.norm()
+       if x.neg {
+               // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1)
+               z.neg = true
+               t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0
+               t = t.shr(t, n)
+               z.abs = t.add(t, natOne)
+               return z
+       }
+
+       z.neg = false
+       z.abs = z.abs.shr(x.abs, n)
        return z
 }
index cdcd28eac7929753716a3bc8a7ee11ecb40ea7db..ceb31e069e0fec2e9d389219fd70bc842c20925c 100644 (file)
@@ -562,6 +562,7 @@ type intShiftTest struct {
 
 var rshTests = []intShiftTest{
        intShiftTest{"0", 0, "0"},
+       intShiftTest{"-0", 0, "0"},
        intShiftTest{"0", 1, "0"},
        intShiftTest{"0", 2, "0"},
        intShiftTest{"1", 0, "1"},
@@ -569,7 +570,12 @@ var rshTests = []intShiftTest{
        intShiftTest{"1", 2, "0"},
        intShiftTest{"2", 0, "2"},
        intShiftTest{"2", 1, "1"},
-       intShiftTest{"2", 2, "0"},
+       intShiftTest{"-1", 0, "-1"},
+       intShiftTest{"-1", 1, "-1"},
+       intShiftTest{"-1", 10, "-1"},
+       intShiftTest{"-100", 2, "-25"},
+       intShiftTest{"-100", 3, "-13"},
+       intShiftTest{"-100", 100, "-1"},
        intShiftTest{"4294967296", 0, "4294967296"},
        intShiftTest{"4294967296", 1, "2147483648"},
        intShiftTest{"4294967296", 2, "1073741824"},
index 2db9e59f8eb02b2de8fbfd9221eb77141f26ecb9..ff8e806b2429821051e00260f3a280372aa469ab 100644 (file)
@@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
 
        // D1.
        shift := uint(leadingZeroBits(v[n-1]))
-       v.shiftLeft(v, shift)
-       u.shiftLeft(uIn, shift)
+       v.shiftLeftDeprecated(v, shift)
+       u.shiftLeftDeprecated(uIn, shift)
        u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
 
        // D2.
@@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
        }
 
        q = q.norm()
-       u.shiftRight(u, shift)
-       v.shiftRight(v, shift)
+       u.shiftRightDeprecated(u, shift)
+       v.shiftRightDeprecated(v, shift)
        r = u.norm()
 
        return q, r
@@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int {
 }
 
 
-// TODO(gri) Make the shift routines faster.
-//           Use pidigits.go benchmark as a test case.
+// z = x << s
+func (z nat) shl(x nat, s uint) nat {
+       m := len(x)
+       if m == 0 {
+               return z.make(0)
+       }
+       // m > 0
+
+       // determine if z can be reused
+       // TODO(gri) change shlVW so we don't need this
+       if len(z) > 0 && alias(z, x) {
+               z = nil // z is an alias for x - cannot reuse
+       }
+
+       n := m + int(s/_W)
+       z = z.make(n + 1)
+       z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m)
+
+       return z.norm()
+}
+
+
+// z = x >> s
+func (z nat) shr(x nat, s uint) nat {
+       m := len(x)
+       n := m - int(s/_W)
+       if n <= 0 {
+               return z.make(0)
+       }
+       // n > 0
 
+       // determine if z can be reused
+       // TODO(gri) change shrVW so we don't need this
+       if len(z) > 0 && alias(z, x) {
+               z = nil // z is an alias for x - cannot reuse
+       }
+
+       z = z.make(n)
+       shrVW(&z[0], &x[m-n], Word(s%_W), m)
+
+       return z.norm()
+}
+
+
+// TODO(gri) Remove these shift functions once shlVW and shrVW can be
+//           used directly in divLarge and powersOfTwoDecompose
+//
 // To avoid losing the top n bits, z should be sized so that
 // len(z) == len(x) + 1.
-func (z nat) shiftLeft(x nat, n uint) nat {
+func (z nat) shiftLeftDeprecated(x nat, n uint) nat {
        if len(x) == 0 {
                return x
        }
@@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat {
 }
 
 
-func (z nat) shiftRight(x nat, n uint) nat {
+func (z nat) shiftRightDeprecated(x nat, n uint) nat {
        if len(x) == 0 {
                return x
        }
@@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) {
        x := trailingZeroBits(n[zeroWords])
 
        q = q.make(len(n) - zeroWords)
-       q.shiftRight(n[zeroWords:], uint(x))
+       q.shiftRightDeprecated(n[zeroWords:], uint(x))
        q = q.norm()
 
        k = Word(_W*zeroWords + x)
index e1039c48a110641c80815202ae606c9e32d4d02e..bf637b0daa9b4363a197ab94863b7840a44e69ec 100644 (file)
@@ -230,9 +230,8 @@ type shiftTest struct {
 var leftShiftTests = []shiftTest{
        shiftTest{nil, 0, nil},
        shiftTest{nil, 1, nil},
-       shiftTest{nat{0}, 0, nat{0}},
-       shiftTest{nat{1}, 0, nat{1}},
-       shiftTest{nat{1}, 1, nat{2}},
+       shiftTest{natOne, 0, natOne},
+       shiftTest{natOne, 1, natTwo},
        shiftTest{nat{1 << (_W - 1)}, 1, nat{0}},
        shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
 }
@@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{
 
 func TestShiftLeft(t *testing.T) {
        for i, test := range leftShiftTests {
-               dst := make(nat, len(test.out))
-               dst.shiftLeft(test.in, test.shift)
-               for j, v := range dst {
-                       if test.out[j] != v {
-                               t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
+               var z nat
+               z = z.shl(test.in, test.shift)
+               for j, d := range test.out {
+                       if j >= len(z) || z[j] != d {
+                               t.Errorf("#%d: got: %v want: %v", i, z, test.out)
                                break
                        }
                }
@@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) {
 var rightShiftTests = []shiftTest{
        shiftTest{nil, 0, nil},
        shiftTest{nil, 1, nil},
-       shiftTest{nat{0}, 0, nat{0}},
-       shiftTest{nat{1}, 0, nat{1}},
-       shiftTest{nat{1}, 1, nat{0}},
-       shiftTest{nat{2}, 1, nat{1}},
-       shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1), 0}},
-       shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1), 0}},
+       shiftTest{natOne, 0, natOne},
+       shiftTest{natOne, 1, nil},
+       shiftTest{natTwo, 1, natOne},
+       shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1)}},
+       shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
 }
 
 
 func TestShiftRight(t *testing.T) {
        for i, test := range rightShiftTests {
-               dst := make(nat, len(test.out))
-               dst.shiftRight(test.in, test.shift)
-               for j, v := range dst {
-                       if test.out[j] != v {
-                               t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
+               var z nat
+               z = z.shr(test.in, test.shift)
+               for j, d := range test.out {
+                       if j >= len(z) || z[j] != d {
+                               t.Errorf("#%d: got: %v want: %v", i, z, test.out)
                                break
                        }
                }
index 3e455dc8385b76b9079b2c1fd3563d83c0129dc2..a05515028ae7a843763b01444c14fb219685df75 100644 (file)
@@ -63,7 +63,7 @@ func extract_digit() int64 {
        }
 
        // Compute (numer * 3 + accum) / denom
-       tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
+       tmp1.Lsh(numer, 1)
        tmp1.Add(tmp1, numer)
        tmp1.Add(tmp1, accum)
        tmp1.DivMod(tmp1, denom, tmp2)
@@ -84,7 +84,7 @@ func next_term(k int64) {
        y2.New(k*2 + 1)
        bigk.New(k)
 
-       tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
+       tmp1.Lsh(numer, 1)
        accum.Add(accum, tmp1)
        accum.Mul(accum, y2)
        numer.Mul(numer, bigk)