]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: do not panic in Exp when y < 0 and x doesn't have an inverse
authorFilippo Valsorda <filippo@golang.org>
Thu, 4 Apr 2019 16:46:50 +0000 (12:46 -0400)
committerFilippo Valsorda <filippo@golang.org>
Thu, 4 Apr 2019 23:02:09 +0000 (23:02 +0000)
If x does not have an inverse modulo m, and a negative exponent is used,
return nil just like ModInverse does now.

Change-Id: I8fa72f7a851e8cf77c5fab529ede88408740626f
Reviewed-on: https://go-review.googlesource.com/c/go/+/170757
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/math/big/int.go
src/math/big/int_test.go

index afad1bc96178ca1f76de44a1c652621e7abb87b3..8e52f0ab27b0aa3ef7cb63e432ef4f46f14dcd7c 100644 (file)
@@ -463,7 +463,8 @@ func (x *Int) TrailingZeroBits() uint {
 }
 
 // Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z.
-// If m == nil or m == 0, z = x**y unless y <= 0 then z = 1.
+// If m == nil or m == 0, z = x**y unless y <= 0 then z = 1. If m > 0, y < 0,
+// and x and n are not relatively prime, z is unchanged and nil is returned.
 //
 // Modular exponentation of inputs of a particular size is not a
 // cryptographically constant-time operation.
@@ -475,7 +476,11 @@ func (z *Int) Exp(x, y, m *Int) *Int {
                        return z.SetInt64(1)
                }
                // for y < 0: x**y mod m == (x**(-1))**|y| mod m
-               xWords = new(Int).ModInverse(x, m).abs
+               inverse := new(Int).ModInverse(x, m)
+               if inverse == nil {
+                       return nil
+               }
+               xWords = inverse.abs
        }
        yWords := y.abs
 
index 2435b3610c03bb9b9e1d428e2d59af9e65cc9748..ade973b20743543b0ef1a4d60db23d84ff294542 100644 (file)
@@ -533,6 +533,9 @@ var expTests = []struct {
        {"1", "0", "", "1"},
        {"-10", "0", "", "1"},
        {"1234", "-1", "", "1"},
+       {"1234", "-1", "0", "1"},
+       {"17", "-100", "1234", "865"},
+       {"2", "-100", "1234", ""},
 
        // m == 1
        {"0", "0", "1", "0"},
@@ -605,10 +608,15 @@ func TestExp(t *testing.T) {
        for i, test := range expTests {
                x, ok1 := new(Int).SetString(test.x, 0)
                y, ok2 := new(Int).SetString(test.y, 0)
-               out, ok3 := new(Int).SetString(test.out, 0)
 
-               var ok4 bool
-               var m *Int
+               var ok3, ok4 bool
+               var out, m *Int
+
+               if len(test.out) == 0 {
+                       out, ok3 = nil, true
+               } else {
+                       out, ok3 = new(Int).SetString(test.out, 0)
+               }
 
                if len(test.m) == 0 {
                        m, ok4 = nil, true
@@ -622,10 +630,10 @@ func TestExp(t *testing.T) {
                }
 
                z1 := new(Int).Exp(x, y, m)
-               if !isNormalized(z1) {
+               if z1 != nil && !isNormalized(z1) {
                        t.Errorf("#%d: %v is not normalized", i, *z1)
                }
-               if z1.Cmp(out) != 0 {
+               if !(z1 == nil && out == nil || z1.Cmp(out) == 0) {
                        t.Errorf("#%d: got %x want %x", i, z1, out)
                }