Implement overflow-aware optimization in ctrBlocks8Asm: make a fast branch
in case when there is no overflow. One branch per 8 blocks is faster than
7 increments in general purpose registers and transfers from them to XMM.
Added AES-192 and AES-256 modes to the AES-CTR benchmark.
Added a correctness test in ctr_test.go for the overflow optimization.
This improves performance, especially in AES-128 mode.
goos: windows
goarch: amd64
pkg: crypto/cipher
cpu: AMD Ryzen 7 5800H with Radeon Graphics
│ B/s │ B/s vs base
AESCTR/128/50-16 1.377Gi ± 0% 1.384Gi ± 0% +0.51% (p=0.028 n=20)
AESCTR/128/1K-16 6.164Gi ± 0% 6.892Gi ± 1% +11.81% (p=0.000 n=20)
AESCTR/128/8K-16 7.372Gi ± 0% 8.768Gi ± 1% +18.95% (p=0.000 n=20)
AESCTR/192/50-16 1.289Gi ± 0% 1.279Gi ± 0% -0.75% (p=0.001 n=20)
AESCTR/192/1K-16 5.734Gi ± 0% 6.011Gi ± 0% +4.83% (p=0.000 n=20)
AESCTR/192/8K-16 6.889Gi ± 1% 7.437Gi ± 0% +7.96% (p=0.000 n=20)
AESCTR/256/50-16 1.170Gi ± 0% 1.163Gi ± 0% -0.54% (p=0.005 n=20)
AESCTR/256/1K-16 5.235Gi ± 0% 5.391Gi ± 0% +2.98% (p=0.000 n=20)
AESCTR/256/8K-16 6.361Gi ± 0% 6.676Gi ± 0% +4.94% (p=0.000 n=20)
geomean 3.681Gi 3.882Gi +5.46%
The slight slowdown on 50-byte workloads is unrelated to this change,
because such workloads never use ctrBlocks8Asm.
Updates #76061
Change-Id: Idfd628ac8bb282d9c73c6adf048eb12274a41379
GitHub-Last-Rev:
5aadd39351806fbbf5201e07511aac05bdcb0529
GitHub-Pull-Request: golang/go#76059
Reviewed-on: https://go-review.googlesource.com/c/go/+/714361
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: AHMAD ابو وليد <mizommz@gmail.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
}
}
-func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte) {
+func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte, keySize int) {
b.SetBytes(int64(len(buf)))
- var key [16]byte
+ key := make([]byte, keySize)
var iv [16]byte
- aes, _ := aes.NewCipher(key[:])
+ aes, _ := aes.NewCipher(key)
stream := mode(aes, iv[:])
b.ResetTimer()
const almost8K = 8*1024 - 5
func BenchmarkAESCTR(b *testing.B) {
- b.Run("50", func(b *testing.B) {
- benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50))
- })
- b.Run("1K", func(b *testing.B) {
- benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K))
- })
- b.Run("8K", func(b *testing.B) {
- benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K))
- })
+ for _, keyBits := range []int{128, 192, 256} {
+ keySize := keyBits / 8
+ b.Run(strconv.Itoa(keyBits), func(b *testing.B) {
+ b.Run("50", func(b *testing.B) {
+ benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50), keySize)
+ })
+ b.Run("1K", func(b *testing.B) {
+ benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K), keySize)
+ })
+ b.Run("8K", func(b *testing.B) {
+ benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K), keySize)
+ })
+ })
+ }
}
func BenchmarkAESCBCEncrypt1K(b *testing.B) {
"crypto/internal/boring"
"crypto/internal/cryptotest"
fipsaes "crypto/internal/fips140/aes"
+ "encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
return cipher.NewCTR(wrap(aesBlock), iv), cipher.NewCTR(aesBlock, iv)
}
+// TestCTR_AES_blocks8FastPathMatchesGeneric ensures the overlow aware branch
+// produces identical keystreams to the generic counter walker across
+// representative IVs, including near-overflow cases.
+func TestCTR_AES_blocks8FastPathMatchesGeneric(t *testing.T) {
+ key := make([]byte, aes.BlockSize)
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := block.(*fipsaes.Block); !ok {
+ t.Skip("requires crypto/internal/fips140/aes")
+ }
+
+ keystream := make([]byte, 8*aes.BlockSize)
+
+ testCases := []struct {
+ name string
+ hi uint64
+ lo uint64
+ }{
+ {"Zero", 0, 0},
+ {"NearOverflowMinus7", 1, ^uint64(0) - 7},
+ {"NearOverflowMinus6", 2, ^uint64(0) - 6},
+ {"Overflow", 0, ^uint64(0)},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var iv [aes.BlockSize]byte
+ binary.BigEndian.PutUint64(iv[0:8], tc.hi)
+ binary.BigEndian.PutUint64(iv[8:], tc.lo)
+
+ generic, multiblock := makeTestingCiphers(block, iv[:])
+
+ genericOut := make([]byte, len(keystream))
+ multiblockOut := make([]byte, len(keystream))
+
+ generic.XORKeyStream(genericOut, keystream)
+ multiblock.XORKeyStream(multiblockOut, keystream)
+
+ if !bytes.Equal(multiblockOut, genericOut) {
+ t.Fatalf("mismatch for iv %#x:%#x\n"+
+ "asm keystream: %x\n"+
+ "gen keystream: %x\n"+
+ "asm counters: %x\n"+
+ "gen counters: %x",
+ tc.hi, tc.lo, multiblockOut, genericOut,
+ extractCounters(block, multiblockOut),
+ extractCounters(block, genericOut))
+ }
+ })
+ }
+}
+
func randBytes(t *testing.T, r *rand.Rand, count int) []byte {
t.Helper()
buf := make([]byte, count)
})
}
}
+
+func extractCounters(block cipher.Block, keystream []byte) []byte {
+ blockSize := block.BlockSize()
+ res := make([]byte, len(keystream))
+ for i := 0; i < len(keystream); i += blockSize {
+ block.Decrypt(res[i:i+blockSize], keystream[i:i+blockSize])
+ }
+ return res
+}
bswap := XMM()
MOVOU(bswapMask(), bswap)
- blocks := make([]VecVirtual, 0, numBlocks)
+ blocks := make([]VecVirtual, numBlocks)
+
+ // For the 8-block case we optimize counter generation. We build the first
+ // counter as usual, then check whether the remaining seven increments will
+ // overflow. When they do not (the common case) we keep the work entirely in
+ // XMM registers to avoid expensive general-purpose -> XMM moves. Otherwise
+ // we fall back to the traditional scalar path.
+ if numBlocks == 8 {
+ for i := range blocks {
+ blocks[i] = XMM()
+ }
- // Lay out counter block plaintext.
- for i := 0; i < numBlocks; i++ {
- x := XMM()
- blocks = append(blocks, x)
-
- MOVQ(ivlo, x)
- PINSRQ(Imm(1), ivhi, x)
- PSHUFB(bswap, x)
- if i < numBlocks-1 {
- ADDQ(Imm(1), ivlo)
- ADCQ(Imm(0), ivhi)
+ base := XMM()
+ tmp := GP64()
+ addVec := XMM()
+
+ MOVQ(ivlo, blocks[0])
+ PINSRQ(Imm(1), ivhi, blocks[0])
+ MOVAPS(blocks[0], base)
+ PSHUFB(bswap, blocks[0])
+
+ // Check whether any of these eight counters will overflow.
+ MOVQ(ivlo, tmp)
+ ADDQ(Imm(uint64(numBlocks-1)), tmp)
+ slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks)
+ doneLabel := fmt.Sprintf("ctr%d_done", numBlocks)
+ JC(LabelRef(slowLabel))
+
+ // Fast branch: create an XMM increment vector containing the value 1.
+ // Adding it to the base counter yields each subsequent counter.
+ XORQ(tmp, tmp)
+ INCQ(tmp)
+ PXOR(addVec, addVec)
+ PINSRQ(Imm(0), tmp, addVec)
+
+ for i := 1; i < numBlocks; i++ {
+ PADDQ(addVec, base)
+ MOVAPS(base, blocks[i])
+ }
+ JMP(LabelRef(doneLabel))
+
+ Label(slowLabel)
+ ADDQ(Imm(1), ivlo)
+ ADCQ(Imm(0), ivhi)
+ for i := 1; i < numBlocks; i++ {
+ MOVQ(ivlo, blocks[i])
+ PINSRQ(Imm(1), ivhi, blocks[i])
+ if i < numBlocks-1 {
+ ADDQ(Imm(1), ivlo)
+ ADCQ(Imm(0), ivhi)
+ }
+ }
+
+ Label(doneLabel)
+
+ // Convert little-endian counters to big-endian after the branch since
+ // both paths share the same shuffle sequence.
+ for i := 1; i < numBlocks; i++ {
+ PSHUFB(bswap, blocks[i])
+ }
+ } else {
+ // Lay out counter block plaintext.
+ for i := 0; i < numBlocks; i++ {
+ x := XMM()
+ blocks[i] = x
+
+ MOVQ(ivlo, x)
+ PINSRQ(Imm(1), ivhi, x)
+ PSHUFB(bswap, x)
+ if i < numBlocks-1 {
+ ADDQ(Imm(1), ivlo)
+ ADCQ(Imm(0), ivhi)
+ }
}
}
MOVOU bswapMask<>+0(SB), X0
MOVQ SI, X1
PINSRQ $0x01, DI, X1
+ MOVAPS X1, X8
PSHUFB X0, X1
+ MOVQ SI, R8
+ ADDQ $0x07, R8
+ JC ctr8_slow
+ XORQ R8, R8
+ INCQ R8
+ PXOR X9, X9
+ PINSRQ $0x00, R8, X9
+ PADDQ X9, X8
+ MOVAPS X8, X2
+ PADDQ X9, X8
+ MOVAPS X8, X3
+ PADDQ X9, X8
+ MOVAPS X8, X4
+ PADDQ X9, X8
+ MOVAPS X8, X5
+ PADDQ X9, X8
+ MOVAPS X8, X6
+ PADDQ X9, X8
+ MOVAPS X8, X7
+ PADDQ X9, X8
+ MOVAPS X8, X8
+ JMP ctr8_done
+
+ctr8_slow:
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X2
PINSRQ $0x01, DI, X2
- PSHUFB X0, X2
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X3
PINSRQ $0x01, DI, X3
- PSHUFB X0, X3
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X4
PINSRQ $0x01, DI, X4
- PSHUFB X0, X4
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X5
PINSRQ $0x01, DI, X5
- PSHUFB X0, X5
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X6
PINSRQ $0x01, DI, X6
- PSHUFB X0, X6
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X7
PINSRQ $0x01, DI, X7
- PSHUFB X0, X7
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X8
PINSRQ $0x01, DI, X8
+
+ctr8_done:
+ PSHUFB X0, X2
+ PSHUFB X0, X3
+ PSHUFB X0, X4
+ PSHUFB X0, X5
+ PSHUFB X0, X6
+ PSHUFB X0, X7
PSHUFB X0, X8
MOVUPS (CX), X0
PXOR X0, X1