]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/nistec: port cleanups and docs from p256_asm.go to purego
authorFilippo Valsorda <filippo@golang.org>
Sat, 27 Jul 2024 21:40:15 +0000 (23:40 +0200)
committerGopher Robot <gobot@golang.org>
Tue, 19 Nov 2024 22:30:19 +0000 (22:30 +0000)
Change-Id: Ieaad0692f4301cc301a0dd2eadca2f2f9e96bff0
Reviewed-on: https://go-review.googlesource.com/c/go/+/627942
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Russ Cox <rsc@golang.org>
src/crypto/internal/nistec/p256.go

index 611bfbac73aa7a90dfb4f1567f0854e0d7858d9f..853764f48a9cdeb8d46b4aa426409879a36240c1 100644 (file)
@@ -18,14 +18,14 @@ import (
        "unsafe"
 )
 
-// p256ElementLength is the length of an element of the base or scalar field,
-// which have the same bytes length for all NIST P curves.
-const p256ElementLength = 32
-
-// P256Point is a P256 point. The zero value is NOT valid.
+// P256Point is a P-256 point. The zero value is NOT valid.
 type P256Point struct {
-       // The point is represented in projective coordinates (X:Y:Z),
-       // where x = X/Z and y = Y/Z.
+       // The point is represented in projective coordinates (X:Y:Z), where x = X/Z
+       // and y = Y/Z. Infinity is (0:1:0).
+       //
+       // fiat.P256Element is a base field element in [0, P-1] in the Montgomery
+       // domain (with R 2²⁵⁶ and P 2²⁵⁶ - 2²²⁴ + 2¹⁹² + 2⁹⁶ - 1) as four limbs in
+       // little-endian order value.
        x, y, z fiat.P256Element
 }
 
@@ -52,6 +52,10 @@ func (p *P256Point) Set(q *P256Point) *P256Point {
        return p
 }
 
+const p256ElementLength = 32
+const p256UncompressedLength = 1 + 2*p256ElementLength
+const p256CompressedLength = 1 + p256ElementLength
+
 // SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
 // b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
 // the curve, it returns nil and an error, and the receiver is unchanged.
