]> Cypherpunks repositories - gostls13.git/commitdiff
big: add Div, Mod, Exp, GcdExt and several other fixes.
authorAdam Langley <agl@golang.org>
Thu, 5 Nov 2009 23:55:41 +0000 (15:55 -0800)
committerAdam Langley <agl@golang.org>
Thu, 5 Nov 2009 23:55:41 +0000 (15:55 -0800)
R=gri, rsc
CC=go-dev
http://go/go-review/1017036

src/pkg/big/arith.go
src/pkg/big/arith_test.go
src/pkg/big/int.go
src/pkg/big/int_test.go
src/pkg/big/nat.go
src/pkg/big/nat_test.go

index 26acb0e6b97a058b78eb6ba516dc2a1459e004b0..2b89638a5d578e7575b71eff8e42d6f0dce47587 100644 (file)
@@ -50,10 +50,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) {
 }
 
 
-// TODO(gri) mulWW_g is not needed anymore. Keep around for
-//           now since mulAddWWW_g should use some of the
-//           optimizations from mulWW_g eventually.
-
 // z1<<_W + z0 = x*y
 func mulWW_g(x, y Word) (z1, z0 Word) {
        // Split x and y into 2 halfWords each, multiply
@@ -86,6 +82,9 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
                // z = z[1]*_B + z[0] = x*y
                z0 = t1<<_W2 + t0;
                z1 = (t1 + t0>>_W2)>>_W2;
+               if z0 < t0 {
+                       z1++;
+               }
                return;
        }
 
@@ -98,13 +97,31 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
 
        // x*y = t2*_B2*_B2 + t1*_B2 + t0
        t0 := x0*y0;
-       t1 := x1*y0 + x0*y1;
-       t2 := x1*y1;
+       // t1 := x1*y0 + x0*y1;
+       var c Word;
+       t1 := x1*y0;
+       t1a := t1;
+       t1 += x0*y1;
+       if t1 < t1a {
+               c++;
+       }
+       t2 := x1*y1 + c*_B2;
 
        // compute result digits but avoid overflow
        // z = z[1]*_B + z[0] = x*y
+       // This may overflow, but that's ok because we also sum t1 and t0 above
+       // and we take care of the overflow there.
        z0 = t1<<_W2 + t0;
-       z1 = t2 + (t1 + t0>>_W2)>>_W2;
+
+       // z1 = t2 + (t1 + t0>>_W2)>>_W2;
+       var c3 Word;
+       z1 = t1 + t0>>_W2;
+       if z1 < t1 {
+               c3++;
+       }
+       z1 >>= _W2;
+       z1 += c3*_B2;
+       z1 += t2;
        return;
 }
 
@@ -126,14 +143,35 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
        c1, c0 := c>>_W2, c&_M2;
 
        // x*y + c = t2*_B2*_B2 + t1*_B2 + t0
+       // (1<<32-1)^2 == 1<<64 - 1<<33 + 1, so there's space to add c0 in here.
        t0 := x0*y0 + c0;
-       t1 := x1*y0 + x0*y1 + c1;
-       t2 := x1*y1;
+
+       // t1 := x1*y0 + x0*y1 + c1;
+       var c2 Word;    // extra carry
+       t1 := x1*y0 + c1;
+       t1a := t1;
+       t1 += x0*y1;
+       if t1 < t1a {   // If the number got smaller then we overflowed.
+               c2++;
+       }
+
+       t2 := x1*y1 + c2*_B2;
 
        // compute result digits but avoid overflow
        // z = z[1]*_B + z[0] = x*y
+       // z0 = t1<<_W2 + t0;
+       // This may overflow, but that's ok because we also sum t1 and t0 below
+       // and we take care of the overflow there.
        z0 = t1<<_W2 + t0;
-       z1 = t2 + (t1 + t0>>_W2)>>_W2;
+
+       var c3 Word;
+       z1 = t1 + t0>>_W2;
+       if z1 < t1 {
+               c3++;
+       }
+       z1 >>= _W2;
+       z1 += t2 + c3*_B2;
+
        return;
 }
 
index 27f640304b7099f49ff48208602399d81ca4815c..af28884e0d039a2b43dd0e0dcbfce10329a4166c 100644 (file)
@@ -268,3 +268,49 @@ func TestFunVWW(t *testing.T) {
                }
        }
 }
