]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/fips140/aes: optimize ctrBlocks8Asm on amd64
authorBoris Nagaev <bnagaev@gmail.com>
Wed, 26 Nov 2025 08:26:49 +0000 (08:26 +0000)
committerGopher Robot <gobot@golang.org>
Wed, 26 Nov 2025 18:11:50 +0000 (10:11 -0800)
Implement overflow-aware optimization in ctrBlocks8Asm: make a fast branch
in case when there is no overflow. One branch per 8 blocks is faster than
7 increments in general purpose registers and transfers from them to XMM.

Added AES-192 and AES-256 modes to the AES-CTR benchmark.

Added a correctness test in ctr_test.go for the overflow optimization.

This improves performance, especially in AES-128 mode.

goos: windows
goarch: amd64
pkg: crypto/cipher
cpu: AMD Ryzen 7 5800H with Radeon Graphics
 │     B/s      │     B/s       vs base
AESCTR/128/50-16   1.377Gi ± 0%   1.384Gi ± 0%   +0.51% (p=0.028 n=20)
AESCTR/128/1K-16   6.164Gi ± 0%   6.892Gi ± 1%  +11.81% (p=0.000 n=20)
AESCTR/128/8K-16   7.372Gi ± 0%   8.768Gi ± 1%  +18.95% (p=0.000 n=20)
AESCTR/192/50-16   1.289Gi ± 0%   1.279Gi ± 0%   -0.75% (p=0.001 n=20)
AESCTR/192/1K-16   5.734Gi ± 0%   6.011Gi ± 0%   +4.83% (p=0.000 n=20)
AESCTR/192/8K-16   6.889Gi ± 1%   7.437Gi ± 0%   +7.96% (p=0.000 n=20)
AESCTR/256/50-16   1.170Gi ± 0%   1.163Gi ± 0%   -0.54% (p=0.005 n=20)
AESCTR/256/1K-16   5.235Gi ± 0%   5.391Gi ± 0%   +2.98% (p=0.000 n=20)
AESCTR/256/8K-16   6.361Gi ± 0%   6.676Gi ± 0%   +4.94% (p=0.000 n=20)
geomean            3.681Gi        3.882Gi        +5.46%

The slight slowdown on 50-byte workloads is unrelated to this change,
because such workloads never use ctrBlocks8Asm.

Updates #76061

Change-Id: Idfd628ac8bb282d9c73c6adf048eb12274a41379
GitHub-Last-Rev: 5aadd39351806fbbf5201e07511aac05bdcb0529
GitHub-Pull-Request: golang/go#76059
Reviewed-on: https://go-review.googlesource.com/c/go/+/714361
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: AHMAD ابو وليد <mizommz@gmail.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Filippo Valsorda <filippo@golang.org>

src/crypto/cipher/benchmark_test.go
src/crypto/cipher/ctr_aes_test.go
src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go
src/crypto/internal/fips140/aes/ctr_amd64.s

index 181d08c9b1469952a5e780ef665f092e0cc09456..1a5b1b1ddd552d51410b79b8d0a1476122580d8f 100644 (file)
@@ -65,12 +65,12 @@ func BenchmarkAESGCM(b *testing.B) {
        }
 }
 
