]> Cypherpunks repositories - gostls13.git/commitdiff
- div and mod (arbitrary precision)
authorRobert Griesemer <gri@golang.org>
Fri, 31 Oct 2008 06:37:34 +0000 (23:37 -0700)
committerRobert Griesemer <gri@golang.org>
Fri, 31 Oct 2008 06:37:34 +0000 (23:37 -0700)
- more tests
- some global renames

R=r
OCL=18219
CL=18219

usr/gri/bignum/bignum.go
usr/gri/bignum/bignum_test.go

index 8ca1a0d75e38105f19fb29c3ded37992c9d9c8a4..7b15bf47636cbf887d128eeccdd0673aba7a7657 100755 (executable)
@@ -20,41 +20,41 @@ package Bignum
 //   x = x[n-1]*B^(n-1) + x[n-2]*B^(n-2) + ... + x[1]*B + x[0]
 //
 // with 0 <= x[i] < B and 0 <= i < n is stored in an array of length n,
-// with the digits x[i] as the array elements. 0 is represented as an
-// empty array (length == 0).
+// with the digits x[i] as the array elements.
 //
 // A natural number is normalized if the array contains no leading 0 digits.
 // During arithmetic operations, denormalized values may occur which are
-// always normalized before returning the final result.
+// always normalized before returning the final result. The normalized
+// representation of 0 is the empty array (length = 0).
 //
 // The base B is chosen as large as possible on a given platform but there
-// are a few constraints besides the largest unsigned integer type available.
+// are a few constraints besides the size of the largest unsigned integer
+// type available.
 // TODO describe the constraints.
 
-type Word uint64;
 const LogW = 64;
-
 const LogH = 4;  // bits for a hex digit (= "small" number)
-const H = 1 << LogH;
-
 const LogB = LogW - LogH;
-const L = LogB;
-const B = 1 << LogB;
-const M = B - 1;
 
 
-// For division
-
 const (
-       L3 = L / 3;
+       L3 = LogB / 3;
        B3 = 1 << L3;
        M3 = B3 - 1;
+
+       L2 = L3 * 2;
+       B2 = 1 << L2;
+       M2 = B2 - 1;
+
+       L = L3 * 3;
+       B = 1 << L;
+       M = B - 1;
 )
 
 
 type (
-       Word3 uint32;
-       Natural3 [] Word3;
+       Digit3 uint32;
+       Digit  uint64;
 )
 
 
@@ -69,17 +69,26 @@ func assert(p bool) {
 }
 
 