+
+
+type mulWWTest struct {
+       x, y    Word;
+       q, r    Word;
+}
+
+
+var mulWWTests = []mulWWTest{
+       mulWWTest{_M, _M, _M - 1, 1},
+}
+
+
+func TestMulWWW(t *testing.T) {
+       for i, test := range mulWWTests {
+               q, r := mulWW_g(test.x, test.y);
+               if q != test.q || r != test.r {
+                       t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r);
+               }
+       }
+}
+
+
+type mulAddWWWTest struct {
+       x, y, c Word;
+       q, r    Word;
+}
+
+
+var mulAddWWWTests = []mulAddWWWTest{
+       // TODO(agl): These will only work on 64-bit platforms.
+       // mulAddWWWTest{15064310297182388543, 0xe7df04d2d35d5d80, 13537600649892366549, 13644450054494335067, 10832252001440893781},
+       // mulAddWWWTest{15064310297182388543, 0xdab2f18048baa68d, 13644450054494335067, 12869334219691522700, 14233854684711418382},
+       mulAddWWWTest{_M, _M, 0, _M - 1, 1},
+       mulAddWWWTest{_M, _M, _M, _M, 0},
+}
+
+
+func TestMulAddWWW(t *testing.T) {
+       for i, test := range mulAddWWWTests {
+               q, r := mulAddWWW_g(test.x, test.y, test.c);
+               if q != test.q || r != test.r {
+                       t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r);
+               }
+       }
+}
index efa678c05d49142b2ed5dae0d999722ec0840b59..5db4f0901aee272bcfe10f79a70e1ce1084110cf 100644 (file)
@@ -26,6 +26,12 @@ func (z *Int) New(x int64) *Int {
 }
 
 
