const (
- L3 = LogB / 3;
- B3 = 1 << L3;
- M3 = B3 - 1;
-
- L2 = L3 * 2;
+ L2 = LogB / 2;
B2 = 1 << L2;
M2 = B2 - 1;
- L = L3 * 3;
+ L = L2 * 2;
B = 1 << L;
M = B - 1;
)
type (
- Digit3 uint32;
+ Digit2 uint32;
Digit uint64;
)
}
-func IsSmall(x Digit) bool {
- return x < 1<<LogH;
+// ----------------------------------------------------------------------------
+// Raw operations
+
+func And1(z, x *[]Digit, y Digit) {
+ for i := len(x) - 1; i >= 0; i-- {
+ z[i] = x[i] & y;
+ }
}
-func Split(x Digit) (Digit, Digit) {
- return x>>L, x&M;
+func And(z, x, y *[]Digit) {
+ for i := len(x) - 1; i >= 0; i-- {
+ z[i] = x[i] & y[i];
+ }
}
-export func Dump(x *[]Digit) {
- print("[", len(x), "]");
+func Or1(z, x *[]Digit, y Digit) {
for i := len(x) - 1; i >= 0; i-- {
- print(" ", x[i]);
+ z[i] = x[i] | y;
}
- println();
}
-export func Dump3(x *[]Digit3) {
+func Or(z, x, y *[]Digit) {
+ for i := len(x) - 1; i >= 0; i-- {
+ z[i] = x[i] | y[i];
+ }
+}
+
+
+func Xor1(z, x *[]Digit, y Digit) {
+ for i := len(x) - 1; i >= 0; i-- {
+ z[i] = x[i] ^ y;
+ }
+}
+
+
+func Xor(z, x, y *[]Digit) {
+ for i := len(x) - 1; i >= 0; i-- {
+ z[i] = x[i] ^ y[i];
+ }
+}
+
+
+func Add1(z, x *[]Digit, c Digit) Digit {
+ n := len(x);
+ for i := 0; i < n; i++ {
+ t := c + x[i];
+ c, z[i] = t>>L, t&M
+ }
+ return c;
+}
+
+
+func Add(z, x, y *[]Digit) Digit {
+ var c Digit;
+ n := len(x);
+ for i := 0; i < n; i++ {
+ t := c + x[i] + y[i];
+ c, z[i] = t>>L, t&M
+ }
+ return c;
+}
+
+
+func Sub1(z, x *[]Digit, c Digit) Digit {
+ n := len(x);
+ for i := 0; i < n; i++ {
+ t := c + x[i];
+ c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
+ }
+ return c;
+}
+
+
+func Sub(z, x, y *[]Digit) Digit {
+ var c Digit;
+ n := len(x);
+ for i := 0; i < n; i++ {
+ t := c + x[i] - y[i];
+ c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
+ }
+ return c;
+}
+
+
+// Returns c = x*y div B, z = x*y mod B.
+func Mul11(x, y Digit) (Digit, Digit) {
+ // Split x and y into 2 sub-digits each (in base sqrt(B)),
+ // multiply the digits separately while avoiding overflow,
+ // and return the product as two separate digits.
+
+ const L0 = (L + 1)/2;
+ const L1 = L - L0;
+ const DL = L0 - L1; // 0 or 1
+ const b = 1<<L0;
+ const m = b - 1;
+
+ // split x and y into sub-digits
+ // x = (x1*b + x0)
+ // y = (y1*b + y0)
+ x1, x0 := x>>L0, x&m;
+ y1, y0 := y>>L0, y&m;
+
+ // x*y = t2*b^2 + t1*b + t0
+ t0 := x0*y0;
+ t1 := x1*y0 + x0*y1;
+ t2 := x1*y1;
+
+ // compute the result digits but avoid overflow
+ // z = z1*B + z0 = x*y
+ z0 := (t1<<L0 + t0)&M;
+ z1 := t2<<DL + (t1 + t0>>L0)>>L1;
+
+ return z1, z0;
+}
+
+
+func Mul(z, x, y *[]Digit) {
+ n := len(x);
+ m := len(y);
+ for j := 0; j < m; j++ {
+ d := y[j];
+ if d != 0 {
+ c := Digit(0);
+ for i := 0; i < n; i++ {
+ // z[i+j] += c + x[i]*d;
+ z1, z0 := Mul11(x[i], d);
+ t := c + z[i+j] + z0;
+ c, z[i+j] = t>>L, t&M;
+ c += z1;
+ }
+ z[n+j] = c;
+ }
+ }
+}
+
+
+func Mul1(z, x *[]Digit2, y Digit2) Digit2 {
+ n := len(x);
+ var c Digit;
+ f := Digit(y);
+ for i := 0; i < n; i++ {
+ t := c + Digit(x[i])*f;
+ c, z[i] = t>>L2, Digit2(t&M2);
+ }
+ return Digit2(c);
+}
+
+
+func Div1(z, x *[]Digit2, y Digit2) Digit2 {
+ n := len(x);
+ var c Digit;
+ d := Digit(y);
+ for i := n-1; i >= 0; i-- {
+ t := c*B2 + Digit(x[i]);
+ c, z[i] = t%d, Digit2(t/d);
+ }
+ return Digit2(c);
+}
+
+
+func Shl(z, x *[]Digit, s uint) Digit {
+ assert(s <= L);
+ n := len(x);
+ var c Digit;
+ for i := 0; i < n; i++ {
+ c, z[i] = x[i] >> (L-s), x[i] << s & M | c;
+ }
+ return c;
+}
+
+
+func Shr(z, x *[]Digit, s uint) Digit {
+ assert(s <= L);
+ n := len(x);
+ var c Digit;
+ for i := n - 1; i >= 0; i-- {
+ c, z[i] = x[i] << (L-s) & M, x[i] >> s | c;
+ }
+ return c;
+}
+
+
+// ----------------------------------------------------------------------------
+// Support
+
+func IsSmall(x Digit) bool {
+ return x < 1<<LogH;
+}
+
+
+func Split(x Digit) (Digit, Digit) {
+ return x>>L, x&M;
+}
+
+
+export func Dump(x *[]Digit) {
print("[", len(x), "]");
for i := len(x) - 1; i >= 0; i-- {
print(" ", x[i]);
}
-func Normalize3(x *[]Digit3) *[]Digit3 {
+func Normalize2(x *[]Digit2) *[]Digit2 {
n := len(x);
for n > 0 && x[n - 1] == 0 { n-- }
if n < len(x) {
}
-func (x *Natural) IsZero() bool {
- return len(x) == 0;
-}
+// Predicates
+
+func (x *Natural) IsZero() bool { return len(x) == 0; }
+func (x *Natural) IsOdd() bool { return len(x) > 0 && x[0]&1 != 0; }
func (x *Natural) Add(y *Natural) *Natural {
if n < m {
return y.Add(x);
}
- assert(n >= m);
- z := new(Natural, n + 1);
- c := Digit(0);
- for i := 0; i < m; i++ { c, z[i] = Split(c + x[i] + y[i]); }
- for i := m; i < n; i++ { c, z[i] = Split(c + x[i]); }
- z[n] = c;
+ z := new(Natural, n + 1);
+ c := Add(z[0 : m], x[0 : m], y);
+ z[n] = Add1(z[m : n], x[m : n], c);
return Normalize(z);
}
func (x *Natural) Sub(y *Natural) *Natural {
n := len(x);
m := len(y);
- assert(n >= m);
- z := new(Natural, n);
-
- c := Digit(0);
- for i := 0; i < m; i++ {
- t := c + x[i] - y[i];
- c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
+ if n < m {
+ panic("underflow")
}
- for i := m; i < n; i++ {
- t := c + x[i];
- c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
+
+ z := new(Natural, n);
+ c := Sub(z[0 : m], x[0 : m], y);
+ if Sub1(z[m : n], x[m : n], c) != 0 {
+ panic("underflow");
}
- assert(c == 0); // x.Sub(y) must be called with x >= y
return Normalize(z);
}
}
-// Returns c = x*y div B, z = x*y mod B.
-func Mul1(x, y Digit) (Digit, Digit) {
- // Split x and y into 2 sub-digits each (in base sqrt(B)),
- // multiply the digits separately while avoiding overflow,
- // and return the product as two separate digits.
-
- const L0 = (L + 1)/2;
- const L1 = L - L0;
- const DL = L0 - L1; // 0 or 1
- const b = 1<<L0;
- const m = b - 1;
-
- // split x and y into sub-digits
- // x = (x1*b + x0)
- // y = (y1*b + y0)
- x1, x0 := x>>L0, x&m;
- y1, y0 := y>>L0, y&m;
-
- // x*y = t2*b^2 + t1*b + t0
- t0 := x0*y0;
- t1 := x1*y0 + x0*y1;
- t2 := x1*y1;
-
- // compute the result digits but avoid overflow
- // z = z1*B + z0 = x*y
- z0 := (t1<<L0 + t0)&M;
- z1 := t2<<DL + (t1 + t0>>L0)>>L1;
-
- return z1, z0;
-}
-
-
func (x *Natural) Mul(y *Natural) *Natural {
n := len(x);
m := len(y);
- z := new(Natural, n + m);
- for j := 0; j < m; j++ {
- d := y[j];
- if d != 0 {
- c := Digit(0);
- for i := 0; i < n; i++ {
- // z[i+j] += c + x[i]*d;
- z1, z0 := Mul1(x[i], d);
- c, z[i+j] = Split(c + z[i+j] + z0);
- c += z1;
- }
- z[n+j] = c;
- }
- }
+ z := new(Natural, n + m);
+ Mul(z, x, y);
return Normalize(z);
}
}
-func Shl1(x, c Digit, s uint) (Digit, Digit) {
- assert(s <= LogB);
- return x >> (LogB - s), x << s & M | c
-}
-
-
-func Shr1(x, c Digit, s uint) (Digit, Digit) {
- assert(s <= LogB);
- return x << (LogB - s) & M, x >> s | c
-}
-
-
func (x *Natural) Shl(s uint) *Natural {
- n := len(x);
- si := int(s / LogB);
- s = s % LogB;
- z := new(Natural, n + si + 1);
+ n := uint(len(x));
+ m := n + s/L;
+ z := new(Natural, m+1);
- c := Digit(0);
- for i := 0; i < n; i++ { c, z[i+si] = Shl1(x[i], c, s); }
- z[n+si] = c;
+ z[m] = Shl(z[m-n : m], x, s%L);
return Normalize(z);
}
func (x *Natural) Shr(s uint) *Natural {
- n := len(x);
- si := int(s / LogB);
- if si >= n { si = n; }
- s = s % LogB;
- assert(si <= n);
- z := new(Natural, n - si);
+ n := uint(len(x));
+ m := n - s/L;
+ if m > n { // check for underflow
+ m = 0;
+ }
+ z := new(Natural, m);
- c := Digit(0);
- for i := n - 1; i >= si; i-- { c, z[i-si] = Shr1(x[i], c, s); }
+ Shr(z, x[n-m : n], s%L);
return Normalize(z);
}
// DivMod needs multi-precision division which is not available if Digit
// is already using the largest uint size. Split base before division,
-// and merge again after. Each Digit is split into 3 Digit3's.
+// and merge again after. Each Digit is split into 2 Digit2's.
-func SplitBase(x *Natural) *[]Digit3 {
- // TODO Use Log() for better result - don't need Normalize3 at the end!
+func Unpack(x *Natural) *[]Digit2 {
+ // TODO Use Log() for better result - don't need Normalize2 at the end!
n := len(x);
- z := new([]Digit3, n*3 + 1); // add space for extra digit (used by DivMod)
- for i, j := 0, 0; i < n; i, j = i+1, j+3 {
+ z := new([]Digit2, n*2 + 1); // add space for extra digit (used by DivMod)
+ for i := 0; i < n; i++ {
t := x[i];
- z[j+0] = Digit3(t >> (L3*0) & M3);
- z[j+1] = Digit3(t >> (L3*1) & M3);
- z[j+2] = Digit3(t >> (L3*2) & M3);
+ z[i*2] = Digit2(t & M2);
+ z[i*2 + 1] = Digit2(t >> L2 & M2);
}
- return Normalize3(z);
+ return Normalize2(z);
}
-func MergeBase(x *[]Digit3) *Natural {
- i := len(x);
- j := (i+2)/3;
- z := new(Natural, j);
-
- switch i%3 {
- case 1: z[j-1] = Digit(x[i-1]); i--; j--;
- case 2: z[j-1] = Digit(x[i-1])<<L3 | Digit(x[i-2]); i -= 2; j--;
- case 0:
+func Pack(x *[]Digit2) *Natural {
+ n := (len(x) + 1) / 2;
+ z := new(Natural, n);
+ if len(x) & 1 == 1 {
+ // handle odd len(x)
+ n--;
+ z[n] = Digit(x[n*2]);
}
-
- for i >= 3 {
- z[j-1] = ((Digit(x[i-1])<<L3) | Digit(x[i-2]))<<L3 | Digit(x[i-3]);
- i -= 3;
- j--;
+ for i := 0; i < n; i++ {
+ z[i] = Digit(x[i*2 + 1]) << L2 | Digit(x[i*2]);
}
- assert(j == 0);
-
return Normalize(z);
}
-func Split3(x Digit) (Digit, Digit3) {
- return uint64(int64(x)>>L3), Digit3(x&M3)
-}
-
-
-func Product(x *[]Digit3, y Digit) {
- n := len(x);
- c := Digit(0);
- for i := 0; i < n; i++ { c, x[i] = Split3(c + Digit(x[i])*y) }
- assert(c == 0);
-}
-
-
-func Quotient(x *[]Digit3, y Digit) {
- n := len(x);
- c := Digit(0);
- for i := n-1; i >= 0; i-- {
- t := c*B3 + Digit(x[i]);
- c, x[i] = t%y, Digit3(t/y);
- }
- assert(c == 0);
-}
-
-
// Division and modulo computation - destroys x and y. Based on the
// algorithms described in:
//
// is described in 1), while 2) provides the background for a more
// accurate initial guess of the trial digit.
-func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
- const b = B3;
+func DivMod2(x, y *[]Digit2) (*[]Digit2, *[]Digit2) {
+ const b = B2;
n := len(x);
m := len(y);
if m == 1 {
// division by single digit
- d := Digit(y[0]);
- c := Digit(0);
- for i := n; i > 0; i-- {
- t := c*b + Digit(x[i-1]);
- c, x[i] = t%d, Digit3(t/d);
- }
- x[0] = Digit3(c);
-
+ // result is shifted left by 1 in place!
+ x[0] = Div1(x[1 : n+1], x[0 : n], y[0]);
+
} else if m > n {
// quotient = 0, remainder = x
// TODO in this case we shouldn't even split base - FIX THIS
// normalize x and y
f := b/(Digit(y[m-1]) + 1);
- Product(x, f);
- Product(y, f);
+ Mul1(x, x, Digit2(f));
+ Mul1(y, y, Digit2(f));
assert(b/2 <= y[m-1] && y[m-1] < b); // incorrect scaling
- d2 := Digit(y[m-1])*b + Digit(y[m-2]);
+ y1, y2 := Digit(y[m-1]), Digit(y[m-2]);
+ d2 := Digit(y1)*b + Digit(y2);
for i := n-m; i >= 0; i-- {
k := i+m;
// compute trial digit
- r3 := (Digit(x[k])*b + Digit(x[k-1]))*b + Digit(x[k-2]);
- q := r3/d2;
- if q >= b { q = b-1 }
+ var q Digit;
+ { // Knuth
+ x0, x1, x2 := Digit(x[k]), Digit(x[k-1]), Digit(x[k-2]);
+ if x0 != y1 {
+ q = (x0*b + x1)/y1;
+ } else {
+ q = b-1;
+ }
+ for y2 * q > (x0*b + x1 - y1*q)*b + x2 {
+ q--
+ }
+ }
// subtract y*q
c := Digit(0);
for j := 0; j < m; j++ {
t := c + Digit(x[i+j]) - Digit(y[j])*q; // arithmetic shift!
- c, x[i+j] = Digit(int64(t)>>L3), Digit3(t&M3);
+ c, x[i+j] = Digit(int64(t)>>L2), Digit2(t&M2);
}
// correct if trial digit was too large
// add y
c := Digit(0);
for j := 0; j < m; j++ {
- c, x[i+j] = Split3(c + Digit(x[i+j]) + Digit(y[j]));
+ t := c + Digit(x[i+j]) + Digit(y[j]);
+ c, x[i+j] = uint64(int64(t) >> L2), Digit2(t & M2)
}
assert(c + Digit(x[k]) == 0);
// correct trial digit
q--;
}
- x[k] = Digit3(q);
+ x[k] = Digit2(q);
}
// undo normalization for remainder
- Quotient(x[0 : m], f);
+ c := Div1(x[0 : m], x[0 : m], Digit2(f));
+ assert(c == 0);
}
return x[m : n+1], x[0 : m];
func (x *Natural) Div(y *Natural) *Natural {
- q, r := DivMod(SplitBase(x), SplitBase(y));
- return MergeBase(q);
+ q, r := DivMod2(Unpack(x), Unpack(y));
+ return Pack(q);
}
func (x *Natural) Mod(y *Natural) *Natural {
- q, r := DivMod(SplitBase(x), SplitBase(y));
- return MergeBase(r);
+ q, r := DivMod2(Unpack(x), Unpack(y));
+ return Pack(r);
+}
+
+
+func (x *Natural) DivMod(y *Natural) (*Natural, *Natural) {
+ q, r := DivMod2(Unpack(x), Unpack(y));
+ return Pack(q), Pack(r);
}
}
-func Log1(x Digit) int {
+func Log2(x Digit) int {
n := -1;
for x != 0 { x = x >> 1; n++; } // BUG >>= broken for uint64
return n;
}
-func (x *Natural) Log() int {
+func (x *Natural) Log2() int {
n := len(x);
if n > 0 {
- n = (n - 1)*L + Log1(x[n - 1]);
+ n = (n - 1)*L + Log2(x[n - 1]);
} else {
n = -1;
}
if n < m {
return y.And(x);
}
- assert(n >= m);
- z := new(Natural, n);
- for i := 0; i < m; i++ { z[i] = x[i] & y[i]; }
- for i := m; i < n; i++ { z[i] = x[i]; }
+ z := new(Natural, n);
+ And(z[0 : m], x[0 : m], y);
+ Or1(z[m : n], x[m : n], 0);
return Normalize(z);
}
if n < m {
return y.Or(x);
}
- assert(n >= m);
- z := new(Natural, n);
- for i := 0; i < m; i++ { z[i] = x[i] | y[i]; }
- for i := m; i < n; i++ { z[i] = x[i]; }
+ z := new(Natural, n);
+ Or(z[0 : m], x[0 : m], y);
+ Or1(z[m : n], x[m : n], 0);
return Normalize(z);
}
if n < m {
return y.Xor(x);
}
- assert(n >= m);
- z := new(Natural, n);
- for i := 0; i < m; i++ { z[i] = x[i] ^ y[i]; }
- for i := m; i < n; i++ { z[i] = x[i]; }
+ z := new(Natural, n);
+ Xor(z[0 : m], x[0 : m], y);
+ Or1(z[m : n], x[m : n], 0);
return Normalize(z);
}
-func Copy(x *Natural) *Natural {
- z := new(Natural, len(x));
- //*z = *x; // BUG assignment does't work yet
- for i := len(x) - 1; i >= 0; i-- { z[i] = x[i]; }
- return z;
-}
-
-
// Computes x = x div d (in place - the recv maybe modified) for "small" d's.
// Returns updated x and x mod d.
func (x *Natural) DivMod1(d Digit) (*Natural, Digit) {
c := Digit(0);
for i := len(x) - 1; i >= 0; i-- {
- c = c<<L + x[i];
- x[i] = c/d;
- c %= d;
+ t := c<<L + x[i];
+ c, x[i] = t%d, t/d;
}
return Normalize(x), c;
}
-func (x *Natural) String(base Digit) string {
+func (x *Natural) String(base uint) string {
if x.IsZero() {
return "0";
}
// allocate string
- // TODO n is too small for bases < 10!!!
- assert(base >= 10); // for now
- // approx. length: 1 char for 3 bits
- n := x.Log()/3 + 10; // +10 (round up) - what is the right number?
+ assert(2 <= base && base <= 16);
+ n := (x.Log2() + 1) / Log2(Digit(base)) + 1; // TODO why the +1?
s := new([]byte, n);
// convert
- const hex = "0123456789abcdef";
+
+ // don't destroy x, make a copy
+ t := new(Natural, len(x));
+ Or1(t, x, 0); // copy x
+
i := n;
- x = Copy(x); // don't destroy recv
- for !x.IsZero() {
+ for !t.IsZero() {
i--;
var d Digit;
- x, d = x.DivMod1(base);
- s[i] = hex[d];
+ t, d = t.DivMod1(Digit(base));
+ s[i] = "0123456789abcdef"[d];
};
return string(s[i : n]);
}
-func HexValue(ch byte) Digit {
- d := Digit(1 << LogH);
+func HexValue(ch byte) uint {
+ d := uint(1 << LogH);
switch {
- case '0' <= ch && ch <= '9': d = Digit(ch - '0');
- case 'a' <= ch && ch <= 'f': d = Digit(ch - 'a') + 10;
- case 'A' <= ch && ch <= 'F': d = Digit(ch - 'A') + 10;
+ case '0' <= ch && ch <= '9': d = uint(ch - '0');
+ case 'a' <= ch && ch <= 'f': d = uint(ch - 'a') + 10;
+ case 'A' <= ch && ch <= 'F': d = uint(ch - 'A') + 10;
}
return d;
}
// TODO auto-detect base if base argument is 0
-export func NatFromString(s string, base Digit) *Natural {
+export func NatFromString(s string, base uint) *Natural {
x := NatZero;
for i := 0; i < len(s); i++ {
d := HexValue(s[i]);
if d < base {
- x = x.MulAdd1(base, d);
+ x = x.MulAdd1(Digit(base), Digit(d));
} else {
break;
}
func (x *Integer) Quo(y *Integer) *Integer {
- panic("UNIMPLEMENTED");
- return nil;
+ // x / y == x / y
+ // x / (-y) == -(x / y)
+ // (-x) / y == -(x / y)
+ // (-x) / (-y) == x / y
+ return &Integer{x.sign != y.sign, x.mant.Div(y.mant)};
}
func (x *Integer) Rem(y *Integer) *Integer {
- panic("UNIMPLEMENTED");
- return nil;
+ // x % y == x % y
+ // x % (-y) == x % y
+ // (-x) % y == -(x % y)
+ // (-x) % (-y) == -(x % y)
+ return &Integer{y.sign, x.mant.Mod(y.mant)};
+}
+
+
+func (x *Integer) QuoRem(y *Integer) (*Integer, *Integer) {
+ q, r := x.mant.DivMod(y.mant);
+ return &Integer{x.sign != y.sign, q}, &Integer{y.sign, q};
}
func (x *Integer) Div(y *Integer) *Integer {
- panic("UNIMPLEMENTED");
+ q, r := x.mant.DivMod(y.mant);
return nil;
}
}
-func (x *Integer) String(base Digit) string {
+func (x *Integer) String(base uint) string {
if x.mant.IsZero() {
return "0";
}
}
-export func IntFromString(s string, base Digit) *Integer {
+export func IntFromString(s string, base uint) *Integer {
// get sign, if any
sign := false;
if len(s) > 0 && (s[0] == '-' || s[0] == '+') {