From: Evgeniy Kulikov Date: Thu, 13 Feb 2020 12:55:07 +0000 (+0000) Subject: crypto/elliptic: implement MarshalCompressed and UnmarshalCompressed X-Git-Tag: go1.15beta1~191 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=5c13cab36b4667cc1a42667b16b8f049016586e0;p=gostls13.git crypto/elliptic: implement MarshalCompressed and UnmarshalCompressed Fixes #34105 Co-authored-by: Filippo Valsorda Change-Id: I3470343ec9ce9a0bd5593a04d3ba5816b75d3332 GitHub-Last-Rev: 9b38b0a7f8cef7f001fe9126a1cfcb4990f7b996 GitHub-Pull-Request: golang/go#35110 Reviewed-on: https://go-review.googlesource.com/c/go/+/202819 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Filippo Valsorda --- diff --git a/src/crypto/elliptic/elliptic.go b/src/crypto/elliptic/elliptic.go index bd5168c5fd..8735d3acf6 100644 --- a/src/crypto/elliptic/elliptic.go +++ b/src/crypto/elliptic/elliptic.go @@ -52,11 +52,8 @@ func (curve *CurveParams) Params() *CurveParams { return curve } -func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool { - // y² = x³ - 3x + b - y2 := new(big.Int).Mul(y, y) - y2.Mod(y2, curve.P) - +// polynomial returns x³ - 3x + b. +func (curve *CurveParams) polynomial(x *big.Int) *big.Int { x3 := new(big.Int).Mul(x, x) x3.Mul(x3, x) @@ -67,7 +64,15 @@ func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool { x3.Add(x3, curve.B) x3.Mod(x3, curve.P) - return x3.Cmp(y2) == 0 + return x3 +} + +func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool { + // y² = x³ - 3x + b + y2 := new(big.Int).Mul(y, y) + y2.Mod(y2, curve.P) + + return curve.polynomial(x).Cmp(y2) == 0 } // zForAffine returns a Jacobian Z value for the affine point (x, y). If x and @@ -315,16 +320,25 @@ func Marshal(curve Curve, x, y *big.Int) []byte { return ret } +// MarshalCompressed converts a point into the compressed form specified in section 4.3.6 of ANSI X9.62. +func MarshalCompressed(curve Curve, x, y *big.Int) []byte { + byteLen := (curve.Params().BitSize + 7) / 8 + compressed := make([]byte, 1+byteLen) + compressed[0] = byte(y.Bit(0)) | 2 + x.FillBytes(compressed[1:]) + return compressed +} + // Unmarshal converts a point, serialized by Marshal, into an x, y pair. // It is an error if the point is not in uncompressed form or is not on the curve. // On error, x = nil. func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { byteLen := (curve.Params().BitSize + 7) / 8 if len(data) != 1+2*byteLen { - return + return nil, nil } if data[0] != 4 { // uncompressed form - return + return nil, nil } p := curve.Params().P x = new(big.Int).SetBytes(data[1 : 1+byteLen]) @@ -338,6 +352,37 @@ func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { return } +// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into an x, y pair. +// It is an error if the point is not in compressed form or is not on the curve. +// On error, x = nil. +func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) { + byteLen := (curve.Params().BitSize + 7) / 8 + if len(data) != 1+byteLen { + return nil, nil + } + if data[0] != 2 && data[0] != 3 { // compressed form + return nil, nil + } + p := curve.Params().P + x = new(big.Int).SetBytes(data[1:]) + if x.Cmp(p) >= 0 { + return nil, nil + } + // y² = x³ - 3x + b + y = curve.Params().polynomial(x) + y = y.ModSqrt(y, p) + if y == nil { + return nil, nil + } + if byte(y.Bit(0)) != data[0]&1 { + y.Neg(y).Mod(y, p) + } + if !curve.IsOnCurve(x, y) { + return nil, nil + } + return +} + var initonce sync.Once var p384 *CurveParams var p521 *CurveParams diff --git a/src/crypto/elliptic/elliptic_test.go b/src/crypto/elliptic/elliptic_test.go index 09c5483520..45c2fb63f5 100644 --- a/src/crypto/elliptic/elliptic_test.go +++ b/src/crypto/elliptic/elliptic_test.go @@ -5,6 +5,7 @@ package elliptic import ( + "bytes" "crypto/rand" "encoding/hex" "fmt" @@ -628,3 +629,74 @@ func TestUnmarshalToLargeCoordinates(t *testing.T) { t.Errorf("Unmarshal accepts invalid Y coordinate") } } + +func TestMarshalCompressed(t *testing.T) { + t.Run("P-256/03", func(t *testing.T) { + data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") + x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) + y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10) + testMarshalCompressed(t, P256(), x, y, data) + }) + t.Run("P-256/02", func(t *testing.T) { + data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") + x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) + y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10) + testMarshalCompressed(t, P256(), x, y, data) + }) + + t.Run("Invalid", func(t *testing.T) { + data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") + X, Y := UnmarshalCompressed(P256(), data) + if X != nil || Y != nil { + t.Error("expected an error for invalid encoding") + } + }) + + if testing.Short() { + t.Skip("skipping other curves on short test") + } + + t.Run("P-224", func(t *testing.T) { + _, x, y, err := GenerateKey(P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } + testMarshalCompressed(t, P224(), x, y, nil) + }) + t.Run("P-384", func(t *testing.T) { + _, x, y, err := GenerateKey(P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + testMarshalCompressed(t, P384(), x, y, nil) + }) + t.Run("P-521", func(t *testing.T) { + _, x, y, err := GenerateKey(P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + testMarshalCompressed(t, P521(), x, y, nil) + }) +} + +func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) { + if !curve.IsOnCurve(x, y) { + t.Fatal("invalid test point") + } + got := MarshalCompressed(curve, x, y) + if want != nil && !bytes.Equal(got, want) { + t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) + } + + X, Y := UnmarshalCompressed(curve, got) + if X == nil || Y == nil { + t.Fatalf("UnmarshalCompressed failed unexpectedly") + } + + if !curve.IsOnCurve(X, Y) { + t.Error("UnmarshalCompressed returned a point not on the curve") + } + if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { + t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) + } +}