From 55d961b202de263d6fba73a60d01d142f1d66dc6 Mon Sep 17 00:00:00 2001 From: Austin Clements Date: Thu, 12 Jun 2025 15:33:41 -0400 Subject: [PATCH] runtime: save AVX2 and AVX-512 state on asynchronous preemption Based on CL 669415 by shaojunyang@google.com. This is a cherry-pick of CL 680900 from the dev.simd branch. Change-Id: I574f15c3b18a7179a1573aaf567caf18d8602ef1 Reviewed-on: https://go-review.googlesource.com/c/go/+/693397 Reviewed-by: Cherry Mui Auto-Submit: Austin Clements LUCI-TryBot-Result: Go LUCI --- src/runtime/cpuflags.go | 1 + src/runtime/mkpreempt.go | 74 ++++++++++++++-- src/runtime/preempt_amd64.go | 40 +++++---- src/runtime/preempt_amd64.s | 166 ++++++++++++++++++++++++++++------- 4 files changed, 227 insertions(+), 54 deletions(-) diff --git a/src/runtime/cpuflags.go b/src/runtime/cpuflags.go index bd1cb328d3..6452364b68 100644 --- a/src/runtime/cpuflags.go +++ b/src/runtime/cpuflags.go @@ -13,6 +13,7 @@ import ( const ( offsetX86HasAVX = unsafe.Offsetof(cpu.X86.HasAVX) offsetX86HasAVX2 = unsafe.Offsetof(cpu.X86.HasAVX2) + offsetX86HasAVX512 = unsafe.Offsetof(cpu.X86.HasAVX512) // F+CD+BW+DQ+VL offsetX86HasERMS = unsafe.Offsetof(cpu.X86.HasERMS) offsetX86HasRDTSCP = unsafe.Offsetof(cpu.X86.HasRDTSCP) diff --git a/src/runtime/mkpreempt.go b/src/runtime/mkpreempt.go index 8fce40ccd8..2bd2ef07fa 100644 --- a/src/runtime/mkpreempt.go +++ b/src/runtime/mkpreempt.go @@ -285,7 +285,7 @@ func gen386(g *gen) { func genAMD64(g *gen) { const xReg = "AX" // *xRegState - p := g.p + p, label := g.p, g.label // Assign stack offsets. var l = layout{sp: "SP"} @@ -297,15 +297,33 @@ func genAMD64(g *gen) { l.add("MOVQ", reg, 8) } } - lXRegs := layout{sp: xReg} // Non-GP registers - for _, reg := range regNamesAMD64 { - if strings.HasPrefix(reg, "X") { - lXRegs.add("MOVUPS", reg, 16) + // Create layouts for X, Y, and Z registers. + const ( + numXRegs = 16 + numZRegs = 16 // TODO: If we start using upper registers, change to 32 + numKRegs = 8 + ) + lZRegs := layout{sp: xReg} // Non-GP registers + lXRegs, lYRegs := lZRegs, lZRegs + for i := range numZRegs { + lZRegs.add("VMOVDQU64", fmt.Sprintf("Z%d", i), 512/8) + if i < numXRegs { + // Use SSE-only instructions for X registers. + lXRegs.add("MOVUPS", fmt.Sprintf("X%d", i), 128/8) + lYRegs.add("VMOVDQU", fmt.Sprintf("Y%d", i), 256/8) } } - writeXRegs(g.goarch, &lXRegs) - - // TODO: MXCSR register? + for i := range numKRegs { + lZRegs.add("KMOVQ", fmt.Sprintf("K%d", i), 8) + } + // The Z layout is the most general, so we line up the others with that one. + // We don't have to do this, but it results in a nice Go type. If we split + // this into multiple types, we probably should stop doing this. + for i := range lXRegs.regs { + lXRegs.regs[i].pos = lZRegs.regs[i].pos + lYRegs.regs[i].pos = lZRegs.regs[i].pos + } + writeXRegs(g.goarch, &lZRegs) p("PUSHQ BP") p("MOVQ SP, BP") @@ -333,16 +351,56 @@ func genAMD64(g *gen) { p("MOVQ g_m(R14), %s", xReg) p("MOVQ m_p(%s), %s", xReg, xReg) p("LEAQ (p_xRegs+xRegPerP_scratch)(%s), %s", xReg, xReg) + + // Which registers do we need to save? + p("#ifdef GOEXPERIMENT_simd") + p("CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1") + p("JE saveAVX512") + p("CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1") + p("JE saveAVX2") + p("#endif") + + // No features. Assume only SSE. + label("saveSSE:") lXRegs.save(g) + p("JMP preempt") + label("saveAVX2:") + lYRegs.save(g) + p("JMP preempt") + + label("saveAVX512:") + lZRegs.save(g) + p("JMP preempt") + + label("preempt:") p("CALL ·asyncPreempt2(SB)") p("// Restore non-GPs from *p.xRegs.cache") p("MOVQ g_m(R14), %s", xReg) p("MOVQ m_p(%s), %s", xReg, xReg) p("MOVQ (p_xRegs+xRegPerP_cache)(%s), %s", xReg, xReg) + + p("#ifdef GOEXPERIMENT_simd") + p("CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1") + p("JE restoreAVX512") + p("CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1") + p("JE restoreAVX2") + p("#endif") + + label("restoreSSE:") lXRegs.restore(g) + p("JMP restoreGPs") + + label("restoreAVX2:") + lYRegs.restore(g) + p("JMP restoreGPs") + + label("restoreAVX512:") + lZRegs.restore(g) + p("JMP restoreGPs") + label("restoreGPs:") p("// Restore GPs") l.restore(g) p("ADJSP $%d", -l.stack) diff --git a/src/runtime/preempt_amd64.go b/src/runtime/preempt_amd64.go index 4a3ced79df..88c0ddd34a 100644 --- a/src/runtime/preempt_amd64.go +++ b/src/runtime/preempt_amd64.go @@ -3,20 +3,28 @@ package runtime type xRegs struct { - X0 [16]byte - X1 [16]byte - X2 [16]byte - X3 [16]byte - X4 [16]byte - X5 [16]byte - X6 [16]byte - X7 [16]byte - X8 [16]byte - X9 [16]byte - X10 [16]byte - X11 [16]byte - X12 [16]byte - X13 [16]byte - X14 [16]byte - X15 [16]byte + Z0 [64]byte + Z1 [64]byte + Z2 [64]byte + Z3 [64]byte + Z4 [64]byte + Z5 [64]byte + Z6 [64]byte + Z7 [64]byte + Z8 [64]byte + Z9 [64]byte + Z10 [64]byte + Z11 [64]byte + Z12 [64]byte + Z13 [64]byte + Z14 [64]byte + Z15 [64]byte + K0 uint64 + K1 uint64 + K2 uint64 + K3 uint64 + K4 uint64 + K5 uint64 + K6 uint64 + K7 uint64 } diff --git a/src/runtime/preempt_amd64.s b/src/runtime/preempt_amd64.s index 0a33ce7f3e..c35de7f3b7 100644 --- a/src/runtime/preempt_amd64.s +++ b/src/runtime/preempt_amd64.s @@ -36,43 +36,149 @@ TEXT ·asyncPreempt(SB),NOSPLIT|NOFRAME,$0-0 MOVQ g_m(R14), AX MOVQ m_p(AX), AX LEAQ (p_xRegs+xRegPerP_scratch)(AX), AX + #ifdef GOEXPERIMENT_simd + CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1 + JE saveAVX512 + CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1 + JE saveAVX2 + #endif +saveSSE: MOVUPS X0, 0(AX) - MOVUPS X1, 16(AX) - MOVUPS X2, 32(AX) - MOVUPS X3, 48(AX) - MOVUPS X4, 64(AX) - MOVUPS X5, 80(AX) - MOVUPS X6, 96(AX) - MOVUPS X7, 112(AX) - MOVUPS X8, 128(AX) - MOVUPS X9, 144(AX) - MOVUPS X10, 160(AX) - MOVUPS X11, 176(AX) - MOVUPS X12, 192(AX) - MOVUPS X13, 208(AX) - MOVUPS X14, 224(AX) - MOVUPS X15, 240(AX) + MOVUPS X1, 64(AX) + MOVUPS X2, 128(AX) + MOVUPS X3, 192(AX) + MOVUPS X4, 256(AX) + MOVUPS X5, 320(AX) + MOVUPS X6, 384(AX) + MOVUPS X7, 448(AX) + MOVUPS X8, 512(AX) + MOVUPS X9, 576(AX) + MOVUPS X10, 640(AX) + MOVUPS X11, 704(AX) + MOVUPS X12, 768(AX) + MOVUPS X13, 832(AX) + MOVUPS X14, 896(AX) + MOVUPS X15, 960(AX) + JMP preempt +saveAVX2: + VMOVDQU Y0, 0(AX) + VMOVDQU Y1, 64(AX) + VMOVDQU Y2, 128(AX) + VMOVDQU Y3, 192(AX) + VMOVDQU Y4, 256(AX) + VMOVDQU Y5, 320(AX) + VMOVDQU Y6, 384(AX) + VMOVDQU Y7, 448(AX) + VMOVDQU Y8, 512(AX) + VMOVDQU Y9, 576(AX) + VMOVDQU Y10, 640(AX) + VMOVDQU Y11, 704(AX) + VMOVDQU Y12, 768(AX) + VMOVDQU Y13, 832(AX) + VMOVDQU Y14, 896(AX) + VMOVDQU Y15, 960(AX) + JMP preempt +saveAVX512: + VMOVDQU64 Z0, 0(AX) + VMOVDQU64 Z1, 64(AX) + VMOVDQU64 Z2, 128(AX) + VMOVDQU64 Z3, 192(AX) + VMOVDQU64 Z4, 256(AX) + VMOVDQU64 Z5, 320(AX) + VMOVDQU64 Z6, 384(AX) + VMOVDQU64 Z7, 448(AX) + VMOVDQU64 Z8, 512(AX) + VMOVDQU64 Z9, 576(AX) + VMOVDQU64 Z10, 640(AX) + VMOVDQU64 Z11, 704(AX) + VMOVDQU64 Z12, 768(AX) + VMOVDQU64 Z13, 832(AX) + VMOVDQU64 Z14, 896(AX) + VMOVDQU64 Z15, 960(AX) + KMOVQ K0, 1024(AX) + KMOVQ K1, 1032(AX) + KMOVQ K2, 1040(AX) + KMOVQ K3, 1048(AX) + KMOVQ K4, 1056(AX) + KMOVQ K5, 1064(AX) + KMOVQ K6, 1072(AX) + KMOVQ K7, 1080(AX) + JMP preempt +preempt: CALL ·asyncPreempt2(SB) // Restore non-GPs from *p.xRegs.cache MOVQ g_m(R14), AX MOVQ m_p(AX), AX MOVQ (p_xRegs+xRegPerP_cache)(AX), AX - MOVUPS 240(AX), X15 - MOVUPS 224(AX), X14 - MOVUPS 208(AX), X13 - MOVUPS 192(AX), X12 - MOVUPS 176(AX), X11 - MOVUPS 160(AX), X10 - MOVUPS 144(AX), X9 - MOVUPS 128(AX), X8 - MOVUPS 112(AX), X7 - MOVUPS 96(AX), X6 - MOVUPS 80(AX), X5 - MOVUPS 64(AX), X4 - MOVUPS 48(AX), X3 - MOVUPS 32(AX), X2 - MOVUPS 16(AX), X1 + #ifdef GOEXPERIMENT_simd + CMPB internal∕cpu·X86+const_offsetX86HasAVX512(SB), $1 + JE restoreAVX512 + CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1 + JE restoreAVX2 + #endif +restoreSSE: + MOVUPS 960(AX), X15 + MOVUPS 896(AX), X14 + MOVUPS 832(AX), X13 + MOVUPS 768(AX), X12 + MOVUPS 704(AX), X11 + MOVUPS 640(AX), X10 + MOVUPS 576(AX), X9 + MOVUPS 512(AX), X8 + MOVUPS 448(AX), X7 + MOVUPS 384(AX), X6 + MOVUPS 320(AX), X5 + MOVUPS 256(AX), X4 + MOVUPS 192(AX), X3 + MOVUPS 128(AX), X2 + MOVUPS 64(AX), X1 MOVUPS 0(AX), X0 + JMP restoreGPs +restoreAVX2: + VMOVDQU 960(AX), Y15 + VMOVDQU 896(AX), Y14 + VMOVDQU 832(AX), Y13 + VMOVDQU 768(AX), Y12 + VMOVDQU 704(AX), Y11 + VMOVDQU 640(AX), Y10 + VMOVDQU 576(AX), Y9 + VMOVDQU 512(AX), Y8 + VMOVDQU 448(AX), Y7 + VMOVDQU 384(AX), Y6 + VMOVDQU 320(AX), Y5 + VMOVDQU 256(AX), Y4 + VMOVDQU 192(AX), Y3 + VMOVDQU 128(AX), Y2 + VMOVDQU 64(AX), Y1 + VMOVDQU 0(AX), Y0 + JMP restoreGPs +restoreAVX512: + KMOVQ 1080(AX), K7 + KMOVQ 1072(AX), K6 + KMOVQ 1064(AX), K5 + KMOVQ 1056(AX), K4 + KMOVQ 1048(AX), K3 + KMOVQ 1040(AX), K2 + KMOVQ 1032(AX), K1 + KMOVQ 1024(AX), K0 + VMOVDQU64 960(AX), Z15 + VMOVDQU64 896(AX), Z14 + VMOVDQU64 832(AX), Z13 + VMOVDQU64 768(AX), Z12 + VMOVDQU64 704(AX), Z11 + VMOVDQU64 640(AX), Z10 + VMOVDQU64 576(AX), Z9 + VMOVDQU64 512(AX), Z8 + VMOVDQU64 448(AX), Z7 + VMOVDQU64 384(AX), Z6 + VMOVDQU64 320(AX), Z5 + VMOVDQU64 256(AX), Z4 + VMOVDQU64 192(AX), Z3 + VMOVDQU64 128(AX), Z2 + VMOVDQU64 64(AX), Z1 + VMOVDQU64 0(AX), Z0 + JMP restoreGPs +restoreGPs: // Restore GPs MOVQ 104(SP), R15 MOVQ 96(SP), R14 -- 2.51.0