]> Cypherpunks repositories - gostls13.git/commitdiff
- changed general div/mod implementation to a faster algorithm
authorRobert Griesemer <gri@golang.org>
Mon, 3 Nov 2008 17:21:10 +0000 (09:21 -0800)
committerRobert Griesemer <gri@golang.org>
Mon, 3 Nov 2008 17:21:10 +0000 (09:21 -0800)
  (operates on 30bit values at a time instead of 20bit values)
- refactored and cleaned up lots of code
- more tests
- close to check-in as complete library

R=r
OCL=18326
CL=18326

usr/gri/bignum/bignum.go
usr/gri/bignum/bignum_test.go

index 08217200485dcf74ae6a10621298ad978ae0c48c..0852f42262338a90dc88595cf29d2a7a4e8b0f8a 100755 (executable)
@@ -38,22 +38,18 @@ const LogB = LogW - LogH;
 
 
 const (
-       L3 = LogB / 3;
-       B3 = 1 << L3;
-       M3 = B3 - 1;
-
-       L2 = L3 * 2;
+       L2 = LogB / 2;
        B2 = 1 << L2;
        M2 = B2 - 1;
 
-       L = L3 * 3;
+       L = L2 * 2;
        B = 1 << L;
        M = B - 1;
 )
 
 
 type (
-       Digit3 uint32;
+       Digit2 uint32;
        Digit  uint64;
 )
 
@@ -69,26 +65,205 @@ func assert(p bool) {
 }
 
 