-func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte) {
+func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte, keySize int) {
        b.SetBytes(int64(len(buf)))
 
-       var key [16]byte
+       key := make([]byte, keySize)
        var iv [16]byte
-       aes, _ := aes.NewCipher(key[:])
+       aes, _ := aes.NewCipher(key)
        stream := mode(aes, iv[:])
 
        b.ResetTimer()
@@ -87,15 +87,20 @@ const almost1K = 1024 - 5
 const almost8K = 8*1024 - 5
 
 func BenchmarkAESCTR(b *testing.B) {
-       b.Run("50", func(b *testing.B) {
-               benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50))
-       })
-       b.Run("1K", func(b *testing.B) {
-               benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K))
-       })
-       b.Run("8K", func(b *testing.B) {
-               benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K))
-       })
+       for _, keyBits := range []int{128, 192, 256} {
+               keySize := keyBits / 8
+               b.Run(strconv.Itoa(keyBits), func(b *testing.B) {
+                       b.Run("50", func(b *testing.B) {
+                               benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50), keySize)
+                       })
+                       b.Run("1K", func(b *testing.B) {
+                               benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K), keySize)
+                       })
+                       b.Run("8K", func(b *testing.B) {
+                               benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K), keySize)
+                       })
+               })
+       }
 }
 
 func BenchmarkAESCBCEncrypt1K(b *testing.B) {
index 9b7d30e2164422da741a141a84b838d6f36d5323..1d8ae78674ebe2294aab961b8e0ccdd33bf073f6 100644 (file)
@@ -17,6 +17,7 @@ import (
        "crypto/internal/boring"
        "crypto/internal/cryptotest"
        fipsaes "crypto/internal/fips140/aes"
+       "encoding/binary"
        "encoding/hex"
        "fmt"
        "math/rand"
@@ -117,6 +118,60 @@ func makeTestingCiphers(aesBlock cipher.Block, iv []byte) (genericCtr, multibloc
        return cipher.NewCTR(wrap(aesBlock), iv), cipher.NewCTR(aesBlock, iv)
 }
 
+// TestCTR_AES_blocks8FastPathMatchesGeneric ensures the overlow aware branch
+// produces identical keystreams to the generic counter walker across
+// representative IVs, including near-overflow cases.
+func TestCTR_AES_blocks8FastPathMatchesGeneric(t *testing.T) {
+       key := make([]byte, aes.BlockSize)
+       block, err := aes.NewCipher(key)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if _, ok := block.(*fipsaes.Block); !ok {
+               t.Skip("requires crypto/internal/fips140/aes")
+       }
+
+       keystream := make([]byte, 8*aes.BlockSize)
+
+       testCases := []struct {
+               name string
+               hi   uint64
+               lo   uint64
+       }{
+               {"Zero", 0, 0},
+               {"NearOverflowMinus7", 1, ^uint64(0) - 7},
+               {"NearOverflowMinus6", 2, ^uint64(0) - 6},
+               {"Overflow", 0, ^uint64(0)},
+       }
+
+       for _, tc := range testCases {
+               t.Run(tc.name, func(t *testing.T) {
+                       var iv [aes.BlockSize]byte
+                       binary.BigEndian.PutUint64(iv[0:8], tc.hi)
+                       binary.BigEndian.PutUint64(iv[8:], tc.lo)
+
+                       generic, multiblock := makeTestingCiphers(block, iv[:])
+
+                       genericOut := make([]byte, len(keystream))
+                       multiblockOut := make([]byte, len(keystream))
+
+                       generic.XORKeyStream(genericOut, keystream)
+                       multiblock.XORKeyStream(multiblockOut, keystream)
+
+                       if !bytes.Equal(multiblockOut, genericOut) {
+                               t.Fatalf("mismatch for iv %#x:%#x\n"+
+                                       "asm keystream: %x\n"+
+                                       "gen keystream: %x\n"+
+                                       "asm counters: %x\n"+
+                                       "gen counters: %x",
+                                       tc.hi, tc.lo, multiblockOut, genericOut,
+                                       extractCounters(block, multiblockOut),
+                                       extractCounters(block, genericOut))
+                       }
+               })
+       }
+}
+
 func randBytes(t *testing.T, r *rand.Rand, count int) []byte {
        t.Helper()
        buf := make([]byte, count)
@@ -297,3 +352,12 @@ func TestCTR_AES_multiblock_XORKeyStreamAt(t *testing.T) {
                })
        }
 }