+// NewInt allocates and returns a new Int set to x.
+func NewInt(x int64) *Int {
+       return new(Int).New(x);
+}
+
+
 // Set sets z to x.
 func (z *Int) Set(x *Int) *Int {
        z.neg = x.neg;
@@ -96,6 +102,66 @@ func (z *Int) Mul(x, y *Int) *Int {
 }
 
 
+// Div calculates q = (x-r)/y where 0 <= r < y. The receiver is set to q.
+func (z *Int) Div(x, y *Int) (q, r *Int) {
+       q = z;
+       r = new(Int);
+       div(q, r, x, y);
+       return;
+}
+
+
+// Mod calculates q = (x-r)/y and returns r.
+func (z *Int) Mod(x, y *Int) (r *Int) {
+       q := new(Int);
+       r = z;
+       div(q, r, x, y);
+       return;
+}
+
+
+func div(q, r, x, y *Int) {
+       if len(y.abs) == 0 {
+               panic("Divide by zero undefined");
+       }
+
+       if cmpNN(x.abs, y.abs) < 0 {
+               q.neg = false;
+               q.abs = nil;
+               r.neg = y.neg;
+
+               src := x.abs;
+               dst := x.abs;
+               if r == x {
+                       dst = nil;
+               }
+
+               r.abs = makeN(dst, len(src), false);
+               for i, v := range src {
+                       r.abs[i] = v;
+               }
+               return;
+       }
+
+       if len(y.abs) == 1 {
+               var rprime Word;
+               q.abs, rprime = divNW(q.abs, x.abs, y.abs[0]);
+               if rprime > 0 {
+                       r.abs = makeN(r.abs, 1, false);
+                       r.abs[0] = rprime;
+                       r.neg = x.neg;
+               }
+               q.neg = len(q.abs) > 0 && x.neg != y.neg;
+               return;
+       }
+
+       q.neg = x.neg != y.neg;
+       r.neg = x.neg;
+       q.abs, r.abs = divNN(q.abs, r.abs, x.abs, y.abs);
+       return;
+}
+
+
 // Neg computes z = -x.
 func (z *Int) Neg(x *Int) *Int {
        z.abs = setN(z.abs, x.abs);
@@ -132,10 +198,234 @@ func CmpInt(x, y *Int) (r int) {
 }
 
 
-func (x *Int) String() string {
+func (z *Int) String() string {
        s := "";
-       if x.neg {
+       if z.neg {
                s = "-";
        }
-       return s + stringN(x.abs, 10);
+       return s + stringN(z.abs, 10);
+}
+
+
+// SetString sets z to the value of s, interpreted in the given base.
+// If base is 0 then SetString attempts to detect the base by at the prefix of
+// s. '0x' implies base 16, '0' implies base 8. Otherwise base 10 is assumed.
+func (z *Int) SetString(s string, base int) (*Int, bool) {
+       var scanned int;
+
+       if base == 1 || base > 16 {
+               goto Error;
+       }
+
+       if len(s) == 0 {
+               goto Error;
+       }
+
+       if s[0] == '-' {
+               z.neg = true;
+               s = s[1:len(s)];
+       } else {
+               z.neg = false;
+       }
+
+       z.abs, _, scanned = scanN(z.abs, s, base);
+       if scanned != len(s) {
+               goto Error;
+       }
+
+       return z, true;
+
+Error:
+       z.neg = false;
+       z.abs = nil;
+       return nil, false;
+}
+
+
+// SetBytes interprets b as the bytes of a big-endian, unsigned integer and
+// sets x to that value.
+func (z *Int) SetBytes(b []byte) *Int {
+       s := int(_S);
+       z.abs = makeN(z.abs, (len(b)+s-1)/s, false);
+       z.neg = false;
+
+       j := 0;
+       for len(b) >= s {
+               var w Word;
+
+               for i := s; i > 0; i-- {
+                       w <<= 8;
+                       w |= Word(b[len(b)-i]);
+               }
+
+               z.abs[j] = w;
+               j++;
+               b = b[0 : len(b)-s];
+       }
+
+       if len(b) > 0 {
+               var w Word;
+
+               for i := len(b); i > 0; i-- {
+                       w <<= 8;
+                       w |= Word(b[len(b)-i]);
+               }
+
+               z.abs[j] = w;
+       }
+
+       z.abs = normN(z.abs);
+
+       return z;
+}
+
+
+// Bytes returns the absolute value of x as a big-endian byte array.
+func (z *Int) Bytes() []byte {
+       s := int(_S);
+       b := make([]byte, len(z.abs)*s);
+
+       for i, w := range z.abs {
+               wordBytes := b[(len(z.abs)-i-1)*s : (len(z.abs)-i)*s];
+               for j := s-1; j >= 0; j-- {
+                       wordBytes[j] = byte(w);
+                       w >>= 8;
+               }
+       }
+
+       i := 0;
+       for i < len(b) && b[i] == 0 {
+               i++;
+       }
+
+       return b[i:len(b)];
+}
+
+
+// Len returns the length of the absolute value of x in bits. Zero is
+// considered to have a length of one.
+func (z *Int) Len() int {
+       if len(z.abs) == 0 {
+               return 0;
+       }
+
+       return len(z.abs)*int(_W) - int(leadingZeros(z.abs[len(z.abs)-1]));
+}
+
+
+// Exp sets z = x**y mod m. If m is nil, z = x**y.
+// See Knuth, volume 2, section 4.6.3.
+func (z *Int) Exp(x, y, m *Int) *Int {
+       if y.neg || len(y.abs) == 0 {
+               z.New(1);
+               z.neg = x.neg;
+               return z;
+       }
+
+       z.Set(x);
+       v := y.abs[len(y.abs)-1];
+       // It's invalid for the most significant word to be zero, therefore we
+       // will find a one bit.
+       shift := leadingZeros(v) + 1;
+       v <<= shift;
+
+       const mask = 1<<(_W-1);
+
+       // We walk through the bits of the exponent one by one. Each time we see
+       // a bit, we square, thus doubling the power. If the bit is a one, we
+       // also multiply by x, thus adding one to the power.
+
+       w := int(_W)-int(shift);
+       for j := 0; j < w; j++ {
+               z.Mul(z, z);
+
+               if v&mask != 0 {
+                       z.Mul(z, x);
+               }
+
+               if m != nil {
+                       z.Mod(z, m);
+               }
+
+               v <<= 1;
+       }
+
+       for i := len(y.abs)-2; i >= 0; i-- {
+               v = y.abs[i];
+
+               for j := 0; j < int(_W); j++ {
+                       z.Mul(z, z);
+
+                       if v&mask != 0 {
+                               z.Mul(z, x);
+                       }
+
+                       if m != nil {
+                               z.Mod(z, m);
+                       }
+
+                       v <<= 1;
+               }
+       }
+
+       z.neg = x.neg && y.abs[0] & 1 == 1;
+       return z;
+}
+
+
+// GcdInt sets d to the greatest common divisor of a and b, which must be
+// positive numbers.
+// If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y.
+// If either a or b is not positive, GcdInt sets d = x = y = 0.
+func GcdInt(d, x, y, a, b *Int) {
+       if a.neg || b.neg {
+               d.New(0);
+               if x != nil {
+                       x.New(0);
+               }
+               if y != nil {
+                       y.New(0);
+               }
+               return;
+       }
+
+       A := new(Int).Set(a);
+       B := new(Int).Set(b);
+
+       X := new(Int);
+       Y := new(Int).New(1);
+
+       lastX := new(Int).New(1);
+       lastY := new(Int);
+
+       q := new(Int);
+       temp := new(Int);
+
+       for len(B.abs) > 0 {
+               q, r := q.Div(A, B);
+
+               A, B = B, r;
+
+               temp.Set(X);
+               X.Mul(X, q);
+               X.neg = !X.neg;
+               X.Add(X, lastX);
+               lastX.Set(temp);
+
+               temp.Set(Y);
+               Y.Mul(Y, q);
+               Y.neg = !Y.neg;
+               Y.Add(Y, lastY);
+               lastY.Set(temp);
+       }
+
+       if x != nil {
+               *x = *lastX;
+       }
+
+       if y != nil {
+               *y = *lastY;
+       }
+
+       *d = *A;
 }
index 30d3cd68138b494ce2694afde873354f1c1ab13a..43f7eedf264145f180cca53bd7ee3c27fc9ac8f8 100644 (file)
@@ -4,8 +4,12 @@
 
 package big
 
-import "testing"
-
+import (
+       "bytes";
+       "encoding/hex";
+       "testing";
+       "testing/quick";
+)
 
 func newZ(x int64) *Int {
        var z Int;
@@ -18,6 +22,7 @@ type argZZ struct {
        z, x, y *Int;
 }
 
+
 var sumZZ = []argZZ{
        argZZ{newZ(0), newZ(0), newZ(0)},
        argZZ{newZ(1), newZ(1), newZ(0)},
@@ -27,12 +32,13 @@ var sumZZ = []argZZ{
        argZZ{newZ(-1111111110), newZ(-123456789), newZ(-987654321)},
 }
 
+
 var prodZZ = []argZZ{
        argZZ{newZ(0), newZ(0), newZ(0)},
        argZZ{newZ(0), newZ(1), newZ(0)},
        argZZ{newZ(1), newZ(1), newZ(1)},
        argZZ{newZ(-991 * 991), newZ(991), newZ(-991)},
-// TODO(gri) add larger products
+       // TODO(gri) add larger products
 }
 
 
@@ -57,12 +63,8 @@ func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
 
 
 func TestSumZZ(t *testing.T) {
-       AddZZ := func(z, x, y *Int) *Int {
-               return z.Add(x, y);
-       };
-       SubZZ := func(z, x, y *Int) *Int {
-               return z.Sub(x, y);
-       };
+       AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) };
+       SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) };
        for _, a := range sumZZ {
                arg := a;
                testFunZZ(t, "AddZZ", AddZZ, arg);
@@ -80,9 +82,7 @@ func TestSumZZ(t *testing.T) {
 
 
 func TestProdZZ(t *testing.T) {
-       MulZZ := func(z, x, y *Int) *Int {
-               return z.Mul(x, y);
-       };
+       MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) };
        for _, a := range prodZZ {
                arg := a;
                testFunZZ(t, "MulZZ", MulZZ, arg);
@@ -100,8 +100,8 @@ var facts = map[int]string{
        10: "3628800",
        20: "2432902008176640000",
        100: "933262154439441526816992388562667004907159682643816214685929"
-               "638952175999932299156089414639761565182862536979208272237582"
-               "51185210916864000000000000000000000000",
+       "638952175999932299156089414639761565182862536979208272237582"
+       "51185210916864000000000000000000000000",
 }
 
 
@@ -125,3 +125,306 @@ func TestFact(t *testing.T) {
                }
        }
 }
+
+
+type fromStringTest struct {
+       in      string;
+       base    int;
+       out     int64;
+       ok      bool;
+}
+
+
+var fromStringTests = []fromStringTest{
+       fromStringTest{in: "", ok: false},
+       fromStringTest{in: "a", ok: false},
+       fromStringTest{in: "z", ok: false},
+       fromStringTest{in: "+", ok: false},
+       fromStringTest{"0", 0, 0, true},
+       fromStringTest{"0", 10, 0, true},
+       fromStringTest{"0", 16, 0, true},
+       fromStringTest{"10", 0, 10, true},
+       fromStringTest{"10", 10, 10, true},
+       fromStringTest{"10", 16, 16, true},
+       fromStringTest{"-10", 16, -16, true},
+       fromStringTest{in: "0x", ok: false},
+       fromStringTest{"0x10", 0, 16, true},
+       fromStringTest{in: "0x10", base: 16, ok: false},
+       fromStringTest{"-0x10", 0, -16, true},
+}
+
+
+func TestSetString(t *testing.T) {
+       for i, test := range fromStringTests {
+               n, ok := new(Int).SetString(test.in, test.base);
+               if ok != test.ok {
+                       t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok);
+                       continue;
+               }
+               if !ok {
+                       continue;
+               }
+
+               if CmpInt(n, new(Int).New(test.out)) != 0 {
+                       t.Errorf("#%d (input '%s') got: %s want: %d\n", i, test.in, n, test.out);
+               }
+       }
+}
+
+
+type divSignsTest struct {
+       x, y    int64;
+       q, r    int64;
+}
+
+
+// These examples taken from the Go Language Spec, section "Arithmetic operators"
+var divSignsTests = []divSignsTest{
+       divSignsTest{5, 3, 1, 2},
+       divSignsTest{-5, 3, -1, -2},
+       divSignsTest{5, -3, -1, 2},
+       divSignsTest{-5, -3, 1, -2},
+       divSignsTest{1, 2, 0, 1},
+}
+
+
+func TestDivSigns(t *testing.T) {
+       for i, test := range divSignsTests {
+               x := new(Int).New(test.x);
+               y := new(Int).New(test.y);
+               q, r := new(Int).Div(x, y);
+               expectedQ := new(Int).New(test.q);
+               expectedR := new(Int).New(test.r);
+
+               if CmpInt(q, expectedQ) != 0 || CmpInt(r, expectedR) != 0 {
+                       t.Errorf("#%d: got (%s, %s) want (%s, %s)", i, q, r, expectedQ, expectedR);
+               }
+       }
+}
+
+
+func checkSetBytes(b []byte) bool {
+       hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes());
+       hex2 := hex.EncodeToString(b);
+
+       for len(hex1) < len(hex2) {
+               hex1 = "0"+hex1;
+       }
+
+       for len(hex1) > len(hex2) {
+               hex2 = "0"+hex2;
+       }
+
+       return hex1 == hex2;
+}
+
+
+func TestSetBytes(t *testing.T) {
+       err := quick.Check(checkSetBytes, nil);
+       if err != nil {
+               t.Error(err);
+       }
+}
+
+
+func checkBytes(b []byte) bool {
+       b2 := new(Int).SetBytes(b).Bytes();
+       return bytes.Compare(b, b2) == 0;
+}
+
+
+func TestBytes(t *testing.T) {
+       err := quick.Check(checkSetBytes, nil);
+       if err != nil {
+               t.Error(err);
+       }
+}
+
+
+func checkDiv(x, y []byte) bool {
+       u := new(Int).SetBytes(x);
+       v := new(Int).SetBytes(y);
+
+       if len(v.abs) == 0 {
+               return true;
+       }
+
+       q, r := new(Int).Div(u, v);
+
+       if CmpInt(r, v) >= 0 {
+               return false;
+       }
+
+       uprime := new(Int).Set(q);
+       uprime.Mul(uprime, v);
+       uprime.Add(uprime, r);
+
+       return CmpInt(uprime, u) == 0;
+}
+
+
+func TestDiv(t *testing.T) {
+       err := quick.Check(checkDiv, nil);
+       if err != nil {
+               t.Error(err);
+       }
+}
+
+
+func TestDivStepD6(t *testing.T) {
+       // See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises
+       // a code path which only triggers 1 in 10^{-19} cases.
+
+       u := &Int{false, []Word{0, 0, 0x8000000000000001, 0x7fffffffffffffff}};
+       v := &Int{false, []Word{5, 0x8000000000000002, 0x8000000000000000}};
+
+       q, r := new(Int).Div(u, v);
+       const expectedQ = "18446744073709551613";
+       const expectedR = "3138550867693340382088035895064302439801311770021610913807";
+       if q.String() != expectedQ || r.String() != expectedR {
+               t.Errorf("got (%s, %s) want (%s, %s)", q, r, expectedQ, expectedR);
+       }
+}
+
+
+type lenTest struct {
+       in      string;
+       out     int;
+}
+
+
+var lenTests = []lenTest{
+       lenTest{"0", 0},
+       lenTest{"1", 1},
+       lenTest{"2", 2},
+       lenTest{"4", 3},
+       lenTest{"0x8000", 16},
+       lenTest{"0x80000000", 32},
+       lenTest{"0x800000000000", 48},
+       lenTest{"0x8000000000000000", 64},
+       lenTest{"0x80000000000000000000", 80},
+}
+
+
+func TestLen(t *testing.T) {
+       for i, test := range lenTests {
+               n, ok := new(Int).SetString(test.in, 0);
+               if !ok {
+                       t.Errorf("#%d test input invalid: %s", i, test.in);
+                       continue;
+               }
+
+               if n.Len() != test.out {
+                       t.Errorf("#%d got %d want %d\n", i, n.Len(), test.out);
+               }
+       }
+}
+
+
+type expTest struct {
+       x, y, m string;
+       out     string;
+}
+
+
+var expTests = []expTest{
+       /*expTest{"5", "0", "", "1"},
+       expTest{"-5", "0", "", "-1"},
+       expTest{"5", "1", "", "5"},
+       expTest{"-5", "1", "", "-5"},
+       expTest{"5", "2", "", "25"},*/
+       expTest{"1", "65537", "2", "1"},
+       /*expTest{"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"},
+       expTest{"0x8000000000000000", "2", "6719", "4944"},
+       expTest{"0x8000000000000000", "3", "6719", "5447"},
+       expTest{"0x8000000000000000", "1000", "6719", "1603"},
+       expTest{"0x8000000000000000", "1000000", "6719", "3199"},
+       expTest{
+               "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347",
+               "298472983472983471903246121093472394872319615612417471234712061",
+               "29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464",
+               "23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291",
+       },*/
+}
+
+
+func TestExp(t *testing.T) {
+       for i, test := range expTests {
+               x, ok1 := new(Int).SetString(test.x, 0);
+               y, ok2 := new(Int).SetString(test.y, 0);
+               out, ok3 := new(Int).SetString(test.out, 0);
+
+               var ok4 bool;
+               var m *Int;
+
+               if len(test.m) == 0 {
+                       m, ok4 = nil, true;
+               } else {
+                       m, ok4 = new(Int).SetString(test.m, 0);
+               }
+
+               if !ok1 || !ok2 || !ok3 || !ok4 {
+                       t.Errorf("#%d error in input", i);
+                       continue;
+               }
+
+               z := new(Int).Exp(x, y, m);
+               if CmpInt(z, out) != 0 {
+                       t.Errorf("#%d got %s want %s", i, z, out);
+               }
+       }
+}
+
+
+func checkGcd(aBytes, bBytes []byte) bool {
+       a := new(Int).SetBytes(aBytes);
+       b := new(Int).SetBytes(bBytes);
+
+       x := new(Int);
+       y := new(Int);
+       d := new(Int);
+
+       GcdInt(d, x, y, a, b);
+       x.Mul(x, a);
+       y.Mul(y, b);
+       x.Add(x, y);
+
+       return CmpInt(x, d) == 0;
+}
+
+
+type gcdTest struct {
+       a, b    int64;
+       d, x, y int64;
+}
+
+
+var gcdTests = []gcdTest{
+       gcdTest{120, 23, 1, -9, 47},
+}
+
+
+func TestGcd(t *testing.T) {
+       for i, test := range gcdTests {
+               a := new(Int).New(test.a);
+               b := new(Int).New(test.b);
+
+               x := new(Int);
+               y := new(Int);
+               d := new(Int);
+
+               expectedX := new(Int).New(test.x);
+               expectedY := new(Int).New(test.y);
+               expectedD := new(Int).New(test.d);
+
+               GcdInt(d, x, y, a, b);
+
+               if CmpInt(expectedX, x) != 0 ||
+                       CmpInt(expectedY, y) != 0 ||
+                       CmpInt(expectedD, d) != 0 {
+                       t.Errorf("#%d got (%s %s %s) want (%s %s %s)", i, x, y, d, expectedX, expectedY, expectedD);
+               }
+       }
+
+       quick.Check(checkGcd, nil);
+}
index 62e649be9fa5b51c755bd37eede1637ced570643..71f6565a237e037ed268097e62117562e5182351 100644 (file)
@@ -257,6 +257,66 @@ func divNW(z, x []Word, y Word) (q []Word, r Word) {
 }
 
 
+// q = (uIn-r)/v, with 0 <= r < y
+// See Knuth, Volume 2, section 4.3.1, Algorithm D.
+// Preconditions:
+//    len(v) >= 2
+//    len(uIn) >= 1 + len(vIn)
+func divNN(z, z2, uIn, v []Word) (q, r []Word) {
+       n := len(v);
+       m := len(uIn)-len(v);
+
+       u := makeN(z2, len(uIn)+1, false);
+       qhatv := make([]Word, len(v)+1);
+       q = makeN(z, m+1, false);
+
+       // D1.
+       shift := leadingZeroBits(v[n-1]);
+       shiftLeft(v, v, shift);
+       shiftLeft(u, uIn, shift);
+       u[len(uIn)] = uIn[len(uIn)-1]>>(uint(_W)-uint(shift));
+
+       // D2.
+       for j := m; j >= 0; j-- {
+               // D3.
+               qhat, rhat := divWW_g(u[j+n], u[j+n-1], v[n-1]);
+
+               // x1 | x2 = q̂v_{n-2}
+               x1, x2 := mulWW_g(qhat, v[n-2]);
+               // test if q̂v_{n-2} > br̂ + u_{j+n-2}
+               for greaterThan(x1, x2, rhat, u[j+n-2]) {
+                       qhat--;
+                       prevRhat := rhat;
+                       rhat += v[n-1];
+                       // v[n-1] >= 0, so this tests for overflow.
+                       if rhat < prevRhat {
+                               break;
+                       }
+                       x1, x2 = mulWW_g(qhat, v[n-2]);
+               }
+
+               // D4.
+               qhatv[len(v)] = mulAddVWW(&qhatv[0], &v[0], qhat, 0, len(v));
+
+               c := subVV(&u[j], &u[j], &qhatv[0], len(qhatv));
+               if c != 0 {
+                       c := addVV(&u[j], &u[j], &v[0], len(v));
+                       u[j+len(v)] += c;
+                       qhat--;
+               }
+
+               q[j] = qhat;
+       }
+
+       q = normN(q);
+       shiftRight(u, u, shift);
+       shiftRight(v, v, shift);
+       r = normN(u);
+
+       return q, r;
+}
+
+
 // log2 computes the integer binary logarithm of x.
 // The result is the integer n for which 2^n <= x < 2^(n+1).
 // If x == 0, the result is -1.
@@ -314,6 +374,10 @@ func scanN(z []Word, s string, base int) ([]Word, int, int) {
                base = 10;
                if n > 0 && s[0] == '0' {
                        if n > 1 && (s[1] == 'x' || s[1] == 'X') {
+                               if n == 2 {
+                                       // Reject a string which is just '0x' as nonsense.
+                                       return nil, 0, 0;
+                               }
                                base, i = 16, 2;
                        } else {
                                base, i = 8, 1;
@@ -362,9 +426,62 @@ func stringN(x []Word, base int) string {
        for len(q) > 0 {
                i--;
                var r Word;
-               q, r = divNW(q, q, 10);
+               q, r = divNW(q, q, Word(base));
                s[i] = "0123456789abcdef"[r];
        }
 
        return string(s[i:len(s)]);
 }
+
+
+// leadingZeroBits returns the number of leading zero bits in x.
+func leadingZeroBits(x Word) int {
+       c := 0;
+       if x < 1 << (_W/2) {
+               x <<= _W/2;
+               c = int(_W/2);
+       }
+
+       for i := 0; x != 0; i++ {
+               if x&(1<<(_W-1)) != 0 {
+                       return i + c;
+               }
+               x <<= 1;
+       }
+
+       return int(_W);
+}
+
+
+func shiftLeft(dst, src []Word, n int) {
+       if len(src) == 0 {
+               return;
+       }
+
+       ñ := uint(_W) - uint(n);
+       for i := len(src)-1; i >= 1; i-- {
+               dst[i] = src[i]<<uint(n);
+               dst[i] |= src[i-1]>>ñ;
+       }
+       dst[0] = src[0]<<uint(n);
+}
+
+
+func shiftRight(dst, src []Word, n int) {
+       if len(src) == 0 {
+               return;
+       }
+
+       ñ := uint(_W) - uint(n);
+       for i := 0; i < len(src)-1; i++ {
+               dst[i] = src[i]>>uint(n);
+               dst[i] |= src[i+1]<<ñ;
+       }
+       dst[len(src)-1] = src[len(src)-1]>>uint(n);
+}
+
+
+// greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
+func greaterThan(x1, x2, y1, y2 Word) bool {
+       return x1 > y1 || x1 == y1 && x2 > y2;
+}
index ba3a27cf6aa25f2becfbdf0776925f64e320f394..098575a17224ab2c99f9c6a7634d061850341345 100644 (file)
@@ -7,7 +7,7 @@ package big
 import "testing"
 
 func TestCmpNN(t *testing.T) {
-// TODO(gri) write this test - all other tests depends on it
+       // TODO(gri) write this test - all other tests depends on it
 }
 
 
@@ -16,6 +16,7 @@ type argNN struct {
        z, x, y []Word;
 }
 
+
 var sumNN = []argNN{
        argNN{},
        argNN{[]Word{1}, nil, []Word{1}},
@@ -25,6 +26,7 @@ var sumNN = []argNN{
        argNN{[]Word{0, 0, 0, 1}, []Word{0, 0, _M}, []Word{0, 0, 1}},
 }
 
+
 var prodNN = []argNN{
        argNN{},
        argNN{nil, nil, nil},
@@ -86,6 +88,7 @@ type strN struct {
        s       string;
 }
 
+
 var tabN = []strN{
        strN{nil, 10, "0"},
        strN{[]Word{1}, 10, "1"},
@@ -93,6 +96,7 @@ var tabN = []strN{
        strN{[]Word{1234567890}, 10, "1234567890"},
 }
 
+
 func TestStringN(t *testing.T) {
        for _, a := range tabN {
                s := stringN(a.x, a.b);
@@ -112,3 +116,72 @@ func TestStringN(t *testing.T) {
                }
        }
 }
+
+
+func TestLeadingZeroBits(t *testing.T) {
+       var x Word = 1<<(_W-1);
+       for i := 0; i <= int(_W); i++ {
+               if leadingZeroBits(x) != i {
+                       t.Errorf("failed at %x: got %d want %d", x, leadingZeroBits(x), i);
+               }
+               x >>= 1;
+       }
+}
+
+
+type shiftTest struct {
+       in      []Word;
+       shift   int;
+       out     []Word;
+}
+
+
+var leftShiftTests = []shiftTest{
+       shiftTest{nil, 0, nil},
+       shiftTest{nil, 1, nil},
+       shiftTest{[]Word{0}, 0, []Word{0}},
+       shiftTest{[]Word{1}, 0, []Word{1}},
+       shiftTest{[]Word{1}, 1, []Word{2}},
+       shiftTest{[]Word{1<<(_W-1)}, 1, []Word{0}},
+       shiftTest{[]Word{1<<(_W-1), 0}, 1, []Word{0, 1}},
+}
+
+
+func TestShiftLeft(t *testing.T) {
+       for i, test := range leftShiftTests {
+               dst := make([]Word, len(test.out));
+               shiftLeft(dst, test.in, test.shift);
+               for j, v := range dst {
+                       if test.out[j] != v {
+                               t.Errorf("#%d: got: %v want: %v", i, dst, test.out);
+                               break;
+                       }
+               }
+       }
+}
+
+
+var rightShiftTests = []shiftTest{
+       shiftTest{nil, 0, nil},
+       shiftTest{nil, 1, nil},
+       shiftTest{[]Word{0}, 0, []Word{0}},
+       shiftTest{[]Word{1}, 0, []Word{1}},
+       shiftTest{[]Word{1}, 1, []Word{0}},
+       shiftTest{[]Word{2}, 1, []Word{1}},
+       shiftTest{[]Word{0, 1}, 1, []Word{1<<(_W-1), 0}},
+       shiftTest{[]Word{2, 1, 1}, 1, []Word{1<<(_W-1) + 1, 1<<(_W-1), 0}},
+}
+
+
+func TestShiftRight(t *testing.T) {
+       for i, test := range rightShiftTests {
+               dst := make([]Word, len(test.out));
+               shiftRight(dst, test.in, test.shift);
+               for j, v := range dst {
+                       if test.out[j] != v {
+                               t.Errorf("#%d: got: %v want: %v", i, dst, test.out);
+                               break;
+                       }
+               }
+       }
+}