-func IsSmall(x Digit) bool {
-       return x < 1<<LogH;
+// ----------------------------------------------------------------------------
+// Raw operations
+
+func And1(z, x *[]Digit, y Digit) {
+       for i := len(x) - 1; i >= 0; i-- {
+               z[i] = x[i] & y;
+       }
 }
 
 
-func Split(x Digit) (Digit, Digit) {
-       return x>>L, x&M;
+func And(z, x, y *[]Digit) {
+       for i := len(x) - 1; i >= 0; i-- {
+               z[i] = x[i] & y[i];
+       }
 }
 
 
-export func Dump(x *[]Digit) {
-       print("[", len(x), "]");
+func Or1(z, x *[]Digit, y Digit) {
        for i := len(x) - 1; i >= 0; i-- {
-               print(" ", x[i]);
+               z[i] = x[i] | y;
        }
-       println();
 }
 
 
-export func Dump3(x *[]Digit3) {
+func Or(z, x, y *[]Digit) {
+       for i := len(x) - 1; i >= 0; i-- {
+               z[i] = x[i] | y[i];
+       }
+}
+
+
+func Xor1(z, x *[]Digit, y Digit) {
+       for i := len(x) - 1; i >= 0; i-- {
+               z[i] = x[i] ^ y;
+       }
+}
+
+
+func Xor(z, x, y *[]Digit) {
+       for i := len(x) - 1; i >= 0; i-- {
+               z[i] = x[i] ^ y[i];
+       }
+}
+
+
+func Add1(z, x *[]Digit, c Digit) Digit {
+       n := len(x);
+       for i := 0; i < n; i++ {
+               t := c + x[i];
+               c, z[i] = t>>L, t&M
+       }
+       return c;
+}
+
+
+func Add(z, x, y *[]Digit) Digit {
+       var c Digit;
+       n := len(x);
+       for i := 0; i < n; i++ {
+               t := c + x[i] + y[i];
+               c, z[i] = t>>L, t&M
+       }
+       return c;
+}
+
+
+func Sub1(z, x *[]Digit, c Digit) Digit {
+       n := len(x);
+       for i := 0; i < n; i++ {
+               t := c + x[i];
+               c, z[i] = Digit(int64(t)>>L), t&M;  // arithmetic shift!
+       }
+       return c;
+}
+
+
+func Sub(z, x, y *[]Digit) Digit {
+       var c Digit;
+       n := len(x);
+       for i := 0; i < n; i++ {
+               t := c + x[i] - y[i];
+               c, z[i] = Digit(int64(t)>>L), t&M;  // arithmetic shift!
+       }
+       return c;
+}
+
+
+// Returns c = x*y div B, z = x*y mod B.
+func Mul11(x, y Digit) (Digit, Digit) {
+       // Split x and y into 2 sub-digits each (in base sqrt(B)),
+       // multiply the digits separately while avoiding overflow,
+       // and return the product as two separate digits.
+
+       const L0 = (L + 1)/2;
+       const L1 = L - L0;
+       const DL = L0 - L1;  // 0 or 1
+       const b  = 1<<L0;
+       const m  = b - 1;
+
+       // split x and y into sub-digits
+       // x = (x1*b + x0)
+       // y = (y1*b + y0)
+       x1, x0 := x>>L0, x&m;
+       y1, y0 := y>>L0, y&m;
+
+       // x*y = t2*b^2 + t1*b + t0
+       t0 := x0*y0;
+       t1 := x1*y0 + x0*y1;
+       t2 := x1*y1;
+
+       // compute the result digits but avoid overflow
+       // z = z1*B + z0 = x*y
+       z0 := (t1<<L0 + t0)&M;
+       z1 := t2<<DL + (t1 + t0>>L0)>>L1;
+       
+       return z1, z0;
+}
+
+
+func Mul(z, x, y *[]Digit) {
+       n := len(x);
+       m := len(y);
+       for j := 0; j < m; j++ {
+               d := y[j];
+               if d != 0 {
+                       c := Digit(0);
+                       for i := 0; i < n; i++ {
+                               // z[i+j] += c + x[i]*d;
+                               z1, z0 := Mul11(x[i], d);
+                               t := c + z[i+j] + z0;
+                               c, z[i+j] = t>>L, t&M;
+                               c += z1;
+                       }
+                       z[n+j] = c;
+               }
+       }
+}
+
+
+func Mul1(z, x *[]Digit2, y Digit2) Digit2 {
+       n := len(x);
+       var c Digit;
+       f := Digit(y);
+       for i := 0; i < n; i++ {
+               t := c + Digit(x[i])*f;
+               c, z[i] = t>>L2, Digit2(t&M2);
+       }
+       return Digit2(c);
+}
+
+
+func Div1(z, x *[]Digit2, y Digit2) Digit2 {
+       n := len(x);
+       var c Digit;
+       d := Digit(y);
+       for i := n-1; i >= 0; i-- {
+               t := c*B2 + Digit(x[i]);
+               c, z[i] = t%d, Digit2(t/d);
+       }
+       return Digit2(c);
+}
+
+
+func Shl(z, x *[]Digit, s uint) Digit {
+       assert(s <= L);
+       n := len(x);
+       var c Digit;
+       for i := 0; i < n; i++ {
+               c, z[i] = x[i] >> (L-s), x[i] << s & M | c;
+       }
+       return c;
+}
+
+
+func Shr(z, x *[]Digit, s uint) Digit {
+       assert(s <= L);
+       n := len(x);
+       var c Digit;
+       for i := n - 1; i >= 0; i-- {
+               c, z[i] = x[i] << (L-s) & M, x[i] >> s | c;
+       }
+       return c;
+}
+
+
+// ----------------------------------------------------------------------------
+// Support
+
+func IsSmall(x Digit) bool {
+       return x < 1<<LogH;
+}
+
+
+func Split(x Digit) (Digit, Digit) {
+       return x>>L, x&M;
+}
+
+
+export func Dump(x *[]Digit) {
        print("[", len(x), "]");
        for i := len(x) - 1; i >= 0; i-- {
                print(" ", x[i]);
@@ -140,7 +315,7 @@ func Normalize(x *Natural) *Natural {
 }
 
 
-func Normalize3(x *[]Digit3) *[]Digit3 {
+func Normalize2(x *[]Digit2) *[]Digit2 {
        n := len(x);
        for n > 0 && x[n - 1] == 0 { n-- }
        if n < len(x) {
@@ -150,9 +325,10 @@ func Normalize3(x *[]Digit3) *[]Digit3 {
 }
 
 
-func (x *Natural) IsZero() bool {
-       return len(x) == 0;
-}
+// Predicates
+
+func (x *Natural) IsZero() bool { return len(x) == 0; }
+func (x *Natural) IsOdd() bool { return len(x) > 0 && x[0]&1 != 0; }
 
 
 func (x *Natural) Add(y *Natural) *Natural {
@@ -161,13 +337,10 @@ func (x *Natural) Add(y *Natural) *Natural {
        if n < m {
                return y.Add(x);
        }
-       assert(n >= m);
-       z := new(Natural, n + 1);
 
-       c := Digit(0);
-       for i := 0; i < m; i++ { c, z[i] = Split(c + x[i] + y[i]); }
-       for i := m; i < n; i++ { c, z[i] = Split(c + x[i]); }
-       z[n] = c;
+       z := new(Natural, n + 1);
+       c := Add(z[0 : m], x[0 : m], y);
+       z[n] = Add1(z[m : n], x[m : n], c);
 
        return Normalize(z);
 }
@@ -176,19 +349,15 @@ func (x *Natural) Add(y *Natural) *Natural {
 func (x *Natural) Sub(y *Natural) *Natural {
        n := len(x);
        m := len(y);
-       assert(n >= m);
-       z := new(Natural, n);
-
-       c := Digit(0);
-       for i := 0; i < m; i++ {
-               t := c + x[i] - y[i];
-               c, z[i] = Digit(int64(t)>>L), t&M;  // arithmetic shift!
+       if n < m {
+               panic("underflow")
        }
-       for i := m; i < n; i++ {
-               t := c + x[i];
-               c, z[i] = Digit(int64(t)>>L), t&M;  // arithmetic shift!
+
+       z := new(Natural, n);
+       c := Sub(z[0 : m], x[0 : m], y);
+       if Sub1(z[m : n], x[m : n], c) != 0 {
+               panic("underflow");
        }
-       assert(c == 0);  // x.Sub(y) must be called with x >= y
 
        return Normalize(z);
 }
@@ -207,56 +376,12 @@ func (x* Natural) MulAdd1(a, c Digit) *Natural {
 }
 
 
-// Returns c = x*y div B, z = x*y mod B.
-func Mul1(x, y Digit) (Digit, Digit) {
-       // Split x and y into 2 sub-digits each (in base sqrt(B)),
-       // multiply the digits separately while avoiding overflow,
-       // and return the product as two separate digits.
-
-       const L0 = (L + 1)/2;
-       const L1 = L - L0;
-       const DL = L0 - L1;  // 0 or 1
-       const b  = 1<<L0;
-       const m  = b - 1;
-
-       // split x and y into sub-digits
-       // x = (x1*b + x0)
-       // y = (y1*b + y0)
-       x1, x0 := x>>L0, x&m;
-       y1, y0 := y>>L0, y&m;
-
-       // x*y = t2*b^2 + t1*b + t0
-       t0 := x0*y0;
-       t1 := x1*y0 + x0*y1;
-       t2 := x1*y1;
-
-       // compute the result digits but avoid overflow
-       // z = z1*B + z0 = x*y
-       z0 := (t1<<L0 + t0)&M;
-       z1 := t2<<DL + (t1 + t0>>L0)>>L1;
-       
-       return z1, z0;
-}
-
-
 func (x *Natural) Mul(y *Natural) *Natural {
        n := len(x);
        m := len(y);
-       z := new(Natural, n + m);
 
-       for j := 0; j < m; j++ {
-               d := y[j];
-               if d != 0 {
-                       c := Digit(0);
-                       for i := 0; i < n; i++ {
-                               // z[i+j] += c + x[i]*d;
-                               z1, z0 := Mul1(x[i], d);
-                               c, z[i+j] = Split(c + z[i+j] + z0);
-                               c += z1;
-                       }
-                       z[n+j] = c;
-               }
-       }
+       z := new(Natural, n + m);
+       Mul(z, x, y);
 
        return Normalize(z);
 }
@@ -294,42 +419,26 @@ func (x *Natural) Pow(n uint) *Natural {
 }
 
 
-func Shl1(x, c Digit, s uint) (Digit, Digit) {
-       assert(s <= LogB);
-       return x >> (LogB - s), x << s & M | c
-}
-
-
-func Shr1(x, c Digit, s uint) (Digit, Digit) {
-       assert(s <= LogB);
-       return x << (LogB - s) & M, x >> s | c
-}
-
-
 func (x *Natural) Shl(s uint) *Natural {
-       n := len(x);
-       si := int(s / LogB);
-       s = s % LogB;
-       z := new(Natural, n + si + 1);
+       n := uint(len(x));
+       m := n + s/L;
+       z := new(Natural, m+1);
        
-       c := Digit(0);
-       for i := 0; i < n; i++ { c, z[i+si] = Shl1(x[i], c, s); }
-       z[n+si] = c;
+       z[m] = Shl(z[m-n : m], x, s%L);
        
        return Normalize(z);
 }
 
 
 func (x *Natural) Shr(s uint) *Natural {
-       n := len(x);
-       si := int(s / LogB);
-       if si >= n { si = n; }
-       s = s % LogB;
-       assert(si <= n);
-       z := new(Natural, n - si);
+       n := uint(len(x));
+       m := n - s/L;
+       if m > n {  // check for underflow
+               m = 0;
+       }
+       z := new(Natural, m);
        
-       c := Digit(0);
-       for i := n - 1; i >= si; i-- { c, z[i-si] = Shr1(x[i], c, s); }
+       Shr(z, x[n-m : n], s%L);
        
        return Normalize(z);
 }
@@ -337,68 +446,36 @@ func (x *Natural) Shr(s uint) *Natural {
 
 // DivMod needs multi-precision division which is not available if Digit
 // is already using the largest uint size. Split base before division,
-// and merge again after. Each Digit is split into 3 Digit3's.
+// and merge again after. Each Digit is split into 2 Digit2's.
 
-func SplitBase(x *Natural) *[]Digit3 {
-       // TODO Use Log() for better result - don't need Normalize3 at the end!
+func Unpack(x *Natural) *[]Digit2 {
+       // TODO Use Log() for better result - don't need Normalize2 at the end!
        n := len(x);
-       z := new([]Digit3, n*3 + 1);  // add space for extra digit (used by DivMod)
-       for i, j := 0, 0; i < n; i, j = i+1, j+3 {
+       z := new([]Digit2, n*2 + 1);  // add space for extra digit (used by DivMod)
+       for i := 0; i < n; i++ {
                t := x[i];
-               z[j+0] = Digit3(t >> (L3*0) & M3);
-               z[j+1] = Digit3(t >> (L3*1) & M3);
-               z[j+2] = Digit3(t >> (L3*2) & M3);
+               z[i*2] = Digit2(t & M2);
+               z[i*2 + 1] = Digit2(t >> L2 & M2);
        }
-       return Normalize3(z);
+       return Normalize2(z);
 }
 
 
-func MergeBase(x *[]Digit3) *Natural {
-       i := len(x);
-       j := (i+2)/3;
-       z := new(Natural, j);
-
-       switch i%3 {
-       case 1: z[j-1] = Digit(x[i-1]); i--; j--;
-       case 2: z[j-1] = Digit(x[i-1])<<L3 | Digit(x[i-2]); i -= 2; j--;
-       case 0:
+func Pack(x *[]Digit2) *Natural {
+       n := (len(x) + 1) / 2;
+       z := new(Natural, n);
+       if len(x) & 1 == 1 {
+               // handle odd len(x)
+               n--;
+               z[n] = Digit(x[n*2]);
        }
-       
-       for i >= 3 {
-               z[j-1] = ((Digit(x[i-1])<<L3) | Digit(x[i-2]))<<L3 | Digit(x[i-3]);
-               i -= 3;
-               j--;
+       for i := 0; i < n; i++ {
+               z[i] = Digit(x[i*2 + 1]) << L2 | Digit(x[i*2]);
        }
-       assert(j == 0);
-
        return Normalize(z);
 }
 
 
-func Split3(x Digit) (Digit, Digit3) {
-       return uint64(int64(x)>>L3), Digit3(x&M3)
-}
-
-
-func Product(x *[]Digit3, y Digit) {
-       n := len(x);
-       c := Digit(0);
-       for i := 0; i < n; i++ { c, x[i] = Split3(c + Digit(x[i])*y) }
-       assert(c == 0);
-}
-
-
-func Quotient(x *[]Digit3, y Digit) {
-       n := len(x);
-       c := Digit(0);
-       for i := n-1; i >= 0; i-- {
-               t := c*B3 + Digit(x[i]);
-               c, x[i] = t%y, Digit3(t/y);
-       }
-       assert(c == 0);
-}
-
-
 // Division and modulo computation - destroys x and y. Based on the
 // algorithms described in:
 //
@@ -413,8 +490,8 @@ func Quotient(x *[]Digit3, y Digit) {
 // is described in 1), while 2) provides the background for a more
 // accurate initial guess of the trial digit.
 
-func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
-       const b = B3;
+func DivMod2(x, y *[]Digit2) (*[]Digit2, *[]Digit2) {
+       const b = B2;
        
        n := len(x);
        m := len(y);
@@ -424,14 +501,9 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
        
        if m == 1 {
                // division by single digit
-               d := Digit(y[0]);
-               c := Digit(0);
-               for i := n; i > 0; i-- {
-                       t := c*b + Digit(x[i-1]);
-                       c, x[i] = t%d, Digit3(t/d);
-               }
-               x[0] = Digit3(c);
-
+               // result is shifted left by 1 in place!
+               x[0] = Div1(x[1 : n+1], x[0 : n], y[0]);
+               
        } else if m > n {
                // quotient = 0, remainder = x
                // TODO in this case we shouldn't even split base - FIX THIS
@@ -444,24 +516,34 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
                
                // normalize x and y
                f := b/(Digit(y[m-1]) + 1);
-               Product(x, f);
-               Product(y, f);
+               Mul1(x, x, Digit2(f));
+               Mul1(y, y, Digit2(f));
                assert(b/2 <= y[m-1] && y[m-1] < b);  // incorrect scaling
                
-               d2 := Digit(y[m-1])*b + Digit(y[m-2]);
+               y1, y2 := Digit(y[m-1]), Digit(y[m-2]);
+               d2 := Digit(y1)*b + Digit(y2);
                for i := n-m; i >= 0; i-- {
                        k := i+m;
                        
                        // compute trial digit
-                       r3 := (Digit(x[k])*b + Digit(x[k-1]))*b + Digit(x[k-2]);
-                       q := r3/d2;
-                       if q >= b { q = b-1 }
+                       var q Digit;
+                       {       // Knuth
+                               x0, x1, x2 := Digit(x[k]), Digit(x[k-1]), Digit(x[k-2]);
+                               if x0 != y1 {
+                                       q = (x0*b + x1)/y1;
+                               } else {
+                                       q = b-1;
+                               }
+                               for y2 * q > (x0*b + x1 - y1*q)*b + x2 {
+                                       q--
+                               }
+                       }
                        
                        // subtract y*q
                        c := Digit(0);
                        for j := 0; j < m; j++ {
                                t := c + Digit(x[i+j]) - Digit(y[j])*q;  // arithmetic shift!
-                               c, x[i+j] = Digit(int64(t)>>L3), Digit3(t&M3);
+                               c, x[i+j] = Digit(int64(t)>>L2), Digit2(t&M2);
                        }
                        
                        // correct if trial digit was too large
@@ -469,18 +551,20 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
                                // add y
                                c := Digit(0);
                                for j := 0; j < m; j++ {
-                                       c, x[i+j] = Split3(c + Digit(x[i+j]) + Digit(y[j]));
+                                       t := c + Digit(x[i+j]) + Digit(y[j]);
+                                       c, x[i+j] = uint64(int64(t) >> L2), Digit2(t & M2)
                                }
                                assert(c + Digit(x[k]) == 0);
                                // correct trial digit
                                q--;
                        }
                        
-                       x[k] = Digit3(q);
+                       x[k] = Digit2(q);
                }
                
                // undo normalization for remainder
-               Quotient(x[0 : m], f);
+               c := Div1(x[0 : m], x[0 : m], Digit2(f));
+               assert(c == 0);
        }
 
        return x[m : n+1], x[0 : m];
@@ -488,14 +572,20 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
 
 
 func (x *Natural) Div(y *Natural) *Natural {
-       q, r := DivMod(SplitBase(x), SplitBase(y));
-       return MergeBase(q);
+       q, r := DivMod2(Unpack(x), Unpack(y));
+       return Pack(q);
 }
 
 
 func (x *Natural) Mod(y *Natural) *Natural {
-       q, r := DivMod(SplitBase(x), SplitBase(y));
-       return MergeBase(r);
+       q, r := DivMod2(Unpack(x), Unpack(y));
+       return Pack(r);
+}
+
+
+func (x *Natural) DivMod(y *Natural) (*Natural, *Natural) {
+       q, r := DivMod2(Unpack(x), Unpack(y));
+       return Pack(q), Pack(r);
 }
 
 
@@ -520,17 +610,17 @@ func (x *Natural) Cmp(y *Natural) int {
 }
 
 
-func Log1(x Digit) int {
+func Log2(x Digit) int {
        n := -1;
        for x != 0 { x = x >> 1; n++; }  // BUG >>= broken for uint64
        return n;
 }
 
 
-func (x *Natural) Log() int {
+func (x *Natural) Log2() int {
        n := len(x);
        if n > 0 {
-               n = (n - 1)*L + Log1(x[n - 1]);
+               n = (n - 1)*L + Log2(x[n - 1]);
        } else {
                n = -1;
        }
@@ -544,11 +634,10 @@ func (x *Natural) And(y *Natural) *Natural {
        if n < m {
                return y.And(x);
        }
-       assert(n >= m);
-       z := new(Natural, n);
 
-       for i := 0; i < m; i++ { z[i] = x[i] & y[i]; }
-       for i := m; i < n; i++ { z[i] = x[i]; }
+       z := new(Natural, n);
+       And(z[0 : m], x[0 : m], y);
+       Or1(z[m : n], x[m : n], 0);
 
        return Normalize(z);
 }
@@ -560,11 +649,10 @@ func (x *Natural) Or(y *Natural) *Natural {
        if n < m {
                return y.Or(x);
        }
-       assert(n >= m);
-       z := new(Natural, n);
 
-       for i := 0; i < m; i++ { z[i] = x[i] | y[i]; }
-       for i := m; i < n; i++ { z[i] = x[i]; }
+       z := new(Natural, n);
+       Or(z[0 : m], x[0 : m], y);
+       Or1(z[m : n], x[m : n], 0);
 
        return Normalize(z);
 }
@@ -576,24 +664,15 @@ func (x *Natural) Xor(y *Natural) *Natural {
        if n < m {
                return y.Xor(x);
        }
-       assert(n >= m);
-       z := new(Natural, n);
 
-       for i := 0; i < m; i++ { z[i] = x[i] ^ y[i]; }
-       for i := m; i < n; i++ { z[i] = x[i]; }
+       z := new(Natural, n);
+       Xor(z[0 : m], x[0 : m], y);
+       Or1(z[m : n], x[m : n], 0);
 
        return Normalize(z);
 }
 
 
-func Copy(x *Natural) *Natural {
-       z := new(Natural, len(x));
-       //*z = *x;  // BUG assignment does't work yet
-       for i := len(x) - 1; i >= 0; i-- { z[i] = x[i]; }
-       return z;
-}
-
-
 // Computes x = x div d (in place - the recv maybe modified) for "small" d's.
 // Returns updated x and x mod d.
 func (x *Natural) DivMod1(d Digit) (*Natural, Digit) {
@@ -601,36 +680,36 @@ func (x *Natural) DivMod1(d Digit) (*Natural, Digit) {
 
        c := Digit(0);
        for i := len(x) - 1; i >= 0; i-- {
-               c = c<<L + x[i];
-               x[i] = c/d;
-               c %= d;
+               t := c<<L + x[i];
+               c, x[i] = t%d, t/d;
        }
 
        return Normalize(x), c;
 }
 
 
-func (x *Natural) String(base Digit) string {
+func (x *Natural) String(base uint) string {
        if x.IsZero() {
                return "0";
        }
        
        // allocate string
-       // TODO n is too small for bases < 10!!!
-       assert(base >= 10);  // for now
-       // approx. length: 1 char for 3 bits
-       n := x.Log()/3 + 10;  // +10 (round up) - what is the right number?
+       assert(2 <= base && base <= 16);
+       n := (x.Log2() + 1) / Log2(Digit(base)) + 1;  // TODO why the +1?
        s := new([]byte, n);
 
        // convert
-       const hex = "0123456789abcdef";
+
+       // don't destroy x, make a copy
+       t := new(Natural, len(x));
+       Or1(t, x, 0);  // copy x
+       
        i := n;
-       x = Copy(x);  // don't destroy recv
-       for !x.IsZero() {
+       for !t.IsZero() {
                i--;
                var d Digit;
-               x, d = x.DivMod1(base);
-               s[i] = hex[d];
+               t, d = t.DivMod1(Digit(base));
+               s[i] = "0123456789abcdef"[d];
        };
 
        return string(s[i : n]);
@@ -665,24 +744,24 @@ func (x *Natural) Gcd(y *Natural) *Natural {
 }
 
 
-func HexValue(ch byte) Digit {
-       d := Digit(1 << LogH);
+func HexValue(ch byte) uint {
+       d := uint(1 << LogH);
        switch {
-       case '0' <= ch && ch <= '9': d = Digit(ch - '0');
-       case 'a' <= ch && ch <= 'f': d = Digit(ch - 'a') + 10;
-       case 'A' <= ch && ch <= 'F': d = Digit(ch - 'A') + 10;
+       case '0' <= ch && ch <= '9': d = uint(ch - '0');
+       case 'a' <= ch && ch <= 'f': d = uint(ch - 'a') + 10;
+       case 'A' <= ch && ch <= 'F': d = uint(ch - 'A') + 10;
        }
        return d;
 }
 
 
 // TODO auto-detect base if base argument is 0
-export func NatFromString(s string, base Digit) *Natural {
+export func NatFromString(s string, base uint) *Natural {
        x := NatZero;
        for i := 0; i < len(s); i++ {
                d := HexValue(s[i]);
                if d < base {
-                       x = x.MulAdd1(base, d);
+                       x = x.MulAdd1(Digit(base), Digit(d));
                } else {
                        break;
                }
@@ -776,19 +855,31 @@ func (x *Integer) Mul(y *Integer) *Integer {
 
 
 func (x *Integer) Quo(y *Integer) *Integer {
-       panic("UNIMPLEMENTED");
-       return nil;
+       // x / y == x / y
+       // x / (-y) == -(x / y)
+       // (-x) / y == -(x / y)
+       // (-x) / (-y) == x / y
+       return &Integer{x.sign != y.sign, x.mant.Div(y.mant)};
 }
 
 
 func (x *Integer) Rem(y *Integer) *Integer {
-       panic("UNIMPLEMENTED");
-       return nil;
+       // x % y == x % y
+       // x % (-y) == x % y
+       // (-x) % y == -(x % y)
+       // (-x) % (-y) == -(x % y)
+       return &Integer{y.sign, x.mant.Mod(y.mant)};
+}
+
+
+func (x *Integer) QuoRem(y *Integer) (*Integer, *Integer) {
+       q, r := x.mant.DivMod(y.mant);
+       return &Integer{x.sign != y.sign, q}, &Integer{y.sign, q};
 }
 
 
 func (x *Integer) Div(y *Integer) *Integer {
-       panic("UNIMPLEMENTED");
+       q, r := x.mant.DivMod(y.mant);
        return nil;
 }
 
@@ -805,7 +896,7 @@ func (x *Integer) Cmp(y *Integer) int {
 }
 
 
-func (x *Integer) String(base Digit) string {
+func (x *Integer) String(base uint) string {
        if x.mant.IsZero() {
                return "0";
        }
@@ -817,7 +908,7 @@ func (x *Integer) String(base Digit) string {
 }
 
        
-export func IntFromString(s string, base Digit) *Integer {
+export func IntFromString(s string, base uint) *Integer {
        // get sign, if any
        sign := false;
        if len(s) > 0 && (s[0] == '-' || s[0] == '+') {
index ae81036b0fd997166a89ac07afc7bd41696bcf29..10d8da7db7ca60fef138ad34b5aa3107f9842852 100644 (file)
@@ -38,14 +38,54 @@ func TEST_EQ(n uint, x, y *Big.Natural) {
 }
 
 
+func TestLog2() {
+       test_msg = "TestLog2A";
+       TEST(0, Big.Nat(1).Log2() == 0);
+       TEST(1, Big.Nat(2).Log2() == 1);
+       TEST(2, Big.Nat(3).Log2() == 1);
+       TEST(3, Big.Nat(4).Log2() == 2);
+       
+       test_msg = "TestLog2B";
+       for i := uint(0); i < 100; i++ {
+               TEST(i, Big.Nat(1).Shl(i).Log2() == int(i));
+       }
+}
+
+
 func TestConv() {
-       test_msg = "TestConv";
+       test_msg = "TestConvA";
        TEST(0, a.Cmp(Big.Nat(991)) == 0);
        TEST(1, b.Cmp(Big.Fact(20)) == 0);
        TEST(2, c.Cmp(Big.Fact(100)) == 0);
        TEST(3, a.String(10) == sa);
        TEST(4, b.String(10) == sb);
        TEST(5, c.String(10) == sc);
+
+       test_msg = "TestConvB";
+       t := c.Mul(c);
+       for base := uint(2); base <= 16; base++ {
+               TEST_EQ(base, Big.NatFromString(t.String(base), base), t);
+       }
+}
+
+
+func Sum(n uint, scale *Big.Natural) *Big.Natural {
+       s := Big.Nat(0);
+       for ; n > 0; n-- {
+               s = s.Add(Big.Nat(uint64(n)).Mul(scale));
+       }
+       return s;
+}
+
+
+func TestAdd() {
+       test_msg = "TestAddA";
+
+       test_msg = "TestAddB";
+       for i := uint(0); i < 100; i++ {
+               t := Big.Nat(uint64(i));
+               TEST_EQ(i, Sum(i, c), t.Mul(t).Add(t).Shr(1).Mul(c));
+       }
 }
 
 
@@ -163,7 +203,9 @@ func TestPop() {
 
 
 func main() {
+       TestLog2();
        TestConv();
+       TestAdd();
        TestShift();
        TestMul();
        TestDiv();