From f5f42753ab7653fea7b3e4ae9f0c5cf72c8b6a47 Mon Sep 17 00:00:00 2001 From: Junyang Shao Date: Mon, 14 Jul 2025 17:23:19 +0000 Subject: [PATCH] [dev.simd] cmd/compile, simd: add VDPPS This CL is generated by CL 687915. Change-Id: I1a2fb031c086b2b23fd135c48f8494ba5122493a Reviewed-on: https://go-review.googlesource.com/c/go/+/687916 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- src/cmd/compile/internal/amd64/simdssa.go | 4 +- .../compile/internal/ssa/_gen/simdAMD64.rules | 2 + .../compile/internal/ssa/_gen/simdAMD64ops.go | 2 + .../internal/ssa/_gen/simdgenericOps.go | 2 + src/cmd/compile/internal/ssa/opGen.go | 48 +++++++++++++++++++ src/cmd/compile/internal/ssa/rewriteAMD64.go | 32 +++++++++++++ .../compile/internal/ssagen/simdintrinsics.go | 2 + src/simd/ops_amd64.go | 10 ++++ src/simd/simd_wrapped_test.go | 4 ++ 9 files changed, 105 insertions(+), 1 deletion(-) diff --git a/src/cmd/compile/internal/amd64/simdssa.go b/src/cmd/compile/internal/amd64/simdssa.go index e2d0dd17c6..0ebb955acc 100644 --- a/src/cmd/compile/internal/amd64/simdssa.go +++ b/src/cmd/compile/internal/amd64/simdssa.go @@ -650,7 +650,9 @@ func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { ssa.OpAMD64VPRORQMasked512: p = simdVkvImm8(s, v) - case ssa.OpAMD64VDPPD128, + case ssa.OpAMD64VDPPS128, + ssa.OpAMD64VDPPS256, + ssa.OpAMD64VDPPD128, ssa.OpAMD64VCMPPS128, ssa.OpAMD64VCMPPS256, ssa.OpAMD64VCMPPD128, diff --git a/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules b/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules index 6043edad70..0cbca8bf72 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules +++ b/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules @@ -264,6 +264,8 @@ (DivMaskedFloat64x2 x y mask) => (VDIVPDMasked128 x y (VPMOVVec64x2ToM mask)) (DivMaskedFloat64x4 x y mask) => (VDIVPDMasked256 x y (VPMOVVec64x4ToM mask)) (DivMaskedFloat64x8 x y mask) => (VDIVPDMasked512 x y (VPMOVVec64x8ToM mask)) +(DotProdBroadcastFloat32x4 x y) => (VDPPS128 [127] x y) +(DotProdBroadcastFloat32x8 x y) => (VDPPS256 [127] x y) (DotProdBroadcastFloat64x2 x y) => (VDPPD128 [127] x y) (EqualFloat32x4 x y) => (VCMPPS128 [0] x y) (EqualFloat32x8 x y) => (VCMPPS256 [0] x y) diff --git a/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go b/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go index 3f777db5b7..6985daa04b 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go +++ b/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go @@ -736,6 +736,7 @@ func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vf {name: "VRNDSCALEPSMasked128", argLength: 2, reg: wkw, asm: "VRNDSCALEPS", aux: "Int8", commutative: false, typ: "Vec128", resultInArg0: false}, {name: "VREDUCEPS128", argLength: 1, reg: w11, asm: "VREDUCEPS", aux: "Int8", commutative: false, typ: "Vec128", resultInArg0: false}, {name: "VREDUCEPSMasked128", argLength: 2, reg: wkw, asm: "VREDUCEPS", aux: "Int8", commutative: false, typ: "Vec128", resultInArg0: false}, + {name: "VDPPS128", argLength: 2, reg: v21, asm: "VDPPS", aux: "Int8", commutative: true, typ: "Vec128", resultInArg0: false}, {name: "VCMPPS128", argLength: 2, reg: v21, asm: "VCMPPS", aux: "Int8", commutative: true, typ: "Vec128", resultInArg0: false}, {name: "VCMPPSMasked128", argLength: 3, reg: w2kk, asm: "VCMPPS", aux: "Int8", commutative: true, typ: "Mask", resultInArg0: false}, {name: "VROUNDPS256", argLength: 1, reg: v11, asm: "VROUNDPS", aux: "Int8", commutative: false, typ: "Vec256", resultInArg0: false}, @@ -743,6 +744,7 @@ func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vf {name: "VRNDSCALEPSMasked256", argLength: 2, reg: wkw, asm: "VRNDSCALEPS", aux: "Int8", commutative: false, typ: "Vec256", resultInArg0: false}, {name: "VREDUCEPS256", argLength: 1, reg: w11, asm: "VREDUCEPS", aux: "Int8", commutative: false, typ: "Vec256", resultInArg0: false}, {name: "VREDUCEPSMasked256", argLength: 2, reg: wkw, asm: "VREDUCEPS", aux: "Int8", commutative: false, typ: "Vec256", resultInArg0: false}, + {name: "VDPPS256", argLength: 2, reg: v21, asm: "VDPPS", aux: "Int8", commutative: true, typ: "Vec256", resultInArg0: false}, {name: "VCMPPS256", argLength: 2, reg: v21, asm: "VCMPPS", aux: "Int8", commutative: true, typ: "Vec256", resultInArg0: false}, {name: "VCMPPSMasked256", argLength: 3, reg: w2kk, asm: "VCMPPS", aux: "Int8", commutative: true, typ: "Mask", resultInArg0: false}, {name: "VEXTRACTF128128", argLength: 1, reg: v11, asm: "VEXTRACTF128", aux: "Int8", commutative: false, typ: "Vec128", resultInArg0: false}, diff --git a/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go b/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go index 1180d32586..a1dfc1e7da 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go +++ b/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go @@ -53,6 +53,7 @@ func simdGenericOps() []opData { {name: "CeilFloat32x4", argLength: 1, commutative: false}, {name: "DivFloat32x4", argLength: 2, commutative: false}, {name: "DivMaskedFloat32x4", argLength: 3, commutative: false}, + {name: "DotProdBroadcastFloat32x4", argLength: 2, commutative: true}, {name: "EqualFloat32x4", argLength: 2, commutative: true}, {name: "EqualMaskedFloat32x4", argLength: 3, commutative: true}, {name: "FloorFloat32x4", argLength: 1, commutative: false}, @@ -100,6 +101,7 @@ func simdGenericOps() []opData { {name: "CeilFloat32x8", argLength: 1, commutative: false}, {name: "DivFloat32x8", argLength: 2, commutative: false}, {name: "DivMaskedFloat32x8", argLength: 3, commutative: false}, + {name: "DotProdBroadcastFloat32x8", argLength: 2, commutative: true}, {name: "EqualFloat32x8", argLength: 2, commutative: true}, {name: "EqualMaskedFloat32x8", argLength: 3, commutative: true}, {name: "FloorFloat32x8", argLength: 1, commutative: false}, diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 9067023f3a..ba28c58b7e 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -1931,6 +1931,7 @@ const ( OpAMD64VRNDSCALEPSMasked128 OpAMD64VREDUCEPS128 OpAMD64VREDUCEPSMasked128 + OpAMD64VDPPS128 OpAMD64VCMPPS128 OpAMD64VCMPPSMasked128 OpAMD64VROUNDPS256 @@ -1938,6 +1939,7 @@ const ( OpAMD64VRNDSCALEPSMasked256 OpAMD64VREDUCEPS256 OpAMD64VREDUCEPSMasked256 + OpAMD64VDPPS256 OpAMD64VCMPPS256 OpAMD64VCMPPSMasked256 OpAMD64VEXTRACTF128128 @@ -4369,6 +4371,7 @@ const ( OpCeilFloat32x4 OpDivFloat32x4 OpDivMaskedFloat32x4 + OpDotProdBroadcastFloat32x4 OpEqualFloat32x4 OpEqualMaskedFloat32x4 OpFloorFloat32x4 @@ -4416,6 +4419,7 @@ const ( OpCeilFloat32x8 OpDivFloat32x8 OpDivMaskedFloat32x8 + OpDotProdBroadcastFloat32x8 OpEqualFloat32x8 OpEqualMaskedFloat32x8 OpFloorFloat32x8 @@ -29582,6 +29586,22 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "VDPPS128", + auxType: auxInt8, + argLen: 2, + commutative: true, + asm: x86.AVDPPS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + {1, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, { name: "VCMPPS128", auxType: auxInt8, @@ -29687,6 +29707,22 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "VDPPS256", + auxType: auxInt8, + argLen: 2, + commutative: true, + asm: x86.AVDPPS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + {1, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, { name: "VCMPPS256", auxType: auxInt8, @@ -59497,6 +59533,12 @@ var opcodeTable = [...]opInfo{ argLen: 3, generic: true, }, + { + name: "DotProdBroadcastFloat32x4", + argLen: 2, + commutative: true, + generic: true, + }, { name: "EqualFloat32x4", argLen: 2, @@ -59746,6 +59788,12 @@ var opcodeTable = [...]opInfo{ argLen: 3, generic: true, }, + { + name: "DotProdBroadcastFloat32x8", + argLen: 2, + commutative: true, + generic: true, + }, { name: "EqualFloat32x8", argLen: 2, diff --git a/src/cmd/compile/internal/ssa/rewriteAMD64.go b/src/cmd/compile/internal/ssa/rewriteAMD64.go index d78c9212cb..6d10b009bb 100644 --- a/src/cmd/compile/internal/ssa/rewriteAMD64.go +++ b/src/cmd/compile/internal/ssa/rewriteAMD64.go @@ -1407,6 +1407,10 @@ func rewriteValueAMD64(v *Value) bool { return rewriteValueAMD64_OpDivMaskedFloat64x4(v) case OpDivMaskedFloat64x8: return rewriteValueAMD64_OpDivMaskedFloat64x8(v) + case OpDotProdBroadcastFloat32x4: + return rewriteValueAMD64_OpDotProdBroadcastFloat32x4(v) + case OpDotProdBroadcastFloat32x8: + return rewriteValueAMD64_OpDotProdBroadcastFloat32x8(v) case OpDotProdBroadcastFloat64x2: return rewriteValueAMD64_OpDotProdBroadcastFloat64x2(v) case OpEq16: @@ -32312,6 +32316,34 @@ func rewriteValueAMD64_OpDivMaskedFloat64x8(v *Value) bool { return true } } +func rewriteValueAMD64_OpDotProdBroadcastFloat32x4(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (DotProdBroadcastFloat32x4 x y) + // result: (VDPPS128 [127] x y) + for { + x := v_0 + y := v_1 + v.reset(OpAMD64VDPPS128) + v.AuxInt = int8ToAuxInt(127) + v.AddArg2(x, y) + return true + } +} +func rewriteValueAMD64_OpDotProdBroadcastFloat32x8(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (DotProdBroadcastFloat32x8 x y) + // result: (VDPPS256 [127] x y) + for { + x := v_0 + y := v_1 + v.reset(OpAMD64VDPPS256) + v.AuxInt = int8ToAuxInt(127) + v.AddArg2(x, y) + return true + } +} func rewriteValueAMD64_OpDotProdBroadcastFloat64x2(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/src/cmd/compile/internal/ssagen/simdintrinsics.go b/src/cmd/compile/internal/ssagen/simdintrinsics.go index 085c0b8d99..58bc420fc4 100644 --- a/src/cmd/compile/internal/ssagen/simdintrinsics.go +++ b/src/cmd/compile/internal/ssagen/simdintrinsics.go @@ -275,6 +275,8 @@ func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies . addF(simdPackage, "Float64x2.DivMasked", opLen3(ssa.OpDivMaskedFloat64x2, types.TypeVec128), sys.AMD64) addF(simdPackage, "Float64x4.DivMasked", opLen3(ssa.OpDivMaskedFloat64x4, types.TypeVec256), sys.AMD64) addF(simdPackage, "Float64x8.DivMasked", opLen3(ssa.OpDivMaskedFloat64x8, types.TypeVec512), sys.AMD64) + addF(simdPackage, "Float32x4.DotProdBroadcast", opLen2(ssa.OpDotProdBroadcastFloat32x4, types.TypeVec128), sys.AMD64) + addF(simdPackage, "Float32x8.DotProdBroadcast", opLen2(ssa.OpDotProdBroadcastFloat32x8, types.TypeVec256), sys.AMD64) addF(simdPackage, "Float64x2.DotProdBroadcast", opLen2(ssa.OpDotProdBroadcastFloat64x2, types.TypeVec128), sys.AMD64) addF(simdPackage, "Int8x16.Equal", opLen2(ssa.OpEqualInt8x16, types.TypeVec128), sys.AMD64) addF(simdPackage, "Int8x32.Equal", opLen2(ssa.OpEqualInt8x32, types.TypeVec256), sys.AMD64) diff --git a/src/simd/ops_amd64.go b/src/simd/ops_amd64.go index 2c17300ae4..7a8780e5cb 100644 --- a/src/simd/ops_amd64.go +++ b/src/simd/ops_amd64.go @@ -1502,6 +1502,16 @@ func (x Float64x8) DivMasked(y Float64x8, z Mask64x8) Float64x8 /* DotProdBroadcast */ +// DotProdBroadcast multiplies all elements and broadcasts the sum. +// +// Asm: VDPPS, CPU Feature: AVX +func (x Float32x4) DotProdBroadcast(y Float32x4) Float32x4 + +// DotProdBroadcast multiplies all elements and broadcasts the sum. +// +// Asm: VDPPS, CPU Feature: AVX +func (x Float32x8) DotProdBroadcast(y Float32x8) Float32x8 + // DotProdBroadcast multiplies all elements and broadcasts the sum. // // Asm: VDPPD, CPU Feature: AVX diff --git a/src/simd/simd_wrapped_test.go b/src/simd/simd_wrapped_test.go index 15e5c45097..6466684068 100644 --- a/src/simd/simd_wrapped_test.go +++ b/src/simd/simd_wrapped_test.go @@ -22,6 +22,8 @@ func testFloat32x4Binary(t *testing.T, v0 []float32, v1 []float32, want []float3 gotv = vec0.AddSub(vec1) case "Div": gotv = vec0.Div(vec1) + case "DotProdBroadcast": + gotv = vec0.DotProdBroadcast(vec1) case "Max": gotv = vec0.Max(vec1) case "Min": @@ -272,6 +274,8 @@ func testFloat32x8Binary(t *testing.T, v0 []float32, v1 []float32, want []float3 gotv = vec0.AddSub(vec1) case "Div": gotv = vec0.Div(vec1) + case "DotProdBroadcast": + gotv = vec0.DotProdBroadcast(vec1) case "Max": gotv = vec0.Max(vec1) case "Min": -- 2.52.0