+
+func extractCounters(block cipher.Block, keystream []byte) []byte {
+       blockSize := block.BlockSize()
+       res := make([]byte, len(keystream))
+       for i := 0; i < len(keystream); i += blockSize {
+               block.Decrypt(res[i:i+blockSize], keystream[i:i+blockSize])
+       }
+       return res
+}
index 775d4a8acc59698afce6841db073218d3b6b74ef..e3dbdf66d70e1c42241a6fc9dcc527576c7d0d20 100644 (file)
@@ -40,19 +40,79 @@ func ctrBlocks(numBlocks int) {
        bswap := XMM()
        MOVOU(bswapMask(), bswap)
 
-       blocks := make([]VecVirtual, 0, numBlocks)
+       blocks := make([]VecVirtual, numBlocks)
+
+       // For the 8-block case we optimize counter generation. We build the first
+       // counter as usual, then check whether the remaining seven increments will
+       // overflow. When they do not (the common case) we keep the work entirely in
+       // XMM registers to avoid expensive general-purpose -> XMM moves. Otherwise
+       // we fall back to the traditional scalar path.
+       if numBlocks == 8 {
+               for i := range blocks {
+                       blocks[i] = XMM()
+               }
 
-       // Lay out counter block plaintext.
-       for i := 0; i < numBlocks; i++ {
-               x := XMM()
-               blocks = append(blocks, x)
-
-               MOVQ(ivlo, x)
-               PINSRQ(Imm(1), ivhi, x)
-               PSHUFB(bswap, x)
-               if i < numBlocks-1 {
-                       ADDQ(Imm(1), ivlo)
-                       ADCQ(Imm(0), ivhi)
+               base := XMM()
+               tmp := GP64()
+               addVec := XMM()
+
+               MOVQ(ivlo, blocks[0])
+               PINSRQ(Imm(1), ivhi, blocks[0])
+               MOVAPS(blocks[0], base)
+               PSHUFB(bswap, blocks[0])
+
+               // Check whether any of these eight counters will overflow.
+               MOVQ(ivlo, tmp)
+               ADDQ(Imm(uint64(numBlocks-1)), tmp)
+               slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks)
+               doneLabel := fmt.Sprintf("ctr%d_done", numBlocks)
+               JC(LabelRef(slowLabel))
+
+               // Fast branch: create an XMM increment vector containing the value 1.
+               // Adding it to the base counter yields each subsequent counter.
+               XORQ(tmp, tmp)
+               INCQ(tmp)
+               PXOR(addVec, addVec)
+               PINSRQ(Imm(0), tmp, addVec)
+
+               for i := 1; i < numBlocks; i++ {
+                       PADDQ(addVec, base)
+                       MOVAPS(base, blocks[i])
+               }
+               JMP(LabelRef(doneLabel))
+
+               Label(slowLabel)
+               ADDQ(Imm(1), ivlo)
+               ADCQ(Imm(0), ivhi)
+               for i := 1; i < numBlocks; i++ {
+                       MOVQ(ivlo, blocks[i])
+                       PINSRQ(Imm(1), ivhi, blocks[i])
+                       if i < numBlocks-1 {
+                               ADDQ(Imm(1), ivlo)
+                               ADCQ(Imm(0), ivhi)
+                       }
+               }
+
+               Label(doneLabel)
+
+               // Convert little-endian counters to big-endian after the branch since
+               // both paths share the same shuffle sequence.
+               for i := 1; i < numBlocks; i++ {
+                       PSHUFB(bswap, blocks[i])
+               }
+       } else {
+               // Lay out counter block plaintext.
+               for i := 0; i < numBlocks; i++ {
+                       x := XMM()
+                       blocks[i] = x
+
+                       MOVQ(ivlo, x)
+                       PINSRQ(Imm(1), ivhi, x)
+                       PSHUFB(bswap, x)
+                       if i < numBlocks-1 {
+                               ADDQ(Imm(1), ivlo)
+                               ADCQ(Imm(0), ivhi)
+                       }
                }
        }
 
index e6710834dd27e63256e725a7cb54910d167adc7e..deef3e7705a5b32614ebe79646a06b13abed21a9 100644 (file)
@@ -286,41 +286,68 @@ TEXT ·ctrBlocks8Asm(SB), $0-48
        MOVOU  bswapMask<>+0(SB), X0
        MOVQ   SI, X1
        PINSRQ $0x01, DI, X1
+       MOVAPS X1, X8
        PSHUFB X0, X1
+       MOVQ   SI, R8
+       ADDQ   $0x07, R8
+       JC     ctr8_slow
+       XORQ   R8, R8
+       INCQ   R8
+       PXOR   X9, X9
+       PINSRQ $0x00, R8, X9
+       PADDQ  X9, X8
+       MOVAPS X8, X2
+       PADDQ  X9, X8
+       MOVAPS X8, X3
+       PADDQ  X9, X8
+       MOVAPS X8, X4
+       PADDQ  X9, X8
+       MOVAPS X8, X5
+       PADDQ  X9, X8
+       MOVAPS X8, X6
+       PADDQ  X9, X8
+       MOVAPS X8, X7
+       PADDQ  X9, X8
+       MOVAPS X8, X8
+       JMP    ctr8_done
+
+ctr8_slow:
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X2
        PINSRQ $0x01, DI, X2
-       PSHUFB X0, X2
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X3
        PINSRQ $0x01, DI, X3
-       PSHUFB X0, X3
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X4
        PINSRQ $0x01, DI, X4
-       PSHUFB X0, X4
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X5
        PINSRQ $0x01, DI, X5
-       PSHUFB X0, X5
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X6
        PINSRQ $0x01, DI, X6
-       PSHUFB X0, X6
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X7
        PINSRQ $0x01, DI, X7
-       PSHUFB X0, X7
        ADDQ   $0x01, SI
        ADCQ   $0x00, DI
        MOVQ   SI, X8
        PINSRQ $0x01, DI, X8
+
+ctr8_done:
+       PSHUFB X0, X2
+       PSHUFB X0, X3
+       PSHUFB X0, X4
+       PSHUFB X0, X5
+       PSHUFB X0, X6
+       PSHUFB X0, X7
        PSHUFB X0, X8
        MOVUPS (CX), X0
        PXOR   X0, X1