]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: save AVX2 and AVX-512 state on asynchronous preemption
authorAustin Clements <austin@google.com>
Thu, 12 Jun 2025 19:33:41 +0000 (15:33 -0400)
committerGopher Robot <gobot@golang.org>
Tue, 5 Aug 2025 21:00:15 +0000 (14:00 -0700)
Based on CL 669415 by shaojunyang@google.com.

This is a cherry-pick of CL 680900 from the dev.simd branch.

Change-Id: I574f15c3b18a7179a1573aaf567caf18d8602ef1
Reviewed-on: https://go-review.googlesource.com/c/go/+/693397
Reviewed-by: Cherry Mui <cherryyz@google.com>
Auto-Submit: Austin Clements <austin@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/runtime/cpuflags.go
src/runtime/mkpreempt.go
src/runtime/preempt_amd64.go
src/runtime/preempt_amd64.s

index bd1cb328d37b87e2f1097758888924a63cf59342..6452364b68ec3281dbed47c280865f2cbc9c40bb 100644 (file)
@@ -13,6 +13,7 @@ import (
 const (
        offsetX86HasAVX    = unsafe.Offsetof(cpu.X86.HasAVX)
        offsetX86HasAVX2   = unsafe.Offsetof(cpu.X86.HasAVX2)
+       offsetX86HasAVX512 = unsafe.Offsetof(cpu.X86.HasAVX512) // F+CD+BW+DQ+VL
        offsetX86HasERMS   = unsafe.Offsetof(cpu.X86.HasERMS)
        offsetX86HasRDTSCP = unsafe.Offsetof(cpu.X86.HasRDTSCP)
 
index 8fce40ccd8d04c96e7c6c37cf4f9368a6f0dff16..2bd2ef07fa82923332152641d58f9b9618972bcf 100644 (file)
@@ -285,7 +285,7 @@ func gen386(g *gen) {
 func genAMD64(g *gen) {
        const xReg = "AX" // *xRegState
 
-       p := g.p
+       p, label := g.p, g.label
 
        // Assign stack offsets.
        var l = layout{sp: "SP"}
@@ -297,15 +297,33 @@ func genAMD64(g *gen) {
                        l.add("MOVQ", reg, 8)
                }
        }
-       lXRegs := layout{sp: xReg} // Non-GP registers
-       for _, reg := range regNamesAMD64 {
-               if strings.HasPrefix(reg, "X") {
-                       lXRegs.add("MOVUPS", reg, 16)
+       // Create layouts for X, Y, and Z registers.
+       const (
+               numXRegs = 16
+               numZRegs = 16 // TODO: If we start using upper registers, change to 32
+               numKRegs = 8
+       )
+       lZRegs := layout{sp: xReg} // Non-GP registers
+       lXRegs, lYRegs := lZRegs, lZRegs
+       for i := range numZRegs {
+               lZRegs.add("VMOVDQU64", fmt.Sprintf("Z%d", i), 512/8)
+               if i < numXRegs {
+                       // Use SSE-only instructions for X registers.
+                       lXRegs.add("MOVUPS", fmt.Sprintf("X%d", i), 128/8)
+                       lYRegs.add("VMOVDQU", fmt.Sprintf("Y%d", i), 256/8)
                }
        }
-       writeXRegs(g.goarch, &lXRegs)
-
-       // TODO: MXCSR register?
+       for i := range numKRegs {
+               lZRegs.add("KMOVQ", fmt.Sprintf("K%d", i), 8)
+       }
+       // The Z layout is the most general, so we line up the others with that one.
+       // We don't have to do this, but it results in a nice Go type. If we split
+       // this into multiple types, we probably should stop doing this.
+       for i := range lXRegs.regs {
+               lXRegs.regs[i].pos = lZRegs.regs[i].pos
+               lYRegs.regs[i].pos = lZRegs.regs[i].pos
+       }
+       writeXRegs(g.goarch, &lZRegs)
 
        p("PUSHQ BP")
        p("MOVQ SP, BP")
@@ -333,16 +351,56 @@ func genAMD64(g *gen) {
        p("MOVQ g_m(R14), %s", xReg)
        p("MOVQ m_p(%s), %s", xReg, xReg)
        p("LEAQ (p_xRegs+xRegPerP_scratch)(%s), %s", xReg, xReg)
+
+       // Which registers do we need to save?
+       p("#ifdef GOEXPERIMENT_simd")
+       p("CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1")
+       p("JE saveAVX512")
+       p("CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1")
+       p("JE saveAVX2")
+       p("#endif")
+
+       // No features. Assume only SSE.
+       label("saveSSE:")
        lXRegs.save(g)
+       p("JMP preempt")
 
+       label("saveAVX2:")
+       lYRegs.save(g)
+       p("JMP preempt")
+
+       label("saveAVX512:")
+       lZRegs.save(g)
+       p("JMP preempt")
+
+       label("preempt:")
        p("CALL ·asyncPreempt2(SB)")
 
        p("// Restore non-GPs from *p.xRegs.cache")
        p("MOVQ g_m(R14), %s", xReg)
        p("MOVQ m_p(%s), %s", xReg, xReg)
        p("MOVQ (p_xRegs+xRegPerP_cache)(%s), %s", xReg, xReg)
+
+       p("#ifdef GOEXPERIMENT_simd")
+       p("CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1")
+       p("JE restoreAVX512")
+       p("CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1")
+       p("JE restoreAVX2")
+       p("#endif")
+
+       label("restoreSSE:")
        lXRegs.restore(g)
+       p("JMP restoreGPs")
+
+       label("restoreAVX2:")
+       lYRegs.restore(g)
+       p("JMP restoreGPs")
+
+       label("restoreAVX512:")
+       lZRegs.restore(g)
+       p("JMP restoreGPs")
 
+       label("restoreGPs:")
        p("// Restore GPs")
        l.restore(g)
        p("ADJSP $%d", -l.stack)
index 4a3ced79df3ec38a07f95932410f71fe42129e4f..88c0ddd34ade72034356e4d65e4e32a32a44785a 100644 (file)
@@ -3,20 +3,28 @@
 package runtime
 
 type xRegs struct {
-       X0  [16]byte
-       X1  [16]byte
-       X2  [16]byte
-       X3  [16]byte
-       X4  [16]byte
-       X5  [16]byte
-       X6  [16]byte
-       X7  [16]byte
-       X8  [16]byte
-       X9  [16]byte
-       X10 [16]byte
-       X11 [16]byte
-       X12 [16]byte
-       X13 [16]byte
-       X14 [16]byte
-       X15 [16]byte
+       Z0  [64]byte
+       Z1  [64]byte
+       Z2  [64]byte
+       Z3  [64]byte
+       Z4  [64]byte
+       Z5  [64]byte
+       Z6  [64]byte
+       Z7  [64]byte
+       Z8  [64]byte
+       Z9  [64]byte
+       Z10 [64]byte
+       Z11 [64]byte
+       Z12 [64]byte
+       Z13 [64]byte
+       Z14 [64]byte
+       Z15 [64]byte
+       K0  uint64
+       K1  uint64
+       K2  uint64
+       K3  uint64
+       K4  uint64
+       K5  uint64
+       K6  uint64
+       K7  uint64
 }
index 0a33ce7f3e72f5c8d935edb69e3d1cdc0eb3e5e6..c35de7f3b75726f5ce67b97c6cdb83072cd5c6c9 100644 (file)
@@ -36,43 +36,149 @@ TEXT ·asyncPreempt(SB),NOSPLIT|NOFRAME,$0-0
        MOVQ g_m(R14), AX
        MOVQ m_p(AX), AX
        LEAQ (p_xRegs+xRegPerP_scratch)(AX), AX
+       #ifdef GOEXPERIMENT_simd
+       CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1
+       JE saveAVX512
+       CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1
+       JE saveAVX2
+       #endif
+saveSSE:
        MOVUPS X0, 0(AX)
-       MOVUPS X1, 16(AX)
-       MOVUPS X2, 32(AX)
-       MOVUPS X3, 48(AX)
-       MOVUPS X4, 64(AX)
-       MOVUPS X5, 80(AX)
-       MOVUPS X6, 96(AX)
-       MOVUPS X7, 112(AX)
-       MOVUPS X8, 128(AX)
-       MOVUPS X9, 144(AX)
-       MOVUPS X10, 160(AX)
-       MOVUPS X11, 176(AX)
-       MOVUPS X12, 192(AX)
-       MOVUPS X13, 208(AX)
-       MOVUPS X14, 224(AX)
-       MOVUPS X15, 240(AX)
+       MOVUPS X1, 64(AX)
+       MOVUPS X2, 128(AX)
+       MOVUPS X3, 192(AX)
+       MOVUPS X4, 256(AX)
+       MOVUPS X5, 320(AX)
+       MOVUPS X6, 384(AX)
+       MOVUPS X7, 448(AX)
+       MOVUPS X8, 512(AX)
+       MOVUPS X9, 576(AX)
+       MOVUPS X10, 640(AX)
+       MOVUPS X11, 704(AX)
+       MOVUPS X12, 768(AX)
+       MOVUPS X13, 832(AX)
+       MOVUPS X14, 896(AX)
+       MOVUPS X15, 960(AX)
+       JMP preempt
+saveAVX2:
+       VMOVDQU Y0, 0(AX)
+       VMOVDQU Y1, 64(AX)
+       VMOVDQU Y2, 128(AX)
+       VMOVDQU Y3, 192(AX)
+       VMOVDQU Y4, 256(AX)
+       VMOVDQU Y5, 320(AX)
+       VMOVDQU Y6, 384(AX)
+       VMOVDQU Y7, 448(AX)
+       VMOVDQU Y8, 512(AX)
+       VMOVDQU Y9, 576(AX)
+       VMOVDQU Y10, 640(AX)
+       VMOVDQU Y11, 704(AX)
+       VMOVDQU Y12, 768(AX)
+       VMOVDQU Y13, 832(AX)
+       VMOVDQU Y14, 896(AX)
+       VMOVDQU Y15, 960(AX)
+       JMP preempt
+saveAVX512:
+       VMOVDQU64 Z0, 0(AX)
+       VMOVDQU64 Z1, 64(AX)
+       VMOVDQU64 Z2, 128(AX)
+       VMOVDQU64 Z3, 192(AX)
+       VMOVDQU64 Z4, 256(AX)
+       VMOVDQU64 Z5, 320(AX)
+       VMOVDQU64 Z6, 384(AX)
+       VMOVDQU64 Z7, 448(AX)
+       VMOVDQU64 Z8, 512(AX)
+       VMOVDQU64 Z9, 576(AX)
+       VMOVDQU64 Z10, 640(AX)
+       VMOVDQU64 Z11, 704(AX)
+       VMOVDQU64 Z12, 768(AX)
+       VMOVDQU64 Z13, 832(AX)
+       VMOVDQU64 Z14, 896(AX)
+       VMOVDQU64 Z15, 960(AX)
+       KMOVQ K0, 1024(AX)
+       KMOVQ K1, 1032(AX)
+       KMOVQ K2, 1040(AX)
+       KMOVQ K3, 1048(AX)
+       KMOVQ K4, 1056(AX)
+       KMOVQ K5, 1064(AX)
+       KMOVQ K6, 1072(AX)
+       KMOVQ K7, 1080(AX)
+       JMP preempt
+preempt:
        CALL ·asyncPreempt2(SB)
        // Restore non-GPs from *p.xRegs.cache
        MOVQ g_m(R14), AX
        MOVQ m_p(AX), AX
        MOVQ (p_xRegs+xRegPerP_cache)(AX), AX
-       MOVUPS 240(AX), X15
-       MOVUPS 224(AX), X14
-       MOVUPS 208(AX), X13
-       MOVUPS 192(AX), X12
-       MOVUPS 176(AX), X11
-       MOVUPS 160(AX), X10
-       MOVUPS 144(AX), X9
-       MOVUPS 128(AX), X8
-       MOVUPS 112(AX), X7
-       MOVUPS 96(AX), X6
-       MOVUPS 80(AX), X5
-       MOVUPS 64(AX), X4
-       MOVUPS 48(AX), X3
-       MOVUPS 32(AX), X2
-       MOVUPS 16(AX), X1
+       #ifdef GOEXPERIMENT_simd
+       CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1
+       JE restoreAVX512
+       CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1
+       JE restoreAVX2
+       #endif
+restoreSSE:
+       MOVUPS 960(AX), X15
+       MOVUPS 896(AX), X14
+       MOVUPS 832(AX), X13
+       MOVUPS 768(AX), X12
+       MOVUPS 704(AX), X11
+       MOVUPS 640(AX), X10
+       MOVUPS 576(AX), X9
+       MOVUPS 512(AX), X8
+       MOVUPS 448(AX), X7
+       MOVUPS 384(AX), X6
+       MOVUPS 320(AX), X5
+       MOVUPS 256(AX), X4
+       MOVUPS 192(AX), X3
+       MOVUPS 128(AX), X2
+       MOVUPS 64(AX), X1
        MOVUPS 0(AX), X0
+       JMP restoreGPs
+restoreAVX2:
+       VMOVDQU 960(AX), Y15
+       VMOVDQU 896(AX), Y14
+       VMOVDQU 832(AX), Y13
+       VMOVDQU 768(AX), Y12
+       VMOVDQU 704(AX), Y11
+       VMOVDQU 640(AX), Y10
+       VMOVDQU 576(AX), Y9
+       VMOVDQU 512(AX), Y8
+       VMOVDQU 448(AX), Y7
+       VMOVDQU 384(AX), Y6
+       VMOVDQU 320(AX), Y5
+       VMOVDQU 256(AX), Y4
+       VMOVDQU 192(AX), Y3
+       VMOVDQU 128(AX), Y2
+       VMOVDQU 64(AX), Y1
+       VMOVDQU 0(AX), Y0
+       JMP restoreGPs
+restoreAVX512:
+       KMOVQ 1080(AX), K7
+       KMOVQ 1072(AX), K6
+       KMOVQ 1064(AX), K5
+       KMOVQ 1056(AX), K4
+       KMOVQ 1048(AX), K3
+       KMOVQ 1040(AX), K2
+       KMOVQ 1032(AX), K1
+       KMOVQ 1024(AX), K0
+       VMOVDQU64 960(AX), Z15
+       VMOVDQU64 896(AX), Z14
+       VMOVDQU64 832(AX), Z13
+       VMOVDQU64 768(AX), Z12
+       VMOVDQU64 704(AX), Z11
+       VMOVDQU64 640(AX), Z10
+       VMOVDQU64 576(AX), Z9
+       VMOVDQU64 512(AX), Z8
+       VMOVDQU64 448(AX), Z7
+       VMOVDQU64 384(AX), Z6
+       VMOVDQU64 320(AX), Z5
+       VMOVDQU64 256(AX), Z4
+       VMOVDQU64 192(AX), Z3
+       VMOVDQU64 128(AX), Z2
+       VMOVDQU64 64(AX), Z1
+       VMOVDQU64 0(AX), Z0
+       JMP restoreGPs
+restoreGPs:
        // Restore GPs
        MOVQ 104(SP), R15
        MOVQ 96(SP), R14