From: Filippo Valsorda Date: Sat, 25 Feb 2023 18:09:11 +0000 (+0100) Subject: crypto/internal/bigmod: switch to saturated limbs X-Git-Tag: go1.21rc1~262 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=7d96475daa09a89840c55dcc38335589dafccb33;p=gostls13.git crypto/internal/bigmod: switch to saturated limbs Turns out that unsaturated limbs being more performant for Montgomery multiplication was true in portable C89, but is now a misconception. With add-with-carry instructions, it's possible to run the carry chain across the limbs, instead of needing the limb-by-limb product to fit in two words. Switch to saturated limbs, and import the same Montgomery loop as math/big, along with its assembly for some architectures. Since here we know the sizes we care about, we can drop most of the assembly scaffolding. For amd64, ported to avo, too. We recover all the Go 1.20 performance loss on private key operations on both Intel Xeon and AMD EPYC, with even a 10% improvement over Go 1.19 (which used variable-time math/big) for some operations. goos: linux goarch: amd64 pkg: crypto/rsa cpu: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz │ go1.19.txt │ go1.20.txt │ new.txt │ │ sec/op │ sec/op vs base │ sec/op vs base │ DecryptPKCS1v15/2048-4 1.175m ± 0% 1.515m ± 0% +28.95% 1.132m ± 0% -3.59% DecryptPKCS1v15/3072-4 3.428m ± 1% 4.516m ± 0% +31.75% 3.198m ± 0% -6.69% DecryptPKCS1v15/4096-4 7.405m ± 0% 10.092m ± 0% +36.29% 6.446m ± 0% -12.95% EncryptPKCS1v15/2048-4 7.426µ ± 0% 170.829µ ± 0% +2200.57% 131.874µ ± 0% +1675.97% DecryptOAEP/2048-4 1.175m ± 0% 1.524m ± 0% +29.68% 1.137m ± 0% -3.26% EncryptOAEP/2048-4 9.609µ ± 0% 173.008µ ± 0% +1700.48% 132.344µ ± 0% +1277.29% SignPKCS1v15/2048-4 1.181m ± 0% 1.563m ± 0% +32.34% 1.177m ± 0% -0.37% VerifyPKCS1v15/2048-4 6.452µ ± 0% 170.092µ ± 0% +2536.06% 131.225µ ± 0% +1933.70% SignPSS/2048-4 1.184m ± 0% 1.574m ± 0% +32.88% 1.175m ± 0% -0.84% VerifyPSS/2048-4 9.151µ ± 1% 172.909µ ± 0% +1789.50% 132.391µ ± 0% +1346.74% │ go1.19.txt │ go1.20.txt │ new.txt │ │ B/op │ B/op vs base │ B/op vs base │ DecryptPKCS1v15/2048-4 24266.5 ± 0% 640.0 ± 0% -97.36% 640.0 ± 0% -97.36% DecryptPKCS1v15/3072-4 45.465Ki ± 0% 3.375Ki ± 0% -92.58% 4.688Ki ± 0% -89.69% DecryptPKCS1v15/4096-4 61.080Ki ± 0% 4.625Ki ± 0% -92.43% 6.250Ki ± 0% -89.77% EncryptPKCS1v15/2048-4 3.138Ki ± 0% 1.146Ki ± 0% -63.49% 1.082Ki ± 0% -65.52% DecryptOAEP/2048-4 24500.0 ± 0% 872.0 ± 0% -96.44% 872.0 ± 0% -96.44% EncryptOAEP/2048-4 3.610Ki ± 0% 1.371Ki ± 0% -62.02% 1.308Ki ± 0% -63.78% SignPKCS1v15/2048-4 26933.0 ± 0% 896.0 ± 0% -96.67% 896.0 ± 0% -96.67% VerifyPKCS1v15/2048-4 3209.0 ± 0% 912.0 ± 0% -71.58% 848.0 ± 0% -73.57% SignPSS/2048-4 26.940Ki ± 0% 1.266Ki ± 0% -95.30% 1.266Ki ± 0% -95.30% VerifyPSS/2048-4 3.337Ki ± 0% 1.094Ki ± 0% -67.22% 1.031Ki ± 0% -69.10% │ go1.19.txt │ go1.20.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ allocs/op vs base │ DecryptPKCS1v15/2048-4 97.000 ± 0% 4.000 ± 0% -95.88% 4.000 ± 0% -95.88% DecryptPKCS1v15/3072-4 107.00 ± 0% 10.00 ± 0% -90.65% 12.00 ± 0% -88.79% DecryptPKCS1v15/4096-4 113.00 ± 0% 10.00 ± 0% -91.15% 12.00 ± 0% -89.38% EncryptPKCS1v15/2048-4 7.000 ± 0% 7.000 ± 0% ~ 7.000 ± 0% ~ DecryptOAEP/2048-4 103.00 ± 0% 10.00 ± 0% -90.29% 10.00 ± 0% -90.29% EncryptOAEP/2048-4 14.00 ± 0% 13.00 ± 0% -7.14% 13.00 ± 0% -7.14% SignPKCS1v15/2048-4 102.000 ± 0% 5.000 ± 0% -95.10% 5.000 ± 0% -95.10% VerifyPKCS1v15/2048-4 7.000 ± 0% 6.000 ± 0% -14.29% 6.000 ± 0% -14.29% SignPSS/2048-4 108.00 ± 0% 10.00 ± 0% -90.74% 10.00 ± 0% -90.74% VerifyPSS/2048-4 12.00 ± 0% 11.00 ± 0% -8.33% 11.00 ± 0% -8.33% goos: linux goarch: amd64 pkg: crypto/rsa cpu: AMD EPYC 7R13 Processor │ go1.19a.txt │ go1.20a.txt │ newa.txt │ │ sec/op │ sec/op vs base │ sec/op vs base │ DecryptPKCS1v15/2048-4 970.0µ ± 0% 1667.6µ ± 0% +71.92% 951.6µ ± 0% -1.90% DecryptPKCS1v15/3072-4 2.949m ± 0% 5.124m ± 0% +73.75% 2.675m ± 0% -9.29% DecryptPKCS1v15/4096-4 6.350m ± 0% 11.660m ± 0% +83.62% 5.746m ± 0% -9.51% EncryptPKCS1v15/2048-4 6.605µ ± 1% 183.807µ ± 0% +2683.05% 123.720µ ± 0% +1773.27% DecryptOAEP/2048-4 973.8µ ± 0% 1670.8µ ± 0% +71.57% 951.8µ ± 0% -2.27% EncryptOAEP/2048-4 8.444µ ± 1% 185.889µ ± 0% +2101.56% 124.142µ ± 0% +1370.27% SignPKCS1v15/2048-4 976.8µ ± 0% 1725.5µ ± 0% +76.65% 979.6µ ± 0% +0.28% VerifyPKCS1v15/2048-4 5.713µ ± 0% 182.983µ ± 0% +3103.19% 122.737µ ± 0% +2048.56% SignPSS/2048-4 980.3µ ± 0% 1729.5µ ± 0% +76.42% 985.7µ ± 3% +0.55% VerifyPSS/2048-4 8.168µ ± 1% 185.312µ ± 0% +2168.76% 123.772µ ± 0% +1415.33% Fixes #59463 Fixes #59442 Updates #57752 Change-Id: I311a9c1f4f5288e47e53ca14f615a443f3132734 Reviewed-on: https://go-review.googlesource.com/c/go/+/471259 Reviewed-by: Matthew Dempsky Run-TryBot: Filippo Valsorda Auto-Submit: Filippo Valsorda Reviewed-by: Roland Shoemaker TryBot-Result: Gopher Robot --- diff --git a/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go b/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go index 5690f04d1e..bf64565d5c 100644 --- a/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go +++ b/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go @@ -1,131 +1,113 @@ -// Copyright 2022 The Go Authors. All rights reserved. +// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import ( + "strconv" + . "github.com/mmcloughlin/avo/build" . "github.com/mmcloughlin/avo/operand" . "github.com/mmcloughlin/avo/reg" ) -//go:generate go run . -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod +//go:generate go run . -out ../nat_amd64.s -pkg bigmod func main() { Package("crypto/internal/bigmod") - ConstraintExpr("amd64,gc,!purego") - - Implement("montgomeryLoop") - Pragma("noescape") - - size := Load(Param("d").Len(), GP64()) - d := Mem{Base: Load(Param("d").Base(), GP64())} - b := Mem{Base: Load(Param("b").Base(), GP64())} - m := Mem{Base: Load(Param("m").Base(), GP64())} - m0inv := Load(Param("m0inv"), GP64()) - - overflow := zero() - i := zero() - Label("outerLoop") - - ai := Load(Param("a").Base(), GP64()) - MOVQ(Mem{Base: ai}.Idx(i, 8), ai) - - z := uint128{GP64(), GP64()} - mul64(z, b, ai) - add64(z, d) - f := GP64() - MOVQ(m0inv, f) - IMULQ(z.lo, f) - _MASK(f) - addMul64(z, m, f) - carry := shiftBy63(z) - - j := zero() - INCQ(j) - JMP(LabelRef("innerLoopCondition")) - Label("innerLoop") - - // z = d[j] + a[i] * b[j] + f * m[j] + carry - z = uint128{GP64(), GP64()} - mul64(z, b.Idx(j, 8), ai) - addMul64(z, m.Idx(j, 8), f) - add64(z, d.Idx(j, 8)) - add64(z, carry) - // d[j-1] = z_lo & _MASK - storeMasked(z.lo, d.Idx(j, 8).Offset(-8)) - // carry = z_hi<<1 | z_lo>>_W - MOVQ(shiftBy63(z), carry) - - INCQ(j) - Label("innerLoopCondition") - CMPQ(size, j) - JGT(LabelRef("innerLoop")) - - ADDQ(carry, overflow) - storeMasked(overflow, d.Idx(size, 8).Offset(-8)) - SHRQ(Imm(63), overflow) - - INCQ(i) - CMPQ(size, i) - JGT(LabelRef("outerLoop")) - - Store(overflow, ReturnIndex(0)) - RET() - Generate() -} + ConstraintExpr("!purego") -// zero zeroes a new register and returns it. -func zero() Register { - r := GP64() - XORQ(r, r) - return r -} - -// _MASK masks out the top bit of r. -func _MASK(r Register) { - BTRQ(Imm(63), r) -} - -type uint128 struct { - hi, lo GPVirtual -} + addMulVVW(1024) + addMulVVW(1536) + addMulVVW(2048) -// storeMasked stores _MASK(src) in dst. It doesn't modify src. -func storeMasked(src, dst Op) { - out := GP64() - MOVQ(src, out) - _MASK(out) - MOVQ(out, dst) -} - -// shiftBy63 returns z >> 63. It reuses z.lo. -func shiftBy63(z uint128) Register { - SHRQ(Imm(63), z.hi, z.lo) - result := z.lo - z.hi, z.lo = nil, nil - return result -} - -// add64 sets r to r + a. -func add64(r uint128, a Op) { - ADDQ(a, r.lo) - ADCQ(Imm(0), r.hi) + Generate() } -// mul64 sets r to a * b. -func mul64(r uint128, a, b Op) { - MOVQ(a, RAX) - MULQ(b) // RDX, RAX = RAX * b - MOVQ(RAX, r.lo) - MOVQ(RDX, r.hi) -} +func addMulVVW(bits int) { + if bits%64 != 0 { + panic("bit size unsupported") + } + + Implement("addMulVVW" + strconv.Itoa(bits)) + + CMPB(Mem{Symbol: Symbol{Name: "·supportADX"}, Base: StaticBase}, Imm(1)) + JEQ(LabelRef("adx")) + + z := Mem{Base: Load(Param("z"), GP64())} + x := Mem{Base: Load(Param("x"), GP64())} + y := Load(Param("y"), GP64()) + + carry := GP64() + XORQ(carry, carry) // zero out carry + + for i := 0; i < bits/64; i++ { + Comment("Iteration " + strconv.Itoa(i)) + hi, lo := RDX, RAX // implicit MULQ inputs and outputs + MOVQ(x.Offset(i*8), lo) + MULQ(y) + ADDQ(z.Offset(i*8), lo) + ADCQ(Imm(0), hi) + ADDQ(carry, lo) + ADCQ(Imm(0), hi) + MOVQ(hi, carry) + MOVQ(lo, z.Offset(i*8)) + } + + Store(carry, ReturnIndex(0)) + RET() -// addMul64 sets r to r + a * b. -func addMul64(r uint128, a, b Op) { - MOVQ(a, RAX) - MULQ(b) // RDX, RAX = RAX * b - ADDQ(RAX, r.lo) - ADCQ(RDX, r.hi) + Label("adx") + + // The ADX strategy implements the following function, where c1 and c2 are + // the overflow and the carry flag respectively. + // + // func addMulVVW(z, x []uint, y uint) (carry uint) { + // var c1, c2 uint + // for i := range z { + // hi, lo := bits.Mul(x[i], y) + // lo, c1 = bits.Add(lo, z[i], c1) + // z[i], c2 = bits.Add(lo, carry, c2) + // carry = hi + // } + // return carry + c1 + c2 + // } + // + // The loop is fully unrolled and the hi / carry registers are alternated + // instead of introducing a MOV. + + z = Mem{Base: Load(Param("z"), GP64())} + x = Mem{Base: Load(Param("x"), GP64())} + Load(Param("y"), RDX) // implicit source of MULXQ + + carry = GP64() + XORQ(carry, carry) // zero out carry + z0 := GP64() + XORQ(z0, z0) // unset flags and zero out z0 + + for i := 0; i < bits/64; i++ { + hi, lo := GP64(), GP64() + + Comment("Iteration " + strconv.Itoa(i)) + MULXQ(x.Offset(i*8), lo, hi) + ADCXQ(carry, lo) + ADOXQ(z.Offset(i*8), lo) + MOVQ(lo, z.Offset(i*8)) + + i++ + + Comment("Iteration " + strconv.Itoa(i)) + MULXQ(x.Offset(i*8), lo, carry) + ADCXQ(hi, lo) + ADOXQ(z.Offset(i*8), lo) + MOVQ(lo, z.Offset(i*8)) + } + + Comment("Add back carry flags and return") + ADCXQ(z0, carry) + ADOXQ(z0, carry) + + Store(carry, ReturnIndex(0)) + RET() } diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go index 804316f504..3cad382b53 100644 --- a/src/crypto/internal/bigmod/nat.go +++ b/src/crypto/internal/bigmod/nat.go @@ -5,16 +5,17 @@ package bigmod import ( + "encoding/binary" "errors" "math/big" "math/bits" ) const ( - // _W is the number of bits we use for our limbs. - _W = bits.UintSize - 1 - // _MASK selects _W bits from a full machine word. - _MASK = (1 << _W) - 1 + // _W is the size in bits of our limbs. + _W = bits.UintSize + // _S is the size in bytes of our limbs. + _S = _W / 8 ) // choice represents a constant-time boolean. The value of choice is always @@ -27,15 +28,8 @@ func not(c choice) choice { return 1 ^ c } const yes = choice(1) const no = choice(0) -// ctSelect returns x if on == 1, and y if on == 0. The execution time of this -// function does not depend on its inputs. If on is any value besides 1 or 0, -// the result is undefined. -func ctSelect(on choice, x, y uint) uint { - // When on == 1, mask is 0b111..., otherwise mask is 0b000... - mask := -uint(on) - // When mask is all zeros, we just have y, otherwise, y cancels with itself. - return y ^ (mask & (y ^ x)) -} +// ctMask is all 1s if on is yes, and all 0s otherwise. +func ctMask(on choice) uint { return -uint(on) } // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this // function does not depend on its inputs. @@ -60,13 +54,7 @@ func ctGeq(x, y uint) choice { // Operations on this number are allowed to leak this length, but will not leak // any information about the values contained in those limbs. type Nat struct { - // limbs is a little-endian representation in base 2^W with - // W = bits.UintSize - 1. The top bit is always unset between operations. - // - // The top bit is left unset to optimize Montgomery multiplication, in the - // inner loop of exponentiation. Using fully saturated limbs would leave us - // working with 129-bit numbers on 64-bit platforms, wasting a lot of space, - // and thus time. + // limbs is little-endian in base 2^W with W = bits.UintSize. limbs []uint } @@ -128,25 +116,10 @@ func (x *Nat) set(y *Nat) *Nat { // The announced length of x is set based on the actual bit size of the input, // ignoring leading zeroes. func (x *Nat) setBig(n *big.Int) *Nat { - requiredLimbs := (n.BitLen() + _W - 1) / _W - x.reset(requiredLimbs) - - outI := 0 - shift := 0 limbs := n.Bits() + x.reset(len(limbs)) for i := range limbs { - xi := uint(limbs[i]) - x.limbs[outI] |= (xi << shift) & _MASK - outI++ - if outI == requiredLimbs { - return x - } - x.limbs[outI] = xi >> (_W - shift) - shift++ // this assumes bits.UintSize - _W = 1 - if shift == _W { - shift = 0 - outI++ - } + x.limbs[i] = uint(limbs[i]) } return x } @@ -156,24 +129,20 @@ func (x *Nat) setBig(n *big.Int) *Nat { // // x must have the same size as m and it must be reduced modulo m. func (x *Nat) Bytes(m *Modulus) []byte { - bytes := make([]byte, m.Size()) - shift := 0 - outI := len(bytes) - 1 + i := m.Size() + bytes := make([]byte, i) for _, limb := range x.limbs { - remainingBits := _W - for remainingBits >= 8 { - bytes[outI] |= byte(limb) << shift - consumed := 8 - shift - limb >>= consumed - remainingBits -= consumed - shift = 0 - outI-- - if outI < 0 { - return bytes + for j := 0; j < _S; j++ { + i-- + if i < 0 { + if limb == 0 { + break + } + panic("bigmod: modulus is smaller than nat") } + bytes[i] = byte(limb) + limb >>= 8 } - bytes[outI] = byte(limb) - shift = remainingBits } return bytes } @@ -192,9 +161,9 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { return x, nil } -// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes -// returns an error if b has a longer bit length than m, but reduces overflowing -// values up to 2^⌈log2(m)⌉ - 1. +// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. +// SetOverflowingBytes returns an error if b has a longer bit length than m, but +// reduces overflowing values up to 2^⌈log2(m)⌉ - 1. // // The output will be resized to the size of m and overwritten. func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { @@ -203,33 +172,35 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { } leading := _W - bitLen(x.limbs[len(x.limbs)-1]) if leading < m.leading { - return nil, errors.New("input overflows the modulus") + return nil, errors.New("input overflows the modulus size") } - x.sub(x.cmpGeq(m.nat), m.nat) + x.maybeSubtractModulus(no, m) return x, nil } +// bigEndianUint returns the contents of buf interpreted as a +// big-endian encoded uint value. +func bigEndianUint(buf []byte) uint { + if _W == 64 { + return uint(binary.BigEndian.Uint64(buf)) + } + return uint(binary.BigEndian.Uint32(buf)) +} + func (x *Nat) setBytes(b []byte, m *Modulus) error { - outI := 0 - shift := 0 x.resetFor(m) - for i := len(b) - 1; i >= 0; i-- { - bi := b[i] - x.limbs[outI] |= uint(bi) << shift - shift += 8 - if shift >= _W { - shift -= _W - x.limbs[outI] &= _MASK - overflow := bi >> (8 - shift) - outI++ - if outI >= len(x.limbs) { - if overflow > 0 || i > 0 { - return errors.New("input overflows the modulus") - } - break - } - x.limbs[outI] = uint(overflow) - } + i, k := len(b), 0 + for k < len(x.limbs) && i >= _S { + x.limbs[k] = bigEndianUint(b[i-_S : i]) + i -= _S + k++ + } + for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { + x.limbs[k] |= uint(b[i-1]) << s + i-- + } + if i > 0 { + return errors.New("input overflows the modulus size") } return nil } @@ -274,7 +245,7 @@ func (x *Nat) cmpGeq(y *Nat) choice { var c uint for i := 0; i < size; i++ { - c = (xLimbs[i] - yLimbs[i] - c) >> _W + _, c = bits.Sub(xLimbs[i], yLimbs[i], c) } // If there was a carry, then subtracting y underflowed, so // x is not greater than or equal to y. @@ -290,44 +261,39 @@ func (x *Nat) assign(on choice, y *Nat) *Nat { xLimbs := x.limbs[:size] yLimbs := y.limbs[:size] + mask := ctMask(on) for i := 0; i < size; i++ { - xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i]) + xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) } return x } -// add computes x += y if on == 1, and does nothing otherwise. It returns the -// carry of the addition regardless of on. +// add computes x += y and returns the carry. // // Both operands must have the same announced length. -func (x *Nat) add(on choice, y *Nat) (c uint) { +func (x *Nat) add(y *Nat) (c uint) { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] yLimbs := y.limbs[:size] for i := 0; i < size; i++ { - res := xLimbs[i] + yLimbs[i] + c - xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i]) - c = res >> _W + xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) } return } -// sub computes x -= y if on == 1, and does nothing otherwise. It returns the -// borrow of the subtraction regardless of on. +// sub computes x -= y. It returns the borrow of the subtraction. // // Both operands must have the same announced length. -func (x *Nat) sub(on choice, y *Nat) (c uint) { +func (x *Nat) sub(y *Nat) (c uint) { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] yLimbs := y.limbs[:size] for i := 0; i < size; i++ { - res := xLimbs[i] - yLimbs[i] - c - xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i]) - c = res >> _W + xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) } return } @@ -371,19 +337,20 @@ func minusInverseModW(x uint) uint { // Every iteration of this loop doubles the least-significant bits of // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough - // for 61 bits (and wastes only one iteration for 31 bits). + // for 64 bits (and wastes only one iteration for 32 bits). // // See https://crypto.stackexchange.com/a/47496. y := x for i := 0; i < 5; i++ { y = y * (2 - x*y) } - return (1 << _W) - (y & _MASK) + return -y } // NewModulusFromBig creates a new Modulus from a [big.Int]. // -// The Int must be odd. The number of significant bits must be leakable. +// The Int must be odd. The number of significant bits (and nothing else) is +// leaked through timing side-channels. func NewModulusFromBig(n *big.Int) *Modulus { m := &Modulus{} m.nat = NewNat().setBig(n) @@ -424,7 +391,7 @@ func (m *Modulus) Nat() *Nat { // shiftIn calculates x = x << _W + y mod m. // -// This assumes that x is already reduced mod m, and that y < 2^_W. +// This assumes that x is already reduced mod m. func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { d := NewNat().resetFor(m) @@ -440,25 +407,21 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { // // To do the reduction, each iteration computes both 2x + b and 2x + b - m. // The next iteration (and finally the return line) will use either result - // based on whether the subtraction underflowed. + // based on whether 2x + b overflows m. needSubtraction := no for i := _W - 1; i >= 0; i-- { carry := (y >> i) & 1 var borrow uint + mask := ctMask(needSubtraction) for i := 0; i < size; i++ { - l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i]) - - res := l<<1 + carry - xLimbs[i] = res & _MASK - carry = res >> _W - - res = xLimbs[i] - mLimbs[i] - borrow - dLimbs[i] = res & _MASK - borrow = res >> _W + l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) + xLimbs[i], carry = bits.Add(l, l, carry) + dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) } - // See Add for how carry (aka overflow), borrow (aka underflow), and - // needSubtraction relate. - needSubtraction = ctEq(carry, borrow) + // Like in maybeSubtractModulus, we need the subtraction if either it + // didn't underflow (meaning 2x + b > m) or if computing 2x + b + // overflowed (meaning 2x + b > 2^_W*n > m). + needSubtraction = not(choice(borrow)) | choice(carry) } return x.assign(needSubtraction, d) } @@ -494,11 +457,11 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { return out } -// ExpandFor ensures out has the right size to work with operations modulo m. +// ExpandFor ensures x has the right size to work with operations modulo m. // -// The announced size of out must be smaller than or equal to that of m. -func (out *Nat) ExpandFor(m *Modulus) *Nat { - return out.expand(len(m.nat.limbs)) +// The announced size of x must be smaller than or equal to that of m. +func (x *Nat) ExpandFor(m *Modulus) *Nat { + return x.expand(len(m.nat.limbs)) } // resetFor ensures out has the right size to work with operations modulo m. @@ -508,14 +471,34 @@ func (out *Nat) resetFor(m *Modulus) *Nat { return out.reset(len(m.nat.limbs)) } +// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. +// +// It can be used to reduce modulo m a value up to 2m - 1, which is a common +// range for results computed by higher level operations. +// +// always is usually a carry that indicates that the operation that produced x +// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. +// +// x and m operands must have the same announced length. +func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { + t := NewNat().set(x) + underflow := t.sub(m.nat) + // We keep the result if x - m didn't underflow (meaning x >= m) + // or if always was set. + keep := not(choice(underflow)) | choice(always) + x.assign(keep, t) +} + // Sub computes x = x - y mod m. // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { - underflow := x.sub(yes, y) + underflow := x.sub(y) // If the subtraction underflowed, add m. - x.add(choice(underflow), m.nat) + t := NewNat().set(x) + t.add(m.nat) + x.assign(choice(underflow), t) return x } @@ -524,34 +507,8 @@ func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. func (x *Nat) Add(y *Nat, m *Modulus) *Nat { - overflow := x.add(yes, y) - underflow := not(x.cmpGeq(m.nat)) // x < m - - // Three cases are possible: - // - // - overflow = 0, underflow = 0 - // - // In this case, addition fits in our limbs, but we can still subtract away - // m without an underflow, so we need to perform the subtraction to reduce - // our result. - // - // - overflow = 0, underflow = 1 - // - // The addition fits in our limbs, but we can't subtract m without - // underflowing. The result is already reduced. - // - // - overflow = 1, underflow = 1 - // - // The addition does not fit in our limbs, and the subtraction's borrow - // would cancel out with the addition's carry. We need to subtract m to - // reduce our result. - // - // The overflow = 1, underflow = 0 case is not possible, because y is at - // most m - 1, and if adding m - 1 overflows, then subtracting m must - // necessarily underflow. - needSubtraction := ctEq(overflow, uint(underflow)) - - x.sub(needSubtraction, m.nat) + overflow := x.add(y) + x.maybeSubtractModulus(choice(overflow), m) return x } @@ -582,65 +539,146 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat { return x.montgomeryMul(t0, t1, m) } -// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and -// n = len(m.nat.limbs), using the Montgomery Multiplication technique. +// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and +// n = len(m.nat.limbs), also known as a Montgomery multiplication. // -// All inputs should be the same length, not aliasing d, and already -// reduced modulo m. d will be resized to the size of m and overwritten. -func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { - d.resetFor(m) - if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) { - panic("bigmod: invalid montgomeryMul input") - } +// All inputs should be the same length and already reduced modulo m. +// x will be resized to the size of m and overwritten. +func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { + n := len(m.nat.limbs) + mLimbs := m.nat.limbs[:n] + aLimbs := a.limbs[:n] + bLimbs := b.limbs[:n] + + switch n { + default: + // Attempt to use a stack-allocated backing array. + T := make([]uint, 0, preallocLimbs*2) + if cap(T) < n*2 { + T = make([]uint, 0, n*2) + } + T = T[:n*2] + + // This loop implements Word-by-Word Montgomery Multiplication, as + // described in Algorithm 4 (Fig. 3) of "Efficient Software + // Implementations of Modular Exponentiation" by Shay Gueron + // [https://eprint.iacr.org/2011/239.pdf]. + var c uint + for i := 0; i < n; i++ { + _ = T[n+i] // bounds check elimination hint + + // Step 1 (T = a × b) is computed as a large pen-and-paper column + // multiplication of two numbers with n base-2^_W digits. If we just + // wanted to produce 2n-wide T, we would do + // + // for i := 0; i < n; i++ { + // d := bLimbs[i] + // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) + // } + // + // where d is a digit of the multiplier, T[i:n+i] is the shifted + // position of the product of that digit, and T[n+i] is the final carry. + // Note that T[i] isn't modified after processing the i-th digit. + // + // Instead of running two loops, one for Step 1 and one for Steps 2–6, + // the result of Step 1 is computed during the next loop. This is + // possible because each iteration only uses T[i] in Step 2 and then + // discards it in Step 6. + d := bLimbs[i] + c1 := addMulVVW(T[i:n+i], aLimbs, d) + + // Step 6 is replaced by shifting the virtual window we operate + // over: T of the algorithm is T[i:] for us. That means that T1 in + // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. + Y := T[i] * m.m0inv + + // Step 4 and 5 add Y × m to T, which as mentioned above is stored + // at T[i:]. The two carries (from a × d and Y × m) are added up in + // the next word T[n+i], and the carry bit from that addition is + // brought forward to the next iteration. + c2 := addMulVVW(T[i:n+i], mLimbs, Y) + T[n+i], c = bits.Add(c1, c2, c) + } - // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication - // for a description of the algorithm implemented mostly in montgomeryLoop. - // See Add for how overflow, underflow, and needSubtraction relate. - overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv) - underflow := not(d.cmpGeq(m.nat)) // d < m - needSubtraction := ctEq(overflow, uint(underflow)) - d.sub(needSubtraction, m.nat) + // Finally for Step 7 we copy the final T window into x, and subtract m + // if necessary (which as explained in maybeSubtractModulus can be the + // case both if x >= m, or if x overflowed). + // + // The paper suggests in Section 4 that we can do an "Almost Montgomery + // Multiplication" by subtracting only in the overflow case, but the + // cost is very similar since the constant time subtraction tells us if + // x >= m as a side effect, and taking care of the broken invariant is + // highly undesirable (see https://go.dev/issue/13907). + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + + // The following specialized cases follow the exact same algorithm, but + // optimized for the sizes most used in RSA. addMulVVW is implemented in + // assembly with loop unrolling depending on the architecture and bounds + // checks are removed by the compiler thanks to the constant size. + case 1024 / _W: + const n = 1024 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW1024(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + + case 1536 / _W: + const n = 1536 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW1536(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + + case 2048 / _W: + const n = 2048 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW2048(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + } - return d + return x } -func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) { - // Eliminate bounds checks in the loop. - size := len(d) - a = a[:size] - b = b[:size] - m = m[:size] - - for _, ai := range a { - // This is an unrolled iteration of the loop below with j = 0. - hi, lo := bits.Mul(ai, b[0]) - z_lo, c := bits.Add(d[0], lo, 0) - f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv - z_hi, _ := bits.Add(0, hi, c) - hi, lo = bits.Mul(f, m[0]) - z_lo, c = bits.Add(z_lo, lo, 0) - z_hi, _ = bits.Add(z_hi, hi, c) - carry := z_hi<<1 | z_lo>>_W - - for j := 1; j < size; j++ { - // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W - hi, lo := bits.Mul(ai, b[j]) - z_lo, c := bits.Add(d[j], lo, 0) - z_hi, _ := bits.Add(0, hi, c) - hi, lo = bits.Mul(f, m[j]) - z_lo, c = bits.Add(z_lo, lo, 0) - z_hi, _ = bits.Add(z_hi, hi, c) - z_lo, c = bits.Add(z_lo, carry, 0) - z_hi, _ = bits.Add(z_hi, 0, c) - d[j-1] = z_lo & _MASK - carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2 - } - - z := overflow + carry // z <= 2^(W+1) - 1 - d[size-1] = z & _MASK - overflow = z >> _W // overflow <= 1 +// addMulVVW multiplies the multi-word value x by the single-word value y, +// adding the result to the multi-word value z and returning the final carry. +// It can be thought of as one row of a pen-and-paper column multiplication. +func addMulVVW(z, x []uint, y uint) (carry uint) { + _ = x[len(z)-1] // bounds check elimination hint + for i := range z { + hi, lo := bits.Mul(x[i], y) + lo, c := bits.Add(lo, z[i], 0) + // We use bits.Add with zero to get an add-with-carry instruction that + // absorbs the carry from the previous bits.Add. + hi, _ = bits.Add(hi, 0, c) + lo, c = bits.Add(lo, carry, 0) + hi, _ = bits.Add(hi, 0, c) + carry = hi + z[i] = lo } - return + return carry } // Mul calculates x *= y mod m. @@ -661,7 +699,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { // We use a 4 bit window. For our RSA workload, 4 bit windows are faster // than 2 bit windows, but use an extra 12 nats worth of scratch space. - // Using bit sizes that don't divide 8 are more complex to implement. + // Using bit sizes that don't divide 8 are more complex to implement, but + // are likely to be more efficient if necessary. table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) // newNat calls are unrolled so they are allocated on the stack. @@ -681,7 +720,8 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { t1 := NewNat().ExpandFor(m) for _, b := range e { for _, j := range []int{4, 0} { - // Square four times. + // Square four times. Optimization note: this can be implemented + // more efficiently than with generic Montgomery multiplication. t1.montgomeryMul(out, out, m) out.montgomeryMul(t1, t1, m) t1.montgomeryMul(out, out, m) diff --git a/src/crypto/internal/bigmod/nat_386.s b/src/crypto/internal/bigmod/nat_386.s new file mode 100644 index 0000000000..0637d271e8 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_386.s @@ -0,0 +1,47 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-16 + MOVL $32, BX + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-16 + MOVL $48, BX + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-16 + MOVL $64, BX + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVL z+0(FP), DI + MOVL x+4(FP), SI + MOVL y+8(FP), BP + LEAL (DI)(BX*4), DI + LEAL (SI)(BX*4), SI + NEGL BX // i = -n + MOVL $0, CX // c = 0 + JMP E6 + +L6: MOVL (SI)(BX*4), AX + MULL BP + ADDL CX, AX + ADCL $0, DX + ADDL AX, (DI)(BX*4) + ADCL $0, DX + MOVL DX, CX + ADDL $1, BX // i++ + +E6: CMPL BX, $0 // i < 0 + JL L6 + + MOVL CX, c+12(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_amd64.go b/src/crypto/internal/bigmod/nat_amd64.go deleted file mode 100644 index e94778245d..0000000000 --- a/src/crypto/internal/bigmod/nat_amd64.go +++ /dev/null @@ -1,8 +0,0 @@ -// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT. - -//go:build amd64 && gc && !purego - -package bigmod - -//go:noescape -func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint diff --git a/src/crypto/internal/bigmod/nat_amd64.s b/src/crypto/internal/bigmod/nat_amd64.s index 12b7629984..ab94344e10 100644 --- a/src/crypto/internal/bigmod/nat_amd64.s +++ b/src/crypto/internal/bigmod/nat_amd64.s @@ -1,68 +1,1230 @@ -// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT. - -//go:build amd64 && gc && !purego - -// func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint -TEXT ·montgomeryLoop(SB), $8-112 - MOVQ d_len+8(FP), CX - MOVQ d_base+0(FP), BX - MOVQ b_base+48(FP), SI - MOVQ m_base+72(FP), DI - MOVQ m0inv+96(FP), R8 - XORQ R9, R9 - XORQ R10, R10 - -outerLoop: - MOVQ a_base+24(FP), R11 - MOVQ (R11)(R10*8), R11 - MOVQ (SI), AX - MULQ R11 - MOVQ AX, R13 - MOVQ DX, R12 - ADDQ (BX), R13 - ADCQ $0x00, R12 - MOVQ R8, R14 - IMULQ R13, R14 - BTRQ $0x3f, R14 - MOVQ (DI), AX - MULQ R14 - ADDQ AX, R13 - ADCQ DX, R12 - SHRQ $0x3f, R12, R13 - XORQ R12, R12 - INCQ R12 - JMP innerLoopCondition - -innerLoop: - MOVQ (SI)(R12*8), AX - MULQ R11 - MOVQ AX, BP - MOVQ DX, R15 - MOVQ (DI)(R12*8), AX - MULQ R14 - ADDQ AX, BP - ADCQ DX, R15 - ADDQ (BX)(R12*8), BP - ADCQ $0x00, R15 - ADDQ R13, BP - ADCQ $0x00, R15 - MOVQ BP, AX - BTRQ $0x3f, AX - MOVQ AX, -8(BX)(R12*8) - SHRQ $0x3f, R15, BP - MOVQ BP, R13 - INCQ R12 - -innerLoopCondition: - CMPQ CX, R12 - JGT innerLoop - ADDQ R13, R9 - MOVQ R9, AX - BTRQ $0x3f, AX - MOVQ AX, -8(BX)(CX*8) - SHRQ $0x3f, R9 - INCQ R10 - CMPQ CX, R10 - JGT outerLoop - MOVQ R9, ret+104(FP) +// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -pkg bigmod. DO NOT EDIT. + +//go:build !purego + +// func addMulVVW1024(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW1024(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + + // Iteration 4 + MOVQ 32(BX), AX + MULQ SI + ADDQ 32(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 32(CX) + + // Iteration 5 + MOVQ 40(BX), AX + MULQ SI + ADDQ 40(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 40(CX) + + // Iteration 6 + MOVQ 48(BX), AX + MULQ SI + ADDQ 48(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 48(CX) + + // Iteration 7 + MOVQ 56(BX), AX + MULQ SI + ADDQ 56(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 56(CX) + + // Iteration 8 + MOVQ 64(BX), AX + MULQ SI + ADDQ 64(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 64(CX) + + // Iteration 9 + MOVQ 72(BX), AX + MULQ SI + ADDQ 72(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 72(CX) + + // Iteration 10 + MOVQ 80(BX), AX + MULQ SI + ADDQ 80(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 80(CX) + + // Iteration 11 + MOVQ 88(BX), AX + MULQ SI + ADDQ 88(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 88(CX) + + // Iteration 12 + MOVQ 96(BX), AX + MULQ SI + ADDQ 96(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 96(CX) + + // Iteration 13 + MOVQ 104(BX), AX + MULQ SI + ADDQ 104(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 104(CX) + + // Iteration 14 + MOVQ 112(BX), AX + MULQ SI + ADDQ 112(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 112(CX) + + // Iteration 15 + MOVQ 120(BX), AX + MULQ SI + ADDQ 120(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 120(CX) + MOVQ DI, c+24(FP) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Iteration 4 + MULXQ 32(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 32(AX), R8 + MOVQ R8, 32(AX) + + // Iteration 5 + MULXQ 40(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 40(AX), R8 + MOVQ R8, 40(AX) + + // Iteration 6 + MULXQ 48(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 48(AX), R8 + MOVQ R8, 48(AX) + + // Iteration 7 + MULXQ 56(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 56(AX), R8 + MOVQ R8, 56(AX) + + // Iteration 8 + MULXQ 64(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 64(AX), R8 + MOVQ R8, 64(AX) + + // Iteration 9 + MULXQ 72(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 72(AX), R8 + MOVQ R8, 72(AX) + + // Iteration 10 + MULXQ 80(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 80(AX), R8 + MOVQ R8, 80(AX) + + // Iteration 11 + MULXQ 88(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 88(AX), R8 + MOVQ R8, 88(AX) + + // Iteration 12 + MULXQ 96(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 96(AX), R8 + MOVQ R8, 96(AX) + + // Iteration 13 + MULXQ 104(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 104(AX), R8 + MOVQ R8, 104(AX) + + // Iteration 14 + MULXQ 112(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 112(AX), R8 + MOVQ R8, 112(AX) + + // Iteration 15 + MULXQ 120(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 120(AX), R8 + MOVQ R8, 120(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) + RET + +// func addMulVVW1536(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW1536(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + + // Iteration 4 + MOVQ 32(BX), AX + MULQ SI + ADDQ 32(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 32(CX) + + // Iteration 5 + MOVQ 40(BX), AX + MULQ SI + ADDQ 40(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 40(CX) + + // Iteration 6 + MOVQ 48(BX), AX + MULQ SI + ADDQ 48(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 48(CX) + + // Iteration 7 + MOVQ 56(BX), AX + MULQ SI + ADDQ 56(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 56(CX) + + // Iteration 8 + MOVQ 64(BX), AX + MULQ SI + ADDQ 64(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 64(CX) + + // Iteration 9 + MOVQ 72(BX), AX + MULQ SI + ADDQ 72(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 72(CX) + + // Iteration 10 + MOVQ 80(BX), AX + MULQ SI + ADDQ 80(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 80(CX) + + // Iteration 11 + MOVQ 88(BX), AX + MULQ SI + ADDQ 88(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 88(CX) + + // Iteration 12 + MOVQ 96(BX), AX + MULQ SI + ADDQ 96(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 96(CX) + + // Iteration 13 + MOVQ 104(BX), AX + MULQ SI + ADDQ 104(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 104(CX) + + // Iteration 14 + MOVQ 112(BX), AX + MULQ SI + ADDQ 112(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 112(CX) + + // Iteration 15 + MOVQ 120(BX), AX + MULQ SI + ADDQ 120(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 120(CX) + + // Iteration 16 + MOVQ 128(BX), AX + MULQ SI + ADDQ 128(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 128(CX) + + // Iteration 17 + MOVQ 136(BX), AX + MULQ SI + ADDQ 136(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 136(CX) + + // Iteration 18 + MOVQ 144(BX), AX + MULQ SI + ADDQ 144(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 144(CX) + + // Iteration 19 + MOVQ 152(BX), AX + MULQ SI + ADDQ 152(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 152(CX) + + // Iteration 20 + MOVQ 160(BX), AX + MULQ SI + ADDQ 160(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 160(CX) + + // Iteration 21 + MOVQ 168(BX), AX + MULQ SI + ADDQ 168(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 168(CX) + + // Iteration 22 + MOVQ 176(BX), AX + MULQ SI + ADDQ 176(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 176(CX) + + // Iteration 23 + MOVQ 184(BX), AX + MULQ SI + ADDQ 184(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 184(CX) + MOVQ DI, c+24(FP) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Iteration 4 + MULXQ 32(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 32(AX), R8 + MOVQ R8, 32(AX) + + // Iteration 5 + MULXQ 40(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 40(AX), R8 + MOVQ R8, 40(AX) + + // Iteration 6 + MULXQ 48(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 48(AX), R8 + MOVQ R8, 48(AX) + + // Iteration 7 + MULXQ 56(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 56(AX), R8 + MOVQ R8, 56(AX) + + // Iteration 8 + MULXQ 64(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 64(AX), R8 + MOVQ R8, 64(AX) + + // Iteration 9 + MULXQ 72(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 72(AX), R8 + MOVQ R8, 72(AX) + + // Iteration 10 + MULXQ 80(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 80(AX), R8 + MOVQ R8, 80(AX) + + // Iteration 11 + MULXQ 88(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 88(AX), R8 + MOVQ R8, 88(AX) + + // Iteration 12 + MULXQ 96(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 96(AX), R8 + MOVQ R8, 96(AX) + + // Iteration 13 + MULXQ 104(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 104(AX), R8 + MOVQ R8, 104(AX) + + // Iteration 14 + MULXQ 112(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 112(AX), R8 + MOVQ R8, 112(AX) + + // Iteration 15 + MULXQ 120(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 120(AX), R8 + MOVQ R8, 120(AX) + + // Iteration 16 + MULXQ 128(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 128(AX), R8 + MOVQ R8, 128(AX) + + // Iteration 17 + MULXQ 136(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 136(AX), R8 + MOVQ R8, 136(AX) + + // Iteration 18 + MULXQ 144(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 144(AX), R8 + MOVQ R8, 144(AX) + + // Iteration 19 + MULXQ 152(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 152(AX), R8 + MOVQ R8, 152(AX) + + // Iteration 20 + MULXQ 160(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 160(AX), R8 + MOVQ R8, 160(AX) + + // Iteration 21 + MULXQ 168(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 168(AX), R8 + MOVQ R8, 168(AX) + + // Iteration 22 + MULXQ 176(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 176(AX), R8 + MOVQ R8, 176(AX) + + // Iteration 23 + MULXQ 184(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 184(AX), R8 + MOVQ R8, 184(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) + RET + +// func addMulVVW2048(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW2048(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + + // Iteration 4 + MOVQ 32(BX), AX + MULQ SI + ADDQ 32(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 32(CX) + + // Iteration 5 + MOVQ 40(BX), AX + MULQ SI + ADDQ 40(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 40(CX) + + // Iteration 6 + MOVQ 48(BX), AX + MULQ SI + ADDQ 48(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 48(CX) + + // Iteration 7 + MOVQ 56(BX), AX + MULQ SI + ADDQ 56(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 56(CX) + + // Iteration 8 + MOVQ 64(BX), AX + MULQ SI + ADDQ 64(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 64(CX) + + // Iteration 9 + MOVQ 72(BX), AX + MULQ SI + ADDQ 72(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 72(CX) + + // Iteration 10 + MOVQ 80(BX), AX + MULQ SI + ADDQ 80(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 80(CX) + + // Iteration 11 + MOVQ 88(BX), AX + MULQ SI + ADDQ 88(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 88(CX) + + // Iteration 12 + MOVQ 96(BX), AX + MULQ SI + ADDQ 96(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 96(CX) + + // Iteration 13 + MOVQ 104(BX), AX + MULQ SI + ADDQ 104(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 104(CX) + + // Iteration 14 + MOVQ 112(BX), AX + MULQ SI + ADDQ 112(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 112(CX) + + // Iteration 15 + MOVQ 120(BX), AX + MULQ SI + ADDQ 120(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 120(CX) + + // Iteration 16 + MOVQ 128(BX), AX + MULQ SI + ADDQ 128(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 128(CX) + + // Iteration 17 + MOVQ 136(BX), AX + MULQ SI + ADDQ 136(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 136(CX) + + // Iteration 18 + MOVQ 144(BX), AX + MULQ SI + ADDQ 144(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 144(CX) + + // Iteration 19 + MOVQ 152(BX), AX + MULQ SI + ADDQ 152(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 152(CX) + + // Iteration 20 + MOVQ 160(BX), AX + MULQ SI + ADDQ 160(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 160(CX) + + // Iteration 21 + MOVQ 168(BX), AX + MULQ SI + ADDQ 168(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 168(CX) + + // Iteration 22 + MOVQ 176(BX), AX + MULQ SI + ADDQ 176(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 176(CX) + + // Iteration 23 + MOVQ 184(BX), AX + MULQ SI + ADDQ 184(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 184(CX) + + // Iteration 24 + MOVQ 192(BX), AX + MULQ SI + ADDQ 192(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 192(CX) + + // Iteration 25 + MOVQ 200(BX), AX + MULQ SI + ADDQ 200(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 200(CX) + + // Iteration 26 + MOVQ 208(BX), AX + MULQ SI + ADDQ 208(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 208(CX) + + // Iteration 27 + MOVQ 216(BX), AX + MULQ SI + ADDQ 216(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 216(CX) + + // Iteration 28 + MOVQ 224(BX), AX + MULQ SI + ADDQ 224(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 224(CX) + + // Iteration 29 + MOVQ 232(BX), AX + MULQ SI + ADDQ 232(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 232(CX) + + // Iteration 30 + MOVQ 240(BX), AX + MULQ SI + ADDQ 240(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 240(CX) + + // Iteration 31 + MOVQ 248(BX), AX + MULQ SI + ADDQ 248(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 248(CX) + MOVQ DI, c+24(FP) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Iteration 4 + MULXQ 32(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 32(AX), R8 + MOVQ R8, 32(AX) + + // Iteration 5 + MULXQ 40(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 40(AX), R8 + MOVQ R8, 40(AX) + + // Iteration 6 + MULXQ 48(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 48(AX), R8 + MOVQ R8, 48(AX) + + // Iteration 7 + MULXQ 56(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 56(AX), R8 + MOVQ R8, 56(AX) + + // Iteration 8 + MULXQ 64(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 64(AX), R8 + MOVQ R8, 64(AX) + + // Iteration 9 + MULXQ 72(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 72(AX), R8 + MOVQ R8, 72(AX) + + // Iteration 10 + MULXQ 80(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 80(AX), R8 + MOVQ R8, 80(AX) + + // Iteration 11 + MULXQ 88(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 88(AX), R8 + MOVQ R8, 88(AX) + + // Iteration 12 + MULXQ 96(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 96(AX), R8 + MOVQ R8, 96(AX) + + // Iteration 13 + MULXQ 104(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 104(AX), R8 + MOVQ R8, 104(AX) + + // Iteration 14 + MULXQ 112(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 112(AX), R8 + MOVQ R8, 112(AX) + + // Iteration 15 + MULXQ 120(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 120(AX), R8 + MOVQ R8, 120(AX) + + // Iteration 16 + MULXQ 128(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 128(AX), R8 + MOVQ R8, 128(AX) + + // Iteration 17 + MULXQ 136(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 136(AX), R8 + MOVQ R8, 136(AX) + + // Iteration 18 + MULXQ 144(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 144(AX), R8 + MOVQ R8, 144(AX) + + // Iteration 19 + MULXQ 152(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 152(AX), R8 + MOVQ R8, 152(AX) + + // Iteration 20 + MULXQ 160(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 160(AX), R8 + MOVQ R8, 160(AX) + + // Iteration 21 + MULXQ 168(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 168(AX), R8 + MOVQ R8, 168(AX) + + // Iteration 22 + MULXQ 176(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 176(AX), R8 + MOVQ R8, 176(AX) + + // Iteration 23 + MULXQ 184(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 184(AX), R8 + MOVQ R8, 184(AX) + + // Iteration 24 + MULXQ 192(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 192(AX), R8 + MOVQ R8, 192(AX) + + // Iteration 25 + MULXQ 200(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 200(AX), R8 + MOVQ R8, 200(AX) + + // Iteration 26 + MULXQ 208(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 208(AX), R8 + MOVQ R8, 208(AX) + + // Iteration 27 + MULXQ 216(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 216(AX), R8 + MOVQ R8, 216(AX) + + // Iteration 28 + MULXQ 224(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 224(AX), R8 + MOVQ R8, 224(AX) + + // Iteration 29 + MULXQ 232(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 232(AX), R8 + MOVQ R8, 232(AX) + + // Iteration 30 + MULXQ 240(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 240(AX), R8 + MOVQ R8, 240(AX) + + // Iteration 31 + MULXQ 248(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 248(AX), R8 + MOVQ R8, 248(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) RET diff --git a/src/crypto/internal/bigmod/nat_arm.s b/src/crypto/internal/bigmod/nat_arm.s new file mode 100644 index 0000000000..c7397b89c5 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_arm.s @@ -0,0 +1,47 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-16 + MOVW $32, R5 + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-16 + MOVW $48, R5 + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-16 + MOVW $64, R5 + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVW $0, R0 + MOVW z+0(FP), R1 + MOVW x+4(FP), R2 + MOVW y+8(FP), R3 + ADD R5<<2, R1, R5 + MOVW $0, R4 + B E9 + +L9: MOVW.P 4(R2), R6 + MULLU R6, R3, (R7, R6) + ADD.S R4, R6 + ADC R0, R7 + MOVW 0(R1), R4 + ADD.S R4, R6 + ADC R0, R7 + MOVW.P R6, 4(R1) + MOVW R7, R4 + +E9: TEQ R1, R5 + BNE L9 + + MOVW R4, c+12(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_arm64.s b/src/crypto/internal/bigmod/nat_arm64.s new file mode 100644 index 0000000000..ba1e6118cc --- /dev/null +++ b/src/crypto/internal/bigmod/nat_arm64.s @@ -0,0 +1,69 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-32 + MOVD $16, R0 + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-32 + MOVD $24, R0 + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-32 + MOVD $32, R0 + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVD z+0(FP), R1 + MOVD x+8(FP), R2 + MOVD y+16(FP), R3 + MOVD $0, R4 + +// The main loop of this code operates on a block of 4 words every iteration +// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9] +// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next +// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z. +loop: + CBZ R0, done + + LDP.P 16(R2), (R5, R6) + LDP.P 16(R2), (R7, R8) + + LDP (R1), (R9, R10) + ADDS R4, R9 + MUL R6, R3, R14 + ADCS R14, R10 + MUL R7, R3, R15 + LDP 16(R1), (R11, R12) + ADCS R15, R11 + MUL R8, R3, R16 + ADCS R16, R12 + UMULH R8, R3, R20 + ADC $0, R20 + + MUL R5, R3, R13 + ADDS R13, R9 + UMULH R5, R3, R17 + ADCS R17, R10 + UMULH R6, R3, R21 + STP.P (R9, R10), 16(R1) + ADCS R21, R11 + UMULH R7, R3, R19 + ADCS R19, R12 + STP.P (R11, R12), 16(R1) + ADC $0, R20, R4 + + SUB $4, R0 + B loop + +done: + MOVD R4, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_asm.go b/src/crypto/internal/bigmod/nat_asm.go new file mode 100644 index 0000000000..5eb91e1c6c --- /dev/null +++ b/src/crypto/internal/bigmod/nat_asm.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego && (386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x) + +package bigmod + +import "internal/cpu" + +// amd64 assembly uses ADCX/ADOX/MULX if ADX is available to run two carry +// chains in the flags in parallel across the whole operation, and aggressively +// unrolls loops. arm64 processes four words at a time. +// +// It's unclear why the assembly for all other architectures, as well as for +// amd64 without ADX, perform better than the compiler output. +// TODO(filippo): file cmd/compile performance issue. + +var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2 + +//go:noescape +func addMulVVW1024(z, x *uint, y uint) (c uint) + +//go:noescape +func addMulVVW1536(z, x *uint, y uint) (c uint) + +//go:noescape +func addMulVVW2048(z, x *uint, y uint) (c uint) diff --git a/src/crypto/internal/bigmod/nat_noasm.go b/src/crypto/internal/bigmod/nat_noasm.go index 870b44519d..eff12536f9 100644 --- a/src/crypto/internal/bigmod/nat_noasm.go +++ b/src/crypto/internal/bigmod/nat_noasm.go @@ -1,11 +1,21 @@ -// Copyright 2022 The Go Authors. All rights reserved. +// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !amd64 || !gc || purego +//go:build purego || !(386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x) package bigmod -func montgomeryLoop(d, a, b, m []uint, m0inv uint) uint { - return montgomeryLoopGeneric(d, a, b, m, m0inv) +import "unsafe" + +func addMulVVW1024(z, x *uint, y uint) (c uint) { + return addMulVVW(unsafe.Slice(z, 1024/_W), unsafe.Slice(x, 1024/_W), y) +} + +func addMulVVW1536(z, x *uint, y uint) (c uint) { + return addMulVVW(unsafe.Slice(z, 1536/_W), unsafe.Slice(x, 1536/_W), y) +} + +func addMulVVW2048(z, x *uint, y uint) (c uint) { + return addMulVVW(unsafe.Slice(z, 2048/_W), unsafe.Slice(x, 2048/_W), y) } diff --git a/src/crypto/internal/bigmod/nat_ppc64x.s b/src/crypto/internal/bigmod/nat_ppc64x.s new file mode 100644 index 0000000000..974f4f945e --- /dev/null +++ b/src/crypto/internal/bigmod/nat_ppc64x.s @@ -0,0 +1,51 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego && (ppc64 || ppc64le) + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-32 + MOVD $16, R22 // R22 = z_len + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-32 + MOVD $24, R22 // R22 = z_len + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-32 + MOVD $32, R22 // R22 = z_len + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+8(FP), R8 // R8 = x[] + MOVD y+16(FP), R9 // R9 = y + + MOVD R0, R3 // R3 will be the index register + CMP R0, R22 + MOVD R0, R4 // R4 = c = 0 + MOVD R22, CTR // Initialize loop counter + BEQ done + PCALIGN $16 + +loop: + MOVD (R8)(R3), R20 // Load x[i] + MOVD (R10)(R3), R21 // Load z[i] + MULLD R9, R20, R6 // R6 = Low-order(x[i]*y) + MULHDU R9, R20, R7 // R7 = High-order(x[i]*y) + ADDC R21, R6 // R6 = z0 + ADDZE R7 // R7 = z1 + ADDC R4, R6 // R6 = z0 + c + 0 + ADDZE R7, R4 // c += z1 + MOVD R6, (R10)(R3) // Store z[i] + ADD $8, R3 + BC 16, 0, loop // bdnz + +done: + MOVD R4, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_s390x.s b/src/crypto/internal/bigmod/nat_s390x.s new file mode 100644 index 0000000000..0c07a0c8a6 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_s390x.s @@ -0,0 +1,85 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-32 + MOVD $16, R5 + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-32 + MOVD $24, R5 + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-32 + MOVD $32, R5 + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVD z+0(FP), R2 + MOVD x+8(FP), R8 + MOVD y+16(FP), R9 + + MOVD $0, R1 // i*8 = 0 + MOVD $0, R7 // i = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R4 // c = 0 + + MOVD R5, R12 + AND $-2, R12 + CMPBGE R5, $2, A6 + BR E6 + +A6: + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + MOVD (8)(R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (8)(R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (8)(R2)(R1*1) + + ADD $16, R1 // i*8 + 8 + ADD $2, R7 // i++ + + CMPBLT R7, R12, A6 + BR E6 + +L6: + // TODO: drop unused single-step loop. + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + ADD $8, R1 // i*8 + 8 + ADD $1, R7 // i++ + +E6: + CMPBLT R7, R5, L6 // i < n + + MOVD R4, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go index 4593a2e493..cc5ffe7bb7 100644 --- a/src/crypto/internal/bigmod/nat_test.go +++ b/src/crypto/internal/bigmod/nat_test.go @@ -5,14 +5,24 @@ package bigmod import ( + "fmt" "math/big" "math/bits" "math/rand" "reflect" + "strings" "testing" "testing/quick" ) +func (n *Nat) String() string { + var limbs []string + for i := range n.limbs { + limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i])) + } + return "{" + strings.Join(limbs, " ") + "}" +} + // Generate generates an even nat. It's used by testing/quick to produce random // *nat values for quick.Check invocations. func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { @@ -54,21 +64,23 @@ func TestModSubThenAddIdentity(t *testing.T) { } } -func testMontgomeryRoundtrip(a *Nat) bool { - one := &Nat{make([]uint, len(a.limbs))} - one.limbs[0] = 1 - aPlusOne := new(big.Int).SetBytes(natBytes(a)) - aPlusOne.Add(aPlusOne, big.NewInt(1)) - m := NewModulusFromBig(aPlusOne) - monty := new(Nat).set(a) - monty.montgomeryRepresentation(m) - aAgain := new(Nat).set(monty) - aAgain.montgomeryMul(monty, one, m) - return a.Equal(aAgain) == 1 -} - func TestMontgomeryRoundtrip(t *testing.T) { - err := quick.Check(testMontgomeryRoundtrip, &quick.Config{}) + err := quick.Check(func(a *Nat) bool { + one := &Nat{make([]uint, len(a.limbs))} + one.limbs[0] = 1 + aPlusOne := new(big.Int).SetBytes(natBytes(a)) + aPlusOne.Add(aPlusOne, big.NewInt(1)) + m := NewModulusFromBig(aPlusOne) + monty := new(Nat).set(a) + monty.montgomeryRepresentation(m) + aAgain := new(Nat).set(monty) + aAgain.montgomeryMul(monty, one, m) + if a.Equal(aAgain) != 1 { + t.Errorf("%v != %v", a, aAgain) + return false + } + return true + }, &quick.Config{}) if err != nil { t.Error(err) } @@ -84,30 +96,30 @@ func TestShiftIn(t *testing.T) { }{{ m: []byte{13}, x: []byte{0}, - y: 0x7FFF_FFFF_FFFF_FFFF, - expected: []byte{7}, + y: 0xFFFF_FFFF_FFFF_FFFF, + expected: []byte{2}, }, { m: []byte{13}, x: []byte{7}, - y: 0x7FFF_FFFF_FFFF_FFFF, - expected: []byte{11}, + y: 0xFFFF_FFFF_FFFF_FFFF, + expected: []byte{10}, }, { m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, x: make([]byte, 9), - y: 0x7FFF_FFFF_FFFF_FFFF, - expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + y: 0xFFFF_FFFF_FFFF_FFFF, + expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, }, { m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, - x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, y: 0, - expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08}, + expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06}, }} for i, tt := range examples { m := modulusFromBytes(tt.m) got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m) - if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 { - t.Errorf("%d: got %x, expected %x", i, got, tt.expected) + if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 { + t.Errorf("%d: got %v, expected %v", i, got, exp) } } } @@ -186,7 +198,7 @@ func TestSetBytes(t *testing.T) { continue } if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes { - t.Errorf("%d: got %x, expected %x", i, got, expected) + t.Errorf("%d: got %v, expected %v", i, got, expected) } } @@ -228,7 +240,7 @@ func TestExpand(t *testing.T) { for i, tt := range examples { got := (&Nat{tt.in}).expand(tt.n) if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 { - t.Errorf("%d: got %x, expected %x", i, got, tt.out) + t.Errorf("%d: got %v, expected %v", i, got, tt.out) } } } @@ -287,11 +299,40 @@ func TestExp(t *testing.T) { } } +// TestMulReductions tests that Mul reduces results equal or slightly greater +// than the modulus. Some Montgomery algorithms don't and need extra care to +// return correct results. See https://go.dev/issue/13907. +func TestMulReductions(t *testing.T) { + // Two short but multi-limb primes. + a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10) + b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) + n := new(big.Int).Mul(a, b) + + N := NewModulusFromBig(n) + A := NewNat().setBig(a).ExpandFor(N) + B := NewNat().setBig(b).ExpandFor(N) + + if A.Mul(B, N).IsZero() != 1 { + t.Error("a * b mod (a * b) != 0") + } + + i := new(big.Int).ModInverse(a, b) + N = NewModulusFromBig(b) + A = NewNat().setBig(a).ExpandFor(N) + I := NewNat().setBig(i).ExpandFor(N) + one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) + + if A.Mul(I, N).Equal(one) != 1 { + t.Error("a * inv(a) mod b != 1") + } +} + func natBytes(n *Nat) []byte { return n.Bytes(maxModulus(uint(len(n.limbs)))) } func natFromBytes(b []byte) *Nat { + // Must not use Nat.SetBytes as it's used in TestSetBytes. bb := new(big.Int).SetBytes(b) return NewNat().setBig(bb) } @@ -316,7 +357,7 @@ func makeBenchmarkModulus() *Modulus { func makeBenchmarkValue() *Nat { x := make([]uint, 32) for i := 0; i < 32; i++ { - x[i] = _MASK - 1 + x[i]-- } return &Nat{limbs: x} }