From: Filippo Valsorda Date: Sat, 1 May 2021 05:28:16 +0000 (-0400) Subject: crypto/elliptic: make P-521 scalar multiplication constant time X-Git-Tag: go1.17beta1~194 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=ea93e6885847b50bf4e6d3f263843f9c4e8d15f8;p=gostls13.git crypto/elliptic: make P-521 scalar multiplication constant time Like for P-224, we do the constant time selects to hide the point-at-infinity special cases of addition, but not the P = Q case, which presumably doesn't happen in normal operations. Runtime increases by about 50%, as expected, since on average we were able to skip half the additions, and the additions reasonably amounted to half the runtime. Still, the Fiat code is so much faster than big.Int that we're still more than three time faster overall than pre-CL 315271. name old time/op new time/op delta pkg:crypto/elliptic goos:darwin goarch:arm64 ScalarBaseMult/P521-8 4.18ms ± 3% 1.35ms ± 1% -67.64% (p=0.000 n=10+10) ScalarMult/P521-8 4.17ms ± 2% 1.36ms ± 1% -67.45% (p=0.000 n=10+10) pkg:crypto/ecdsa goos:darwin goarch:arm64 Sign/P521-8 4.23ms ± 1% 1.44ms ± 1% -66.02% (p=0.000 n=9+10) Verify/P521-8 8.31ms ± 2% 2.73ms ± 2% -67.08% (p=0.000 n=9+9) GenerateKey/P521-8 4.15ms ± 2% 1.35ms ± 2% -67.41% (p=0.000 n=10+10) Updates #40171 Change-Id: I782f2b7f33dd60af9b3b75e46d920d4cb47f719f Reviewed-on: https://go-review.googlesource.com/c/go/+/315274 Run-TryBot: Filippo Valsorda TryBot-Result: Go Bot Trust: Filippo Valsorda Trust: Katie Hockman Reviewed-by: Katie Hockman --- diff --git a/src/crypto/elliptic/p521.go b/src/crypto/elliptic/p521.go index ac9de63702..ce74e0539c 100644 --- a/src/crypto/elliptic/p521.go +++ b/src/crypto/elliptic/p521.go @@ -125,18 +125,8 @@ func (curve p521Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { // addJacobian sets q = p1 + p2, and returns q. The points may overlap. func (q *p512Point) addJacobian(p1, p2 *p512Point) *p512Point { // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl - if p1.z.IsZero() == 1 { - q.x.Set(p2.x) - q.y.Set(p2.y) - q.z.Set(p2.z) - return q - } - if p2.z.IsZero() == 1 { - q.x.Set(p1.x) - q.y.Set(p1.y) - q.z.Set(p1.z) - return q - } + z1IsZero := p1.z.IsZero() + z2IsZero := p2.z.IsZero() z1z1 := new(fiat.P521Element).Square(p1.z) z2z2 := new(fiat.P521Element).Square(p2.z) @@ -155,31 +145,41 @@ func (q *p512Point) addJacobian(p1, p2 *p512Point) *p512Point { s2.Mul(s2, z1z1) r := new(fiat.P521Element).Sub(s2, s1) yEqual := r.IsZero() == 1 - if xEqual && yEqual { + if xEqual && yEqual && z1IsZero == 0 && z2IsZero == 0 { return q.doubleJacobian(p1) } r.Add(r, r) v := new(fiat.P521Element).Mul(u1, i) - q.x.Set(r) - q.x.Square(q.x) - q.x.Sub(q.x, j) - q.x.Sub(q.x, v) - q.x.Sub(q.x, v) + x := new(fiat.P521Element).Set(r) + x.Square(x) + x.Sub(x, j) + x.Sub(x, v) + x.Sub(x, v) - q.y.Set(r) - v.Sub(v, q.x) - q.y.Mul(q.y, v) + y := new(fiat.P521Element).Set(r) + v.Sub(v, x) + y.Mul(y, v) s1.Mul(s1, j) s1.Add(s1, s1) - q.y.Sub(q.y, s1) - - q.z.Add(p1.z, p2.z) - q.z.Square(q.z) - q.z.Sub(q.z, z1z1) - q.z.Sub(q.z, z2z2) - q.z.Mul(q.z, h) - + y.Sub(y, s1) + + z := new(fiat.P521Element).Add(p1.z, p2.z) + z.Square(z) + z.Sub(z, z1z1) + z.Sub(z, z2z2) + z.Mul(z, h) + + x.Select(p2.x, x, z1IsZero) + x.Select(p1.x, x, z2IsZero) + y.Select(p2.y, y, z1IsZero) + y.Select(p1.y, y, z2IsZero) + z.Select(p2.z, z, z1IsZero) + z.Select(p1.z, z, z2IsZero) + + q.x.Set(x) + q.y.Set(y) + q.z.Set(z) return q } @@ -228,21 +228,26 @@ func (q *p512Point) doubleJacobian(p *p512Point) *p512Point { return q } -func (curve p521Curve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { +func (curve p521Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) { B := curve.jacobianFromAffine(Bx, By) - p := &p512Point{ + p, t := &p512Point{ + x: new(fiat.P521Element), + y: new(fiat.P521Element), + z: new(fiat.P521Element), + }, &p512Point{ x: new(fiat.P521Element), y: new(fiat.P521Element), z: new(fiat.P521Element), } - for _, byte := range k { + for _, byte := range scalar { for bitNum := 0; bitNum < 8; bitNum++ { p.doubleJacobian(p) - if byte&0x80 == 0x80 { - p.addJacobian(B, p) - } - byte <<= 1 + bit := (byte >> (7 - bitNum)) & 1 + t.addJacobian(p, B) + p.x.Select(t.x, p.x, int(bit)) + p.y.Select(t.y, p.y, int(bit)) + p.z.Select(t.z, p.z, int(bit)) } }