From ea93e6885847b50bf4e6d3f263843f9c4e8d15f8 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sat, 1 May 2021 01:28:16 -0400 Subject: [PATCH] crypto/elliptic: make P-521 scalar multiplication constant time MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Like for P-224, we do the constant time selects to hide the point-at-infinity special cases of addition, but not the P = Q case, which presumably doesn't happen in normal operations. Runtime increases by about 50%, as expected, since on average we were able to skip half the additions, and the additions reasonably amounted to half the runtime. Still, the Fiat code is so much faster than big.Int that we're still more than three time faster overall than pre-CL 315271. name old time/op new time/op delta pkg:crypto/elliptic goos:darwin goarch:arm64 ScalarBaseMult/P521-8 4.18ms ± 3% 1.35ms ± 1% -67.64% (p=0.000 n=10+10) ScalarMult/P521-8 4.17ms ± 2% 1.36ms ± 1% -67.45% (p=0.000 n=10+10) pkg:crypto/ecdsa goos:darwin goarch:arm64 Sign/P521-8 4.23ms ± 1% 1.44ms ± 1% -66.02% (p=0.000 n=9+10) Verify/P521-8 8.31ms ± 2% 2.73ms ± 2% -67.08% (p=0.000 n=9+9) GenerateKey/P521-8 4.15ms ± 2% 1.35ms ± 2% -67.41% (p=0.000 n=10+10) Updates #40171 Change-Id: I782f2b7f33dd60af9b3b75e46d920d4cb47f719f Reviewed-on: https://go-review.googlesource.com/c/go/+/315274 Run-TryBot: Filippo Valsorda TryBot-Result: Go Bot Trust: Filippo Valsorda Trust: Katie Hockman Reviewed-by: Katie Hockman --- src/crypto/elliptic/p521.go | 77 ++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/src/crypto/elliptic/p521.go b/src/crypto/elliptic/p521.go index ac9de63702..ce74e0539c 100644 --- a/src/crypto/elliptic/p521.go +++ b/src/crypto/elliptic/p521.go @@ -125,18 +125,8 @@ func (curve p521Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { // addJacobian sets q = p1 + p2, and returns q. The points may overlap. func (q *p512Point) addJacobian(p1, p2 *p512Point) *p512Point { // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl - if p1.z.IsZero() == 1 { - q.x.Set(p2.x) - q.y.Set(p2.y) - q.z.Set(p2.z) - return q - } - if p2.z.IsZero() == 1 { - q.x.Set(p1.x) - q.y.Set(p1.y) - q.z.Set(p1.z) - return q - } + z1IsZero := p1.z.IsZero() + z2IsZero := p2.z.IsZero() z1z1 := new(fiat.P521Element).Square(p1.z) z2z2 := new(fiat.P521Element).Square(p2.z) @@ -155,31 +145,41 @@ func (q *p512Point) addJacobian(p1, p2 *p512Point) *p512Point { s2.Mul(s2, z1z1) r := new(fiat.P521Element).Sub(s2, s1) yEqual := r.IsZero() == 1 - if xEqual && yEqual { + if xEqual && yEqual && z1IsZero == 0 && z2IsZero == 0 { return q.doubleJacobian(p1) } r.Add(r, r) v := new(fiat.P521Element).Mul(u1, i) - q.x.Set(r) - q.x.Square(q.x) - q.x.Sub(q.x, j) - q.x.Sub(q.x, v) - q.x.Sub(q.x, v) + x := new(fiat.P521Element).Set(r) + x.Square(x) + x.Sub(x, j) + x.Sub(x, v) + x.Sub(x, v) - q.y.Set(r) - v.Sub(v, q.x) - q.y.Mul(q.y, v) + y := new(fiat.P521Element).Set(r) + v.Sub(v, x) + y.Mul(y, v) s1.Mul(s1, j) s1.Add(s1, s1) - q.y.Sub(q.y, s1) - - q.z.Add(p1.z, p2.z) - q.z.Square(q.z) - q.z.Sub(q.z, z1z1) - q.z.Sub(q.z, z2z2) - q.z.Mul(q.z, h) - + y.Sub(y, s1) + + z := new(fiat.P521Element).Add(p1.z, p2.z) + z.Square(z) + z.Sub(z, z1z1) + z.Sub(z, z2z2) + z.Mul(z, h) + + x.Select(p2.x, x, z1IsZero) + x.Select(p1.x, x, z2IsZero) + y.Select(p2.y, y, z1IsZero) + y.Select(p1.y, y, z2IsZero) + z.Select(p2.z, z, z1IsZero) + z.Select(p1.z, z, z2IsZero) + + q.x.Set(x) + q.y.Set(y) + q.z.Set(z) return q } @@ -228,21 +228,26 @@ func (q *p512Point) doubleJacobian(p *p512Point) *p512Point { return q } -func (curve p521Curve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { +func (curve p521Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) { B := curve.jacobianFromAffine(Bx, By) - p := &p512Point{ + p, t := &p512Point{ + x: new(fiat.P521Element), + y: new(fiat.P521Element), + z: new(fiat.P521Element), + }, &p512Point{ x: new(fiat.P521Element), y: new(fiat.P521Element), z: new(fiat.P521Element), } - for _, byte := range k { + for _, byte := range scalar { for bitNum := 0; bitNum < 8; bitNum++ { p.doubleJacobian(p) - if byte&0x80 == 0x80 { - p.addJacobian(B, p) - } - byte <<= 1 + bit := (byte >> (7 - bitNum)) & 1 + t.addJacobian(p, B) + p.x.Select(t.x, p.x, int(bit)) + p.y.Select(t.y, p.y, int(bit)) + p.z.Select(t.z, p.z, int(bit)) } } -- 2.50.0