From 896f293a252ad5784a80f42f26b944eabf93eaa6 Mon Sep 17 00:00:00 2001 From: Junyang Shao Date: Mon, 17 Nov 2025 23:19:56 +0000 Subject: [PATCH] [dev.simd] cmd/compile, simd: change DotProductQuadruple and add peepholes This CL addressed some API change decisions in the API audit. Instead of exposing the Intel format, we hide the add part of the instructions under the peephole, and rename the API as DotProdQuadruple Change-Id: I471c0a755174bc15dd83bdc0f757d6356b92d835 Reviewed-on: https://go-review.googlesource.com/c/go/+/721420 Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI --- src/cmd/compile/internal/amd64/simdssa.go | 36 +- src/cmd/compile/internal/ssa/_gen/AMD64.rules | 8 + .../compile/internal/ssa/_gen/simdAMD64.rules | 40 +- .../internal/ssa/_gen/simdgenericOps.go | 12 +- src/cmd/compile/internal/ssa/opGen.go | 72 ++-- src/cmd/compile/internal/ssa/rewriteAMD64.go | 350 +++++++++++++----- src/cmd/compile/internal/ssagen/intrinsics.go | 14 +- .../compile/internal/ssagen/simdintrinsics.go | 12 +- src/simd/_gen/simdgen/gen_simdIntrinsics.go | 2 +- src/simd/_gen/simdgen/gen_simdTypes.go | 4 +- .../_gen/simdgen/ops/MLOps/categories.yaml | 10 +- src/simd/_gen/simdgen/ops/MLOps/go.yaml | 8 +- src/simd/internal/simd_test/simd_test.go | 34 ++ src/simd/ops_amd64.go | 74 ++-- 14 files changed, 441 insertions(+), 235 deletions(-) diff --git a/src/cmd/compile/internal/amd64/simdssa.go b/src/cmd/compile/internal/amd64/simdssa.go index 82ec733cc0..3f8ce17972 100644 --- a/src/cmd/compile/internal/amd64/simdssa.go +++ b/src/cmd/compile/internal/amd64/simdssa.go @@ -1274,12 +1274,6 @@ func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { case ssa.OpAMD64VPDPWSSDMasked128, ssa.OpAMD64VPDPWSSDMasked256, ssa.OpAMD64VPDPWSSDMasked512, - ssa.OpAMD64VPDPBUSDMasked128, - ssa.OpAMD64VPDPBUSDMasked256, - ssa.OpAMD64VPDPBUSDMasked512, - ssa.OpAMD64VPDPBUSDSMasked128, - ssa.OpAMD64VPDPBUSDSMasked256, - ssa.OpAMD64VPDPBUSDSMasked512, ssa.OpAMD64VADDPSMasked128Merging, ssa.OpAMD64VADDPSMasked256Merging, ssa.OpAMD64VADDPSMasked512Merging, @@ -1343,6 +1337,12 @@ func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { ssa.OpAMD64VPMADDUBSWMasked128Merging, ssa.OpAMD64VPMADDUBSWMasked256Merging, ssa.OpAMD64VPMADDUBSWMasked512Merging, + ssa.OpAMD64VPDPBUSDMasked128, + ssa.OpAMD64VPDPBUSDMasked256, + ssa.OpAMD64VPDPBUSDMasked512, + ssa.OpAMD64VPDPBUSDSMasked128, + ssa.OpAMD64VPDPBUSDSMasked256, + ssa.OpAMD64VPDPBUSDSMasked512, ssa.OpAMD64VGF2P8MULBMasked128Merging, ssa.OpAMD64VGF2P8MULBMasked256Merging, ssa.OpAMD64VGF2P8MULBMasked512Merging, @@ -2543,18 +2543,6 @@ func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { ssa.OpAMD64VPDPWSSDMasked256load, ssa.OpAMD64VPDPWSSDMasked512, ssa.OpAMD64VPDPWSSDMasked512load, - ssa.OpAMD64VPDPBUSDMasked128, - ssa.OpAMD64VPDPBUSDMasked128load, - ssa.OpAMD64VPDPBUSDMasked256, - ssa.OpAMD64VPDPBUSDMasked256load, - ssa.OpAMD64VPDPBUSDMasked512, - ssa.OpAMD64VPDPBUSDMasked512load, - ssa.OpAMD64VPDPBUSDSMasked128, - ssa.OpAMD64VPDPBUSDSMasked128load, - ssa.OpAMD64VPDPBUSDSMasked256, - ssa.OpAMD64VPDPBUSDSMasked256load, - ssa.OpAMD64VPDPBUSDSMasked512, - ssa.OpAMD64VPDPBUSDSMasked512load, ssa.OpAMD64VADDPSMasked128, ssa.OpAMD64VADDPSMasked128load, ssa.OpAMD64VADDPSMasked256, @@ -2821,6 +2809,18 @@ func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { ssa.OpAMD64VPMADDUBSWMasked128, ssa.OpAMD64VPMADDUBSWMasked256, ssa.OpAMD64VPMADDUBSWMasked512, + ssa.OpAMD64VPDPBUSDMasked128, + ssa.OpAMD64VPDPBUSDMasked128load, + ssa.OpAMD64VPDPBUSDMasked256, + ssa.OpAMD64VPDPBUSDMasked256load, + ssa.OpAMD64VPDPBUSDMasked512, + ssa.OpAMD64VPDPBUSDMasked512load, + ssa.OpAMD64VPDPBUSDSMasked128, + ssa.OpAMD64VPDPBUSDSMasked128load, + ssa.OpAMD64VPDPBUSDSMasked256, + ssa.OpAMD64VPDPBUSDSMasked256load, + ssa.OpAMD64VPDPBUSDSMasked512, + ssa.OpAMD64VPDPBUSDSMasked512load, ssa.OpAMD64VEXPANDPSMasked128, ssa.OpAMD64VEXPANDPSMasked256, ssa.OpAMD64VEXPANDPSMasked512, diff --git a/src/cmd/compile/internal/ssa/_gen/AMD64.rules b/src/cmd/compile/internal/ssa/_gen/AMD64.rules index 38ca44f7eb..353d272179 100644 --- a/src/cmd/compile/internal/ssa/_gen/AMD64.rules +++ b/src/cmd/compile/internal/ssa/_gen/AMD64.rules @@ -1817,3 +1817,11 @@ (EQ (VPTEST x:(VPAND(D|Q)512 j k) y) yes no) && x == y && x.Uses == 2 => (EQ (VPTEST j k) yes no) (EQ (VPTEST x:(VPANDN(128|256) j k) y) yes no) && x == y && x.Uses == 2 => (ULT (VPTEST k j) yes no) // AndNot has swapped its operand order (EQ (VPTEST x:(VPANDN(D|Q)512 j k) y) yes no) && x == y && x.Uses == 2 => (ULT (VPTEST k j) yes no) // AndNot has swapped its operand order + +// DotProductQuadruple optimizations +(VPADDD128 (VPDPBUSD128 (Zero128 ) x y) z) => (VPDPBUSD128 z x y) +(VPADDD256 (VPDPBUSD256 (Zero256 ) x y) z) => (VPDPBUSD256 z x y) +(VPADDD512 (VPDPBUSD512 (Zero512 ) x y) z) => (VPDPBUSD512 z x y) +(VPADDD128 (VPDPBUSDS128 (Zero128 ) x y) z) => (VPDPBUSDS128 z x y) +(VPADDD256 (VPDPBUSDS256 (Zero256 ) x y) z) => (VPDPBUSDS256 z x y) +(VPADDD512 (VPDPBUSDS512 (Zero512 ) x y) z) => (VPDPBUSDS512 z x y) \ No newline at end of file diff --git a/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules b/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules index 5169bf24d9..5a9a1c0bc7 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules +++ b/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules @@ -56,12 +56,6 @@ (AddUint64x2 ...) => (VPADDQ128 ...) (AddUint64x4 ...) => (VPADDQ256 ...) (AddUint64x8 ...) => (VPADDQ512 ...) -(AddDotProductQuadrupleInt32x4 ...) => (VPDPBUSD128 ...) -(AddDotProductQuadrupleInt32x8 ...) => (VPDPBUSD256 ...) -(AddDotProductQuadrupleInt32x16 ...) => (VPDPBUSD512 ...) -(AddDotProductQuadrupleSaturatedInt32x4 ...) => (VPDPBUSDS128 ...) -(AddDotProductQuadrupleSaturatedInt32x8 ...) => (VPDPBUSDS256 ...) -(AddDotProductQuadrupleSaturatedInt32x16 ...) => (VPDPBUSDS512 ...) (AddPairsFloat32x4 ...) => (VHADDPS128 ...) (AddPairsFloat32x8 ...) => (VHADDPS256 ...) (AddPairsFloat64x2 ...) => (VHADDPD128 ...) @@ -363,6 +357,12 @@ (DotProductPairsSaturatedUint8x16 ...) => (VPMADDUBSW128 ...) (DotProductPairsSaturatedUint8x32 ...) => (VPMADDUBSW256 ...) (DotProductPairsSaturatedUint8x64 ...) => (VPMADDUBSW512 ...) +(DotProductQuadrupleInt32x4 ...) => (VPDPBUSD128 ...) +(DotProductQuadrupleInt32x8 ...) => (VPDPBUSD256 ...) +(DotProductQuadrupleInt32x16 ...) => (VPDPBUSD512 ...) +(DotProductQuadrupleSaturatedInt32x4 ...) => (VPDPBUSDS128 ...) +(DotProductQuadrupleSaturatedInt32x8 ...) => (VPDPBUSDS256 ...) +(DotProductQuadrupleSaturatedInt32x16 ...) => (VPDPBUSDS512 ...) (EqualFloat32x4 x y) => (VCMPPS128 [0] x y) (EqualFloat32x8 x y) => (VCMPPS256 [0] x y) (EqualFloat32x16 x y) => (VPMOVMToVec32x16 (VCMPPS512 [0] x y)) @@ -1348,12 +1348,6 @@ (VMOVDQU64Masked128 (VPABSQ128 x) mask) => (VPABSQMasked128 x mask) (VMOVDQU64Masked256 (VPABSQ256 x) mask) => (VPABSQMasked256 x mask) (VMOVDQU64Masked512 (VPABSQ512 x) mask) => (VPABSQMasked512 x mask) -(VMOVDQU32Masked128 (VPDPBUSD128 x y z) mask) => (VPDPBUSDMasked128 x y z mask) -(VMOVDQU32Masked256 (VPDPBUSD256 x y z) mask) => (VPDPBUSDMasked256 x y z mask) -(VMOVDQU32Masked512 (VPDPBUSD512 x y z) mask) => (VPDPBUSDMasked512 x y z mask) -(VMOVDQU32Masked128 (VPDPBUSDS128 x y z) mask) => (VPDPBUSDSMasked128 x y z mask) -(VMOVDQU32Masked256 (VPDPBUSDS256 x y z) mask) => (VPDPBUSDSMasked256 x y z mask) -(VMOVDQU32Masked512 (VPDPBUSDS512 x y z) mask) => (VPDPBUSDSMasked512 x y z mask) (VMOVDQU32Masked128 (VADDPS128 x y) mask) => (VADDPSMasked128 x y mask) (VMOVDQU32Masked256 (VADDPS256 x y) mask) => (VADDPSMasked256 x y mask) (VMOVDQU32Masked512 (VADDPS512 x y) mask) => (VADDPSMasked512 x y mask) @@ -1540,6 +1534,12 @@ (VMOVDQU16Masked128 (VPMADDUBSW128 x y) mask) => (VPMADDUBSWMasked128 x y mask) (VMOVDQU16Masked256 (VPMADDUBSW256 x y) mask) => (VPMADDUBSWMasked256 x y mask) (VMOVDQU16Masked512 (VPMADDUBSW512 x y) mask) => (VPMADDUBSWMasked512 x y mask) +(VMOVDQU32Masked128 (VPDPBUSD128 x y z) mask) => (VPDPBUSDMasked128 x y z mask) +(VMOVDQU32Masked256 (VPDPBUSD256 x y z) mask) => (VPDPBUSDMasked256 x y z mask) +(VMOVDQU32Masked512 (VPDPBUSD512 x y z) mask) => (VPDPBUSDMasked512 x y z mask) +(VMOVDQU32Masked128 (VPDPBUSDS128 x y z) mask) => (VPDPBUSDSMasked128 x y z mask) +(VMOVDQU32Masked256 (VPDPBUSDS256 x y z) mask) => (VPDPBUSDSMasked256 x y z mask) +(VMOVDQU32Masked512 (VPDPBUSDS512 x y z) mask) => (VPDPBUSDSMasked512 x y z mask) (VMOVDQU8Masked128 (VGF2P8AFFINEINVQB128 [a] x y) mask) => (VGF2P8AFFINEINVQBMasked128 [a] x y mask) (VMOVDQU8Masked256 (VGF2P8AFFINEINVQB256 [a] x y) mask) => (VGF2P8AFFINEINVQBMasked256 [a] x y mask) (VMOVDQU8Masked512 (VGF2P8AFFINEINVQB512 [a] x y) mask) => (VGF2P8AFFINEINVQBMasked512 [a] x y mask) @@ -2358,14 +2358,6 @@ (VPDPWSSDMasked128 x y l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPWSSDMasked128load {sym} [off] x y ptr mask mem) (VPDPWSSDMasked256 x y l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPWSSDMasked256load {sym} [off] x y ptr mask mem) (VPDPWSSDMasked512 x y l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPWSSDMasked512load {sym} [off] x y ptr mask mem) -(VPDPBUSD512 x y l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSD512load {sym} [off] x y ptr mem) -(VPDPBUSDMasked128 x y l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDMasked128load {sym} [off] x y ptr mask mem) -(VPDPBUSDMasked256 x y l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDMasked256load {sym} [off] x y ptr mask mem) -(VPDPBUSDMasked512 x y l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDMasked512load {sym} [off] x y ptr mask mem) -(VPDPBUSDS512 x y l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDS512load {sym} [off] x y ptr mem) -(VPDPBUSDSMasked128 x y l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDSMasked128load {sym} [off] x y ptr mask mem) -(VPDPBUSDSMasked256 x y l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDSMasked256load {sym} [off] x y ptr mask mem) -(VPDPBUSDSMasked512 x y l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDSMasked512load {sym} [off] x y ptr mask mem) (VADDPSMasked128 x l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VADDPSMasked128load {sym} [off] x ptr mask mem) (VADDPSMasked256 x l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VADDPSMasked256load {sym} [off] x ptr mask mem) (VADDPSMasked512 x l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VADDPSMasked512load {sym} [off] x ptr mask mem) @@ -2444,6 +2436,14 @@ (VDIVPDMasked128 x l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VDIVPDMasked128load {sym} [off] x ptr mask mem) (VDIVPDMasked256 x l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VDIVPDMasked256load {sym} [off] x ptr mask mem) (VDIVPDMasked512 x l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VDIVPDMasked512load {sym} [off] x ptr mask mem) +(VPDPBUSD512 x y l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSD512load {sym} [off] x y ptr mem) +(VPDPBUSDMasked128 x y l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDMasked128load {sym} [off] x y ptr mask mem) +(VPDPBUSDMasked256 x y l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDMasked256load {sym} [off] x y ptr mask mem) +(VPDPBUSDMasked512 x y l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDMasked512load {sym} [off] x y ptr mask mem) +(VPDPBUSDS512 x y l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDS512load {sym} [off] x y ptr mem) +(VPDPBUSDSMasked128 x y l:(VMOVDQUload128 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDSMasked128load {sym} [off] x y ptr mask mem) +(VPDPBUSDSMasked256 x y l:(VMOVDQUload256 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDSMasked256load {sym} [off] x y ptr mask mem) +(VPDPBUSDSMasked512 x y l:(VMOVDQUload512 {sym} [off] ptr mem) mask) && canMergeLoad(v, l) && clobber(l) => (VPDPBUSDSMasked512load {sym} [off] x y ptr mask mem) (VPCMPEQD512 x l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VPCMPEQD512load {sym} [off] x ptr mem) (VPCMPEQQ512 x l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VPCMPEQQ512load {sym} [off] x ptr mem) (VCMPPS512 [c] x l:(VMOVDQUload512 {sym} [off] ptr mem)) && canMergeLoad(v, l) && clobber(l) => (VCMPPS512load {sym} [makeValAndOff(int32(int8(c)),off)] x ptr mem) diff --git a/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go b/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go index dca366f0f9..6a79fa3856 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go +++ b/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go @@ -29,12 +29,6 @@ func simdGenericOps() []opData { {name: "AbsInt64x2", argLength: 1, commutative: false}, {name: "AbsInt64x4", argLength: 1, commutative: false}, {name: "AbsInt64x8", argLength: 1, commutative: false}, - {name: "AddDotProductQuadrupleInt32x4", argLength: 3, commutative: false}, - {name: "AddDotProductQuadrupleInt32x8", argLength: 3, commutative: false}, - {name: "AddDotProductQuadrupleInt32x16", argLength: 3, commutative: false}, - {name: "AddDotProductQuadrupleSaturatedInt32x4", argLength: 3, commutative: false}, - {name: "AddDotProductQuadrupleSaturatedInt32x8", argLength: 3, commutative: false}, - {name: "AddDotProductQuadrupleSaturatedInt32x16", argLength: 3, commutative: false}, {name: "AddFloat32x4", argLength: 2, commutative: true}, {name: "AddFloat32x8", argLength: 2, commutative: true}, {name: "AddFloat32x16", argLength: 2, commutative: true}, @@ -351,6 +345,12 @@ func simdGenericOps() []opData { {name: "DotProductPairsSaturatedUint8x16", argLength: 2, commutative: false}, {name: "DotProductPairsSaturatedUint8x32", argLength: 2, commutative: false}, {name: "DotProductPairsSaturatedUint8x64", argLength: 2, commutative: false}, + {name: "DotProductQuadrupleInt32x4", argLength: 3, commutative: false}, + {name: "DotProductQuadrupleInt32x8", argLength: 3, commutative: false}, + {name: "DotProductQuadrupleInt32x16", argLength: 3, commutative: false}, + {name: "DotProductQuadrupleSaturatedInt32x4", argLength: 3, commutative: false}, + {name: "DotProductQuadrupleSaturatedInt32x8", argLength: 3, commutative: false}, + {name: "DotProductQuadrupleSaturatedInt32x16", argLength: 3, commutative: false}, {name: "EqualFloat32x4", argLength: 2, commutative: true}, {name: "EqualFloat32x8", argLength: 2, commutative: true}, {name: "EqualFloat32x16", argLength: 2, commutative: true}, diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index d0482743d1..9c5d79fa56 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -5977,12 +5977,6 @@ const ( OpAbsInt64x2 OpAbsInt64x4 OpAbsInt64x8 - OpAddDotProductQuadrupleInt32x4 - OpAddDotProductQuadrupleInt32x8 - OpAddDotProductQuadrupleInt32x16 - OpAddDotProductQuadrupleSaturatedInt32x4 - OpAddDotProductQuadrupleSaturatedInt32x8 - OpAddDotProductQuadrupleSaturatedInt32x16 OpAddFloat32x4 OpAddFloat32x8 OpAddFloat32x16 @@ -6299,6 +6293,12 @@ const ( OpDotProductPairsSaturatedUint8x16 OpDotProductPairsSaturatedUint8x32 OpDotProductPairsSaturatedUint8x64 + OpDotProductQuadrupleInt32x4 + OpDotProductQuadrupleInt32x8 + OpDotProductQuadrupleInt32x16 + OpDotProductQuadrupleSaturatedInt32x4 + OpDotProductQuadrupleSaturatedInt32x8 + OpDotProductQuadrupleSaturatedInt32x16 OpEqualFloat32x4 OpEqualFloat32x8 OpEqualFloat32x16 @@ -85911,36 +85911,6 @@ var opcodeTable = [...]opInfo{ argLen: 1, generic: true, }, - { - name: "AddDotProductQuadrupleInt32x4", - argLen: 3, - generic: true, - }, - { - name: "AddDotProductQuadrupleInt32x8", - argLen: 3, - generic: true, - }, - { - name: "AddDotProductQuadrupleInt32x16", - argLen: 3, - generic: true, - }, - { - name: "AddDotProductQuadrupleSaturatedInt32x4", - argLen: 3, - generic: true, - }, - { - name: "AddDotProductQuadrupleSaturatedInt32x8", - argLen: 3, - generic: true, - }, - { - name: "AddDotProductQuadrupleSaturatedInt32x16", - argLen: 3, - generic: true, - }, { name: "AddFloat32x4", argLen: 2, @@ -87593,6 +87563,36 @@ var opcodeTable = [...]opInfo{ argLen: 2, generic: true, }, + { + name: "DotProductQuadrupleInt32x4", + argLen: 3, + generic: true, + }, + { + name: "DotProductQuadrupleInt32x8", + argLen: 3, + generic: true, + }, + { + name: "DotProductQuadrupleInt32x16", + argLen: 3, + generic: true, + }, + { + name: "DotProductQuadrupleSaturatedInt32x4", + argLen: 3, + generic: true, + }, + { + name: "DotProductQuadrupleSaturatedInt32x8", + argLen: 3, + generic: true, + }, + { + name: "DotProductQuadrupleSaturatedInt32x16", + argLen: 3, + generic: true, + }, { name: "EqualFloat32x4", argLen: 2, diff --git a/src/cmd/compile/internal/ssa/rewriteAMD64.go b/src/cmd/compile/internal/ssa/rewriteAMD64.go index 5f564000d9..76e524d524 100644 --- a/src/cmd/compile/internal/ssa/rewriteAMD64.go +++ b/src/cmd/compile/internal/ssa/rewriteAMD64.go @@ -850,6 +850,10 @@ func rewriteValueAMD64(v *Value) bool { return rewriteValueAMD64_OpAMD64VPACKUSDWMasked256(v) case OpAMD64VPACKUSDWMasked512: return rewriteValueAMD64_OpAMD64VPACKUSDWMasked512(v) + case OpAMD64VPADDD128: + return rewriteValueAMD64_OpAMD64VPADDD128(v) + case OpAMD64VPADDD256: + return rewriteValueAMD64_OpAMD64VPADDD256(v) case OpAMD64VPADDD512: return rewriteValueAMD64_OpAMD64VPADDD512(v) case OpAMD64VPADDDMasked128: @@ -1916,24 +1920,6 @@ func rewriteValueAMD64(v *Value) bool { case OpAdd8: v.Op = OpAMD64ADDL return true - case OpAddDotProductQuadrupleInt32x16: - v.Op = OpAMD64VPDPBUSD512 - return true - case OpAddDotProductQuadrupleInt32x4: - v.Op = OpAMD64VPDPBUSD128 - return true - case OpAddDotProductQuadrupleInt32x8: - v.Op = OpAMD64VPDPBUSD256 - return true - case OpAddDotProductQuadrupleSaturatedInt32x16: - v.Op = OpAMD64VPDPBUSDS512 - return true - case OpAddDotProductQuadrupleSaturatedInt32x4: - v.Op = OpAMD64VPDPBUSDS128 - return true - case OpAddDotProductQuadrupleSaturatedInt32x8: - v.Op = OpAMD64VPDPBUSDS256 - return true case OpAddFloat32x16: v.Op = OpAMD64VADDPS512 return true @@ -3123,6 +3109,24 @@ func rewriteValueAMD64(v *Value) bool { case OpDotProductPairsSaturatedUint8x64: v.Op = OpAMD64VPMADDUBSW512 return true + case OpDotProductQuadrupleInt32x16: + v.Op = OpAMD64VPDPBUSD512 + return true + case OpDotProductQuadrupleInt32x4: + v.Op = OpAMD64VPDPBUSD128 + return true + case OpDotProductQuadrupleInt32x8: + v.Op = OpAMD64VPDPBUSD256 + return true + case OpDotProductQuadrupleSaturatedInt32x16: + v.Op = OpAMD64VPDPBUSDS512 + return true + case OpDotProductQuadrupleSaturatedInt32x4: + v.Op = OpAMD64VPDPBUSDS128 + return true + case OpDotProductQuadrupleSaturatedInt32x8: + v.Op = OpAMD64VPDPBUSDS256 + return true case OpEq16: return rewriteValueAMD64_OpEq16(v) case OpEq32: @@ -32793,34 +32797,6 @@ func rewriteValueAMD64_OpAMD64VMOVDQU32Masked128(v *Value) bool { v.AddArg2(x, mask) return true } - // match: (VMOVDQU32Masked128 (VPDPBUSD128 x y z) mask) - // result: (VPDPBUSDMasked128 x y z mask) - for { - if v_0.Op != OpAMD64VPDPBUSD128 { - break - } - z := v_0.Args[2] - x := v_0.Args[0] - y := v_0.Args[1] - mask := v_1 - v.reset(OpAMD64VPDPBUSDMasked128) - v.AddArg4(x, y, z, mask) - return true - } - // match: (VMOVDQU32Masked128 (VPDPBUSDS128 x y z) mask) - // result: (VPDPBUSDSMasked128 x y z mask) - for { - if v_0.Op != OpAMD64VPDPBUSDS128 { - break - } - z := v_0.Args[2] - x := v_0.Args[0] - y := v_0.Args[1] - mask := v_1 - v.reset(OpAMD64VPDPBUSDSMasked128) - v.AddArg4(x, y, z, mask) - return true - } // match: (VMOVDQU32Masked128 (VADDPS128 x y) mask) // result: (VADDPSMasked128 x y mask) for { @@ -33058,6 +33034,34 @@ func rewriteValueAMD64_OpAMD64VMOVDQU32Masked128(v *Value) bool { v.AddArg3(x, y, mask) return true } + // match: (VMOVDQU32Masked128 (VPDPBUSD128 x y z) mask) + // result: (VPDPBUSDMasked128 x y z mask) + for { + if v_0.Op != OpAMD64VPDPBUSD128 { + break + } + z := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + mask := v_1 + v.reset(OpAMD64VPDPBUSDMasked128) + v.AddArg4(x, y, z, mask) + return true + } + // match: (VMOVDQU32Masked128 (VPDPBUSDS128 x y z) mask) + // result: (VPDPBUSDSMasked128 x y z mask) + for { + if v_0.Op != OpAMD64VPDPBUSDS128 { + break + } + z := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + mask := v_1 + v.reset(OpAMD64VPDPBUSDSMasked128) + v.AddArg4(x, y, z, mask) + return true + } // match: (VMOVDQU32Masked128 (VPLZCNTD128 x) mask) // result: (VPLZCNTDMasked128 x mask) for { @@ -33556,34 +33560,6 @@ func rewriteValueAMD64_OpAMD64VMOVDQU32Masked256(v *Value) bool { v.AddArg2(x, mask) return true } - // match: (VMOVDQU32Masked256 (VPDPBUSD256 x y z) mask) - // result: (VPDPBUSDMasked256 x y z mask) - for { - if v_0.Op != OpAMD64VPDPBUSD256 { - break - } - z := v_0.Args[2] - x := v_0.Args[0] - y := v_0.Args[1] - mask := v_1 - v.reset(OpAMD64VPDPBUSDMasked256) - v.AddArg4(x, y, z, mask) - return true - } - // match: (VMOVDQU32Masked256 (VPDPBUSDS256 x y z) mask) - // result: (VPDPBUSDSMasked256 x y z mask) - for { - if v_0.Op != OpAMD64VPDPBUSDS256 { - break - } - z := v_0.Args[2] - x := v_0.Args[0] - y := v_0.Args[1] - mask := v_1 - v.reset(OpAMD64VPDPBUSDSMasked256) - v.AddArg4(x, y, z, mask) - return true - } // match: (VMOVDQU32Masked256 (VADDPS256 x y) mask) // result: (VADDPSMasked256 x y mask) for { @@ -33857,6 +33833,34 @@ func rewriteValueAMD64_OpAMD64VMOVDQU32Masked256(v *Value) bool { v.AddArg3(x, y, mask) return true } + // match: (VMOVDQU32Masked256 (VPDPBUSD256 x y z) mask) + // result: (VPDPBUSDMasked256 x y z mask) + for { + if v_0.Op != OpAMD64VPDPBUSD256 { + break + } + z := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + mask := v_1 + v.reset(OpAMD64VPDPBUSDMasked256) + v.AddArg4(x, y, z, mask) + return true + } + // match: (VMOVDQU32Masked256 (VPDPBUSDS256 x y z) mask) + // result: (VPDPBUSDSMasked256 x y z mask) + for { + if v_0.Op != OpAMD64VPDPBUSDS256 { + break + } + z := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + mask := v_1 + v.reset(OpAMD64VPDPBUSDSMasked256) + v.AddArg4(x, y, z, mask) + return true + } // match: (VMOVDQU32Masked256 (VPLZCNTD256 x) mask) // result: (VPLZCNTDMasked256 x mask) for { @@ -34381,34 +34385,6 @@ func rewriteValueAMD64_OpAMD64VMOVDQU32Masked512(v *Value) bool { v.AddArg2(x, mask) return true } - // match: (VMOVDQU32Masked512 (VPDPBUSD512 x y z) mask) - // result: (VPDPBUSDMasked512 x y z mask) - for { - if v_0.Op != OpAMD64VPDPBUSD512 { - break - } - z := v_0.Args[2] - x := v_0.Args[0] - y := v_0.Args[1] - mask := v_1 - v.reset(OpAMD64VPDPBUSDMasked512) - v.AddArg4(x, y, z, mask) - return true - } - // match: (VMOVDQU32Masked512 (VPDPBUSDS512 x y z) mask) - // result: (VPDPBUSDSMasked512 x y z mask) - for { - if v_0.Op != OpAMD64VPDPBUSDS512 { - break - } - z := v_0.Args[2] - x := v_0.Args[0] - y := v_0.Args[1] - mask := v_1 - v.reset(OpAMD64VPDPBUSDSMasked512) - v.AddArg4(x, y, z, mask) - return true - } // match: (VMOVDQU32Masked512 (VADDPS512 x y) mask) // result: (VADDPSMasked512 x y mask) for { @@ -34636,6 +34612,34 @@ func rewriteValueAMD64_OpAMD64VMOVDQU32Masked512(v *Value) bool { v.AddArg3(x, y, mask) return true } + // match: (VMOVDQU32Masked512 (VPDPBUSD512 x y z) mask) + // result: (VPDPBUSDMasked512 x y z mask) + for { + if v_0.Op != OpAMD64VPDPBUSD512 { + break + } + z := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + mask := v_1 + v.reset(OpAMD64VPDPBUSDMasked512) + v.AddArg4(x, y, z, mask) + return true + } + // match: (VMOVDQU32Masked512 (VPDPBUSDS512 x y z) mask) + // result: (VPDPBUSDSMasked512 x y z mask) + for { + if v_0.Op != OpAMD64VPDPBUSDS512 { + break + } + z := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + mask := v_1 + v.reset(OpAMD64VPDPBUSDSMasked512) + v.AddArg4(x, y, z, mask) + return true + } // match: (VMOVDQU32Masked512 (VPLZCNTD512 x) mask) // result: (VPLZCNTDMasked512 x mask) for { @@ -39616,9 +39620,151 @@ func rewriteValueAMD64_OpAMD64VPACKUSDWMasked512(v *Value) bool { } return false } +func rewriteValueAMD64_OpAMD64VPADDD128(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (VPADDD128 (VPDPBUSD128 (Zero128 ) x y) z) + // result: (VPDPBUSD128 z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64VPDPBUSD128 { + continue + } + y := v_0.Args[2] + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpAMD64Zero128 { + continue + } + t := v_0_0.Type + x := v_0.Args[1] + z := v_1 + v.reset(OpAMD64VPDPBUSD128) + v.Type = t + v.AddArg3(z, x, y) + return true + } + break + } + // match: (VPADDD128 (VPDPBUSDS128 (Zero128 ) x y) z) + // result: (VPDPBUSDS128 z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64VPDPBUSDS128 { + continue + } + y := v_0.Args[2] + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpAMD64Zero128 { + continue + } + t := v_0_0.Type + x := v_0.Args[1] + z := v_1 + v.reset(OpAMD64VPDPBUSDS128) + v.Type = t + v.AddArg3(z, x, y) + return true + } + break + } + return false +} +func rewriteValueAMD64_OpAMD64VPADDD256(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (VPADDD256 (VPDPBUSD256 (Zero256 ) x y) z) + // result: (VPDPBUSD256 z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64VPDPBUSD256 { + continue + } + y := v_0.Args[2] + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpAMD64Zero256 { + continue + } + t := v_0_0.Type + x := v_0.Args[1] + z := v_1 + v.reset(OpAMD64VPDPBUSD256) + v.Type = t + v.AddArg3(z, x, y) + return true + } + break + } + // match: (VPADDD256 (VPDPBUSDS256 (Zero256 ) x y) z) + // result: (VPDPBUSDS256 z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64VPDPBUSDS256 { + continue + } + y := v_0.Args[2] + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpAMD64Zero256 { + continue + } + t := v_0_0.Type + x := v_0.Args[1] + z := v_1 + v.reset(OpAMD64VPDPBUSDS256) + v.Type = t + v.AddArg3(z, x, y) + return true + } + break + } + return false +} func rewriteValueAMD64_OpAMD64VPADDD512(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + // match: (VPADDD512 (VPDPBUSD512 (Zero512 ) x y) z) + // result: (VPDPBUSD512 z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64VPDPBUSD512 { + continue + } + y := v_0.Args[2] + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpAMD64Zero512 { + continue + } + t := v_0_0.Type + x := v_0.Args[1] + z := v_1 + v.reset(OpAMD64VPDPBUSD512) + v.Type = t + v.AddArg3(z, x, y) + return true + } + break + } + // match: (VPADDD512 (VPDPBUSDS512 (Zero512 ) x y) z) + // result: (VPDPBUSDS512 z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64VPDPBUSDS512 { + continue + } + y := v_0.Args[2] + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpAMD64Zero512 { + continue + } + t := v_0_0.Type + x := v_0.Args[1] + z := v_1 + v.reset(OpAMD64VPDPBUSDS512) + v.Type = t + v.AddArg3(z, x, y) + return true + } + break + } // match: (VPADDD512 x l:(VMOVDQUload512 {sym} [off] ptr mem)) // cond: canMergeLoad(v, l) && clobber(l) // result: (VPADDD512load {sym} [off] x ptr mem) diff --git a/src/cmd/compile/internal/ssagen/intrinsics.go b/src/cmd/compile/internal/ssagen/intrinsics.go index a20529258a..e346b00a1b 100644 --- a/src/cmd/compile/internal/ssagen/intrinsics.go +++ b/src/cmd/compile/internal/ssagen/intrinsics.go @@ -1869,9 +1869,19 @@ func opLen3(op ssa.Op, t *types.Type) func(s *state, n *ir.CallExpr, args []*ssa } } -func opLen3_31(op ssa.Op, t *types.Type) func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { +var ssaVecBySize = map[int64]*types.Type{ + 16: types.TypeVec128, + 32: types.TypeVec256, + 64: types.TypeVec512, +} + +func opLen3_31Zero3(op ssa.Op, t *types.Type) func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { - return s.newValue3(op, t, args[2], args[1], args[0]) + if t, ok := ssaVecBySize[args[1].Type.Size()]; !ok { + panic("unknown simd vector size") + } else { + return s.newValue3(op, t, s.newValue0(ssa.OpZeroSIMD, t), args[1], args[0]) + } } } diff --git a/src/cmd/compile/internal/ssagen/simdintrinsics.go b/src/cmd/compile/internal/ssagen/simdintrinsics.go index 492f581781..818b3544ae 100644 --- a/src/cmd/compile/internal/ssagen/simdintrinsics.go +++ b/src/cmd/compile/internal/ssagen/simdintrinsics.go @@ -68,12 +68,6 @@ func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies . addF(simdPackage, "Uint64x2.Add", opLen2(ssa.OpAddUint64x2, types.TypeVec128), sys.AMD64) addF(simdPackage, "Uint64x4.Add", opLen2(ssa.OpAddUint64x4, types.TypeVec256), sys.AMD64) addF(simdPackage, "Uint64x8.Add", opLen2(ssa.OpAddUint64x8, types.TypeVec512), sys.AMD64) - addF(simdPackage, "Int8x16.AddDotProductQuadruple", opLen3_31(ssa.OpAddDotProductQuadrupleInt32x4, types.TypeVec128), sys.AMD64) - addF(simdPackage, "Int8x32.AddDotProductQuadruple", opLen3_31(ssa.OpAddDotProductQuadrupleInt32x8, types.TypeVec256), sys.AMD64) - addF(simdPackage, "Int8x64.AddDotProductQuadruple", opLen3_31(ssa.OpAddDotProductQuadrupleInt32x16, types.TypeVec512), sys.AMD64) - addF(simdPackage, "Int8x16.AddDotProductQuadrupleSaturated", opLen3_31(ssa.OpAddDotProductQuadrupleSaturatedInt32x4, types.TypeVec128), sys.AMD64) - addF(simdPackage, "Int8x32.AddDotProductQuadrupleSaturated", opLen3_31(ssa.OpAddDotProductQuadrupleSaturatedInt32x8, types.TypeVec256), sys.AMD64) - addF(simdPackage, "Int8x64.AddDotProductQuadrupleSaturated", opLen3_31(ssa.OpAddDotProductQuadrupleSaturatedInt32x16, types.TypeVec512), sys.AMD64) addF(simdPackage, "Float32x4.AddPairs", opLen2(ssa.OpAddPairsFloat32x4, types.TypeVec128), sys.AMD64) addF(simdPackage, "Float32x8.AddPairs", opLen2(ssa.OpAddPairsFloat32x8, types.TypeVec256), sys.AMD64) addF(simdPackage, "Float64x2.AddPairs", opLen2(ssa.OpAddPairsFloat64x2, types.TypeVec128), sys.AMD64) @@ -375,6 +369,12 @@ func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies . addF(simdPackage, "Uint8x16.DotProductPairsSaturated", opLen2(ssa.OpDotProductPairsSaturatedUint8x16, types.TypeVec128), sys.AMD64) addF(simdPackage, "Uint8x32.DotProductPairsSaturated", opLen2(ssa.OpDotProductPairsSaturatedUint8x32, types.TypeVec256), sys.AMD64) addF(simdPackage, "Uint8x64.DotProductPairsSaturated", opLen2(ssa.OpDotProductPairsSaturatedUint8x64, types.TypeVec512), sys.AMD64) + addF(simdPackage, "Int8x16.DotProductQuadruple", opLen3_31Zero3(ssa.OpDotProductQuadrupleInt32x4, types.TypeVec128), sys.AMD64) + addF(simdPackage, "Int8x32.DotProductQuadruple", opLen3_31Zero3(ssa.OpDotProductQuadrupleInt32x8, types.TypeVec256), sys.AMD64) + addF(simdPackage, "Int8x64.DotProductQuadruple", opLen3_31Zero3(ssa.OpDotProductQuadrupleInt32x16, types.TypeVec512), sys.AMD64) + addF(simdPackage, "Int8x16.DotProductQuadrupleSaturated", opLen3_31Zero3(ssa.OpDotProductQuadrupleSaturatedInt32x4, types.TypeVec128), sys.AMD64) + addF(simdPackage, "Int8x32.DotProductQuadrupleSaturated", opLen3_31Zero3(ssa.OpDotProductQuadrupleSaturatedInt32x8, types.TypeVec256), sys.AMD64) + addF(simdPackage, "Int8x64.DotProductQuadrupleSaturated", opLen3_31Zero3(ssa.OpDotProductQuadrupleSaturatedInt32x16, types.TypeVec512), sys.AMD64) addF(simdPackage, "Int8x16.Equal", opLen2(ssa.OpEqualInt8x16, types.TypeVec128), sys.AMD64) addF(simdPackage, "Int8x32.Equal", opLen2(ssa.OpEqualInt8x32, types.TypeVec256), sys.AMD64) addF(simdPackage, "Int8x64.Equal", opLen2(ssa.OpEqualInt8x64, types.TypeVec512), sys.AMD64) diff --git a/src/simd/_gen/simdgen/gen_simdIntrinsics.go b/src/simd/_gen/simdgen/gen_simdIntrinsics.go index 8827ce07c1..b963fb9abb 100644 --- a/src/simd/_gen/simdgen/gen_simdIntrinsics.go +++ b/src/simd/_gen/simdgen/gen_simdIntrinsics.go @@ -42,7 +42,7 @@ func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies . {{end}} {{define "op3_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64) {{end}} -{{define "op3_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen3_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64) +{{define "op3_31Zero3"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen3_31Zero3(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64) {{end}} {{define "op4"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen4(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64) {{end}} diff --git a/src/simd/_gen/simdgen/gen_simdTypes.go b/src/simd/_gen/simdgen/gen_simdTypes.go index b33c51b1ab..23b363d38a 100644 --- a/src/simd/_gen/simdgen/gen_simdTypes.go +++ b/src/simd/_gen/simdgen/gen_simdTypes.go @@ -257,11 +257,11 @@ func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}} func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}} {{end}} -{{define "op3_31"}} +{{define "op3_31Zero3"}} {{if .Documentation}}{{.Documentation}} //{{end}} // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}} -func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}} +func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}) {{.GoType}} {{end}} {{define "op3_21"}} diff --git a/src/simd/_gen/simdgen/ops/MLOps/categories.yaml b/src/simd/_gen/simdgen/ops/MLOps/categories.yaml index 0317b42c6a..2b1da7adaf 100644 --- a/src/simd/_gen/simdgen/ops/MLOps/categories.yaml +++ b/src/simd/_gen/simdgen/ops/MLOps/categories.yaml @@ -15,14 +15,16 @@ # commutative: true # # documentation: !string |- # // NAME multiplies all elements and broadcasts the sum. -- go: AddDotProductQuadruple +- go: DotProductQuadruple commutative: false documentation: !string |- - // NAME performs dot products on groups of 4 elements of x and y and then adds z. -- go: AddDotProductQuadrupleSaturated + // NAME performs dot products on groups of 4 elements of x and y. + // NAME(x, y).Add(z) will be optimized to the full form of the underlying instruction. +- go: DotProductQuadrupleSaturated commutative: false documentation: !string |- - // NAME multiplies performs dot products on groups of 4 elements of x and y and then adds z. + // NAME multiplies performs dot products on groups of 4 elements of x and y. + // NAME(x, y).Add(z) will be optimized to the full form of the underlying instruction. - go: AddDotProductPairs commutative: false noTypes: "true" diff --git a/src/simd/_gen/simdgen/ops/MLOps/go.yaml b/src/simd/_gen/simdgen/ops/MLOps/go.yaml index 162c47ea0e..4a1195b52d 100644 --- a/src/simd/_gen/simdgen/ops/MLOps/go.yaml +++ b/src/simd/_gen/simdgen/ops/MLOps/go.yaml @@ -33,9 +33,9 @@ # const: 127 # out: # - *dpb_src -- go: AddDotProductQuadruple +- go: DotProductQuadruple asm: "VPDPBUSD" - operandOrder: "31" # switch operand 3 and 1 + operandOrder: "31Zero3" # switch operand 3 and 1, and make 3 always 0 in: - &qdpa_acc go: $t_acc @@ -51,9 +51,9 @@ overwriteElementBits: 8 out: - *qdpa_acc -- go: AddDotProductQuadrupleSaturated +- go: DotProductQuadrupleSaturated asm: "VPDPBUSDS" - operandOrder: "31" # switch operand 3 and 1 + operandOrder: "31Zero3" # switch operand 3 and 1, and make 3 always 0 in: - *qdpa_acc - *qdpa_src1 diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go index a15925dbfa..2d7793ef05 100644 --- a/src/simd/internal/simd_test/simd_test.go +++ b/src/simd/internal/simd_test/simd_test.go @@ -1127,3 +1127,37 @@ func TestMaskedMerge(t *testing.T) { } } } + +func TestDotProductQuadruple(t *testing.T) { + if !simd.X86.AVXVNNI() { + t.Skip("Test requires X86.AVXVNNI, not available on this hardware") + return + } + xd := make([]int8, 16) + yd := make([]uint8, 16) + zd := make([]int32, 4) + wanted1 := make([]int32, 4) + wanted2 := make([]int32, 4) + res1 := make([]int32, 4) + res2 := make([]int32, 4) + for i := range 4 { + xd[i] = 5 + yd[i] = 6 + zd[i] = 3 + wanted1[i] = 30 + wanted2[i] = 30 + } + x := simd.LoadInt8x16Slice(xd) + y := simd.LoadUint8x16Slice(yd) + z := simd.LoadInt32x4Slice(zd) + x.DotProductQuadruple(y).StoreSlice(res1) + x.DotProductQuadruple(y).Add(z).StoreSlice(res1) + for i := range 4 { + if res1[i] != wanted1[i] { + t.Errorf("got %d wanted %d", res1[i], wanted1[i]) + } + if res2[i] != wanted2[i] { + t.Errorf("got %d wanted %d", res2[i], wanted2[i]) + } + } +} diff --git a/src/simd/ops_amd64.go b/src/simd/ops_amd64.go index 29c9387d78..e06d1f652e 100644 --- a/src/simd/ops_amd64.go +++ b/src/simd/ops_amd64.go @@ -346,40 +346,6 @@ func (x Uint64x4) Add(y Uint64x4) Uint64x4 // Asm: VPADDQ, CPU Feature: AVX512 func (x Uint64x8) Add(y Uint64x8) Uint64x8 -/* AddDotProductQuadruple */ - -// AddDotProductQuadruple performs dot products on groups of 4 elements of x and y and then adds z. -// -// Asm: VPDPBUSD, CPU Feature: AVXVNNI -func (x Int8x16) AddDotProductQuadruple(y Uint8x16, z Int32x4) Int32x4 - -// AddDotProductQuadruple performs dot products on groups of 4 elements of x and y and then adds z. -// -// Asm: VPDPBUSD, CPU Feature: AVXVNNI -func (x Int8x32) AddDotProductQuadruple(y Uint8x32, z Int32x8) Int32x8 - -// AddDotProductQuadruple performs dot products on groups of 4 elements of x and y and then adds z. -// -// Asm: VPDPBUSD, CPU Feature: AVX512VNNI -func (x Int8x64) AddDotProductQuadruple(y Uint8x64, z Int32x16) Int32x16 - -/* AddDotProductQuadrupleSaturated */ - -// AddDotProductQuadrupleSaturated multiplies performs dot products on groups of 4 elements of x and y and then adds z. -// -// Asm: VPDPBUSDS, CPU Feature: AVXVNNI -func (x Int8x16) AddDotProductQuadrupleSaturated(y Uint8x16, z Int32x4) Int32x4 - -// AddDotProductQuadrupleSaturated multiplies performs dot products on groups of 4 elements of x and y and then adds z. -// -// Asm: VPDPBUSDS, CPU Feature: AVXVNNI -func (x Int8x32) AddDotProductQuadrupleSaturated(y Uint8x32, z Int32x8) Int32x8 - -// AddDotProductQuadrupleSaturated multiplies performs dot products on groups of 4 elements of x and y and then adds z. -// -// Asm: VPDPBUSDS, CPU Feature: AVX512VNNI -func (x Int8x64) AddDotProductQuadrupleSaturated(y Uint8x64, z Int32x16) Int32x16 - /* AddPairs */ // AddPairs horizontally adds adjacent pairs of elements. @@ -2228,6 +2194,46 @@ func (x Uint8x32) DotProductPairsSaturated(y Int8x32) Int16x16 // Asm: VPMADDUBSW, CPU Feature: AVX512 func (x Uint8x64) DotProductPairsSaturated(y Int8x64) Int16x32 +/* DotProductQuadruple */ + +// DotProductQuadruple performs dot products on groups of 4 elements of x and y. +// DotProductQuadruple(x, y).Add(z) will be optimized to the full form of the underlying instruction. +// +// Asm: VPDPBUSD, CPU Feature: AVXVNNI +func (x Int8x16) DotProductQuadruple(y Uint8x16) Int32x4 + +// DotProductQuadruple performs dot products on groups of 4 elements of x and y. +// DotProductQuadruple(x, y).Add(z) will be optimized to the full form of the underlying instruction. +// +// Asm: VPDPBUSD, CPU Feature: AVXVNNI +func (x Int8x32) DotProductQuadruple(y Uint8x32) Int32x8 + +// DotProductQuadruple performs dot products on groups of 4 elements of x and y. +// DotProductQuadruple(x, y).Add(z) will be optimized to the full form of the underlying instruction. +// +// Asm: VPDPBUSD, CPU Feature: AVX512VNNI +func (x Int8x64) DotProductQuadruple(y Uint8x64) Int32x16 + +/* DotProductQuadrupleSaturated */ + +// DotProductQuadrupleSaturated multiplies performs dot products on groups of 4 elements of x and y. +// DotProductQuadrupleSaturated(x, y).Add(z) will be optimized to the full form of the underlying instruction. +// +// Asm: VPDPBUSDS, CPU Feature: AVXVNNI +func (x Int8x16) DotProductQuadrupleSaturated(y Uint8x16) Int32x4 + +// DotProductQuadrupleSaturated multiplies performs dot products on groups of 4 elements of x and y. +// DotProductQuadrupleSaturated(x, y).Add(z) will be optimized to the full form of the underlying instruction. +// +// Asm: VPDPBUSDS, CPU Feature: AVXVNNI +func (x Int8x32) DotProductQuadrupleSaturated(y Uint8x32) Int32x8 + +// DotProductQuadrupleSaturated multiplies performs dot products on groups of 4 elements of x and y. +// DotProductQuadrupleSaturated(x, y).Add(z) will be optimized to the full form of the underlying instruction. +// +// Asm: VPDPBUSDS, CPU Feature: AVX512VNNI +func (x Int8x64) DotProductQuadrupleSaturated(y Uint8x64) Int32x16 + /* Equal */ // Equal compares for equality. -- 2.52.0