]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: Allow non-prime modulus for ModInverse
authorKeith Randall <khr@golang.org>
Tue, 14 Oct 2014 21:09:56 +0000 (14:09 -0700)
committerKeith Randall <khr@golang.org>
Tue, 14 Oct 2014 21:09:56 +0000 (14:09 -0700)
The inverse is defined whenever the element and the
modulus are relatively prime.  The code already handles
this situation, but the spec does not.

Test that it does indeed work.

Fixes #8875

LGTM=agl
R=agl
CC=golang-codereviews
https://golang.org/cl/155010043

src/math/big/int.go
src/math/big/int_test.go

index fc53719d710a7fe4a8632945fa6f7379474a5af0..d22e39e7c94fa3591b61236eb433e4130bdf0d86 100644 (file)
@@ -752,15 +752,16 @@ func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
        return z
 }
 
-// ModInverse sets z to the multiplicative inverse of g in the group ℤ/pℤ (where
-// p is a prime) and returns z.
-func (z *Int) ModInverse(g, p *Int) *Int {
+// ModInverse sets z to the multiplicative inverse of g in the ring ℤ/nℤ
+// and returns z. If g and n are not relatively prime, the result is undefined.
+func (z *Int) ModInverse(g, n *Int) *Int {
        var d Int
-       d.GCD(z, nil, g, p)
-       // x and y are such that g*x + p*y = d. Since p is prime, d = 1. Taking
-       // that modulo p results in g*x = 1, therefore x is the inverse element.
+       d.GCD(z, nil, g, n)
+       // x and y are such that g*x + n*y = d. Since g and n are
+       // relatively prime, d = 1. Taking that modulo n results in
+       // g*x = 1, therefore x is the inverse element.
        if z.neg {
-               z.Add(z, p)
+               z.Add(z, n)
        }
        return z
 }
index ec05fbb1c0cebde946db3d50b9c9a2b9928003dd..6070cf325d206ed6ec508cf3805ac3dfac9c23c6 100644 (file)
@@ -1448,24 +1448,40 @@ func TestNot(t *testing.T) {
 
 var modInverseTests = []struct {
        element string
-       prime   string
+       modulus string
 }{
-       {"1", "7"},
-       {"1", "13"},
+       {"1234567", "458948883992"},
        {"239487239847", "2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919"},
 }
 
 func TestModInverse(t *testing.T) {
-       var element, prime Int
+       var element, modulus, gcd, inverse Int
        one := NewInt(1)
        for i, test := range modInverseTests {
                (&element).SetString(test.element, 10)
-               (&prime).SetString(test.prime, 10)
-               inverse := new(Int).ModInverse(&element, &prime)
-               inverse.Mul(inverse, &element)
-               inverse.Mod(inverse, &prime)
-               if inverse.Cmp(one) != 0 {
-                       t.Errorf("#%d: failed (e·e^(-1)=%s)", i, inverse)
+               (&modulus).SetString(test.modulus, 10)
+               (&inverse).ModInverse(&element, &modulus)
+               (&inverse).Mul(&inverse, &element)
+               (&inverse).Mod(&inverse, &modulus)
+               if (&inverse).Cmp(one) != 0 {
+                       t.Errorf("#%d: failed (e·e^(-1)=%s)", i, &inverse)
+               }
+       }
+       // exhaustive test for small values
+       for n := 2; n < 100; n++ {
+               (&modulus).SetInt64(int64(n))
+               for x := 1; x < n; x++ {
+                       (&element).SetInt64(int64(x))
+                       (&gcd).GCD(nil, nil, &element, &modulus)
+                       if (&gcd).Cmp(one) != 0 {
+                               continue
+                       }
+                       (&inverse).ModInverse(&element, &modulus)
+                       (&inverse).Mul(&inverse, &element)
+                       (&inverse).Mod(&inverse, &modulus)
+                       if (&inverse).Cmp(one) != 0 {
+                               t.Errorf("ModInverse(%d,%d)*%d%%%d=%d, not 1", &element, &modulus, &element, &modulus, &inverse)
+                       }
                }
        }
 }