From 298defcb54b88c4a5cbdf493b3b66a448fa53f0e Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Fri, 26 Jul 2024 22:35:50 +0200 Subject: [PATCH] crypto/internal/nistec: use Booth multiplication in purego P-256 Brings ScalarMult from 71 adds/op + 259 doubles/op to 58 adds/op + 263 doubles/op and ScalarBaseMult from 64 adds/op to 42 adds/op, matching the assembly scalar multiplication algorithm. Change-Id: I6603b52d1c3b2c25ace471bd36044149f6e8cfab Reviewed-on: https://go-review.googlesource.com/c/go/+/627937 Reviewed-by: Daniel McCarney LUCI-TryBot-Result: Go LUCI Reviewed-by: Russ Cox Auto-Submit: Filippo Valsorda Reviewed-by: Dmitri Shuralyov --- src/crypto/internal/nistec/p256.go | 228 ++++++++++++++++++++++------- 1 file changed, 172 insertions(+), 56 deletions(-) diff --git a/src/crypto/internal/nistec/p256.go b/src/crypto/internal/nistec/p256.go index f2dfbbb1ee..16a43a5ced 100644 --- a/src/crypto/internal/nistec/p256.go +++ b/src/crypto/internal/nistec/p256.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Code generated by generate.go. DO NOT EDIT. - //go:build (!amd64 && !arm64 && !ppc64le && !s390x) || purego package nistec @@ -12,6 +10,8 @@ import ( "crypto/internal/nistec/fiat" "crypto/subtle" "errors" + "internal/byteorder" + "math/bits" "sync" ) @@ -322,91 +322,194 @@ func (q *P256Point) Select(p1, p2 *P256Point, cond int) *P256Point { return q } -// A p256Table holds the first 15 multiples of a point at offset -1, so [1]P -// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity +// p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the +// Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order. +type p256OrdElement [4]uint64 + +// p256OrdReduce ensures s is in the range [0, ord(G)-1]. +func p256OrdReduce(s *p256OrdElement) { + // Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G), + // keeping the result if it doesn't underflow. + t0, b := bits.Sub64(s[0], 0xf3b9cac2fc632551, 0) + t1, b := bits.Sub64(s[1], 0xbce6faada7179e84, b) + t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b) + t3, b := bits.Sub64(s[3], 0xffffffff00000000, b) + tMask := b - 1 // zero if subtraction underflowed + s[0] ^= (t0 ^ s[0]) & tMask + s[1] ^= (t1 ^ s[1]) & tMask + s[2] ^= (t2 ^ s[2]) & tMask + s[3] ^= (t3 ^ s[3]) & tMask +} + +func p256OrdLittleToBig(b *[32]byte, l *p256OrdElement) { + byteorder.BePutUint64(b[24:], l[0]) + byteorder.BePutUint64(b[16:], l[1]) + byteorder.BePutUint64(b[8:], l[2]) + byteorder.BePutUint64(b[:], l[3]) +} + +func p256OrdBigToLittle(l *p256OrdElement, b *[32]byte) { + l[0] = byteorder.BeUint64(b[24:]) + l[1] = byteorder.BeUint64(b[16:]) + l[2] = byteorder.BeUint64(b[8:]) + l[3] = byteorder.BeUint64(b[:]) +} + +// p256OrdRsh returns the 64 least significant bits of x >> n. n must be lower +// than 256. The value of n leaks through timing side-channels. +func p256OrdRsh(x *p256OrdElement, n int) uint64 { + i := n / 64 + n = n % 64 + res := x[i] >> n + // Shift in the more significant limb, if present. + if i := i + 1; i < len(x) { + res |= x[i] << (64 - n) + } + return res +} + +// A p256Table holds the first 16 multiples of a point at offset -1, so [1]P +// is at table[0], [16]P is at table[15], and [0]P is implicitly the identity // point. -type p256Table [15]*P256Point +type p256Table [16]*P256Point // Select selects the n-th multiple of the table base point into p. It works in -// constant time by iterating over every entry of the table. n must be in [0, 15]. +// constant time by iterating over every entry of the table. n must be in [0, 16]. +// If n is 0, p is implicitly set to the identity point. func (table *p256Table) Select(p *P256Point, n uint8) { - if n >= 16 { + if n > 16 { panic("nistec: internal error: p256Table called with out-of-bounds value") } p.Set(NewP256Point()) - for i := uint8(1); i < 16; i++ { + for i := uint8(1); i <= 16; i++ { cond := subtle.ConstantTimeByteEq(i, n) p.Select(table[i-1], p, cond) } } +func boothW5(in uint64) (uint8, int) { + s := ^((in >> 5) - 1) + d := (1 << 6) - in - 1 + d = (d & s) | (in & (^s)) + d = (d >> 1) + (d & 1) + return uint8(d), int(s & 1) +} + // ScalarMult sets p = scalar * q, and returns p. func (p *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error) { + if len(scalar) != p256ElementLength { + return nil, errors.New("invalid scalar length") + } + s := new(p256OrdElement) + p256OrdBigToLittle(s, (*[32]byte)(scalar)) + p256OrdReduce(s) + // Compute a p256Table for the base point q. The explicit NewP256Point // calls get inlined, letting the allocations live on the stack. - var table = p256Table{NewP256Point(), NewP256Point(), NewP256Point(), + var table = p256Table{ + NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point(), NewP256Point()} table[0].Set(q) - for i := 1; i < 15; i += 2 { + for i := 1; i < 16; i += 2 { table[i].Double(table[i/2]) - table[i+1].Add(table[i], q) + if i+1 < 16 { + table[i+1].Add(table[i], q) + } } - // Instead of doing the classic double-and-add chain, we do it with a - // four-bit window: we double four times, and then add [0-15]P. - t := NewP256Point() - p.Set(NewP256Point()) - for i, byte := range scalar { - // No need to double on the first iteration, as p is the identity at - // this point, and [N]∞ = ∞. - if i != 0 { - p.Double(p) - p.Double(p) - p.Double(p) - p.Double(p) - } + // Start scanning the window from the most significant bits. We move by + // 5 bits at a time and need to finish at -1, so -1 + 5 * 51 = 254. + index := 254 - windowValue := byte >> 4 - table.Select(t, windowValue) - p.Add(p, t) + sel, sign := boothW5(p256OrdRsh(s, index)) + // sign is always zero because the boothW5 input here is at + // most two bits long, so the top bit is never set. + _ = sign + + table.Select(p, sel) + + t := NewP256Point() + for index >= 4 { + index -= 5 p.Double(p) p.Double(p) p.Double(p) p.Double(p) + p.Double(p) + + if index >= 0 { + sel, sign = boothW5(p256OrdRsh(s, index) & 0b111111) + } else { + // Booth encoding considers a virtual zero bit at index -1, + // so we shift left the least significant limb. + wvalue := (s[0] << 1) & 0b111111 + sel, sign = boothW5(wvalue) + } - windowValue = byte & 0b1111 - table.Select(t, windowValue) + table.Select(t, sel) + t.Negate(sign) p.Add(p, t) } return p, nil } -var p256GeneratorTable *[p256ElementLength * 2]p256Table +// TODO +func (p *P256Point) Negate(cond int) *P256Point { + negY := new(fiat.P256Element) + negY.Sub(negY, p.y) + p.y.Select(negY, p.y, cond) + return p +} + +type p256TableFive [32]*P256Point + +func (table *p256TableFive) Select(p *P256Point, n uint8) { + if n > 32 { + panic("nistec: internal error: p256TableFive called with out-of-bounds value") + } + p.Set(NewP256Point()) + for i := uint8(1); i <= 32; i++ { + cond := subtle.ConstantTimeByteEq(i, n) + p.Select(table[i-1], p, cond) + } +} + +var _p256GeneratorTable *[43]p256TableFive var p256GeneratorTableOnce sync.Once -// generatorTable returns a sequence of p256Tables. The first table contains +// p256GeneratorTable returns a sequence of p256Tables. The first table contains // multiples of G. Each successive table is the previous table doubled four // times. -func (p *P256Point) generatorTable() *[p256ElementLength * 2]p256Table { +func p256GeneratorTable() *[43]p256TableFive { p256GeneratorTableOnce.Do(func() { - p256GeneratorTable = new([p256ElementLength * 2]p256Table) + _p256GeneratorTable = new([43]p256TableFive) base := NewP256Point().SetGenerator() - for i := 0; i < p256ElementLength*2; i++ { - p256GeneratorTable[i][0] = NewP256Point().Set(base) - for j := 1; j < 15; j++ { - p256GeneratorTable[i][j] = NewP256Point().Add(p256GeneratorTable[i][j-1], base) + for i := 0; i < 43; i++ { + _p256GeneratorTable[i][0] = NewP256Point().Set(base) + for j := 1; j < 32; j++ { + _p256GeneratorTable[i][j] = NewP256Point().Add(_p256GeneratorTable[i][j-1], base) } base.Double(base) base.Double(base) base.Double(base) base.Double(base) + base.Double(base) + base.Double(base) } }) - return p256GeneratorTable + return _p256GeneratorTable +} + +func boothW6(in uint64) (uint8, int) { + s := ^((in >> 6) - 1) + d := (1 << 7) - in - 1 + d = (d & s) | (in & (^s)) + d = (d >> 1) + (d & 1) + return uint8(d), int(s & 1) } // ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and @@ -415,27 +518,40 @@ func (p *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) { if len(scalar) != p256ElementLength { return nil, errors.New("invalid scalar length") } - tables := p.generatorTable() - - // This is also a scalar multiplication with a four-bit window like in - // ScalarMult, but in this case the doublings are precomputed. The value - // [windowValue]G added at iteration k would normally get doubled - // (totIterations-k)×4 times, but with a larger precomputation we can - // instead add [2^((totIterations-k)×4)][windowValue]G and avoid the - // doublings between iterations. + s := new(p256OrdElement) + p256OrdBigToLittle(s, (*[32]byte)(scalar)) + p256OrdReduce(s) + tables := p256GeneratorTable() + + // Start scanning the window from the most significant bits. We move by + // 6 bits at a time and need to finish at -1, so -1 + 6 * 42 = 251. + index := 251 + + sel, sign := boothW6(p256OrdRsh(s, index)) + // sign is always zero because the boothW6 input here is at + // most five bits long, so the top bit is never set. + _ = sign + + table := &tables[(index+1)/6] + table.Select(p, sel) + t := NewP256Point() - p.Set(NewP256Point()) - tableIndex := len(tables) - 1 - for _, byte := range scalar { - windowValue := byte >> 4 - tables[tableIndex].Select(t, windowValue) - p.Add(p, t) - tableIndex-- + for index >= 5 { + index -= 6 + + if index >= 0 { + sel, sign = boothW6(p256OrdRsh(s, index) & 0b1111111) + } else { + // Booth encoding considers a virtual zero bit at index -1, + // so we shift left the least significant limb. + wvalue := (s[0] << 1) & 0b1111111 + sel, sign = boothW6(wvalue) + } - windowValue = byte & 0b1111 - tables[tableIndex].Select(t, windowValue) + table := &tables[(index+1)/6] + table.Select(t, sel) + t.Negate(sign) p.Add(p, t) - tableIndex-- } return p, nil -- 2.48.1