@@ -63,7 +67,7 @@ func (p *P256Point) SetBytes(b []byte) (*P256Point, error) {
                return p.Set(NewP256Point()), nil
 
        // Uncompressed form.
-       case len(b) == 1+2*p256ElementLength && b[0] == 4:
+       case len(b) == p256UncompressedLength && b[0] == 4:
                x, err := new(fiat.P256Element).SetBytes(b[1 : 1+p256ElementLength])
                if err != nil {
                        return nil, err
@@ -81,7 +85,7 @@ func (p *P256Point) SetBytes(b []byte) (*P256Point, error) {
                return p, nil
 
        // Compressed form.
-       case len(b) == 1+p256ElementLength && (b[0] == 2 || b[0] == 3):
+       case len(b) == p256CompressedLength && (b[0] == 2 || b[0] == 3):
                x, err := new(fiat.P256Element).SetBytes(b[1:])
                if err != nil {
                        return nil, err
@@ -148,11 +152,13 @@ func p256CheckOnCurve(x, y *fiat.P256Element) error {
 func (p *P256Point) Bytes() []byte {
        // This function is outlined to make the allocations inline in the caller
        // rather than happen on the heap.
-       var out [1 + 2*p256ElementLength]byte
+       var out [p256UncompressedLength]byte
        return p.bytes(&out)
 }
 
-func (p *P256Point) bytes(out *[1 + 2*p256ElementLength]byte) []byte {
+func (p *P256Point) bytes(out *[p256UncompressedLength]byte) []byte {
+       // The SEC 1 representation of the point at infinity is a single zero byte,
+       // and only infinity has z = 0.
        if p.z.IsZero() == 1 {
                return append(out[:0], 0)
        }
@@ -193,11 +199,11 @@ func (p *P256Point) bytesX(out *[p256ElementLength]byte) ([]byte, error) {
 func (p *P256Point) BytesCompressed() []byte {
        // This function is outlined to make the allocations inline in the caller
        // rather than happen on the heap.
-       var out [1 + p256ElementLength]byte
+       var out [p256CompressedLength]byte
        return p.bytesCompressed(&out)
 }
 
-func (p *P256Point) bytesCompressed(out *[1 + p256ElementLength]byte) []byte {
+func (p *P256Point) bytesCompressed(out *[p256CompressedLength]byte) []byte {
        if p.z.IsZero() == 1 {
                return append(out[:0], 0)
        }
@@ -315,11 +321,23 @@ func (q *P256Point) Double(p *P256Point) *P256Point {
        return q
 }
 
+// p256AffinePoint is a point in affine coordinates (x, y). x and y are still
+// Montgomery domain elements. The point can't be the point at infinity.
 type p256AffinePoint struct {
        x, y fiat.P256Element
 }
 
-func (q *P256Point) AddAffine(p1 *P256Point, p2 *p256AffinePoint, cond int) *P256Point {
+func (p *p256AffinePoint) Projective() *P256Point {
+       pp := &P256Point{x: p.x, y: p.y}
+       pp.z.One()
+       return pp
+}
+
+// AddAffine sets q = p1 + p2, if infinity == 0, and to p1 if infinity == 1.
+// p2 can't be the point at infinity as it can't be represented in affine
+// coordinates, instead callers can set p2 to an arbitrary point and set
+// infinity to 1.
+func (q *P256Point) AddAffine(p1 *P256Point, p2 *p256AffinePoint, infinity int) *P256Point {
        // Complete mixed addition formula for a = -3 from "Complete addition
        // formulas for prime order elliptic curves"
        // (https://eprint.iacr.org/2015/1060), Algorithm 5.
@@ -361,9 +379,9 @@ func (q *P256Point) AddAffine(p1 *P256Point, p2 *p256AffinePoint, cond int) *P25
        t1.Mul(t3, t0)                                  // t1 ← t3 · t0
        z3.Add(z3, t1)                                  // Z3 ← Z3 + t1
 
-       q.x.Select(x3, &p1.x, cond)
-       q.y.Select(y3, &p1.y, cond)
-       q.z.Select(z3, &p1.z, cond)
+       q.x.Select(&p1.x, x3, infinity)
+       q.y.Select(&p1.y, y3, infinity)
+       q.z.Select(&p1.z, z3, infinity)
        return q
 }
 
@@ -379,10 +397,20 @@ func (q *P256Point) Select(p1, p2 *P256Point, cond int) *P256Point {
 // Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order.
 type p256OrdElement [4]uint64
 
-// p256OrdReduce ensures s is in the range [0, ord(G)-1].
-func p256OrdReduce(s *p256OrdElement) {
-       // Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G),
-       // keeping the result if it doesn't underflow.
+// SetBytes sets s to the big-endian value of x, reducing it as necessary.
+func (s *p256OrdElement) SetBytes(x []byte) (*p256OrdElement, error) {
+       if len(x) != 32 {
+               return nil, errors.New("invalid scalar length")
+       }
+
+       s[0] = byteorder.BeUint64(x[24:])
+       s[1] = byteorder.BeUint64(x[16:])
+       s[2] = byteorder.BeUint64(x[8:])
+       s[3] = byteorder.BeUint64(x[:])
+
+       // Ensure s is in the range [0, ord(G)-1]. Since 2 * ord(G) > 2²⁵⁶, we can
+       // just conditionally subtract ord(G), keeping the result if it doesn't
+       // underflow.
        t0, b := bits.Sub64(s[0], 0xf3b9cac2fc632551, 0)
        t1, b := bits.Sub64(s[1], 0xbce6faada7179e84, b)
        t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b)
@@ -392,43 +420,39 @@ func p256OrdReduce(s *p256OrdElement) {
        s[1] ^= (t1 ^ s[1]) & tMask
        s[2] ^= (t2 ^ s[2]) & tMask
        s[3] ^= (t3 ^ s[3]) & tMask
-}
 
-func p256OrdLittleToBig(b *[32]byte, l *p256OrdElement) {
-       byteorder.BePutUint64(b[24:], l[0])
-       byteorder.BePutUint64(b[16:], l[1])
-       byteorder.BePutUint64(b[8:], l[2])
-       byteorder.BePutUint64(b[:], l[3])
+       return s, nil
 }
 
-func p256OrdBigToLittle(l *p256OrdElement, b *[32]byte) {
-       l[0] = byteorder.BeUint64(b[24:])
-       l[1] = byteorder.BeUint64(b[16:])
-       l[2] = byteorder.BeUint64(b[8:])
-       l[3] = byteorder.BeUint64(b[:])
+func (s *p256OrdElement) Bytes() []byte {
+       var out [32]byte
+       byteorder.BePutUint64(out[24:], s[0])
+       byteorder.BePutUint64(out[16:], s[1])
+       byteorder.BePutUint64(out[8:], s[2])
+       byteorder.BePutUint64(out[:], s[3])
+       return out[:]
 }
 
-// p256OrdRsh returns the 64 least significant bits of x >> n. n must be lower
+// Rsh returns the 64 least significant bits of x >> n. n must be lower
 // than 256. The value of n leaks through timing side-channels.
-func p256OrdRsh(x *p256OrdElement, n int) uint64 {
+func (s *p256OrdElement) Rsh(n int) uint64 {
        i := n / 64
        n = n % 64
-       res := x[i] >> n
+       res := s[i] >> n
        // Shift in the more significant limb, if present.
-       if i := i + 1; i < len(x) {
-               res |= x[i] << (64 - n)
+       if i := i + 1; i < len(s) {
+               res |= s[i] << (64 - n)
        }
        return res
 }
 
-// A p256Table holds the first 16 multiples of a point at offset -1, so [1]P
-// is at table[0], [16]P is at table[15], and [0]P is implicitly the identity
-// point.
+// p256Table is a table of the first 16 multiples of a point. Points are stored
+// at an index offset of -1 so [8]P is at index 7, P is at 0, and [16]P is at 15.
+// [0]P is the point at infinity and it's not stored.
 type p256Table [16]P256Point
 
 // Select selects the n-th multiple of the table base point into p. It works in
-// constant time by iterating over every entry of the table. n must be in [0, 16].
-// If n is 0, p is implicitly set to the identity point.
+// constant time. n must be in [0, 16]. If n is 0, p is set to the identity point.
 func (table *p256Table) Select(p *P256Point, n uint8) {
        if n > 16 {
                panic("nistec: internal error: p256Table called with out-of-bounds value")
@@ -440,6 +464,18 @@ func (table *p256Table) Select(p *P256Point, n uint8) {
        }
 }
 
+// Compute populates the table to the first 16 multiples of q.
+func (table *p256Table) Compute(q *P256Point) *p256Table {
+       table[0].Set(q)
+       for i := 1; i < 16; i += 2 {
+               table[i].Double(&table[i/2])
+               if i+1 < 16 {
+                       table[i+1].Add(&table[i], q)
+               }
+       }
+       return table
+}
+
 func boothW5(in uint64) (uint8, int) {
        s := ^((in >> 5) - 1)
        d := (1 << 6) - in - 1
@@ -448,33 +484,27 @@ func boothW5(in uint64) (uint8, int) {
        return uint8(d), int(s & 1)
 }
 
-// ScalarMult sets p = scalar * q, and returns p.
+// ScalarMult sets r = scalar * q, where scalar is a 32-byte big endian value,
+// and returns r. If scalar is not 32 bytes long, ScalarMult returns an error
+// and the receiver is unchanged.
 func (p *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error) {
-       if len(scalar) != p256ElementLength {
-               return nil, errors.New("invalid scalar length")
-       }
-       s := new(p256OrdElement)
-       p256OrdBigToLittle(s, (*[32]byte)(scalar))
-       p256OrdReduce(s)
-
-       var table p256Table
-       table[0].Set(q)
-       for i := 1; i < 16; i += 2 {
-               table[i].Double(&table[i/2])
-               if i+1 < 16 {
-                       table[i+1].Add(&table[i], q)
-               }
+       s, err := new(p256OrdElement).SetBytes(scalar)
+       if err != nil {
+               return nil, err
        }
 
        // Start scanning the window from the most significant bits. We move by
        // 5 bits at a time and need to finish at -1, so -1 + 5 * 51 = 254.
        index := 254
 
-       sel, sign := boothW5(p256OrdRsh(s, index))
+       sel, sign := boothW5(s.Rsh(index))
        // sign is always zero because the boothW5 input here is at
        // most two bits long, so the top bit is never set.
        _ = sign
 
+       // Neither Select nor Add have exceptions for the point at infinity /
+       // selector zero, so we don't need to check for it here or in the loop.
+       table := new(p256Table).Compute(q)
        table.Select(p, sel)
 
        t := NewP256Point()
@@ -488,7 +518,7 @@ func (p *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error)
                p.Double(p)
 
                if index >= 0 {
-                       sel, sign = boothW5(p256OrdRsh(s, index) & 0b111111)
+                       sel, sign = boothW5(s.Rsh(index) & 0b111111)
                } else {
                        // Booth encoding considers a virtual zero bit at index -1,
                        // so we shift left the least significant limb.
@@ -504,7 +534,7 @@ func (p *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error)
        return p, nil
 }
 
-// TODO
+// Negate sets p to -p, if cond == 1, and to p if cond == 0.
 func (p *P256Point) Negate(cond int) *P256Point {
        negY := new(fiat.P256Element)
        negY.Sub(negY, &p.y)
@@ -512,11 +542,16 @@ func (p *P256Point) Negate(cond int) *P256Point {
        return p
 }
 
+// p256AffineTable is a table of the first 32 multiples of a point. Points are
+// stored at an index offset of -1 like in p256Table, and [0]P is not stored.
 type p256AffineTable [32]p256AffinePoint
 
+// Select selects the n-th multiple of the table base point into p. It works in
+// constant time. n can be in [0, 32], but (unlike p256Table.Select) if n is 0,
+// p is set to an undefined value.
 func (table *p256AffineTable) Select(p *p256AffinePoint, n uint8) {
        if n > 32 {
-               panic("nistec: internal error: p256TableFive called with out-of-bounds value")
+               panic("nistec: internal error: p256AffineTable.Select called with out-of-bounds value")
        }
        for i := uint8(1); i <= 32; i++ {
                cond := subtle.ConstantTimeByteEq(i, n)
@@ -558,23 +593,24 @@ func boothW6(in uint64) (uint8, int) {
        return uint8(d), int(s & 1)
 }
 
-// ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
-// returns p.
+// ScalarBaseMult sets p = scalar * generator, where scalar is a 32-byte big
+// endian value, and returns r. If scalar is not 32 bytes long, ScalarBaseMult
+// returns an error and the receiver is unchanged.
 func (p *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) {
-       if len(scalar) != p256ElementLength {
-               return nil, errors.New("invalid scalar length")
-       }
-       s := new(p256OrdElement)
-       p256OrdBigToLittle(s, (*[32]byte)(scalar))
-       p256OrdReduce(s)
+       // This function works like ScalarMult above, but the table is fixed and
+       // "pre-doubled" for each iteration, so instead of doubling we move to the
+       // next table at each iteration.
 
-       p.Set(NewP256Point())
+       s, err := new(p256OrdElement).SetBytes(scalar)
+       if err != nil {
+               return nil, err
+       }
 
        // Start scanning the window from the most significant bits. We move by
        // 6 bits at a time and need to finish at -1, so -1 + 6 * 42 = 251.
        index := 251
 
-       sel, sign := boothW6(p256OrdRsh(s, index))
+       sel, sign := boothW6(s.Rsh(index))
        // sign is always zero because the boothW6 input here is at
        // most five bits long, so the top bit is never set.
        _ = sign
@@ -582,16 +618,19 @@ func (p *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) {
        t := &p256AffinePoint{}
        table := &p256GeneratorTables[(index+1)/6]
        table.Select(t, sel)
-       selIsNotZero := subtle.ConstantTimeByteEq(sel, 0) ^ 1
-       p.x.Select(&t.x, &p.x, selIsNotZero)
-       p.y.Select(&t.y, &p.y, selIsNotZero)
-       p.z.Select(new(fiat.P256Element).One(), &p.z, selIsNotZero)
+
+       // Select's output is undefined if the selector is zero, when it should be
+       // the point at infinity (because infinity can't be represented in affine
+       // coordinates). Here we conditionally set p to the infinity if sel is zero.
+       // In the loop, that's handled by AddAffine.
+       selIsZero := subtle.ConstantTimeByteEq(sel, 0)
+       p.Select(NewP256Point(), t.Projective(), selIsZero)
 
        for index >= 5 {
                index -= 6
 
                if index >= 0 {
-                       sel, sign = boothW6(p256OrdRsh(s, index) & 0b1111111)
+                       sel, sign = boothW6(s.Rsh(index) & 0b1111111)
                } else {
                        // Booth encoding considers a virtual zero bit at index -1,
                        // so we shift left the least significant limb.
@@ -599,18 +638,17 @@ func (p *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) {
                        sel, sign = boothW6(wvalue)
                }
 
-               selIsNotZero := subtle.ConstantTimeByteEq(sel, 0) ^ 1
-
                table := &p256GeneratorTables[(index+1)/6]
                table.Select(t, sel)
                t.Negate(sign)
-               p.AddAffine(p, t, selIsNotZero)
+               selIsZero := subtle.ConstantTimeByteEq(sel, 0)
+               p.AddAffine(p, t, selIsZero)
        }
 
        return p, nil
 }
 
-// TODO
+// Negate sets p to -p, if cond == 1, and to p if cond == 0.
 func (p *p256AffinePoint) Negate(cond int) *p256AffinePoint {
        negY := new(fiat.P256Element)
        negY.Sub(negY, &p.y)
@@ -621,18 +659,8 @@ func (p *p256AffinePoint) Negate(cond int) *p256AffinePoint {
 // p256Sqrt sets e to a square root of x. If x is not a square, p256Sqrt returns
 // false and e is unchanged. e and x can overlap.
 func p256Sqrt(e, x *fiat.P256Element) (isSquare bool) {
-       candidate := new(fiat.P256Element)
-       p256SqrtCandidate(candidate, x)
-       square := new(fiat.P256Element).Square(candidate)
-       if square.Equal(x) != 1 {
-               return false
-       }
-       e.Set(candidate)
-       return true
-}
+       t0, t1 := new(fiat.P256Element), new(fiat.P256Element)
 
-// p256SqrtCandidate sets z to a square root candidate for x. z and x must not overlap.
-func p256SqrtCandidate(z, x *fiat.P256Element) {
        // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate.
        //
        // The sequence of 7 multiplications and 253 squarings is derived from the
@@ -648,39 +676,35 @@ func p256SqrtCandidate(z, x *fiat.P256Element) {
        //      x32       = x16 << 16 + x16
        //      return      ((x32 << 32 + 1) << 96 + 1) << 94
        //
-       var t0 = new(fiat.P256Element)
-
-       z.Square(x)
-       z.Mul(x, z)
-       t0.Square(z)
-       for s := 1; s < 2; s++ {
-               t0.Square(t0)
-       }
-       z.Mul(z, t0)
-       t0.Square(z)
-       for s := 1; s < 4; s++ {
-               t0.Square(t0)
-       }
-       z.Mul(z, t0)
-       t0.Square(z)
-       for s := 1; s < 8; s++ {
-               t0.Square(t0)
-       }
-       z.Mul(z, t0)
-       t0.Square(z)
-       for s := 1; s < 16; s++ {
-               t0.Square(t0)
-       }
-       z.Mul(z, t0)
-       for s := 0; s < 32; s++ {
-               z.Square(z)
-       }
-       z.Mul(x, z)
-       for s := 0; s < 96; s++ {
-               z.Square(z)
+       p256Square(t0, x, 1)
+       t0.Mul(x, t0)
+       p256Square(t1, t0, 2)
+       t0.Mul(t0, t1)
+       p256Square(t1, t0, 4)
+       t0.Mul(t0, t1)
+       p256Square(t1, t0, 8)
+       t0.Mul(t0, t1)
+       p256Square(t1, t0, 16)
+       t0.Mul(t0, t1)
+       p256Square(t0, t0, 32)
+       t0.Mul(x, t0)
+       p256Square(t0, t0, 96)
+       t0.Mul(x, t0)
+       p256Square(t0, t0, 94)
+
+       // Check if the candidate t0 is indeed a square root of x.
+       t1.Square(t0)
+       if t1.Equal(x) != 1 {
+               return false
        }
-       z.Mul(x, z)
-       for s := 0; s < 94; s++ {
-               z.Square(z)
+       e.Set(t0)
+       return true
+}
+
+// p256Square sets e to the square of x, repeated n times > 1.
+func p256Square(e, x *fiat.P256Element, n int) {
+       e.Square(x)
+       for i := 1; i < n; i++ {
+               e.Square(e)
        }
 }