]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/nistec: use Booth multiplication in purego P-256
authorFilippo Valsorda <filippo@golang.org>
Fri, 26 Jul 2024 20:35:50 +0000 (22:35 +0200)
committerGopher Robot <gobot@golang.org>
Tue, 19 Nov 2024 22:30:05 +0000 (22:30 +0000)
Brings ScalarMult from 71 adds/op + 259 doubles/op to 58 adds/op + 263
doubles/op and ScalarBaseMult from 64 adds/op to 42 adds/op, matching
the assembly scalar multiplication algorithm.

Change-Id: I6603b52d1c3b2c25ace471bd36044149f6e8cfab
Reviewed-on: https://go-review.googlesource.com/c/go/+/627937
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Russ Cox <rsc@golang.org>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
src/crypto/internal/nistec/p256.go

index f2dfbbb1ee9b43cd92621804aca3c51e487e92f8..16a43a5ced8f70f14b3a7ce2735d69200fde8102 100644 (file)
@@ -2,8 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Code generated by generate.go. DO NOT EDIT.
-
 //go:build (!amd64 && !arm64 && !ppc64le && !s390x) || purego
 
 package nistec
@@ -12,6 +10,8 @@ import (
        "crypto/internal/nistec/fiat"
        "crypto/subtle"
        "errors"
+       "internal/byteorder"
+       "math/bits"
        "sync"
 )
 
@@ -322,91 +322,194 @@ func (q *P256Point) Select(p1, p2 *P256Point, cond int) *P256Point {
        return q
 }
 
-// A p256Table holds the first 15 multiples of a point at offset -1, so [1]P
-// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
+// p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the
+// Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order.
+type p256OrdElement [4]uint64
+
+// p256OrdReduce ensures s is in the range [0, ord(G)-1].
+func p256OrdReduce(s *p256OrdElement) {
+       // Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G),
+       // keeping the result if it doesn't underflow.
+       t0, b := bits.Sub64(s[0], 0xf3b9cac2fc632551, 0)
+       t1, b := bits.Sub64(s[1], 0xbce6faada7179e84, b)
+       t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b)
+       t3, b := bits.Sub64(s[3], 0xffffffff00000000, b)
+       tMask := b - 1 // zero if subtraction underflowed
+       s[0] ^= (t0 ^ s[0]) & tMask
+       s[1] ^= (t1 ^ s[1]) & tMask
+       s[2] ^= (t2 ^ s[2]) & tMask
+       s[3] ^= (t3 ^ s[3]) & tMask
+}
+
+func p256OrdLittleToBig(b *[32]byte, l *p256OrdElement) {
+       byteorder.BePutUint64(b[24:], l[0])
+       byteorder.BePutUint64(b[16:], l[1])
+       byteorder.BePutUint64(b[8:], l[2])
+       byteorder.BePutUint64(b[:], l[3])
+}
+
+func p256OrdBigToLittle(l *p256OrdElement, b *[32]byte) {
+       l[0] = byteorder.BeUint64(b[24:])
+       l[1] = byteorder.BeUint64(b[16:])
+       l[2] = byteorder.BeUint64(b[8:])
+       l[3] = byteorder.BeUint64(b[:])
+}
+
+// p256OrdRsh returns the 64 least significant bits of x >> n. n must be lower
+// than 256. The value of n leaks through timing side-channels.
+func p256OrdRsh(x *p256OrdElement, n int) uint64 {
+       i := n / 64
+       n = n % 64
+       res := x[i] >> n
+       // Shift in the more significant limb, if present.
+       if i := i + 1; i < len(x) {
+               res |= x[i] << (64 - n)
+       }
+       return res
+}
+
+// A p256Table holds the first 16 multiples of a point at offset -1, so [1]P
+// is at table[0], [16]P is at table[15], and [0]P is implicitly the identity
 // point.
-type p256Table [15]*P256Point
+type p256Table [16]*P256Point
 
 // Select selects the n-th multiple of the table base point into p. It works in
