From 65063bc61d66a092f499e3bef1d5a681b7cde814 Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Thu, 5 Nov 2009 15:55:41 -0800 Subject: [PATCH] big: add Div, Mod, Exp, GcdExt and several other fixes. R=gri, rsc CC=go-dev http://go/go-review/1017036 --- src/pkg/big/arith.go | 58 +++++-- src/pkg/big/arith_test.go | 46 ++++++ src/pkg/big/int.go | 296 +++++++++++++++++++++++++++++++++- src/pkg/big/int_test.go | 331 ++++++++++++++++++++++++++++++++++++-- src/pkg/big/nat.go | 119 +++++++++++++- src/pkg/big/nat_test.go | 75 ++++++++- 6 files changed, 896 insertions(+), 29 deletions(-) diff --git a/src/pkg/big/arith.go b/src/pkg/big/arith.go index 26acb0e6b9..2b89638a5d 100644 --- a/src/pkg/big/arith.go +++ b/src/pkg/big/arith.go @@ -50,10 +50,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) { } -// TODO(gri) mulWW_g is not needed anymore. Keep around for -// now since mulAddWWW_g should use some of the -// optimizations from mulWW_g eventually. - // z1<<_W + z0 = x*y func mulWW_g(x, y Word) (z1, z0 Word) { // Split x and y into 2 halfWords each, multiply @@ -86,6 +82,9 @@ func mulWW_g(x, y Word) (z1, z0 Word) { // z = z[1]*_B + z[0] = x*y z0 = t1<<_W2 + t0; z1 = (t1 + t0>>_W2)>>_W2; + if z0 < t0 { + z1++; + } return; } @@ -98,13 +97,31 @@ func mulWW_g(x, y Word) (z1, z0 Word) { // x*y = t2*_B2*_B2 + t1*_B2 + t0 t0 := x0*y0; - t1 := x1*y0 + x0*y1; - t2 := x1*y1; + // t1 := x1*y0 + x0*y1; + var c Word; + t1 := x1*y0; + t1a := t1; + t1 += x0*y1; + if t1 < t1a { + c++; + } + t2 := x1*y1 + c*_B2; // compute result digits but avoid overflow // z = z[1]*_B + z[0] = x*y + // This may overflow, but that's ok because we also sum t1 and t0 above + // and we take care of the overflow there. z0 = t1<<_W2 + t0; - z1 = t2 + (t1 + t0>>_W2)>>_W2; + + // z1 = t2 + (t1 + t0>>_W2)>>_W2; + var c3 Word; + z1 = t1 + t0>>_W2; + if z1 < t1 { + c3++; + } + z1 >>= _W2; + z1 += c3*_B2; + z1 += t2; return; } @@ -126,14 +143,35 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) { c1, c0 := c>>_W2, c&_M2; // x*y + c = t2*_B2*_B2 + t1*_B2 + t0 + // (1<<32-1)^2 == 1<<64 - 1<<33 + 1, so there's space to add c0 in here. t0 := x0*y0 + c0; - t1 := x1*y0 + x0*y1 + c1; - t2 := x1*y1; + + // t1 := x1*y0 + x0*y1 + c1; + var c2 Word; // extra carry + t1 := x1*y0 + c1; + t1a := t1; + t1 += x0*y1; + if t1 < t1a { // If the number got smaller then we overflowed. + c2++; + } + + t2 := x1*y1 + c2*_B2; // compute result digits but avoid overflow // z = z[1]*_B + z[0] = x*y + // z0 = t1<<_W2 + t0; + // This may overflow, but that's ok because we also sum t1 and t0 below + // and we take care of the overflow there. z0 = t1<<_W2 + t0; - z1 = t2 + (t1 + t0>>_W2)>>_W2; + + var c3 Word; + z1 = t1 + t0>>_W2; + if z1 < t1 { + c3++; + } + z1 >>= _W2; + z1 += t2 + c3*_B2; + return; } diff --git a/src/pkg/big/arith_test.go b/src/pkg/big/arith_test.go index 27f640304b..af28884e0d 100644 --- a/src/pkg/big/arith_test.go +++ b/src/pkg/big/arith_test.go @@ -268,3 +268,49 @@ func TestFunVWW(t *testing.T) { } } } + + +type mulWWTest struct { + x, y Word; + q, r Word; +} + + +var mulWWTests = []mulWWTest{ + mulWWTest{_M, _M, _M - 1, 1}, +} + + +func TestMulWWW(t *testing.T) { + for i, test := range mulWWTests { + q, r := mulWW_g(test.x, test.y); + if q != test.q || r != test.r { + t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r); + } + } +} + + +type mulAddWWWTest struct { + x, y, c Word; + q, r Word; +} + + +var mulAddWWWTests = []mulAddWWWTest{ + // TODO(agl): These will only work on 64-bit platforms. + // mulAddWWWTest{15064310297182388543, 0xe7df04d2d35d5d80, 13537600649892366549, 13644450054494335067, 10832252001440893781}, + // mulAddWWWTest{15064310297182388543, 0xdab2f18048baa68d, 13644450054494335067, 12869334219691522700, 14233854684711418382}, + mulAddWWWTest{_M, _M, 0, _M - 1, 1}, + mulAddWWWTest{_M, _M, _M, _M, 0}, +} + + +func TestMulAddWWW(t *testing.T) { + for i, test := range mulAddWWWTests { + q, r := mulAddWWW_g(test.x, test.y, test.c); + if q != test.q || r != test.r { + t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r); + } + } +} diff --git a/src/pkg/big/int.go b/src/pkg/big/int.go index efa678c05d..5db4f0901a 100644 --- a/src/pkg/big/int.go +++ b/src/pkg/big/int.go @@ -26,6 +26,12 @@ func (z *Int) New(x int64) *Int { } +// NewInt allocates and returns a new Int set to x. +func NewInt(x int64) *Int { + return new(Int).New(x); +} + + // Set sets z to x. func (z *Int) Set(x *Int) *Int { z.neg = x.neg; @@ -96,6 +102,66 @@ func (z *Int) Mul(x, y *Int) *Int { } +// Div calculates q = (x-r)/y where 0 <= r < y. The receiver is set to q. +func (z *Int) Div(x, y *Int) (q, r *Int) { + q = z; + r = new(Int); + div(q, r, x, y); + return; +} + + +// Mod calculates q = (x-r)/y and returns r. +func (z *Int) Mod(x, y *Int) (r *Int) { + q := new(Int); + r = z; + div(q, r, x, y); + return; +} + + +func div(q, r, x, y *Int) { + if len(y.abs) == 0 { + panic("Divide by zero undefined"); + } + + if cmpNN(x.abs, y.abs) < 0 { + q.neg = false; + q.abs = nil; + r.neg = y.neg; + + src := x.abs; + dst := x.abs; + if r == x { + dst = nil; + } + + r.abs = makeN(dst, len(src), false); + for i, v := range src { + r.abs[i] = v; + } + return; + } + + if len(y.abs) == 1 { + var rprime Word; + q.abs, rprime = divNW(q.abs, x.abs, y.abs[0]); + if rprime > 0 { + r.abs = makeN(r.abs, 1, false); + r.abs[0] = rprime; + r.neg = x.neg; + } + q.neg = len(q.abs) > 0 && x.neg != y.neg; + return; + } + + q.neg = x.neg != y.neg; + r.neg = x.neg; + q.abs, r.abs = divNN(q.abs, r.abs, x.abs, y.abs); + return; +} + + // Neg computes z = -x. func (z *Int) Neg(x *Int) *Int { z.abs = setN(z.abs, x.abs); @@ -132,10 +198,234 @@ func CmpInt(x, y *Int) (r int) { } -func (x *Int) String() string { +func (z *Int) String() string { s := ""; - if x.neg { + if z.neg { s = "-"; } - return s + stringN(x.abs, 10); + return s + stringN(z.abs, 10); +} + + +// SetString sets z to the value of s, interpreted in the given base. +// If base is 0 then SetString attempts to detect the base by at the prefix of +// s. '0x' implies base 16, '0' implies base 8. Otherwise base 10 is assumed. +func (z *Int) SetString(s string, base int) (*Int, bool) { + var scanned int; + + if base == 1 || base > 16 { + goto Error; + } + + if len(s) == 0 { + goto Error; + } + + if s[0] == '-' { + z.neg = true; + s = s[1:len(s)]; + } else { + z.neg = false; + } + + z.abs, _, scanned = scanN(z.abs, s, base); + if scanned != len(s) { + goto Error; + } + + return z, true; + +Error: + z.neg = false; + z.abs = nil; + return nil, false; +} + + +// SetBytes interprets b as the bytes of a big-endian, unsigned integer and +// sets x to that value. +func (z *Int) SetBytes(b []byte) *Int { + s := int(_S); + z.abs = makeN(z.abs, (len(b)+s-1)/s, false); + z.neg = false; + + j := 0; + for len(b) >= s { + var w Word; + + for i := s; i > 0; i-- { + w <<= 8; + w |= Word(b[len(b)-i]); + } + + z.abs[j] = w; + j++; + b = b[0 : len(b)-s]; + } + + if len(b) > 0 { + var w Word; + + for i := len(b); i > 0; i-- { + w <<= 8; + w |= Word(b[len(b)-i]); + } + + z.abs[j] = w; + } + + z.abs = normN(z.abs); + + return z; +} + + +// Bytes returns the absolute value of x as a big-endian byte array. +func (z *Int) Bytes() []byte { + s := int(_S); + b := make([]byte, len(z.abs)*s); + + for i, w := range z.abs { + wordBytes := b[(len(z.abs)-i-1)*s : (len(z.abs)-i)*s]; + for j := s-1; j >= 0; j-- { + wordBytes[j] = byte(w); + w >>= 8; + } + } + + i := 0; + for i < len(b) && b[i] == 0 { + i++; + } + + return b[i:len(b)]; +} + + +// Len returns the length of the absolute value of x in bits. Zero is +// considered to have a length of one. +func (z *Int) Len() int { + if len(z.abs) == 0 { + return 0; + } + + return len(z.abs)*int(_W) - int(leadingZeros(z.abs[len(z.abs)-1])); +} + + +// Exp sets z = x**y mod m. If m is nil, z = x**y. +// See Knuth, volume 2, section 4.6.3. +func (z *Int) Exp(x, y, m *Int) *Int { + if y.neg || len(y.abs) == 0 { + z.New(1); + z.neg = x.neg; + return z; + } + + z.Set(x); + v := y.abs[len(y.abs)-1]; + // It's invalid for the most significant word to be zero, therefore we + // will find a one bit. + shift := leadingZeros(v) + 1; + v <<= shift; + + const mask = 1<<(_W-1); + + // We walk through the bits of the exponent one by one. Each time we see + // a bit, we square, thus doubling the power. If the bit is a one, we + // also multiply by x, thus adding one to the power. + + w := int(_W)-int(shift); + for j := 0; j < w; j++ { + z.Mul(z, z); + + if v&mask != 0 { + z.Mul(z, x); + } + + if m != nil { + z.Mod(z, m); + } + + v <<= 1; + } + + for i := len(y.abs)-2; i >= 0; i-- { + v = y.abs[i]; + + for j := 0; j < int(_W); j++ { + z.Mul(z, z); + + if v&mask != 0 { + z.Mul(z, x); + } + + if m != nil { + z.Mod(z, m); + } + + v <<= 1; + } + } + + z.neg = x.neg && y.abs[0] & 1 == 1; + return z; +} + + +// GcdInt sets d to the greatest common divisor of a and b, which must be +// positive numbers. +// If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y. +// If either a or b is not positive, GcdInt sets d = x = y = 0. +func GcdInt(d, x, y, a, b *Int) { + if a.neg || b.neg { + d.New(0); + if x != nil { + x.New(0); + } + if y != nil { + y.New(0); + } + return; + } + + A := new(Int).Set(a); + B := new(Int).Set(b); + + X := new(Int); + Y := new(Int).New(1); + + lastX := new(Int).New(1); + lastY := new(Int); + + q := new(Int); + temp := new(Int); + + for len(B.abs) > 0 { + q, r := q.Div(A, B); + + A, B = B, r; + + temp.Set(X); + X.Mul(X, q); + X.neg = !X.neg; + X.Add(X, lastX); + lastX.Set(temp); + + temp.Set(Y); + Y.Mul(Y, q); + Y.neg = !Y.neg; + Y.Add(Y, lastY); + lastY.Set(temp); + } + + if x != nil { + *x = *lastX; + } + + if y != nil { + *y = *lastY; + } + + *d = *A; } diff --git a/src/pkg/big/int_test.go b/src/pkg/big/int_test.go index 30d3cd6813..43f7eedf26 100644 --- a/src/pkg/big/int_test.go +++ b/src/pkg/big/int_test.go @@ -4,8 +4,12 @@ package big -import "testing" - +import ( + "bytes"; + "encoding/hex"; + "testing"; + "testing/quick"; +) func newZ(x int64) *Int { var z Int; @@ -18,6 +22,7 @@ type argZZ struct { z, x, y *Int; } + var sumZZ = []argZZ{ argZZ{newZ(0), newZ(0), newZ(0)}, argZZ{newZ(1), newZ(1), newZ(0)}, @@ -27,12 +32,13 @@ var sumZZ = []argZZ{ argZZ{newZ(-1111111110), newZ(-123456789), newZ(-987654321)}, } + var prodZZ = []argZZ{ argZZ{newZ(0), newZ(0), newZ(0)}, argZZ{newZ(0), newZ(1), newZ(0)}, argZZ{newZ(1), newZ(1), newZ(1)}, argZZ{newZ(-991 * 991), newZ(991), newZ(-991)}, -// TODO(gri) add larger products + // TODO(gri) add larger products } @@ -57,12 +63,8 @@ func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) { func TestSumZZ(t *testing.T) { - AddZZ := func(z, x, y *Int) *Int { - return z.Add(x, y); - }; - SubZZ := func(z, x, y *Int) *Int { - return z.Sub(x, y); - }; + AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) }; + SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) }; for _, a := range sumZZ { arg := a; testFunZZ(t, "AddZZ", AddZZ, arg); @@ -80,9 +82,7 @@ func TestSumZZ(t *testing.T) { func TestProdZZ(t *testing.T) { - MulZZ := func(z, x, y *Int) *Int { - return z.Mul(x, y); - }; + MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) }; for _, a := range prodZZ { arg := a; testFunZZ(t, "MulZZ", MulZZ, arg); @@ -100,8 +100,8 @@ var facts = map[int]string{ 10: "3628800", 20: "2432902008176640000", 100: "933262154439441526816992388562667004907159682643816214685929" - "638952175999932299156089414639761565182862536979208272237582" - "51185210916864000000000000000000000000", + "638952175999932299156089414639761565182862536979208272237582" + "51185210916864000000000000000000000000", } @@ -125,3 +125,306 @@ func TestFact(t *testing.T) { } } } + + +type fromStringTest struct { + in string; + base int; + out int64; + ok bool; +} + + +var fromStringTests = []fromStringTest{ + fromStringTest{in: "", ok: false}, + fromStringTest{in: "a", ok: false}, + fromStringTest{in: "z", ok: false}, + fromStringTest{in: "+", ok: false}, + fromStringTest{"0", 0, 0, true}, + fromStringTest{"0", 10, 0, true}, + fromStringTest{"0", 16, 0, true}, + fromStringTest{"10", 0, 10, true}, + fromStringTest{"10", 10, 10, true}, + fromStringTest{"10", 16, 16, true}, + fromStringTest{"-10", 16, -16, true}, + fromStringTest{in: "0x", ok: false}, + fromStringTest{"0x10", 0, 16, true}, + fromStringTest{in: "0x10", base: 16, ok: false}, + fromStringTest{"-0x10", 0, -16, true}, +} + + +func TestSetString(t *testing.T) { + for i, test := range fromStringTests { + n, ok := new(Int).SetString(test.in, test.base); + if ok != test.ok { + t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok); + continue; + } + if !ok { + continue; + } + + if CmpInt(n, new(Int).New(test.out)) != 0 { + t.Errorf("#%d (input '%s') got: %s want: %d\n", i, test.in, n, test.out); + } + } +} + + +type divSignsTest struct { + x, y int64; + q, r int64; +} + + +// These examples taken from the Go Language Spec, section "Arithmetic operators" +var divSignsTests = []divSignsTest{ + divSignsTest{5, 3, 1, 2}, + divSignsTest{-5, 3, -1, -2}, + divSignsTest{5, -3, -1, 2}, + divSignsTest{-5, -3, 1, -2}, + divSignsTest{1, 2, 0, 1}, +} + + +func TestDivSigns(t *testing.T) { + for i, test := range divSignsTests { + x := new(Int).New(test.x); + y := new(Int).New(test.y); + q, r := new(Int).Div(x, y); + expectedQ := new(Int).New(test.q); + expectedR := new(Int).New(test.r); + + if CmpInt(q, expectedQ) != 0 || CmpInt(r, expectedR) != 0 { + t.Errorf("#%d: got (%s, %s) want (%s, %s)", i, q, r, expectedQ, expectedR); + } + } +} + + +func checkSetBytes(b []byte) bool { + hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes()); + hex2 := hex.EncodeToString(b); + + for len(hex1) < len(hex2) { + hex1 = "0"+hex1; + } + + for len(hex1) > len(hex2) { + hex2 = "0"+hex2; + } + + return hex1 == hex2; +} + + +func TestSetBytes(t *testing.T) { + err := quick.Check(checkSetBytes, nil); + if err != nil { + t.Error(err); + } +} + + +func checkBytes(b []byte) bool { + b2 := new(Int).SetBytes(b).Bytes(); + return bytes.Compare(b, b2) == 0; +} + + +func TestBytes(t *testing.T) { + err := quick.Check(checkSetBytes, nil); + if err != nil { + t.Error(err); + } +} + + +func checkDiv(x, y []byte) bool { + u := new(Int).SetBytes(x); + v := new(Int).SetBytes(y); + + if len(v.abs) == 0 { + return true; + } + + q, r := new(Int).Div(u, v); + + if CmpInt(r, v) >= 0 { + return false; + } + + uprime := new(Int).Set(q); + uprime.Mul(uprime, v); + uprime.Add(uprime, r); + + return CmpInt(uprime, u) == 0; +} + + +func TestDiv(t *testing.T) { + err := quick.Check(checkDiv, nil); + if err != nil { + t.Error(err); + } +} + + +func TestDivStepD6(t *testing.T) { + // See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises + // a code path which only triggers 1 in 10^{-19} cases. + + u := &Int{false, []Word{0, 0, 0x8000000000000001, 0x7fffffffffffffff}}; + v := &Int{false, []Word{5, 0x8000000000000002, 0x8000000000000000}}; + + q, r := new(Int).Div(u, v); + const expectedQ = "18446744073709551613"; + const expectedR = "3138550867693340382088035895064302439801311770021610913807"; + if q.String() != expectedQ || r.String() != expectedR { + t.Errorf("got (%s, %s) want (%s, %s)", q, r, expectedQ, expectedR); + } +} + + +type lenTest struct { + in string; + out int; +} + + +var lenTests = []lenTest{ + lenTest{"0", 0}, + lenTest{"1", 1}, + lenTest{"2", 2}, + lenTest{"4", 3}, + lenTest{"0x8000", 16}, + lenTest{"0x80000000", 32}, + lenTest{"0x800000000000", 48}, + lenTest{"0x8000000000000000", 64}, + lenTest{"0x80000000000000000000", 80}, +} + + +func TestLen(t *testing.T) { + for i, test := range lenTests { + n, ok := new(Int).SetString(test.in, 0); + if !ok { + t.Errorf("#%d test input invalid: %s", i, test.in); + continue; + } + + if n.Len() != test.out { + t.Errorf("#%d got %d want %d\n", i, n.Len(), test.out); + } + } +} + + +type expTest struct { + x, y, m string; + out string; +} + + +var expTests = []expTest{ + /*expTest{"5", "0", "", "1"}, + expTest{"-5", "0", "", "-1"}, + expTest{"5", "1", "", "5"}, + expTest{"-5", "1", "", "-5"}, + expTest{"5", "2", "", "25"},*/ + expTest{"1", "65537", "2", "1"}, + /*expTest{"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"}, + expTest{"0x8000000000000000", "2", "6719", "4944"}, + expTest{"0x8000000000000000", "3", "6719", "5447"}, + expTest{"0x8000000000000000", "1000", "6719", "1603"}, + expTest{"0x8000000000000000", "1000000", "6719", "3199"}, + expTest{ + "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", + "298472983472983471903246121093472394872319615612417471234712061", + "29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464", + "23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291", + },*/ +} + + +func TestExp(t *testing.T) { + for i, test := range expTests { + x, ok1 := new(Int).SetString(test.x, 0); + y, ok2 := new(Int).SetString(test.y, 0); + out, ok3 := new(Int).SetString(test.out, 0); + + var ok4 bool; + var m *Int; + + if len(test.m) == 0 { + m, ok4 = nil, true; + } else { + m, ok4 = new(Int).SetString(test.m, 0); + } + + if !ok1 || !ok2 || !ok3 || !ok4 { + t.Errorf("#%d error in input", i); + continue; + } + + z := new(Int).Exp(x, y, m); + if CmpInt(z, out) != 0 { + t.Errorf("#%d got %s want %s", i, z, out); + } + } +} + + +func checkGcd(aBytes, bBytes []byte) bool { + a := new(Int).SetBytes(aBytes); + b := new(Int).SetBytes(bBytes); + + x := new(Int); + y := new(Int); + d := new(Int); + + GcdInt(d, x, y, a, b); + x.Mul(x, a); + y.Mul(y, b); + x.Add(x, y); + + return CmpInt(x, d) == 0; +} + + +type gcdTest struct { + a, b int64; + d, x, y int64; +} + + +var gcdTests = []gcdTest{ + gcdTest{120, 23, 1, -9, 47}, +} + + +func TestGcd(t *testing.T) { + for i, test := range gcdTests { + a := new(Int).New(test.a); + b := new(Int).New(test.b); + + x := new(Int); + y := new(Int); + d := new(Int); + + expectedX := new(Int).New(test.x); + expectedY := new(Int).New(test.y); + expectedD := new(Int).New(test.d); + + GcdInt(d, x, y, a, b); + + if CmpInt(expectedX, x) != 0 || + CmpInt(expectedY, y) != 0 || + CmpInt(expectedD, d) != 0 { + t.Errorf("#%d got (%s %s %s) want (%s %s %s)", i, x, y, d, expectedX, expectedY, expectedD); + } + } + + quick.Check(checkGcd, nil); +} diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go index 62e649be9f..71f6565a23 100644 --- a/src/pkg/big/nat.go +++ b/src/pkg/big/nat.go @@ -257,6 +257,66 @@ func divNW(z, x []Word, y Word) (q []Word, r Word) { } +// q = (uIn-r)/v, with 0 <= r < y +// See Knuth, Volume 2, section 4.3.1, Algorithm D. +// Preconditions: +// len(v) >= 2 +// len(uIn) >= 1 + len(vIn) +func divNN(z, z2, uIn, v []Word) (q, r []Word) { + n := len(v); + m := len(uIn)-len(v); + + u := makeN(z2, len(uIn)+1, false); + qhatv := make([]Word, len(v)+1); + q = makeN(z, m+1, false); + + // D1. + shift := leadingZeroBits(v[n-1]); + shiftLeft(v, v, shift); + shiftLeft(u, uIn, shift); + u[len(uIn)] = uIn[len(uIn)-1]>>(uint(_W)-uint(shift)); + + // D2. + for j := m; j >= 0; j-- { + // D3. + qhat, rhat := divWW_g(u[j+n], u[j+n-1], v[n-1]); + + // x1 | x2 = q̂v_{n-2} + x1, x2 := mulWW_g(qhat, v[n-2]); + // test if q̂v_{n-2} > br̂ + u_{j+n-2} + for greaterThan(x1, x2, rhat, u[j+n-2]) { + qhat--; + prevRhat := rhat; + rhat += v[n-1]; + // v[n-1] >= 0, so this tests for overflow. + if rhat < prevRhat { + break; + } + x1, x2 = mulWW_g(qhat, v[n-2]); + } + + // D4. + qhatv[len(v)] = mulAddVWW(&qhatv[0], &v[0], qhat, 0, len(v)); + + c := subVV(&u[j], &u[j], &qhatv[0], len(qhatv)); + if c != 0 { + c := addVV(&u[j], &u[j], &v[0], len(v)); + u[j+len(v)] += c; + qhat--; + } + + q[j] = qhat; + } + + q = normN(q); + shiftRight(u, u, shift); + shiftRight(v, v, shift); + r = normN(u); + + return q, r; +} + + // log2 computes the integer binary logarithm of x. // The result is the integer n for which 2^n <= x < 2^(n+1). // If x == 0, the result is -1. @@ -314,6 +374,10 @@ func scanN(z []Word, s string, base int) ([]Word, int, int) { base = 10; if n > 0 && s[0] == '0' { if n > 1 && (s[1] == 'x' || s[1] == 'X') { + if n == 2 { + // Reject a string which is just '0x' as nonsense. + return nil, 0, 0; + } base, i = 16, 2; } else { base, i = 8, 1; @@ -362,9 +426,62 @@ func stringN(x []Word, base int) string { for len(q) > 0 { i--; var r Word; - q, r = divNW(q, q, 10); + q, r = divNW(q, q, Word(base)); s[i] = "0123456789abcdef"[r]; } return string(s[i:len(s)]); } + + +// leadingZeroBits returns the number of leading zero bits in x. +func leadingZeroBits(x Word) int { + c := 0; + if x < 1 << (_W/2) { + x <<= _W/2; + c = int(_W/2); + } + + for i := 0; x != 0; i++ { + if x&(1<<(_W-1)) != 0 { + return i + c; + } + x <<= 1; + } + + return int(_W); +} + + +func shiftLeft(dst, src []Word, n int) { + if len(src) == 0 { + return; + } + + ñ := uint(_W) - uint(n); + for i := len(src)-1; i >= 1; i-- { + dst[i] = src[i]<>ñ; + } + dst[0] = src[0]<>uint(n); + dst[i] |= src[i+1]<<ñ; + } + dst[len(src)-1] = src[len(src)-1]>>uint(n); +} + + +// greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) +func greaterThan(x1, x2, y1, y2 Word) bool { + return x1 > y1 || x1 == y1 && x2 > y2; +} diff --git a/src/pkg/big/nat_test.go b/src/pkg/big/nat_test.go index ba3a27cf6a..098575a172 100644 --- a/src/pkg/big/nat_test.go +++ b/src/pkg/big/nat_test.go @@ -7,7 +7,7 @@ package big import "testing" func TestCmpNN(t *testing.T) { -// TODO(gri) write this test - all other tests depends on it + // TODO(gri) write this test - all other tests depends on it } @@ -16,6 +16,7 @@ type argNN struct { z, x, y []Word; } + var sumNN = []argNN{ argNN{}, argNN{[]Word{1}, nil, []Word{1}}, @@ -25,6 +26,7 @@ var sumNN = []argNN{ argNN{[]Word{0, 0, 0, 1}, []Word{0, 0, _M}, []Word{0, 0, 1}}, } + var prodNN = []argNN{ argNN{}, argNN{nil, nil, nil}, @@ -86,6 +88,7 @@ type strN struct { s string; } + var tabN = []strN{ strN{nil, 10, "0"}, strN{[]Word{1}, 10, "1"}, @@ -93,6 +96,7 @@ var tabN = []strN{ strN{[]Word{1234567890}, 10, "1234567890"}, } + func TestStringN(t *testing.T) { for _, a := range tabN { s := stringN(a.x, a.b); @@ -112,3 +116,72 @@ func TestStringN(t *testing.T) { } } } + + +func TestLeadingZeroBits(t *testing.T) { + var x Word = 1<<(_W-1); + for i := 0; i <= int(_W); i++ { + if leadingZeroBits(x) != i { + t.Errorf("failed at %x: got %d want %d", x, leadingZeroBits(x), i); + } + x >>= 1; + } +} + + +type shiftTest struct { + in []Word; + shift int; + out []Word; +} + + +var leftShiftTests = []shiftTest{ + shiftTest{nil, 0, nil}, + shiftTest{nil, 1, nil}, + shiftTest{[]Word{0}, 0, []Word{0}}, + shiftTest{[]Word{1}, 0, []Word{1}}, + shiftTest{[]Word{1}, 1, []Word{2}}, + shiftTest{[]Word{1<<(_W-1)}, 1, []Word{0}}, + shiftTest{[]Word{1<<(_W-1), 0}, 1, []Word{0, 1}}, +} + + +func TestShiftLeft(t *testing.T) { + for i, test := range leftShiftTests { + dst := make([]Word, len(test.out)); + shiftLeft(dst, test.in, test.shift); + for j, v := range dst { + if test.out[j] != v { + t.Errorf("#%d: got: %v want: %v", i, dst, test.out); + break; + } + } + } +} + + +var rightShiftTests = []shiftTest{ + shiftTest{nil, 0, nil}, + shiftTest{nil, 1, nil}, + shiftTest{[]Word{0}, 0, []Word{0}}, + shiftTest{[]Word{1}, 0, []Word{1}}, + shiftTest{[]Word{1}, 1, []Word{0}}, + shiftTest{[]Word{2}, 1, []Word{1}}, + shiftTest{[]Word{0, 1}, 1, []Word{1<<(_W-1), 0}}, + shiftTest{[]Word{2, 1, 1}, 1, []Word{1<<(_W-1) + 1, 1<<(_W-1), 0}}, +} + + +func TestShiftRight(t *testing.T) { + for i, test := range rightShiftTests { + dst := make([]Word, len(test.out)); + shiftRight(dst, test.in, test.shift); + for j, v := range dst { + if test.out[j] != v { + t.Errorf("#%d: got: %v want: %v", i, dst, test.out); + break; + } + } + } +} -- 2.48.1