From: Filippo Valsorda Date: Wed, 30 Mar 2022 19:58:12 +0000 (+0200) Subject: crypto/elliptic: implement UnmarshalCompressed in nistec X-Git-Tag: go1.19beta1~393 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=52e24b492dc9295f2833f32e4f3601e716e5c1ed;p=gostls13.git crypto/elliptic: implement UnmarshalCompressed in nistec For #52182 Change-Id: If9eace36b757ada6cb5123cc60f1e10d4e8280c5 Reviewed-on: https://go-review.googlesource.com/c/go/+/396935 Reviewed-by: Roland Shoemaker Reviewed-by: Fernando Lobato Meeser Run-TryBot: Filippo Valsorda TryBot-Result: Gopher Robot --- diff --git a/src/crypto/elliptic/elliptic.go b/src/crypto/elliptic/elliptic.go index 87947ffe8f..01838dd868 100644 --- a/src/crypto/elliptic/elliptic.go +++ b/src/crypto/elliptic/elliptic.go @@ -94,10 +94,26 @@ func MarshalCompressed(curve Curve, x, y *big.Int) []byte { return compressed } +// unmarshaler is implemented by curves with their own constant-time Unmarshal. +// +// There isn't an equivalent interface for Marshal/MarshalCompressed because +// that doesn't involve any mathematical operations, only FillBytes and Bit. +type unmarshaler interface { + Unmarshal([]byte) (x, y *big.Int) + UnmarshalCompressed([]byte) (x, y *big.Int) +} + +// Assert that the known curves implement unmarshaler. +var _ = []unmarshaler{p224, p256, p384, p521} + // Unmarshal converts a point, serialized by Marshal, into an x, y pair. It is // an error if the point is not in uncompressed form, is not on the curve, or is // the point at infinity. On error, x = nil. func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { + if c, ok := curve.(unmarshaler); ok { + return c.Unmarshal(data) + } + byteLen := (curve.Params().BitSize + 7) / 8 if len(data) != 1+2*byteLen { return nil, nil @@ -121,6 +137,10 @@ func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { // an x, y pair. It is an error if the point is not in compressed form, is not // on the curve, or is the point at infinity. On error, x = nil. func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) { + if c, ok := curve.(unmarshaler); ok { + return c.UnmarshalCompressed(data) + } + byteLen := (curve.Params().BitSize + 7) / 8 if len(data) != 1+byteLen { return nil, nil diff --git a/src/crypto/elliptic/internal/nistec/generate.go b/src/crypto/elliptic/internal/nistec/generate.go index fbca6c3741..30176cb804 100644 --- a/src/crypto/elliptic/internal/nistec/generate.go +++ b/src/crypto/elliptic/internal/nistec/generate.go @@ -6,13 +6,21 @@ package main +// Running this generator requires addchain v0.4.0, which can be installed with +// +// go install github.com/mmcloughlin/addchain/cmd/addchain@v0.4.0 +// + import ( "bytes" "crypto/elliptic" "fmt" "go/format" + "io" "log" + "math/big" "os" + "os/exec" "strings" "text/template" ) @@ -49,6 +57,18 @@ var curves = []struct { func main() { t := template.Must(template.New("tmplNISTEC").Parse(tmplNISTEC)) + tmplAddchainFile, err := os.CreateTemp("", "addchain-template") + if err != nil { + log.Fatal(err) + } + defer os.Remove(tmplAddchainFile.Name()) + if _, err := io.WriteString(tmplAddchainFile, tmplAddchain); err != nil { + log.Fatal(err) + } + if err := tmplAddchainFile.Close(); err != nil { + log.Fatal(err) + } + for _, c := range curves { p := strings.ToLower(c.P) elementLen := (c.Params.BitSize + 7) / 8 @@ -60,6 +80,7 @@ func main() { if err != nil { log.Fatal(err) } + defer f.Close() buf := &bytes.Buffer{} if err := t.Execute(buf, map[string]interface{}{ "P": c.P, "p": p, "B": B, "G": G, @@ -75,7 +96,43 @@ func main() { if _, err := f.Write(out); err != nil { log.Fatal(err) } - if err := f.Close(); err != nil { + + // If p = 3 mod 4, implement modular square root by exponentiation. + mod4 := new(big.Int).Mod(c.Params.P, big.NewInt(4)) + if mod4.Cmp(big.NewInt(3)) != 0 { + continue + } + + exp := new(big.Int).Add(c.Params.P, big.NewInt(1)) + exp.Div(exp, big.NewInt(4)) + + tmp, err := os.CreateTemp("", "addchain-"+p) + if err != nil { + log.Fatal(err) + } + defer os.Remove(tmp.Name()) + cmd := exec.Command("addchain", "search", fmt.Sprintf("%d", exp)) + cmd.Stderr = os.Stderr + cmd.Stdout = tmp + if err := cmd.Run(); err != nil { + log.Fatal(err) + } + if err := tmp.Close(); err != nil { + log.Fatal(err) + } + cmd = exec.Command("addchain", "gen", "-tmpl", tmplAddchainFile.Name(), tmp.Name()) + cmd.Stderr = os.Stderr + out, err = cmd.Output() + if err != nil { + log.Fatal(err) + } + out = bytes.Replace(out, []byte("Element"), []byte(c.Element), -1) + out = bytes.Replace(out, []byte("sqrtCandidate"), []byte(p+"SqrtCandidate"), -1) + out, err = format.Source(out) + if err != nil { + log.Fatal(err) + } + if _, err := f.Write(out); err != nil { log.Fatal(err) } } @@ -169,30 +226,53 @@ func (p *{{.P}}Point) SetBytes(b []byte) (*{{.P}}Point, error) { p.z.One() return p, nil - // Compressed form - case len(b) == 1+{{.p}}ElementLength && b[0] == 0: - return nil, errors.New("unimplemented") // TODO(filippo) + // Compressed form. + case len(b) == 1+{{.p}}ElementLength && (b[0] == 2 || b[0] == 3): + x, err := new({{.Element}}).SetBytes(b[1:]) + if err != nil { + return nil, err + } + + // y² = x³ - 3x + b + y := {{.p}}Polynomial(new({{.Element}}), x) + if !{{.p}}Sqrt(y, y) { + return nil, errors.New("invalid {{.P}} compressed point encoding") + } + + // Select the positive or negative root, as indicated by the least + // significant bit, based on the encoding type byte. + otherRoot := new({{.Element}}) + otherRoot.Sub(otherRoot, y) + cond := y.Bytes()[{{.p}}ElementLength-1]&1 ^ b[0]&1 + y.Select(otherRoot, y, int(cond)) + + p.x.Set(x) + p.y.Set(y) + p.z.One() + return p, nil default: return nil, errors.New("invalid {{.P}} point encoding") } } -func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { - // x³ - 3x + b. - x3 := new({{.Element}}).Square(x) - x3.Mul(x3, x) +// {{.p}}Polynomial sets y2 to x³ - 3x + b, and returns y2. +func {{.p}}Polynomial(y2, x *{{.Element}}) *{{.Element}} { + y2.Square(x) + y2.Mul(y2, x) threeX := new({{.Element}}).Add(x, x) threeX.Add(threeX, x) - x3.Sub(x3, threeX) - x3.Add(x3, {{.p}}B) + y2.Sub(y2, threeX) + return y2.Add(y2, {{.p}}B) +} +func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { // y² = x³ - 3x + b - y2 := new({{.Element}}).Square(y) - - if x3.Equal(y2) != 1 { + rhs := {{.p}}Polynomial(new({{.Element}}), x) + lhs := new({{.Element}}).Square(y) + if rhs.Equal(lhs) != 1 { return errors.New("{{.P}} point not on curve") } return nil @@ -204,22 +284,49 @@ func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { func (p *{{.P}}Point) Bytes() []byte { // This function is outlined to make the allocations inline in the caller // rather than happen on the heap. - var out [133]byte + var out [1+2*{{.p}}ElementLength]byte return p.bytes(&out) } -func (p *{{.P}}Point) bytes(out *[133]byte) []byte { +func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) } zinv := new({{.Element}}).Invert(p.z) - xx := new({{.Element}}).Mul(p.x, zinv) - yy := new({{.Element}}).Mul(p.y, zinv) + x := new({{.Element}}).Mul(p.x, zinv) + y := new({{.Element}}).Mul(p.y, zinv) buf := append(out[:0], 4) - buf = append(buf, xx.Bytes()...) - buf = append(buf, yy.Bytes()...) + buf = append(buf, x.Bytes()...) + buf = append(buf, y.Bytes()...) + return buf +} + +// BytesCompressed returns the compressed or infinity encoding of p, as +// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the +// point at infinity is shorter than all other encodings. +func (p *{{.P}}Point) BytesCompressed() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [1 + {{.p}}ElementLength]byte + return p.bytesCompressed(&out) +} + +func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte { + if p.z.IsZero() == 1 { + return append(out[:0], 0) + } + + zinv := new({{.Element}}).Invert(p.z) + x := new({{.Element}}).Mul(p.x, zinv) + y := new({{.Element}}).Mul(p.y, zinv) + + // Encode the sign of the y coordinate (indicated by the least significant + // bit) as the encoding type (2 or 3). + buf := append(out[:0], 2) + buf[0] |= y.Bytes()[{{.p}}ElementLength-1] & 1 + buf = append(buf, x.Bytes()...) return buf } @@ -450,4 +557,56 @@ func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) { return p, nil } + +// {{.p}}Sqrt sets e to a square root of x. If x is not a square, {{.p}}Sqrt returns +// false and e is unchanged. e and x can overlap. +func {{.p}}Sqrt(e, x *{{ .Element }}) (isSquare bool) { + candidate := new({{ .Element }}) + {{.p}}SqrtCandidate(candidate, x) + square := new({{ .Element }}).Square(candidate) + if square.Equal(x) != 1 { + return false + } + e.Set(candidate) + return true +} +` + +const tmplAddchain = ` +// sqrtCandidate sets z to a square root candidate for x. z and x must not overlap. +func sqrtCandidate(z, x *Element) { + // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate. + // + // The sequence of {{ .Ops.Adds }} multiplications and {{ .Ops.Doubles }} squarings is derived from the + // following addition chain generated with {{ .Meta.Module }} {{ .Meta.ReleaseTag }}. + // + {{- range lines (format .Script) }} + // {{ . }} + {{- end }} + // + + {{- range .Program.Temporaries }} + var {{ . }} = new(Element) + {{- end }} + {{ range $i := .Program.Instructions -}} + {{- with add $i.Op }} + {{ $i.Output }}.Mul({{ .X }}, {{ .Y }}) + {{- end -}} + + {{- with double $i.Op }} + {{ $i.Output }}.Square({{ .X }}) + {{- end -}} + + {{- with shift $i.Op -}} + {{- $first := 0 -}} + {{- if ne $i.Output.Identifier .X.Identifier }} + {{ $i.Output }}.Square({{ .X }}) + {{- $first = 1 -}} + {{- end }} + for s := {{ $first }}; s < {{ .S }}; s++ { + {{ $i.Output }}.Square({{ $i.Output }}) + } + {{- end -}} + {{- end }} +} ` diff --git a/src/crypto/elliptic/internal/nistec/nistec_test.go b/src/crypto/elliptic/internal/nistec/nistec_test.go index 68879d55d7..410e6b0b6c 100644 --- a/src/crypto/elliptic/internal/nistec/nistec_test.go +++ b/src/crypto/elliptic/internal/nistec/nistec_test.go @@ -30,6 +30,10 @@ func TestAllocations(t *testing.T) { if _, err := nistec.NewP224Point().SetBytes(out); err != nil { t.Fatal(err) } + out = p.BytesCompressed() + if _, err := p.SetBytes(out); err != nil { + t.Fatal(err) + } }); allocs > 0 { t.Errorf("expected zero allocations, got %0.1f", allocs) } @@ -45,6 +49,10 @@ func TestAllocations(t *testing.T) { if _, err := nistec.NewP256Point().SetBytes(out); err != nil { t.Fatal(err) } + out = p.BytesCompressed() + if _, err := p.SetBytes(out); err != nil { + t.Fatal(err) + } }); allocs > 0 { t.Errorf("expected zero allocations, got %0.1f", allocs) } @@ -60,6 +68,10 @@ func TestAllocations(t *testing.T) { if _, err := nistec.NewP384Point().SetBytes(out); err != nil { t.Fatal(err) } + out = p.BytesCompressed() + if _, err := p.SetBytes(out); err != nil { + t.Fatal(err) + } }); allocs > 0 { t.Errorf("expected zero allocations, got %0.1f", allocs) } @@ -75,6 +87,10 @@ func TestAllocations(t *testing.T) { if _, err := nistec.NewP521Point().SetBytes(out); err != nil { t.Fatal(err) } + out = p.BytesCompressed() + if _, err := p.SetBytes(out); err != nil { + t.Fatal(err) + } }); allocs > 0 { t.Errorf("expected zero allocations, got %0.1f", allocs) } diff --git a/src/crypto/elliptic/internal/nistec/p224.go b/src/crypto/elliptic/internal/nistec/p224.go index 0db4ba1316..83963a4a69 100644 --- a/src/crypto/elliptic/internal/nistec/p224.go +++ b/src/crypto/elliptic/internal/nistec/p224.go @@ -82,30 +82,53 @@ func (p *P224Point) SetBytes(b []byte) (*P224Point, error) { p.z.One() return p, nil - // Compressed form - case len(b) == 1+p224ElementLength && b[0] == 0: - return nil, errors.New("unimplemented") // TODO(filippo) + // Compressed form. + case len(b) == 1+p224ElementLength && (b[0] == 2 || b[0] == 3): + x, err := new(fiat.P224Element).SetBytes(b[1:]) + if err != nil { + return nil, err + } + + // y² = x³ - 3x + b + y := p224Polynomial(new(fiat.P224Element), x) + if !p224Sqrt(y, y) { + return nil, errors.New("invalid P224 compressed point encoding") + } + + // Select the positive or negative root, as indicated by the least + // significant bit, based on the encoding type byte. + otherRoot := new(fiat.P224Element) + otherRoot.Sub(otherRoot, y) + cond := y.Bytes()[p224ElementLength-1]&1 ^ b[0]&1 + y.Select(otherRoot, y, int(cond)) + + p.x.Set(x) + p.y.Set(y) + p.z.One() + return p, nil default: return nil, errors.New("invalid P224 point encoding") } } -func p224CheckOnCurve(x, y *fiat.P224Element) error { - // x³ - 3x + b. - x3 := new(fiat.P224Element).Square(x) - x3.Mul(x3, x) +// p224Polynomial sets y2 to x³ - 3x + b, and returns y2. +func p224Polynomial(y2, x *fiat.P224Element) *fiat.P224Element { + y2.Square(x) + y2.Mul(y2, x) threeX := new(fiat.P224Element).Add(x, x) threeX.Add(threeX, x) - x3.Sub(x3, threeX) - x3.Add(x3, p224B) + y2.Sub(y2, threeX) + return y2.Add(y2, p224B) +} +func p224CheckOnCurve(x, y *fiat.P224Element) error { // y² = x³ - 3x + b - y2 := new(fiat.P224Element).Square(y) - - if x3.Equal(y2) != 1 { + rhs := p224Polynomial(new(fiat.P224Element), x) + lhs := new(fiat.P224Element).Square(y) + if rhs.Equal(lhs) != 1 { return errors.New("P224 point not on curve") } return nil @@ -117,22 +140,49 @@ func p224CheckOnCurve(x, y *fiat.P224Element) error { func (p *P224Point) Bytes() []byte { // This function is outlined to make the allocations inline in the caller // rather than happen on the heap. - var out [133]byte + var out [1 + 2*p224ElementLength]byte return p.bytes(&out) } -func (p *P224Point) bytes(out *[133]byte) []byte { +func (p *P224Point) bytes(out *[1 + 2*p224ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) } zinv := new(fiat.P224Element).Invert(p.z) - xx := new(fiat.P224Element).Mul(p.x, zinv) - yy := new(fiat.P224Element).Mul(p.y, zinv) + x := new(fiat.P224Element).Mul(p.x, zinv) + y := new(fiat.P224Element).Mul(p.y, zinv) buf := append(out[:0], 4) - buf = append(buf, xx.Bytes()...) - buf = append(buf, yy.Bytes()...) + buf = append(buf, x.Bytes()...) + buf = append(buf, y.Bytes()...) + return buf +} + +// BytesCompressed returns the compressed or infinity encoding of p, as +// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the +// point at infinity is shorter than all other encodings. +func (p *P224Point) BytesCompressed() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [1 + p224ElementLength]byte + return p.bytesCompressed(&out) +} + +func (p *P224Point) bytesCompressed(out *[1 + p224ElementLength]byte) []byte { + if p.z.IsZero() == 1 { + return append(out[:0], 0) + } + + zinv := new(fiat.P224Element).Invert(p.z) + x := new(fiat.P224Element).Mul(p.x, zinv) + y := new(fiat.P224Element).Mul(p.y, zinv) + + // Encode the sign of the y coordinate (indicated by the least significant + // bit) as the encoding type (2 or 3). + buf := append(out[:0], 2) + buf[0] |= y.Bytes()[p224ElementLength-1] & 1 + buf = append(buf, x.Bytes()...) return buf } @@ -363,3 +413,16 @@ func (p *P224Point) ScalarBaseMult(scalar []byte) (*P224Point, error) { return p, nil } + +// p224Sqrt sets e to a square root of x. If x is not a square, p224Sqrt returns +// false and e is unchanged. e and x can overlap. +func p224Sqrt(e, x *fiat.P224Element) (isSquare bool) { + candidate := new(fiat.P224Element) + p224SqrtCandidate(candidate, x) + square := new(fiat.P224Element).Square(candidate) + if square.Equal(x) != 1 { + return false + } + e.Set(candidate) + return true +} diff --git a/src/crypto/elliptic/internal/nistec/p224_sqrt.go b/src/crypto/elliptic/internal/nistec/p224_sqrt.go new file mode 100644 index 0000000000..0c82b7b2e0 --- /dev/null +++ b/src/crypto/elliptic/internal/nistec/p224_sqrt.go @@ -0,0 +1,132 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nistec + +import ( + "crypto/elliptic/internal/fiat" + "sync" +) + +var p224GG *[96]fiat.P224Element +var p224GGOnce sync.Once + +var p224MinusOne = new(fiat.P224Element).Sub( + new(fiat.P224Element), new(fiat.P224Element).One()) + +// p224SqrtCandidate sets r to a square root candidate for x. r and x must not overlap. +func p224SqrtCandidate(r, x *fiat.P224Element) { + // Since p = 1 mod 4, we can't use the exponentiation by (p + 1) / 4 like + // for the other primes. Instead, implement a variation of Tonelli–Shanks. + // The contant-time implementation is adapted from Thomas Pornin's ecGFp5. + // + // https://github.com/pornin/ecgfp5/blob/82325b965/rust/src/field.rs#L337-L385 + + // p = q*2^n + 1 with q odd -> q = 2^128 - 1 and n = 96 + // g^(2^n) = 1 -> g = 11 ^ q (where 11 is the smallest non-square) + // GG[j] = g^(2^j) for j = 0 to n-1 + + p224GGOnce.Do(func() { + p224GG = new([96]fiat.P224Element) + for i := range p224GG { + if i == 0 { + p224GG[i].SetBytes([]byte{0x6a, 0x0f, 0xec, 0x67, + 0x85, 0x98, 0xa7, 0x92, 0x0c, 0x55, 0xb2, 0xd4, + 0x0b, 0x2d, 0x6f, 0xfb, 0xbe, 0xa3, 0xd8, 0xce, + 0xf3, 0xfb, 0x36, 0x32, 0xdc, 0x69, 0x1b, 0x74}) + } else { + p224GG[i].Square(&p224GG[i-1]) + } + } + }) + + // r <- x^((q+1)/2) = x^(2^127) + // v <- x^q = x^(2^128-1) + + // Compute x^(2^127-1) first. + // + // The sequence of 10 multiplications and 126 squarings is derived from the + // following addition chain generated with github.com/mmcloughlin/addchain v0.4.0. + // + // _10 = 2*1 + // _11 = 1 + _10 + // _110 = 2*_11 + // _111 = 1 + _110 + // _111000 = _111 << 3 + // _111111 = _111 + _111000 + // _1111110 = 2*_111111 + // _1111111 = 1 + _1111110 + // x12 = _1111110 << 5 + _111111 + // x24 = x12 << 12 + x12 + // i36 = x24 << 7 + // x31 = _1111111 + i36 + // x48 = i36 << 17 + x24 + // x96 = x48 << 48 + x48 + // return x96 << 31 + x31 + // + var t0 = new(fiat.P224Element) + var t1 = new(fiat.P224Element) + + r.Square(x) + r.Mul(x, r) + r.Square(r) + r.Mul(x, r) + t0.Square(r) + for s := 1; s < 3; s++ { + t0.Square(t0) + } + t0.Mul(r, t0) + t1.Square(t0) + r.Mul(x, t1) + for s := 0; s < 5; s++ { + t1.Square(t1) + } + t0.Mul(t0, t1) + t1.Square(t0) + for s := 1; s < 12; s++ { + t1.Square(t1) + } + t0.Mul(t0, t1) + t1.Square(t0) + for s := 1; s < 7; s++ { + t1.Square(t1) + } + r.Mul(r, t1) + for s := 0; s < 17; s++ { + t1.Square(t1) + } + t0.Mul(t0, t1) + t1.Square(t0) + for s := 1; s < 48; s++ { + t1.Square(t1) + } + t0.Mul(t0, t1) + for s := 0; s < 31; s++ { + t0.Square(t0) + } + r.Mul(r, t0) + + // v = x^(2^127-1)^2 * x + v := new(fiat.P224Element).Square(r) + v.Mul(v, x) + + // r = x^(2^127-1) * x + r.Mul(r, x) + + // for i = n-1 down to 1: + // w = v^(2^(i-1)) + // if w == -1 then: + // v <- v*GG[n-i] + // r <- r*GG[n-i-1] + + for i := 96 - 1; i >= 1; i-- { + w := new(fiat.P224Element).Set(v) + for j := 0; j < i-1; j++ { + w.Square(w) + } + cond := w.Equal(p224MinusOne) + v.Select(t0.Mul(v, &p224GG[96-i]), v, cond) + r.Select(t0.Mul(r, &p224GG[96-i-1]), r, cond) + } +} diff --git a/src/crypto/elliptic/internal/nistec/p256.go b/src/crypto/elliptic/internal/nistec/p256.go index 81812df159..1b9305d044 100644 --- a/src/crypto/elliptic/internal/nistec/p256.go +++ b/src/crypto/elliptic/internal/nistec/p256.go @@ -84,30 +84,53 @@ func (p *P256Point) SetBytes(b []byte) (*P256Point, error) { p.z.One() return p, nil - // Compressed form - case len(b) == 1+p256ElementLength && b[0] == 0: - return nil, errors.New("unimplemented") // TODO(filippo) + // Compressed form. + case len(b) == 1+p256ElementLength && (b[0] == 2 || b[0] == 3): + x, err := new(fiat.P256Element).SetBytes(b[1:]) + if err != nil { + return nil, err + } + + // y² = x³ - 3x + b + y := p256Polynomial(new(fiat.P256Element), x) + if !p256Sqrt(y, y) { + return nil, errors.New("invalid P256 compressed point encoding") + } + + // Select the positive or negative root, as indicated by the least + // significant bit, based on the encoding type byte. + otherRoot := new(fiat.P256Element) + otherRoot.Sub(otherRoot, y) + cond := y.Bytes()[p256ElementLength-1]&1 ^ b[0]&1 + y.Select(otherRoot, y, int(cond)) + + p.x.Set(x) + p.y.Set(y) + p.z.One() + return p, nil default: return nil, errors.New("invalid P256 point encoding") } } -func p256CheckOnCurve(x, y *fiat.P256Element) error { - // x³ - 3x + b. - x3 := new(fiat.P256Element).Square(x) - x3.Mul(x3, x) +// p256Polynomial sets y2 to x³ - 3x + b, and returns y2. +func p256Polynomial(y2, x *fiat.P256Element) *fiat.P256Element { + y2.Square(x) + y2.Mul(y2, x) threeX := new(fiat.P256Element).Add(x, x) threeX.Add(threeX, x) - x3.Sub(x3, threeX) - x3.Add(x3, p256B) + y2.Sub(y2, threeX) + return y2.Add(y2, p256B) +} +func p256CheckOnCurve(x, y *fiat.P256Element) error { // y² = x³ - 3x + b - y2 := new(fiat.P256Element).Square(y) - - if x3.Equal(y2) != 1 { + rhs := p256Polynomial(new(fiat.P256Element), x) + lhs := new(fiat.P256Element).Square(y) + if rhs.Equal(lhs) != 1 { return errors.New("P256 point not on curve") } return nil @@ -119,22 +142,49 @@ func p256CheckOnCurve(x, y *fiat.P256Element) error { func (p *P256Point) Bytes() []byte { // This function is outlined to make the allocations inline in the caller // rather than happen on the heap. - var out [133]byte + var out [1 + 2*p256ElementLength]byte return p.bytes(&out) } -func (p *P256Point) bytes(out *[133]byte) []byte { +func (p *P256Point) bytes(out *[1 + 2*p256ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) } zinv := new(fiat.P256Element).Invert(p.z) - xx := new(fiat.P256Element).Mul(p.x, zinv) - yy := new(fiat.P256Element).Mul(p.y, zinv) + x := new(fiat.P256Element).Mul(p.x, zinv) + y := new(fiat.P256Element).Mul(p.y, zinv) buf := append(out[:0], 4) - buf = append(buf, xx.Bytes()...) - buf = append(buf, yy.Bytes()...) + buf = append(buf, x.Bytes()...) + buf = append(buf, y.Bytes()...) + return buf +} + +// BytesCompressed returns the compressed or infinity encoding of p, as +// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the +// point at infinity is shorter than all other encodings. +func (p *P256Point) BytesCompressed() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [1 + p256ElementLength]byte + return p.bytesCompressed(&out) +} + +func (p *P256Point) bytesCompressed(out *[1 + p256ElementLength]byte) []byte { + if p.z.IsZero() == 1 { + return append(out[:0], 0) + } + + zinv := new(fiat.P256Element).Invert(p.z) + x := new(fiat.P256Element).Mul(p.x, zinv) + y := new(fiat.P256Element).Mul(p.y, zinv) + + // Encode the sign of the y coordinate (indicated by the least significant + // bit) as the encoding type (2 or 3). + buf := append(out[:0], 2) + buf[0] |= y.Bytes()[p256ElementLength-1] & 1 + buf = append(buf, x.Bytes()...) return buf } @@ -365,3 +415,70 @@ func (p *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) { return p, nil } + +// p256Sqrt sets e to a square root of x. If x is not a square, p256Sqrt returns +// false and e is unchanged. e and x can overlap. +func p256Sqrt(e, x *fiat.P256Element) (isSquare bool) { + candidate := new(fiat.P256Element) + p256SqrtCandidate(candidate, x) + square := new(fiat.P256Element).Square(candidate) + if square.Equal(x) != 1 { + return false + } + e.Set(candidate) + return true +} + +// p256SqrtCandidate sets z to a square root candidate for x. z and x must not overlap. +func p256SqrtCandidate(z, x *fiat.P256Element) { + // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate. + // + // The sequence of 7 multiplications and 253 squarings is derived from the + // following addition chain generated with github.com/mmcloughlin/addchain v0.4.0. + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1111 = _11 + _1100 + // _11110000 = _1111 << 4 + // _11111111 = _1111 + _11110000 + // x16 = _11111111 << 8 + _11111111 + // x32 = x16 << 16 + x16 + // return ((x32 << 32 + 1) << 96 + 1) << 94 + // + var t0 = new(fiat.P256Element) + + z.Square(x) + z.Mul(x, z) + t0.Square(z) + for s := 1; s < 2; s++ { + t0.Square(t0) + } + z.Mul(z, t0) + t0.Square(z) + for s := 1; s < 4; s++ { + t0.Square(t0) + } + z.Mul(z, t0) + t0.Square(z) + for s := 1; s < 8; s++ { + t0.Square(t0) + } + z.Mul(z, t0) + t0.Square(z) + for s := 1; s < 16; s++ { + t0.Square(t0) + } + z.Mul(z, t0) + for s := 0; s < 32; s++ { + z.Square(z) + } + z.Mul(x, z) + for s := 0; s < 96; s++ { + z.Square(z) + } + z.Mul(x, z) + for s := 0; s < 94; s++ { + z.Square(z) + } +} diff --git a/src/crypto/elliptic/internal/nistec/p256_asm.go b/src/crypto/elliptic/internal/nistec/p256_asm.go index bf1badd5e0..927da2d217 100644 --- a/src/crypto/elliptic/internal/nistec/p256_asm.go +++ b/src/crypto/elliptic/internal/nistec/p256_asm.go @@ -76,6 +76,12 @@ const p256CompressedLength = 1 + p256ElementLength // the curve, it returns nil and an error, and the receiver is unchanged. // Otherwise, it returns p. func (p *P256Point) SetBytes(b []byte) (*P256Point, error) { + // p256Mul operates in the Montgomery domain with R = 2²⁵⁶ mod p. Thus rr + // here is R in the Montgomery domain, or R×R mod p. See comment in + // P256OrdInverse about how this is used. + rr := p256Element{0x0000000000000003, 0xfffffffbffffffff, + 0xfffffffffffffffe, 0x00000004fffffffd} + switch { // Point at infinity. case len(b) == 1 && b[0] == 0: @@ -89,11 +95,6 @@ func (p *P256Point) SetBytes(b []byte) (*P256Point, error) { if p256LessThanP(&r.x) == 0 || p256LessThanP(&r.y) == 0 { return nil, errors.New("invalid P256 element encoding") } - // p256Mul operates in the Montgomery domain with R = 2²⁵⁶ mod p. Thus rr - // here is R in the Montgomery domain, or R×R mod p. See comment in - // P256OrdInverse about how this is used. - rr := p256Element{0x0000000000000003, 0xfffffffbffffffff, - 0xfffffffffffffffe, 0x00000004fffffffd} p256Mul(&r.x, &r.x, &rr) p256Mul(&r.y, &r.y, &rr) if err := p256CheckOnCurve(&r.x, &r.y); err != nil { @@ -104,15 +105,36 @@ func (p *P256Point) SetBytes(b []byte) (*P256Point, error) { // Compressed form. case len(b) == p256CompressedLength && (b[0] == 2 || b[0] == 3): - return nil, errors.New("unimplemented") // TODO(filippo) + var r P256Point + p256BigToLittle(&r.x, (*[32]byte)(b[1:33])) + if p256LessThanP(&r.x) == 0 { + return nil, errors.New("invalid P256 element encoding") + } + p256Mul(&r.x, &r.x, &rr) + + // y² = x³ - 3x + b + p256Polynomial(&r.y, &r.x) + if !p256Sqrt(&r.y, &r.y) { + return nil, errors.New("invalid P256 compressed point encoding") + } + + // Select the positive or negative root, as indicated by the least + // significant bit, based on the encoding type byte. + yy := new(p256Element) + p256FromMont(yy, &r.y) + cond := int(yy[0]&1) ^ int(b[0]&1) + p256NegCond(&r.y, cond) + + r.z = p256One + return p.Set(&r), nil default: return nil, errors.New("invalid P256 point encoding") } } -func p256CheckOnCurve(x, y *p256Element) error { - // x³ - 3x + b +// p256Polynomial sets y2 to x³ - 3x + b, and returns y2. +func p256Polynomial(y2, x *p256Element) *p256Element { x3 := new(p256Element) p256Sqr(x3, x, 1) p256Mul(x3, x3, x) @@ -128,11 +150,16 @@ func p256CheckOnCurve(x, y *p256Element) error { p256Add(x3, x3, threeX) p256Add(x3, x3, p256B) - // y² = x³ - 3x + b - y2 := new(p256Element) - p256Sqr(y2, y, 1) + *y2 = *x3 + return y2 +} - if p256Equal(y2, x3) != 1 { +func p256CheckOnCurve(x, y *p256Element) error { + // y² = x³ - 3x + b + rhs := p256Polynomial(new(p256Element), x) + lhs := new(p256Element) + p256Sqr(lhs, y, 1) + if p256Equal(lhs, rhs) != 1 { return errors.New("P256 point not on curve") } return nil @@ -177,6 +204,50 @@ func p256Add(res, x, y *p256Element) { res[3] = (t1[3] & ^t2Mask) | (t2[3] & t2Mask) } +// p256Sqrt sets e to a square root of x. If x is not a square, p256Sqrt returns +// false and e is unchanged. e and x can overlap. +func p256Sqrt(e, x *p256Element) (isSquare bool) { + t0, t1 := new(p256Element), new(p256Element) + + // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate. + // + // The sequence of 7 multiplications and 253 squarings is derived from the + // following addition chain generated with github.com/mmcloughlin/addchain v0.4.0. + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1111 = _11 + _1100 + // _11110000 = _1111 << 4 + // _11111111 = _1111 + _11110000 + // x16 = _11111111 << 8 + _11111111 + // x32 = x16 << 16 + x16 + // return ((x32 << 32 + 1) << 96 + 1) << 94 + // + p256Sqr(t0, x, 1) + p256Mul(t0, x, t0) + p256Sqr(t1, t0, 2) + p256Mul(t0, t0, t1) + p256Sqr(t1, t0, 4) + p256Mul(t0, t0, t1) + p256Sqr(t1, t0, 8) + p256Mul(t0, t0, t1) + p256Sqr(t1, t0, 16) + p256Mul(t0, t0, t1) + p256Sqr(t0, t0, 32) + p256Mul(t0, x, t0) + p256Sqr(t0, t0, 96) + p256Mul(t0, x, t0) + p256Sqr(t0, t0, 94) + + p256Sqr(t1, t0, 1) + if p256Equal(t1, x) != 1 { + return false + } + *e = *t0 + return true +} + // The following assembly functions are implemented in p256_asm_*.s // Montgomery multiplication. Sets res = in1 * in2 * R⁻¹ mod p. @@ -463,24 +534,53 @@ func (p *P256Point) Bytes() []byte { func (p *P256Point) bytes(out *[p256UncompressedLength]byte) []byte { // The proper representation of the point at infinity is a single zero byte. if p.isInfinity() == 1 { - return out[:1] + return append(out[:0], 0) } - zInv := new(p256Element) - zInvSq := new(p256Element) - p256Inverse(zInv, &p.z) - p256Sqr(zInvSq, zInv, 1) - p256Mul(zInv, zInv, zInvSq) + x, y := new(p256Element), new(p256Element) + p.affineFromMont(x, y) - p256Mul(zInvSq, &p.x, zInvSq) - p256Mul(zInv, &p.y, zInv) + out[0] = 4 // Uncompressed form. + p256LittleToBig((*[32]byte)(out[1:33]), x) + p256LittleToBig((*[32]byte)(out[33:65]), y) - p256FromMont(zInvSq, zInvSq) - p256FromMont(zInv, zInv) + return out[:] +} - out[0] = 4 // Uncompressed form. - p256LittleToBig((*[32]byte)(out[1:33]), zInvSq) - p256LittleToBig((*[32]byte)(out[33:65]), zInv) +// affineFromMont sets (x, y) to the affine coordinates of p, converted out of the +// Montgomery domain. +func (p *P256Point) affineFromMont(x, y *p256Element) { + p256Inverse(y, &p.z) + p256Sqr(x, y, 1) + p256Mul(y, y, x) + + p256Mul(x, &p.x, x) + p256Mul(y, &p.y, y) + + p256FromMont(x, x) + p256FromMont(y, y) +} + +// BytesCompressed returns the compressed or infinity encoding of p, as +// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the +// point at infinity is shorter than all other encodings. +func (p *P256Point) BytesCompressed() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [p256CompressedLength]byte + return p.bytesCompressed(&out) +} + +func (p *P256Point) bytesCompressed(out *[p256CompressedLength]byte) []byte { + if p.isInfinity() == 1 { + return append(out[:0], 0) + } + + x, y := new(p256Element), new(p256Element) + p.affineFromMont(x, y) + + out[0] = 2 | byte(y[0]&1) + p256LittleToBig((*[32]byte)(out[1:33]), x) return out[:] } diff --git a/src/crypto/elliptic/internal/nistec/p384.go b/src/crypto/elliptic/internal/nistec/p384.go index 1830149b2b..13fe74c534 100644 --- a/src/crypto/elliptic/internal/nistec/p384.go +++ b/src/crypto/elliptic/internal/nistec/p384.go @@ -82,30 +82,53 @@ func (p *P384Point) SetBytes(b []byte) (*P384Point, error) { p.z.One() return p, nil - // Compressed form - case len(b) == 1+p384ElementLength && b[0] == 0: - return nil, errors.New("unimplemented") // TODO(filippo) + // Compressed form. + case len(b) == 1+p384ElementLength && (b[0] == 2 || b[0] == 3): + x, err := new(fiat.P384Element).SetBytes(b[1:]) + if err != nil { + return nil, err + } + + // y² = x³ - 3x + b + y := p384Polynomial(new(fiat.P384Element), x) + if !p384Sqrt(y, y) { + return nil, errors.New("invalid P384 compressed point encoding") + } + + // Select the positive or negative root, as indicated by the least + // significant bit, based on the encoding type byte. + otherRoot := new(fiat.P384Element) + otherRoot.Sub(otherRoot, y) + cond := y.Bytes()[p384ElementLength-1]&1 ^ b[0]&1 + y.Select(otherRoot, y, int(cond)) + + p.x.Set(x) + p.y.Set(y) + p.z.One() + return p, nil default: return nil, errors.New("invalid P384 point encoding") } } -func p384CheckOnCurve(x, y *fiat.P384Element) error { - // x³ - 3x + b. - x3 := new(fiat.P384Element).Square(x) - x3.Mul(x3, x) +// p384Polynomial sets y2 to x³ - 3x + b, and returns y2. +func p384Polynomial(y2, x *fiat.P384Element) *fiat.P384Element { + y2.Square(x) + y2.Mul(y2, x) threeX := new(fiat.P384Element).Add(x, x) threeX.Add(threeX, x) - x3.Sub(x3, threeX) - x3.Add(x3, p384B) + y2.Sub(y2, threeX) + return y2.Add(y2, p384B) +} +func p384CheckOnCurve(x, y *fiat.P384Element) error { // y² = x³ - 3x + b - y2 := new(fiat.P384Element).Square(y) - - if x3.Equal(y2) != 1 { + rhs := p384Polynomial(new(fiat.P384Element), x) + lhs := new(fiat.P384Element).Square(y) + if rhs.Equal(lhs) != 1 { return errors.New("P384 point not on curve") } return nil @@ -117,22 +140,49 @@ func p384CheckOnCurve(x, y *fiat.P384Element) error { func (p *P384Point) Bytes() []byte { // This function is outlined to make the allocations inline in the caller // rather than happen on the heap. - var out [133]byte + var out [1 + 2*p384ElementLength]byte return p.bytes(&out) } -func (p *P384Point) bytes(out *[133]byte) []byte { +func (p *P384Point) bytes(out *[1 + 2*p384ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) } zinv := new(fiat.P384Element).Invert(p.z) - xx := new(fiat.P384Element).Mul(p.x, zinv) - yy := new(fiat.P384Element).Mul(p.y, zinv) + x := new(fiat.P384Element).Mul(p.x, zinv) + y := new(fiat.P384Element).Mul(p.y, zinv) buf := append(out[:0], 4) - buf = append(buf, xx.Bytes()...) - buf = append(buf, yy.Bytes()...) + buf = append(buf, x.Bytes()...) + buf = append(buf, y.Bytes()...) + return buf +} + +// BytesCompressed returns the compressed or infinity encoding of p, as +// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the +// point at infinity is shorter than all other encodings. +func (p *P384Point) BytesCompressed() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [1 + p384ElementLength]byte + return p.bytesCompressed(&out) +} + +func (p *P384Point) bytesCompressed(out *[1 + p384ElementLength]byte) []byte { + if p.z.IsZero() == 1 { + return append(out[:0], 0) + } + + zinv := new(fiat.P384Element).Invert(p.z) + x := new(fiat.P384Element).Mul(p.x, zinv) + y := new(fiat.P384Element).Mul(p.y, zinv) + + // Encode the sign of the y coordinate (indicated by the least significant + // bit) as the encoding type (2 or 3). + buf := append(out[:0], 2) + buf[0] |= y.Bytes()[p384ElementLength-1] & 1 + buf = append(buf, x.Bytes()...) return buf } @@ -363,3 +413,103 @@ func (p *P384Point) ScalarBaseMult(scalar []byte) (*P384Point, error) { return p, nil } + +// p384Sqrt sets e to a square root of x. If x is not a square, p384Sqrt returns +// false and e is unchanged. e and x can overlap. +func p384Sqrt(e, x *fiat.P384Element) (isSquare bool) { + candidate := new(fiat.P384Element) + p384SqrtCandidate(candidate, x) + square := new(fiat.P384Element).Square(candidate) + if square.Equal(x) != 1 { + return false + } + e.Set(candidate) + return true +} + +// p384SqrtCandidate sets z to a square root candidate for x. z and x must not overlap. +func p384SqrtCandidate(z, x *fiat.P384Element) { + // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate. + // + // The sequence of 14 multiplications and 381 squarings is derived from the + // following addition chain generated with github.com/mmcloughlin/addchain v0.4.0. + // + // _10 = 2*1 + // _11 = 1 + _10 + // _110 = 2*_11 + // _111 = 1 + _110 + // _111000 = _111 << 3 + // _111111 = _111 + _111000 + // _1111110 = 2*_111111 + // _1111111 = 1 + _1111110 + // x12 = _1111110 << 5 + _111111 + // x24 = x12 << 12 + x12 + // x31 = x24 << 7 + _1111111 + // x32 = 2*x31 + 1 + // x63 = x32 << 31 + x31 + // x126 = x63 << 63 + x63 + // x252 = x126 << 126 + x126 + // x255 = x252 << 3 + _111 + // return ((x255 << 33 + x32) << 64 + 1) << 30 + // + var t0 = new(fiat.P384Element) + var t1 = new(fiat.P384Element) + var t2 = new(fiat.P384Element) + + z.Square(x) + z.Mul(x, z) + z.Square(z) + t0.Mul(x, z) + z.Square(t0) + for s := 1; s < 3; s++ { + z.Square(z) + } + t1.Mul(t0, z) + t2.Square(t1) + z.Mul(x, t2) + for s := 0; s < 5; s++ { + t2.Square(t2) + } + t1.Mul(t1, t2) + t2.Square(t1) + for s := 1; s < 12; s++ { + t2.Square(t2) + } + t1.Mul(t1, t2) + for s := 0; s < 7; s++ { + t1.Square(t1) + } + t1.Mul(z, t1) + z.Square(t1) + z.Mul(x, z) + t2.Square(z) + for s := 1; s < 31; s++ { + t2.Square(t2) + } + t1.Mul(t1, t2) + t2.Square(t1) + for s := 1; s < 63; s++ { + t2.Square(t2) + } + t1.Mul(t1, t2) + t2.Square(t1) + for s := 1; s < 126; s++ { + t2.Square(t2) + } + t1.Mul(t1, t2) + for s := 0; s < 3; s++ { + t1.Square(t1) + } + t0.Mul(t0, t1) + for s := 0; s < 33; s++ { + t0.Square(t0) + } + z.Mul(z, t0) + for s := 0; s < 64; s++ { + z.Square(z) + } + z.Mul(x, z) + for s := 0; s < 30; s++ { + z.Square(z) + } +} diff --git a/src/crypto/elliptic/internal/nistec/p521.go b/src/crypto/elliptic/internal/nistec/p521.go index 731af4758f..9420894004 100644 --- a/src/crypto/elliptic/internal/nistec/p521.go +++ b/src/crypto/elliptic/internal/nistec/p521.go @@ -82,30 +82,53 @@ func (p *P521Point) SetBytes(b []byte) (*P521Point, error) { p.z.One() return p, nil - // Compressed form - case len(b) == 1+p521ElementLength && b[0] == 0: - return nil, errors.New("unimplemented") // TODO(filippo) + // Compressed form. + case len(b) == 1+p521ElementLength && (b[0] == 2 || b[0] == 3): + x, err := new(fiat.P521Element).SetBytes(b[1:]) + if err != nil { + return nil, err + } + + // y² = x³ - 3x + b + y := p521Polynomial(new(fiat.P521Element), x) + if !p521Sqrt(y, y) { + return nil, errors.New("invalid P521 compressed point encoding") + } + + // Select the positive or negative root, as indicated by the least + // significant bit, based on the encoding type byte. + otherRoot := new(fiat.P521Element) + otherRoot.Sub(otherRoot, y) + cond := y.Bytes()[p521ElementLength-1]&1 ^ b[0]&1 + y.Select(otherRoot, y, int(cond)) + + p.x.Set(x) + p.y.Set(y) + p.z.One() + return p, nil default: return nil, errors.New("invalid P521 point encoding") } } -func p521CheckOnCurve(x, y *fiat.P521Element) error { - // x³ - 3x + b. - x3 := new(fiat.P521Element).Square(x) - x3.Mul(x3, x) +// p521Polynomial sets y2 to x³ - 3x + b, and returns y2. +func p521Polynomial(y2, x *fiat.P521Element) *fiat.P521Element { + y2.Square(x) + y2.Mul(y2, x) threeX := new(fiat.P521Element).Add(x, x) threeX.Add(threeX, x) - x3.Sub(x3, threeX) - x3.Add(x3, p521B) + y2.Sub(y2, threeX) + return y2.Add(y2, p521B) +} +func p521CheckOnCurve(x, y *fiat.P521Element) error { // y² = x³ - 3x + b - y2 := new(fiat.P521Element).Square(y) - - if x3.Equal(y2) != 1 { + rhs := p521Polynomial(new(fiat.P521Element), x) + lhs := new(fiat.P521Element).Square(y) + if rhs.Equal(lhs) != 1 { return errors.New("P521 point not on curve") } return nil @@ -117,22 +140,49 @@ func p521CheckOnCurve(x, y *fiat.P521Element) error { func (p *P521Point) Bytes() []byte { // This function is outlined to make the allocations inline in the caller // rather than happen on the heap. - var out [133]byte + var out [1 + 2*p521ElementLength]byte return p.bytes(&out) } -func (p *P521Point) bytes(out *[133]byte) []byte { +func (p *P521Point) bytes(out *[1 + 2*p521ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) } zinv := new(fiat.P521Element).Invert(p.z) - xx := new(fiat.P521Element).Mul(p.x, zinv) - yy := new(fiat.P521Element).Mul(p.y, zinv) + x := new(fiat.P521Element).Mul(p.x, zinv) + y := new(fiat.P521Element).Mul(p.y, zinv) buf := append(out[:0], 4) - buf = append(buf, xx.Bytes()...) - buf = append(buf, yy.Bytes()...) + buf = append(buf, x.Bytes()...) + buf = append(buf, y.Bytes()...) + return buf +} + +// BytesCompressed returns the compressed or infinity encoding of p, as +// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the +// point at infinity is shorter than all other encodings. +func (p *P521Point) BytesCompressed() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [1 + p521ElementLength]byte + return p.bytesCompressed(&out) +} + +func (p *P521Point) bytesCompressed(out *[1 + p521ElementLength]byte) []byte { + if p.z.IsZero() == 1 { + return append(out[:0], 0) + } + + zinv := new(fiat.P521Element).Invert(p.z) + x := new(fiat.P521Element).Mul(p.x, zinv) + y := new(fiat.P521Element).Mul(p.y, zinv) + + // Encode the sign of the y coordinate (indicated by the least significant + // bit) as the encoding type (2 or 3). + buf := append(out[:0], 2) + buf[0] |= y.Bytes()[p521ElementLength-1] & 1 + buf = append(buf, x.Bytes()...) return buf } @@ -363,3 +413,32 @@ func (p *P521Point) ScalarBaseMult(scalar []byte) (*P521Point, error) { return p, nil } + +// p521Sqrt sets e to a square root of x. If x is not a square, p521Sqrt returns +// false and e is unchanged. e and x can overlap. +func p521Sqrt(e, x *fiat.P521Element) (isSquare bool) { + candidate := new(fiat.P521Element) + p521SqrtCandidate(candidate, x) + square := new(fiat.P521Element).Square(candidate) + if square.Equal(x) != 1 { + return false + } + e.Set(candidate) + return true +} + +// p521SqrtCandidate sets z to a square root candidate for x. z and x must not overlap. +func p521SqrtCandidate(z, x *fiat.P521Element) { + // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate. + // + // The sequence of 0 multiplications and 519 squarings is derived from the + // following addition chain generated with github.com/mmcloughlin/addchain v0.4.0. + // + // return 1 << 519 + // + + z.Square(x) + for s := 1; s < 519; s++ { + z.Square(z) + } +} diff --git a/src/crypto/elliptic/nistec.go b/src/crypto/elliptic/nistec.go index 989da66638..58c9c5c07c 100644 --- a/src/crypto/elliptic/nistec.go +++ b/src/crypto/elliptic/nistec.go @@ -163,14 +163,13 @@ func (curve *nistCurve[Point]) pointFromAffine(x, y *big.Int) (p Point, err erro func (curve *nistCurve[Point]) pointToAffine(p Point) (x, y *big.Int) { out := p.Bytes() if len(out) == 1 && out[0] == 0 { - // This is the correct encoding of the point at infinity, which - // Unmarshal does not support. See Issue 37294. + // This is the encoding of the point at infinity, which the affine + // coordinates API represents as (0, 0) by convention. return new(big.Int), new(big.Int) } - x, y = Unmarshal(curve, out) - if x == nil { - panic("crypto/elliptic: internal error: Unmarshal rejected a valid point encoding") - } + byteLen := (curve.params.BitSize + 7) / 8 + x = new(big.Int).SetBytes(out[1 : 1+byteLen]) + y = new(big.Int).SetBytes(out[1+byteLen:]) return x, y } @@ -268,6 +267,35 @@ func (curve *nistCurve[Point]) CombinedMult(Px, Py *big.Int, s1, s2 []byte) (x, return curve.pointToAffine(p.Add(p, q)) } +func (curve *nistCurve[Point]) Unmarshal(data []byte) (x, y *big.Int) { + if len(data) == 0 || data[0] != 4 { + return nil, nil + } + // Use SetBytes to check that data encodes a valid point. + _, err := curve.newPoint().SetBytes(data) + if err != nil { + return nil, nil + } + // We don't use pointToAffine because it involves an expensive field + // inversion to convert from Jacobian to affine coordinates, which we + // already have. + byteLen := (curve.params.BitSize + 7) / 8 + x = new(big.Int).SetBytes(data[1 : 1+byteLen]) + y = new(big.Int).SetBytes(data[1+byteLen:]) + return x, y +} + +func (curve *nistCurve[Point]) UnmarshalCompressed(data []byte) (x, y *big.Int) { + if len(data) == 0 || (data[0] != 2 && data[0] != 3) { + return nil, nil + } + p, err := curve.newPoint().SetBytes(data) + if err != nil { + return nil, nil + } + return curve.pointToAffine(p) +} + func bigFromDecimal(s string) *big.Int { b, ok := new(big.Int).SetString(s, 10) if !ok {