]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/elliptic: document and test that IsOnCurve(∞) == false
authorFilippo Valsorda <filippo@golang.org>
Tue, 23 Jun 2020 22:14:25 +0000 (18:14 -0400)
committerFilippo Valsorda <filippo@golang.org>
Wed, 24 Jun 2020 16:12:25 +0000 (16:12 +0000)
This also implies it can't be passed to Marshal.

Fixes #37294

Change-Id: I1e6b6abd87ff31f323486958d5cb34a5c8f76b5f
Reviewed-on: https://go-review.googlesource.com/c/go/+/239562
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
src/crypto/elliptic/elliptic.go
src/crypto/elliptic/elliptic_test.go

index 8735d3acf62c0f00ca4ad396131c4ff528a30d4d..f93dc16419f8f95fc215ab969aa323c0c02f4409 100644 (file)
@@ -20,7 +20,10 @@ import (
 )
 
 // A Curve represents a short-form Weierstrass curve with a=-3.
-// See https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
+//
+// Note that the point at infinity (0, 0) is not considered on the curve, and
+// although it can be returned by Add, Double, ScalarMult, or ScalarBaseMult, it
+// can't be marshaled or unmarshaled, and IsOnCurve will return false for it.
 type Curve interface {
        // Params returns the parameters for the curve.
        Params() *CurveParams
@@ -307,7 +310,8 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e
        return
 }
 
-// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
+// Marshal converts a point on the curve into the uncompressed form specified in
+// section 4.3.6 of ANSI X9.62.
 func Marshal(curve Curve, x, y *big.Int) []byte {
        byteLen := (curve.Params().BitSize + 7) / 8
 
@@ -320,7 +324,8 @@ func Marshal(curve Curve, x, y *big.Int) []byte {
        return ret
 }
 
-// MarshalCompressed converts a point into the compressed form specified in section 4.3.6 of ANSI X9.62.
+// MarshalCompressed converts a point on the curve into the compressed form
+// specified in section 4.3.6 of ANSI X9.62.
 func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
        byteLen := (curve.Params().BitSize + 7) / 8
        compressed := make([]byte, 1+byteLen)
index 45c2fb63f5bc5e4c9f5240cf6f126b167a71aee2..e80e7731aaaed03b15dcb698722738a99c4cf8e2 100644 (file)
@@ -418,41 +418,62 @@ func TestP256Mult(t *testing.T) {
        }
 }
 
-func TestInfinity(t *testing.T) {
-       tests := []struct {
-               name  string
-               curve Curve
-       }{
-               {"p224", P224()},
-               {"p256", P256()},
+func testInfinity(t *testing.T, curve Curve) {
+       _, x, y, _ := GenerateKey(curve, rand.Reader)
+       x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
+       if x.Sign() != 0 || y.Sign() != 0 {
+               t.Errorf("x^q != ∞")
        }
 
-       for _, test := range tests {
-               curve := test.curve
-               x, y := curve.ScalarBaseMult(nil)
-               if x.Sign() != 0 || y.Sign() != 0 {
-                       t.Errorf("%s: x^0 != ∞", test.name)
-               }
+       x, y = curve.ScalarBaseMult([]byte{0})
+       if x.Sign() != 0 || y.Sign() != 0 {
+               t.Errorf("b^0 != ∞")
                x.SetInt64(0)
                y.SetInt64(0)
+       }
 
-               x2, y2 := curve.Double(x, y)
-               if x2.Sign() != 0 || y2.Sign() != 0 {
-                       t.Errorf("%s: 2∞ != ∞", test.name)
-               }
+       x2, y2 := curve.Double(x, y)
+       if x2.Sign() != 0 || y2.Sign() != 0 {
+               t.Errorf("2∞ != ∞")
+       }
 
-               baseX := curve.Params().Gx
-               baseY := curve.Params().Gy
+       baseX := curve.Params().Gx
+       baseY := curve.Params().Gy
 
-               x3, y3 := curve.Add(baseX, baseY, x, y)
-               if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
-                       t.Errorf("%s: x+∞ != x", test.name)
-               }
+       x3, y3 := curve.Add(baseX, baseY, x, y)
+       if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
+               t.Errorf("x+∞ != x")
+       }
 
-               x4, y4 := curve.Add(x, y, baseX, baseY)
-               if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
-                       t.Errorf("%s: ∞+x != x", test.name)
-               }
+       x4, y4 := curve.Add(x, y, baseX, baseY)
+       if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
+               t.Errorf("∞+x != x")
+       }
+
+       if curve.IsOnCurve(x, y) {
+               t.Errorf("IsOnCurve(∞) == true")
+       }
+}
+
+func TestInfinity(t *testing.T) {
+       tests := []struct {
+               name  string
+               curve Curve
+       }{
+               {"P-224", P224()},
+               {"P-256", P256()},
+               {"P-256/Generic", P256().Params()},
+               {"P-384", P384()},
+               {"P-521", P521()},
+       }
+       if testing.Short() {
+               tests = tests[:1]
+       }
+       for _, test := range tests {
+               curve := test.curve
+               t.Run(test.name, func(t *testing.T) {
+                       testInfinity(t, curve)
+               })
        }
 }