]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: fix big.Exp and document better
authorRobert Griesemer <gri@golang.org>
Tue, 16 Oct 2012 20:46:27 +0000 (13:46 -0700)
committerRobert Griesemer <gri@golang.org>
Tue, 16 Oct 2012 20:46:27 +0000 (13:46 -0700)
- always return 1 for y <= 0
- document that the sign of m is ignored
- protect against div-0 panics by treating
  m == 0 the same way as m == nil
- added extra tests

Fixes #4239.

R=agl, remyoudompheng, agl
CC=golang-dev
https://golang.org/cl/6724046

src/pkg/math/big/int.go
src/pkg/math/big/int_test.go
src/pkg/math/big/nat.go

index 4bd7958ae5045c8cf063fcffb5b92250af6d813f..caa23ae3d2e60c16a27c0feb5aaa3cdac6882231 100644 (file)
@@ -561,19 +561,18 @@ func (x *Int) BitLen() int {
        return x.abs.bitLen()
 }
 
-// Exp sets z = x**y mod m and returns z. If m is nil, z = x**y.
+// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z.
+// If y <= 0, the result is 1; if m == nil or m == 0, z = x**y.
 // See Knuth, volume 2, section 4.6.3.
 func (z *Int) Exp(x, y, m *Int) *Int {
        if y.neg || len(y.abs) == 0 {
-               neg := x.neg
-               z.SetInt64(1)
-               z.neg = neg
-               return z
+               return z.SetInt64(1)
        }
+       // y > 0
 
        var mWords nat
        if m != nil {
-               mWords = m.abs
+               mWords = m.abs // m.abs may be nil for m == 0
        }
 
        z.abs = z.abs.expNN(x.abs, y.abs, mWords)
index 27834cec6af1c8cd6a4afe431e3174cf0d0ce55b..d3c5a0e8bfe7481a4dd444d20bad7d49b2b8f54e 100644 (file)
@@ -767,8 +767,10 @@ var expTests = []struct {
        x, y, m string
        out     string
 }{
+       {"5", "-7", "", "1"},
+       {"-5", "-7", "", "1"},
        {"5", "0", "", "1"},
-       {"-5", "0", "", "-1"},
+       {"-5", "0", "", "1"},
        {"5", "1", "", "5"},
        {"-5", "1", "", "-5"},
        {"-2", "3", "2", "0"},
@@ -779,6 +781,7 @@ var expTests = []struct {
        {"0x8000000000000000", "3", "6719", "5447"},
        {"0x8000000000000000", "1000", "6719", "1603"},
        {"0x8000000000000000", "1000000", "6719", "3199"},
+       {"0x8000000000000000", "-1000000", "6719", "1"},
        {
                "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347",
                "298472983472983471903246121093472394872319615612417471234712061",
@@ -807,12 +810,22 @@ func TestExp(t *testing.T) {
                        continue
                }
 
-               z := y.Exp(x, y, m)
-               if !isNormalized(z) {
-                       t.Errorf("#%d: %v is not normalized", i, *z)
+               z1 := new(Int).Exp(x, y, m)
+               if !isNormalized(z1) {
+                       t.Errorf("#%d: %v is not normalized", i, *z1)
+               }
+               if z1.Cmp(out) != 0 {
+                       t.Errorf("#%d: got %s want %s", i, z1, out)
                }
-               if z.Cmp(out) != 0 {
-                       t.Errorf("#%d: got %s want %s", i, z, out)
+
+               if m == nil {
+                       // the result should be the same as for m == 0;
+                       // specifically, there should be no div-zero panic
+                       m = &Int{abs: nat{}} // m != nil && len(m.abs) == 0
+                       z2 := new(Int).Exp(x, y, m)
+                       if z2.Cmp(z1) != 0 {
+                               t.Errorf("#%d: got %s want %s", i, z1, z2)
+                       }
                }
        }
 }
index b2d6cd96c63f0bbd9a833f593844c9f8d36ff121..2e5c56d461b8c8a8e44df9e24fd101a6b1aab68f 100644 (file)
@@ -1227,8 +1227,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
        return z.norm()
 }
 
-// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
-// reuses the storage of z if possible.
+// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
+// otherwise it sets z to x**y. The result is the value of z.
 func (z nat) expNN(x, y, m nat) nat {
        if alias(z, x) || alias(z, y) {
                // We cannot allow in-place modification of x or y.
@@ -1240,15 +1240,15 @@ func (z nat) expNN(x, y, m nat) nat {
                z[0] = 1
                return z
        }
+       // y > 0
 
-       if m != nil {
+       if len(m) != 0 {
                // We likely end up being as long as the modulus.
                z = z.make(len(m))
        }
        z = z.set(x)
-       v := y[len(y)-1]
-       // It's invalid for the most significant word to be zero, therefore we
-       // will find a one bit.
+
+       v := y[len(y)-1] // v > 0 because y is normalized and y > 0
        shift := leadingZeros(v) + 1
        v <<= shift
        var q nat
@@ -1272,7 +1272,7 @@ func (z nat) expNN(x, y, m nat) nat {
                        zz, z = z, zz
                }
 
-               if m != nil {
+               if len(m) != 0 {
                        zz, r = zz.div(r, z, m)
                        zz, r, q, z = q, z, zz, r
                }
@@ -1292,7 +1292,7 @@ func (z nat) expNN(x, y, m nat) nat {
                                zz, z = z, zz
                        }
 
-                       if m != nil {
+                       if len(m) != 0 {
                                zz, r = zz.div(r, z, m)
                                zz, r, q, z = q, z, zz, r
                        }