-func IsSmall(x Word) bool {
-       return x < H;
+func IsSmall(x Digit) bool {
+       return x < 1<<LogH;
 }
 
 
-func Split(x Word) (Word, Word) {
+func Split(x Digit) (Digit, Digit) {
        return x>>L, x&M;
 }
 
 
-export func Dump(x *[]Word) {
+export func Dump(x *[]Digit) {
+       print("[", len(x), "]");
+       for i := len(x) - 1; i >= 0; i-- {
+               print(" ", x[i]);
+       }
+       println();
+}
+
+
+export func Dump3(x *[]Digit3) {
        print("[", len(x), "]");
        for i := len(x) - 1; i >= 0; i-- {
                print(" ", x[i]);
@@ -91,11 +100,11 @@ export func Dump(x *[]Word) {
 // ----------------------------------------------------------------------------
 // Natural numbers
 
-export type Natural []Word;
+export type Natural []Digit;
 export var NatZero *Natural = new(Natural, 0);
 
 
-export func NewNat(x Word) *Natural {
+export func Nat(x Digit) *Natural {
        var z *Natural;
        switch {
        case x == 0:
@@ -122,7 +131,7 @@ func Normalize(x *Natural) *Natural {
 }
 
 
-func Normalize3(x *Natural3) *Natural3 {
+func Normalize3(x *[]Digit3) *[]Digit3 {
        n := len(x);
        for n > 0 && x[n - 1] == 0 { n-- }
        if n < len(x) {
@@ -146,7 +155,7 @@ func (x *Natural) Add(y *Natural) *Natural {
        assert(n >= m);
        z := new(Natural, n + 1);
 
-       c := Word(0);
+       c := Digit(0);
        for i := 0; i < m; i++ { c, z[i] = Split(x[i] + y[i] + c); }
        for i := m; i < n; i++ { c, z[i] = Split(x[i] + c); }
        z[n] = c;
@@ -161,8 +170,8 @@ func (x *Natural) Sub(y *Natural) *Natural {
        assert(n >= m);
        z := new(Natural, n);
 
-       c := Word(0);
-       for i := 0; i < m; i++ { c, z[i] = Split(x[i] - y[i] + c); }
+       c := Digit(0);
+       for i := 0; i < m; i++ { c, z[i] = Split(x[i] - y[i] + c); }  // TODO verify asr!!!
        for i := m; i < n; i++ { c, z[i] = Split(x[i] + c); }
        assert(c == 0);  // x.Sub(y) must be called with x >= y
 
@@ -171,7 +180,7 @@ func (x *Natural) Sub(y *Natural) *Natural {
 
 
 // Computes x = x*a + c (in place) for "small" a's.
-func (x* Natural) MulAdd1(a, c Word) *Natural {
+func (x* Natural) MulAdd1(a, c Digit) *Natural {
        assert(IsSmall(a-1) && IsSmall(c));
        n := len(x);
        z := new(Natural, n + 1);
@@ -184,7 +193,7 @@ func (x* Natural) MulAdd1(a, c Word) *Natural {
 
 
 // Returns c = x*y div B, z = x*y mod B.
-func Mul1(x, y Word) (Word, Word) {
+func Mul1(x, y Digit) (Digit, Digit) {
        // Split x and y into 2 sub-digits each (in base sqrt(B)),
        // multiply the digits separately while avoiding overflow,
        // and return the product as two separate digits.
@@ -223,7 +232,7 @@ func (x *Natural) Mul(y *Natural) *Natural {
        for j := 0; j < m; j++ {
                d := y[j];
                if d != 0 {
-                       c := Word(0);
+                       c := Digit(0);
                        for i := 0; i < n; i++ {
                                // z[i+j] += x[i]*d + c;
                                z1, z0 := Mul1(x[i], d);
@@ -238,13 +247,13 @@ func (x *Natural) Mul(y *Natural) *Natural {
 }
 
 
-func Shl1(x, c Word, s uint) (Word, Word) {
+func Shl1(x, c Digit, s uint) (Digit, Digit) {
        assert(s <= LogB);
        return x >> (LogB - s), x << s & M | c
 }
 
 
-func Shr1(x, c Word, s uint) (Word, Word) {
+func Shr1(x, c Digit, s uint) (Digit, Digit) {
        assert(s <= LogB);
        return x << (LogB - s) & M, x >> s | c
 }
@@ -256,7 +265,7 @@ func (x *Natural) Shl(s uint) *Natural {
        s = s % LogB;
        z := new(Natural, n + si + 1);
        
-       c := Word(0);
+       c := Digit(0);
        for i := 0; i < n; i++ { c, z[i+si] = Shl1(x[i], c, s); }
        z[n+si] = c;
        
@@ -272,83 +281,184 @@ func (x *Natural) Shr(s uint) *Natural {
        assert(si <= n);
        z := new(Natural, n - si);
        
-       c := Word(0);
+       c := Digit(0);
        for i := n - 1; i >= si; i-- { c, z[i-si] = Shr1(x[i], c, s); }
        
        return Normalize(z);
 }
 
 
-func SplitBase(x *Natural) *Natural3 {
-       xl := len(x);
-       z := new(Natural3, xl * 3);
-       for i, j := 0, 0; i < xl; i, j = i + 1, j + 3 {
+// DivMod needs multi-precision division which is not available if Digit
+// is already using the largest uint size. Split base before division,
+// and merge again after. Each Digit is split into 3 Digit3's.
+
+func SplitBase(x *Natural) *[]Digit3 {
+       // TODO Use Log() for better result - don't need Normalize3 at the end!
+       n := len(x);
+       z := new([]Digit3, n*3 + 1);  // add space for extra digit (used by DivMod)
+       for i, j := 0, 0; i < n; i, j = i+1, j+3 {
                t := x[i];
-               z[j] = Word3(t & M3);  t >>= L3;  j++;
-               z[j] = Word3(t & M3);  t >>= L3;  j++;
-               z[j] = Word3(t & M3);  t >>= L3;  j++;
+               z[j+0] = Digit3(t >> (L3*0) & M3);
+               z[j+1] = Digit3(t >> (L3*1) & M3);
+               z[j+2] = Digit3(t >> (L3*2) & M3);
        }
        return Normalize3(z);
 }
 
 
-func Scale(x *Natural, f Word) *Natural3 {
-       return nil;
+func MergeBase(x *[]Digit3) *Natural {
+       i := len(x);
+       j := (i+2)/3;
+       z := new(Natural, j);
+
+       switch i%3 {
+       case 1: z[j-1] = Digit(x[i-1]); i--; j--;
+       case 2: z[j-1] = Digit(x[i-1])<<L3 | Digit(x[i-2]); i -= 2; j--;
+       case 0:
+       }
+       
+       for i >= 3 {
+               z[j-1] = ((Digit(x[i-1])<<L3) | Digit(x[i-2]))<<L3 | Digit(x[i-3]);
+               i -= 3;
+               j--;
+       }
+       assert(j == 0);
+
+       return Normalize(z);
+}
+
+
+func Split3(x Digit) (Digit, Digit3) {
+       return uint64(int64(x)>>L3), Digit3(x&M3)
+}
+
+
+func Product(x *[]Digit3, y Digit) {
+       n := len(x);
+       c := Digit(0);
+       for i := 0; i < n; i++ { c, x[i] = Split3(Digit(x[i])*y + c) }
+       assert(c == 0);
 }
 
 
-func TrialDigit(r, d *Natural3, k, m int) Word {
-       km := k + m;
-       assert(2 <= m && m <= km);
-       r3 := (Word(r[km]) << L3 + Word(r[km - 1])) << L3 + Word(r[km - 2]);
-       d2 := Word(d[m - 1]) << L3 + Word(d[m - 2]);
-       qt := r3 / d2;
-       if qt >= B {
-               qt = B - 1;
+func Quotient(x *[]Digit3, y Digit) {
+       n := len(x);
+       c := Digit(0);
+       for i := n-1; i >= 0; i-- {
+               t := c*B3 + Digit(x[i]);
+               c, x[i] = t%y, Digit3(t/y);
        }
-       return qt;
+       assert(c == 0);
 }
 
 
-func DivMod(x, y *Natural) {
-       xl := len(x);
-       yl := len(y);
-       assert(2 <= yl && yl <= xl);  // use special-case algorithm otherwise
+// Division and modulo computation - destroys x and y. Based on the
+// algorithms described in:
+//
+// 1) D. Knuth, "The Art of Computer Programming. Volume 2. Seminumerical
+//    Algorithms." Addison-Wesley, Reading, 1969.
+//
+// 2) P. Brinch Hansen, Multiple-length division revisited: A tour of the
+//    minefield. "Software - Practice and Experience 24", (June 1994),
+//    579-601. John Wiley & Sons, Ltd.
+//
+// Specifically, the inplace computation of quotient and remainder
+// is described in 1), while 2) provides the background for a more
+// accurate initial guess of the trial digit.
+
+func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
+       const b = B3;
        
-       f := B / (y[yl - 1] + 1);
-       r := Scale(x, f);
-       d := Scale(y, f);
-       n := len(r);
-       m := len(d);
+       n := len(x);
+       m := len(y);
+       assert(m > 0);  // division by zero
+       assert(n+1 <= cap(x));  // space for one extra digit (should it be == ?)
+       x = x[0 : n + 1];
        
-       for k := n - m; k >= 0; k-- {
-               qt := TrialDigit(r, d, k, m);
+       if m == 1 {
+               // division by single digit
+               d := Digit(y[0]);
+               c := Digit(0);
+               for i := n; i > 0; i-- {
+                       t := c*b + Digit(x[i-1]);
+                       c, x[i] = t%d, Digit3(t/d);
+               }
+               x[0] = Digit3(c);
+
+       } else if m > n {
+               // quotient = 0, remainder = x
+               // TODO in this case we shouldn't even split base - FIX THIS
+               m = n;
+               
+       } else {
+               // general case
+               assert(2 <= m && m <= n);
+               assert(x[n] == 0);
+               
+               // normalize x and y
+               f := b/(Digit(y[m-1]) + 1);
+               Product(x, f);
+               Product(y, f);
+               assert(b/2 <= y[m-1] && y[m-1] < b);  // incorrect scaling
+               
+               d2 := Digit(y[m-1])*b + Digit(y[m-2]);
+               for i := n-m; i >= 0; i-- {
+                       k := i+m;
+                       
+                       // compute trial digit
+                       r3 := (Digit(x[k])*b + Digit(x[k-1]))*b + Digit(x[k-2]);
+                       q := r3/d2;
+                       if q >= b { q = b-1 }
+                       
+                       // subtract y*q
+                       c := Digit(0);
+                       for j := 0; j < m; j++ {
+                               c, x[i+j] = Split3(c + Digit(x[i+j]) - Digit(y[j])*q);
+                       }
+                       
+                       // correct if trial digit was too large
+                       if c + Digit(x[k]) != 0 {
+                               // add y
+                               c := Digit(0);
+                               for j := 0; j < m; j++ {
+                                       c, x[i+j] = Split3(c + Digit(x[i+j]) + Digit(y[j]));
+                               }
+                               // correct trial digit
+                               q--;
+                       }
+                       
+                       x[k] = Digit3(q);
+               }
                
+               // undo normalization for remainder
+               Quotient(x[0 : m], f);
        }
+
+       return x[m : n+1], x[0 : m];
 }
 
 
 func (x *Natural) Div(y *Natural) *Natural {
-       panic("UNIMPLEMENTED");
-       return nil;
+       q, r := DivMod(SplitBase(x), SplitBase(y));
+       return MergeBase(q);
 }
 
 
 func (x *Natural) Mod(y *Natural) *Natural {
-       panic("UNIMPLEMENTED");
-       return nil;
+       q, r := DivMod(SplitBase(x), SplitBase(y));
+       return MergeBase(r);
 }
 
 
 func (x *Natural) Cmp(y *Natural) int {
-       xl := len(x);
-       yl := len(y);
+       n := len(x);
+       m := len(y);
 
-       if xl != yl || xl == 0 {
-               return xl - yl;
+       if n != m || n == 0 {
+               return n - m;
        }
 
-       i := xl - 1;
+       i := n - 1;
        for i > 0 && x[i] == y[i] { i--; }
        
        d := 0;
@@ -361,7 +471,7 @@ func (x *Natural) Cmp(y *Natural) int {
 }
 
 
-func Log1(x Word) int {
+func Log1(x Digit) int {
        n := -1;
        for x != 0 { x >>= 1; n++; }
        return n;
@@ -437,10 +547,10 @@ func Copy(x *Natural) *Natural {
 
 // Computes x = x div d (in place - the recv maybe modified) for "small" d's.
 // Returns updated x and x mod d.
-func (x *Natural) DivMod1(d Word) (*Natural, Word) {
+func (x *Natural) DivMod1(d Digit) (*Natural, Digit) {
        assert(0 < d && IsSmall(d - 1));
 
-       c := Word(0);
+       c := Digit(0);
        for i := len(x) - 1; i >= 0; i-- {
                c = c<<L + x[i];
                x[i] = c/d;
@@ -451,7 +561,7 @@ func (x *Natural) DivMod1(d Word) (*Natural, Word) {
 }
 
 
-func (x *Natural) String(base Word) string {
+func (x *Natural) String(base Digit) string {
        if x.IsZero() {
                return "0";
        }
@@ -469,7 +579,7 @@ func (x *Natural) String(base Word) string {
        x = Copy(x);  // don't destroy recv
        for !x.IsZero() {
                i--;
-               var d Word;
+               var d Digit;
                x, d = x.DivMod1(base);
                s[i] = hex[d];
        };
@@ -478,11 +588,11 @@ func (x *Natural) String(base Word) string {
 }
 
 
-func MulRange(a, b Word) *Natural {
+export func MulRange(a, b Digit) *Natural {
        switch {
-       case a > b: return NewNat(1);
-       case a == b: return NewNat(a);
-       case a + 1 == b: return NewNat(a).Mul(NewNat(b));
+       case a > b: return Nat(1);
+       case a == b: return Nat(a);
+       case a + 1 == b: return Nat(a).Mul(Nat(b));
        }
        m := (a + b)>>1;
        assert(a <= m && m < b);
@@ -490,26 +600,26 @@ func MulRange(a, b Word) *Natural {
 }
 
 
-export func Fact(n Word) *Natural {
+export func Fact(n Digit) *Natural {
        // Using MulRange() instead of the basic for-loop
        // lead to faster factorial computation.
        return MulRange(2, n);
 }
 
 
-func HexValue(ch byte) Word {
-       d := Word(H);
+func HexValue(ch byte) Digit {
+       d := Digit(1 << LogH);
        switch {
-       case '0' <= ch && ch <= '9': d = Word(ch - '0');
-       case 'a' <= ch && ch <= 'f': d = Word(ch - 'a') + 10;
-       case 'A' <= ch && ch <= 'F': d = Word(ch - 'A') + 10;
+       case '0' <= ch && ch <= '9': d = Digit(ch - '0');
+       case 'a' <= ch && ch <= 'f': d = Digit(ch - 'a') + 10;
+       case 'A' <= ch && ch <= 'F': d = Digit(ch - 'A') + 10;
        }
        return d;
 }
 
 
 // TODO auto-detect base if base argument is 0
-export func NatFromString(s string, base Word) *Natural {
+export func NatFromString(s string, base Digit) *Natural {
        x := NatZero;
        for i := 0; i < len(s); i++ {
                d := HexValue(s[i]);
@@ -532,6 +642,11 @@ export type Integer struct {
 }
 
 
+export func Int(x int64) *Integer {
+       return nil;
+}
+
+
 func (x *Integer) Add(y *Integer) *Integer {
        var z *Integer;
        if x.sign == y.sign {
@@ -603,7 +718,7 @@ func (x *Integer) Cmp(y *Integer) int {
 }
 
 
-func (x *Integer) String(base Word) string {
+func (x *Integer) String(base Digit) string {
        if x.mant.IsZero() {
                return "0";
        }
@@ -615,7 +730,7 @@ func (x *Integer) String(base Word) string {
 }
 
        
-export func IntFromString(s string, base Word) *Integer {
+export func IntFromString(s string, base Digit) *Integer {
        // get sign, if any
        sign := false;
        if len(s) > 0 && (s[0] == '-' || s[0] == '+') {
index dd9706a58e100aae2387c150d9e9c10357b93dbb..338c70bc85a83f6bc8c7d22cd06928d28758150c 100644 (file)
@@ -4,7 +4,7 @@
 
 package main
 
-import Bignum "bignum"
+import Big "bignum"
 
 const (
        sa = "991";
@@ -14,25 +14,35 @@ const (
 
 
 var (
-       a = Bignum.NatFromString(sa, 10);
-       b = Bignum.NatFromString(sb, 10);
-       c = Bignum.NatFromString(sc, 10);
+       a = Big.NatFromString(sa, 10);
+       b = Big.NatFromString(sb, 10);
+       c = Big.NatFromString(sc, 10);
 )
 
 
 var test_msg string;
-func TEST(n int, b bool) {
+func TEST(n uint, b bool) {
        if !b {
                panic("TEST failed: ", test_msg, "(", n, ")\n");
        }
 }
 
 
+func TEST_EQ(n uint, x, y *Big.Natural) {
+       if x.Cmp(y) != 0 {
+               println("TEST failed: ", test_msg, "(", n, ")\n");
+               println("x = ", x.String(10));
+               println("y = ", y.String(10));
+               panic();
+       }
+}
+
+
 func TestConv() {
        test_msg = "TestConv";
-       TEST(0, a.Cmp(Bignum.NewNat(991)) == 0);
-       TEST(1, b.Cmp(Bignum.Fact(20)) == 0);
-       TEST(2, c.Cmp(Bignum.Fact(100)) == 0);
+       TEST(0, a.Cmp(Big.Nat(991)) == 0);
+       TEST(1, b.Cmp(Big.Fact(20)) == 0);
+       TEST(2, c.Cmp(Big.Fact(100)) == 0);
        TEST(3, a.String(10) == sa);
        TEST(4, b.String(10) == sb);
        TEST(5, c.String(10) == sc);
@@ -49,32 +59,81 @@ func TestShift() {
        TEST(1, c.Shr(1).Cmp(c) < 0);
 
        test_msg = "TestShift2";
-       for i := 0; i < 100; i++ {
-               TEST(i, c.Shl(uint(i)).Shr(uint(i)).Cmp(c) == 0);
+       for i := uint(0); i < 100; i++ {
+               TEST(i, c.Shl(i).Shr(i).Cmp(c) == 0);
        }
 
        test_msg = "TestShift3L";
        {       const m = 3;
                p := b;
-               f := Bignum.NewNat(1<<m);
-               for i := 0; i < 100; i++ {
-                       TEST(i, b.Shl(uint(i*m)).Cmp(p) == 0);
+               f := Big.Nat(1<<m);
+               for i := uint(0); i < 100; i++ {
+                       TEST_EQ(i, b.Shl(i*m), p);
                        p = p.Mul(f);
                }
        }
 
        test_msg = "TestShift3R";
        {       p := c;
-               for i := 0; c.Cmp(Bignum.NatZero) == 0; i++ {
-                       TEST(i, c.Shr(uint(i)).Cmp(p) == 0);
+               for i := uint(0); c.Cmp(Big.NatZero) == 0; i++ {
+                       TEST_EQ(i, c.Shr(i), p);
                        p = p.Shr(1);
                }
        }
 }
 
 
+func TestMul() {
+       test_msg = "TestMulA";
+       TEST_EQ(0, b.Mul(Big.MulRange(0, 100)), Big.Nat(0));
+       TEST_EQ(0, b.Mul(Big.MulRange(21, 100)), c);
+       
+       test_msg = "TestMulB";
+       const n = 100;
+       p := b.Mul(c).Shl(n);
+       for i := uint(0); i < n; i++ {
+               TEST_EQ(i, b.Shl(i).Mul(c.Shl(n-i)), p);
+       }
+}
+
+
+func TestDiv() {
+       test_msg = "TestDivA";
+       TEST_EQ(0, c.Div(Big.Nat(1)), c);
+       TEST_EQ(1, c.Div(Big.Nat(100)), Big.Fact(99));
+       TEST_EQ(2, b.Div(c), Big.Nat(0));
+       TEST_EQ(4, Big.Nat(1).Shl(100).Div(Big.Nat(1).Shl(90)), Big.Nat(1).Shl(10));
+       TEST_EQ(5, c.Div(b), Big.MulRange(21, 100));
+       
+       test_msg = "TestDivB";
+       const n = 100;
+       p := Big.Fact(n);
+       for i := uint(0); i < n; i++ {
+               TEST_EQ(i, p.Div(Big.MulRange(1, uint64(i))), Big.MulRange(uint64(i+1), n));
+       }
+}
+
+
+func TestMod() {
+       test_msg = "TestModA";
+       for i := uint(0); ; i++ {
+               d := Big.Nat(1).Shl(i);
+               if d.Cmp(c) < 0 {
+                       TEST_EQ(i, c.Add(d).Mod(c), d);
+               } else {
+                       TEST_EQ(i, c.Add(d).Div(c), Big.Nat(2));
+                       //TEST_EQ(i, c.Add(d).Mod(c), d.Sub(c));
+                       break;
+               }
+       }
+}
+
+
 func main() {
        TestConv();
        TestShift();
+       TestMul();
+       TestDiv();
+       TestMod();
        print("PASSED\n");
 }