-// constant time by iterating over every entry of the table. n must be in [0, 15].
+// constant time by iterating over every entry of the table. n must be in [0, 16].
+// If n is 0, p is implicitly set to the identity point.
 func (table *p256Table) Select(p *P256Point, n uint8) {
-       if n >= 16 {
+       if n > 16 {
                panic("nistec: internal error: p256Table called with out-of-bounds value")
        }
        p.Set(NewP256Point())
-       for i := uint8(1); i < 16; i++ {
+       for i := uint8(1); i <= 16; i++ {
                cond := subtle.ConstantTimeByteEq(i, n)
                p.Select(table[i-1], p, cond)
        }
 }
 
+func boothW5(in uint64) (uint8, int) {
+       s := ^((in >> 5) - 1)
+       d := (1 << 6) - in - 1
+       d = (d & s) | (in & (^s))
+       d = (d >> 1) + (d & 1)
+       return uint8(d), int(s & 1)
+}
+
 // ScalarMult sets p = scalar * q, and returns p.
 func (p *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error) {
+       if len(scalar) != p256ElementLength {
+               return nil, errors.New("invalid scalar length")
+       }
+       s := new(p256OrdElement)
+       p256OrdBigToLittle(s, (*[32]byte)(scalar))
+       p256OrdReduce(s)
+
        // Compute a p256Table for the base point q. The explicit NewP256Point
        // calls get inlined, letting the allocations live on the stack.
-       var table = p256Table{NewP256Point(), NewP256Point(), NewP256Point(),
+       var table = p256Table{
+               NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(),
                NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(),
                NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(),
                NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point()}
        table[0].Set(q)
-       for i := 1; i < 15; i += 2 {
+       for i := 1; i < 16; i += 2 {
                table[i].Double(table[i/2])
-               table[i+1].Add(table[i], q)
+               if i+1 < 16 {
+                       table[i+1].Add(table[i], q)
+               }
        }
 
-       // Instead of doing the classic double-and-add chain, we do it with a
-       // four-bit window: we double four times, and then add [0-15]P.
-       t := NewP256Point()
-       p.Set(NewP256Point())
-       for i, byte := range scalar {
-               // No need to double on the first iteration, as p is the identity at
-               // this point, and [N]∞ = ∞.
-               if i != 0 {
-                       p.Double(p)
-                       p.Double(p)
-                       p.Double(p)
-                       p.Double(p)
-               }
+       // Start scanning the window from the most significant bits. We move by
+       // 5 bits at a time and need to finish at -1, so -1 + 5 * 51 = 254.
+       index := 254
 
-               windowValue := byte >> 4
-               table.Select(t, windowValue)
-               p.Add(p, t)
+       sel, sign := boothW5(p256OrdRsh(s, index))
+       // sign is always zero because the boothW5 input here is at
+       // most two bits long, so the top bit is never set.
+       _ = sign
+
+       table.Select(p, sel)
+
+       t := NewP256Point()
+       for index >= 4 {
+               index -= 5
 
                p.Double(p)
                p.Double(p)
                p.Double(p)
                p.Double(p)
+               p.Double(p)
+
+               if index >= 0 {
+                       sel, sign = boothW5(p256OrdRsh(s, index) & 0b111111)
+               } else {
+                       // Booth encoding considers a virtual zero bit at index -1,
+                       // so we shift left the least significant limb.
+                       wvalue := (s[0] << 1) & 0b111111
+                       sel, sign = boothW5(wvalue)
+               }
 
-               windowValue = byte & 0b1111
-               table.Select(t, windowValue)
+               table.Select(t, sel)
+               t.Negate(sign)
                p.Add(p, t)
        }
 
        return p, nil
 }
 
-var p256GeneratorTable *[p256ElementLength * 2]p256Table
+// TODO
+func (p *P256Point) Negate(cond int) *P256Point {
+       negY := new(fiat.P256Element)
+       negY.Sub(negY, p.y)
+       p.y.Select(negY, p.y, cond)
+       return p
+}
+
+type p256TableFive [32]*P256Point
+
+func (table *p256TableFive) Select(p *P256Point, n uint8) {
+       if n > 32 {
+               panic("nistec: internal error: p256TableFive called with out-of-bounds value")
+       }
+       p.Set(NewP256Point())
+       for i := uint8(1); i <= 32; i++ {
+               cond := subtle.ConstantTimeByteEq(i, n)
+               p.Select(table[i-1], p, cond)
+       }
+}
+
+var _p256GeneratorTable *[43]p256TableFive
 var p256GeneratorTableOnce sync.Once
 
-// generatorTable returns a sequence of p256Tables. The first table contains
+// p256GeneratorTable returns a sequence of p256Tables. The first table contains
 // multiples of G. Each successive table is the previous table doubled four
 // times.
-func (p *P256Point) generatorTable() *[p256ElementLength * 2]p256Table {
+func p256GeneratorTable() *[43]p256TableFive {
        p256GeneratorTableOnce.Do(func() {
-               p256GeneratorTable = new([p256ElementLength * 2]p256Table)
+               _p256GeneratorTable = new([43]p256TableFive)
                base := NewP256Point().SetGenerator()
-               for i := 0; i < p256ElementLength*2; i++ {
-                       p256GeneratorTable[i][0] = NewP256Point().Set(base)
-                       for j := 1; j < 15; j++ {
-                               p256GeneratorTable[i][j] = NewP256Point().Add(p256GeneratorTable[i][j-1], base)
+               for i := 0; i < 43; i++ {
+                       _p256GeneratorTable[i][0] = NewP256Point().Set(base)
+                       for j := 1; j < 32; j++ {
+                               _p256GeneratorTable[i][j] = NewP256Point().Add(_p256GeneratorTable[i][j-1], base)
                        }
                        base.Double(base)
                        base.Double(base)
                        base.Double(base)
                        base.Double(base)
+                       base.Double(base)
+                       base.Double(base)
                }
        })
-       return p256GeneratorTable
+       return _p256GeneratorTable
+}
+
+func boothW6(in uint64) (uint8, int) {
+       s := ^((in >> 6) - 1)
+       d := (1 << 7) - in - 1
+       d = (d & s) | (in & (^s))
+       d = (d >> 1) + (d & 1)
+       return uint8(d), int(s & 1)
 }
 
 // ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
@@ -415,27 +518,40 @@ func (p *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) {
        if len(scalar) != p256ElementLength {
                return nil, errors.New("invalid scalar length")
        }
-       tables := p.generatorTable()
-
-       // This is also a scalar multiplication with a four-bit window like in
-       // ScalarMult, but in this case the doublings are precomputed. The value
-       // [windowValue]G added at iteration k would normally get doubled
-       // (totIterations-k)×4 times, but with a larger precomputation we can
-       // instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
-       // doublings between iterations.
+       s := new(p256OrdElement)
+       p256OrdBigToLittle(s, (*[32]byte)(scalar))
+       p256OrdReduce(s)
+       tables := p256GeneratorTable()
+
+       // Start scanning the window from the most significant bits. We move by
+       // 6 bits at a time and need to finish at -1, so -1 + 6 * 42 = 251.
+       index := 251
+
+       sel, sign := boothW6(p256OrdRsh(s, index))
+       // sign is always zero because the boothW6 input here is at
+       // most five bits long, so the top bit is never set.
+       _ = sign
+
+       table := &tables[(index+1)/6]
+       table.Select(p, sel)
+
        t := NewP256Point()
-       p.Set(NewP256Point())
-       tableIndex := len(tables) - 1
-       for _, byte := range scalar {
-               windowValue := byte >> 4
-               tables[tableIndex].Select(t, windowValue)
-               p.Add(p, t)
-               tableIndex--
+       for index >= 5 {
+               index -= 6
+
+               if index >= 0 {
+                       sel, sign = boothW6(p256OrdRsh(s, index) & 0b1111111)
+               } else {
+                       // Booth encoding considers a virtual zero bit at index -1,
+                       // so we shift left the least significant limb.
+                       wvalue := (s[0] << 1) & 0b1111111
+                       sel, sign = boothW6(wvalue)
+               }
 
-               windowValue = byte & 0b1111
-               tables[tableIndex].Select(t, windowValue)
+               table := &tables[(index+1)/6]
+               table.Select(t, sel)
+               t.Negate(sign)
                p.Add(p, t)
-               tableIndex--
        }
 
        return p, nil