]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/elliptic: implement MarshalCompressed and UnmarshalCompressed
authorEvgeniy Kulikov <tuxuls@gmail.com>
Thu, 13 Feb 2020 12:55:07 +0000 (12:55 +0000)
committerFilippo Valsorda <filippo@golang.org>
Thu, 7 May 2020 23:41:27 +0000 (23:41 +0000)
Fixes #34105

Co-authored-by: Filippo Valsorda <filippo@golang.org>
Change-Id: I3470343ec9ce9a0bd5593a04d3ba5816b75d3332
GitHub-Last-Rev: 9b38b0a7f8cef7f001fe9126a1cfcb4990f7b996
GitHub-Pull-Request: golang/go#35110
Reviewed-on: https://go-review.googlesource.com/c/go/+/202819
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
src/crypto/elliptic/elliptic.go
src/crypto/elliptic/elliptic_test.go

index bd5168c5fd842cf44b5db9c5498198a9760999ad..8735d3acf62c0f00ca4ad396131c4ff528a30d4d 100644 (file)
@@ -52,11 +52,8 @@ func (curve *CurveParams) Params() *CurveParams {
        return curve
 }
 
-func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
-       // y² = x³ - 3x + b
-       y2 := new(big.Int).Mul(y, y)
-       y2.Mod(y2, curve.P)
-
+// polynomial returns x³ - 3x + b.
+func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
        x3 := new(big.Int).Mul(x, x)
        x3.Mul(x3, x)
 
@@ -67,7 +64,15 @@ func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
        x3.Add(x3, curve.B)
        x3.Mod(x3, curve.P)
 
-       return x3.Cmp(y2) == 0
+       return x3
+}
+
+func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
+       // y² = x³ - 3x + b
+       y2 := new(big.Int).Mul(y, y)
+       y2.Mod(y2, curve.P)
+
+       return curve.polynomial(x).Cmp(y2) == 0
 }
 
 // zForAffine returns a Jacobian Z value for the affine point (x, y). If x and
@@ -315,16 +320,25 @@ func Marshal(curve Curve, x, y *big.Int) []byte {
        return ret
 }
 
+// MarshalCompressed converts a point into the compressed form specified in section 4.3.6 of ANSI X9.62.
+func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
+       byteLen := (curve.Params().BitSize + 7) / 8
+       compressed := make([]byte, 1+byteLen)
+       compressed[0] = byte(y.Bit(0)) | 2
+       x.FillBytes(compressed[1:])
+       return compressed
+}
+
 // Unmarshal converts a point, serialized by Marshal, into an x, y pair.
 // It is an error if the point is not in uncompressed form or is not on the curve.
 // On error, x = nil.
 func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
        byteLen := (curve.Params().BitSize + 7) / 8
        if len(data) != 1+2*byteLen {
-               return
+               return nil, nil
        }
        if data[0] != 4 { // uncompressed form
-               return
+               return nil, nil
        }
        p := curve.Params().P
        x = new(big.Int).SetBytes(data[1 : 1+byteLen])
@@ -338,6 +352,37 @@ func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
        return
 }
 
+// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into an x, y pair.
+// It is an error if the point is not in compressed form or is not on the curve.
+// On error, x = nil.
+func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) {
+       byteLen := (curve.Params().BitSize + 7) / 8
+       if len(data) != 1+byteLen {
+               return nil, nil
+       }
+       if data[0] != 2 && data[0] != 3 { // compressed form
+               return nil, nil
+       }
+       p := curve.Params().P
+       x = new(big.Int).SetBytes(data[1:])
+       if x.Cmp(p) >= 0 {
+               return nil, nil
+       }
+       // y² = x³ - 3x + b
+       y = curve.Params().polynomial(x)
+       y = y.ModSqrt(y, p)
+       if y == nil {
+               return nil, nil
+       }
+       if byte(y.Bit(0)) != data[0]&1 {
+               y.Neg(y).Mod(y, p)
+       }
+       if !curve.IsOnCurve(x, y) {
+               return nil, nil
+       }
+       return
+}
+
 var initonce sync.Once
 var p384 *CurveParams
 var p521 *CurveParams
index 09c5483520ee5cbe99430c453586e7f69395d8a1..45c2fb63f5bc5e4c9f5240cf6f126b167a71aee2 100644 (file)
@@ -5,6 +5,7 @@
 package elliptic
 
 import (
+       "bytes"
        "crypto/rand"
        "encoding/hex"
        "fmt"
@@ -628,3 +629,74 @@ func TestUnmarshalToLargeCoordinates(t *testing.T) {
                t.Errorf("Unmarshal accepts invalid Y coordinate")
        }
 }
+
+func TestMarshalCompressed(t *testing.T) {
+       t.Run("P-256/03", func(t *testing.T) {
+               data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
+               x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
+               y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
+               testMarshalCompressed(t, P256(), x, y, data)
+       })
+       t.Run("P-256/02", func(t *testing.T) {
+               data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
+               x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
+               y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
+               testMarshalCompressed(t, P256(), x, y, data)
+       })
+
+       t.Run("Invalid", func(t *testing.T) {
+               data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
+               X, Y := UnmarshalCompressed(P256(), data)
+               if X != nil || Y != nil {
+                       t.Error("expected an error for invalid encoding")
+               }
+       })
+
+       if testing.Short() {
+               t.Skip("skipping other curves on short test")
+       }
+
+       t.Run("P-224", func(t *testing.T) {
+               _, x, y, err := GenerateKey(P224(), rand.Reader)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               testMarshalCompressed(t, P224(), x, y, nil)
+       })
+       t.Run("P-384", func(t *testing.T) {
+               _, x, y, err := GenerateKey(P384(), rand.Reader)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               testMarshalCompressed(t, P384(), x, y, nil)
+       })
+       t.Run("P-521", func(t *testing.T) {
+               _, x, y, err := GenerateKey(P521(), rand.Reader)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               testMarshalCompressed(t, P521(), x, y, nil)
+       })
+}
+
+func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
+       if !curve.IsOnCurve(x, y) {
+               t.Fatal("invalid test point")
+       }
+       got := MarshalCompressed(curve, x, y)
+       if want != nil && !bytes.Equal(got, want) {
+               t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
+       }
+
+       X, Y := UnmarshalCompressed(curve, got)
+       if X == nil || Y == nil {
+               t.Fatalf("UnmarshalCompressed failed unexpectedly")
+       }
+
+       if !curve.IsOnCurve(X, Y) {
+               t.Error("UnmarshalCompressed returned a point not on the curve")
+       }
+       if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
+               t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
+       }
+}