From 58cfc2a5f63f60d092844034adcfa589fb878e02 Mon Sep 17 00:00:00 2001 From: Junyang Shao Date: Wed, 20 Aug 2025 18:42:52 +0000 Subject: [PATCH] [dev.simd] cmd/compile, simd: add VPSADBW This new API is given the name SumAbsDiff, a slightly-longer name for its canonical abbreviation SAD(Sum-Absolute-Differences). This instruction has some similar semantic's one, but their semantic is much more specific and complex: MPSADBW, VDBPSADBW. They should have a more specific name given this fact. Change-Id: Ied9144440f82919c3c2d45ae4ce5b961ae91a020 Reviewed-on: https://go-review.googlesource.com/c/go/+/697776 Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI --- src/cmd/compile/internal/amd64/simdssa.go | 3 + .../compile/internal/ssa/_gen/simdAMD64.rules | 3 + .../compile/internal/ssa/_gen/simdAMD64ops.go | 3 + .../internal/ssa/_gen/simdgenericOps.go | 3 + src/cmd/compile/internal/ssa/opGen.go | 63 +++++++++++++++++++ src/cmd/compile/internal/ssa/rewriteAMD64.go | 9 +++ .../compile/internal/ssagen/simdintrinsics.go | 3 + .../_gen/simdgen/ops/MLOps/categories.yaml | 6 ++ src/simd/_gen/simdgen/ops/MLOps/go.yaml | 12 +++- src/simd/ops_amd64.go | 23 +++++++ 10 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/cmd/compile/internal/amd64/simdssa.go b/src/cmd/compile/internal/amd64/simdssa.go index 03617d4a5d..5fc85457cf 100644 --- a/src/cmd/compile/internal/amd64/simdssa.go +++ b/src/cmd/compile/internal/amd64/simdssa.go @@ -368,6 +368,9 @@ func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { ssa.OpAMD64VPSUBUSW128, ssa.OpAMD64VPSUBUSW256, ssa.OpAMD64VPSUBUSW512, + ssa.OpAMD64VPSADBW128, + ssa.OpAMD64VPSADBW256, + ssa.OpAMD64VPSADBW512, ssa.OpAMD64VPXOR128, ssa.OpAMD64VPXOR256, ssa.OpAMD64VPXORD512, diff --git a/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules b/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules index d5be221c0e..d7bab7b050 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules +++ b/src/cmd/compile/internal/ssa/_gen/simdAMD64.rules @@ -1048,6 +1048,9 @@ (SubSaturatedUint16x8 ...) => (VPSUBUSW128 ...) (SubSaturatedUint16x16 ...) => (VPSUBUSW256 ...) (SubSaturatedUint16x32 ...) => (VPSUBUSW512 ...) +(SumAbsDiffUint8x16 ...) => (VPSADBW128 ...) +(SumAbsDiffUint8x32 ...) => (VPSADBW256 ...) +(SumAbsDiffUint8x64 ...) => (VPSADBW512 ...) (TruncFloat32x4 x) => (VROUNDPS128 [3] x) (TruncFloat32x8 x) => (VROUNDPS256 [3] x) (TruncFloat64x2 x) => (VROUNDPD128 [3] x) diff --git a/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go b/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go index 171ae59e32..7782b43cf5 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go +++ b/src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go @@ -652,6 +652,9 @@ func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vf {name: "VPRORVQMasked128", argLength: 3, reg: w2kw, asm: "VPRORVQ", commutative: false, typ: "Vec128", resultInArg0: false}, {name: "VPRORVQMasked256", argLength: 3, reg: w2kw, asm: "VPRORVQ", commutative: false, typ: "Vec256", resultInArg0: false}, {name: "VPRORVQMasked512", argLength: 3, reg: w2kw, asm: "VPRORVQ", commutative: false, typ: "Vec512", resultInArg0: false}, + {name: "VPSADBW128", argLength: 2, reg: v21, asm: "VPSADBW", commutative: false, typ: "Vec128", resultInArg0: false}, + {name: "VPSADBW256", argLength: 2, reg: v21, asm: "VPSADBW", commutative: false, typ: "Vec256", resultInArg0: false}, + {name: "VPSADBW512", argLength: 2, reg: w21, asm: "VPSADBW", commutative: false, typ: "Vec512", resultInArg0: false}, {name: "VPSHLDVD128", argLength: 3, reg: w31, asm: "VPSHLDVD", commutative: false, typ: "Vec128", resultInArg0: true}, {name: "VPSHLDVD256", argLength: 3, reg: w31, asm: "VPSHLDVD", commutative: false, typ: "Vec256", resultInArg0: true}, {name: "VPSHLDVD512", argLength: 3, reg: w31, asm: "VPSHLDVD", commutative: false, typ: "Vec512", resultInArg0: true}, diff --git a/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go b/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go index 4f9877aa03..4844d8fc0c 100644 --- a/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go +++ b/src/cmd/compile/internal/ssa/_gen/simdgenericOps.go @@ -894,6 +894,9 @@ func simdGenericOps() []opData { {name: "SubUint64x2", argLength: 2, commutative: false}, {name: "SubUint64x4", argLength: 2, commutative: false}, {name: "SubUint64x8", argLength: 2, commutative: false}, + {name: "SumAbsDiffUint8x16", argLength: 2, commutative: false}, + {name: "SumAbsDiffUint8x32", argLength: 2, commutative: false}, + {name: "SumAbsDiffUint8x64", argLength: 2, commutative: false}, {name: "TruncFloat32x4", argLength: 1, commutative: false}, {name: "TruncFloat32x8", argLength: 1, commutative: false}, {name: "TruncFloat64x2", argLength: 1, commutative: false}, diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 8375b3f8a6..c5402c6f17 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -1875,6 +1875,9 @@ const ( OpAMD64VPRORVQMasked128 OpAMD64VPRORVQMasked256 OpAMD64VPRORVQMasked512 + OpAMD64VPSADBW128 + OpAMD64VPSADBW256 + OpAMD64VPSADBW512 OpAMD64VPSHLDVD128 OpAMD64VPSHLDVD256 OpAMD64VPSHLDVD512 @@ -5544,6 +5547,9 @@ const ( OpSubUint64x2 OpSubUint64x4 OpSubUint64x8 + OpSumAbsDiffUint8x16 + OpSumAbsDiffUint8x32 + OpSumAbsDiffUint8x64 OpTruncFloat32x4 OpTruncFloat32x8 OpTruncFloat64x2 @@ -28457,6 +28463,48 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "VPSADBW128", + argLen: 2, + asm: x86.AVPSADBW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + {1, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, + { + name: "VPSADBW256", + argLen: 2, + asm: x86.AVPSADBW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + {1, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, + { + name: "VPSADBW512", + argLen: 2, + asm: x86.AVPSADBW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 281472829161472}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X27 X28 X29 X30 X31 + {1, 281472829161472}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X27 X28 X29 X30 X31 + }, + outputs: []outputInfo{ + {0, 281472829161472}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X27 X28 X29 X30 X31 + }, + }, + }, { name: "VPSHLDVD128", argLen: 3, @@ -67898,6 +67946,21 @@ var opcodeTable = [...]opInfo{ argLen: 2, generic: true, }, + { + name: "SumAbsDiffUint8x16", + argLen: 2, + generic: true, + }, + { + name: "SumAbsDiffUint8x32", + argLen: 2, + generic: true, + }, + { + name: "SumAbsDiffUint8x64", + argLen: 2, + generic: true, + }, { name: "TruncFloat32x4", argLen: 1, diff --git a/src/cmd/compile/internal/ssa/rewriteAMD64.go b/src/cmd/compile/internal/ssa/rewriteAMD64.go index 924fc2ecf6..70c773bc1c 100644 --- a/src/cmd/compile/internal/ssa/rewriteAMD64.go +++ b/src/cmd/compile/internal/ssa/rewriteAMD64.go @@ -4123,6 +4123,15 @@ func rewriteValueAMD64(v *Value) bool { case OpSubUint8x64: v.Op = OpAMD64VPSUBB512 return true + case OpSumAbsDiffUint8x16: + v.Op = OpAMD64VPSADBW128 + return true + case OpSumAbsDiffUint8x32: + v.Op = OpAMD64VPSADBW256 + return true + case OpSumAbsDiffUint8x64: + v.Op = OpAMD64VPSADBW512 + return true case OpTailCall: v.Op = OpAMD64CALLtail return true diff --git a/src/cmd/compile/internal/ssagen/simdintrinsics.go b/src/cmd/compile/internal/ssagen/simdintrinsics.go index 0fd330779e..676cfa9032 100644 --- a/src/cmd/compile/internal/ssagen/simdintrinsics.go +++ b/src/cmd/compile/internal/ssagen/simdintrinsics.go @@ -1024,6 +1024,9 @@ func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies . addF(simdPackage, "Uint16x8.SubSaturated", opLen2(ssa.OpSubSaturatedUint16x8, types.TypeVec128), sys.AMD64) addF(simdPackage, "Uint16x16.SubSaturated", opLen2(ssa.OpSubSaturatedUint16x16, types.TypeVec256), sys.AMD64) addF(simdPackage, "Uint16x32.SubSaturated", opLen2(ssa.OpSubSaturatedUint16x32, types.TypeVec512), sys.AMD64) + addF(simdPackage, "Uint8x16.SumAbsDiff", opLen2(ssa.OpSumAbsDiffUint8x16, types.TypeVec128), sys.AMD64) + addF(simdPackage, "Uint8x32.SumAbsDiff", opLen2(ssa.OpSumAbsDiffUint8x32, types.TypeVec256), sys.AMD64) + addF(simdPackage, "Uint8x64.SumAbsDiff", opLen2(ssa.OpSumAbsDiffUint8x64, types.TypeVec512), sys.AMD64) addF(simdPackage, "Float32x4.Trunc", opLen1(ssa.OpTruncFloat32x4, types.TypeVec128), sys.AMD64) addF(simdPackage, "Float32x8.Trunc", opLen1(ssa.OpTruncFloat32x8, types.TypeVec256), sys.AMD64) addF(simdPackage, "Float64x2.Trunc", opLen1(ssa.OpTruncFloat64x2, types.TypeVec128), sys.AMD64) diff --git a/src/simd/_gen/simdgen/ops/MLOps/categories.yaml b/src/simd/_gen/simdgen/ops/MLOps/categories.yaml index 97381e1e34..8e1ffeb131 100644 --- a/src/simd/_gen/simdgen/ops/MLOps/categories.yaml +++ b/src/simd/_gen/simdgen/ops/MLOps/categories.yaml @@ -45,3 +45,9 @@ commutative: false documentation: !string |- // NAME performs a fused (x * y) + z for odd-indexed elements, and (x * y) - z for even-indexed elements. +- go: SumAbsDiff + commutative: false + documentation: !string |- + // NAME sums the absolute distance of the two input vectors, each adjacent 8 bytes as a group. The output sum will + // be a vector of word-sized elements whose each 8*n-th element contains the sum of the n-th input group. + // This method could be seen as the norm of the L1 distance of each adjacent 8-byte vector group of the two input vectors. diff --git a/src/simd/_gen/simdgen/ops/MLOps/go.yaml b/src/simd/_gen/simdgen/ops/MLOps/go.yaml index f6b6f135b8..5c2009dcf8 100644 --- a/src/simd/_gen/simdgen/ops/MLOps/go.yaml +++ b/src/simd/_gen/simdgen/ops/MLOps/go.yaml @@ -110,4 +110,14 @@ - *fma_op - *fma_op out: - - *fma_op \ No newline at end of file + - *fma_op +- go: SumAbsDiff + asm: "VPSADBW" + in: + - go: $t + base: uint + - go: $t + base: uint + out: + - go: $t2 + base: uint \ No newline at end of file diff --git a/src/simd/ops_amd64.go b/src/simd/ops_amd64.go index 019f9df1ed..4cfebb3a77 100644 --- a/src/simd/ops_amd64.go +++ b/src/simd/ops_amd64.go @@ -5842,6 +5842,29 @@ func (x Uint16x16) SubSaturated(y Uint16x16) Uint16x16 // Asm: VPSUBUSW, CPU Feature: AVX512 func (x Uint16x32) SubSaturated(y Uint16x32) Uint16x32 +/* SumAbsDiff */ + +// SumAbsDiff sums the absolute distance of the two input vectors, each adjacent 8 bytes as a group. The output sum will +// be a vector of word-sized elements whose each 8*n-th element contains the sum of the n-th input group. +// This method could be seen as the norm of the L1 distance of each adjacent 8-byte vector group of the two input vectors. +// +// Asm: VPSADBW, CPU Feature: AVX +func (x Uint8x16) SumAbsDiff(y Uint8x16) Uint16x8 + +// SumAbsDiff sums the absolute distance of the two input vectors, each adjacent 8 bytes as a group. The output sum will +// be a vector of word-sized elements whose each 8*n-th element contains the sum of the n-th input group. +// This method could be seen as the norm of the L1 distance of each adjacent 8-byte vector group of the two input vectors. +// +// Asm: VPSADBW, CPU Feature: AVX2 +func (x Uint8x32) SumAbsDiff(y Uint8x32) Uint16x16 + +// SumAbsDiff sums the absolute distance of the two input vectors, each adjacent 8 bytes as a group. The output sum will +// be a vector of word-sized elements whose each 8*n-th element contains the sum of the n-th input group. +// This method could be seen as the norm of the L1 distance of each adjacent 8-byte vector group of the two input vectors. +// +// Asm: VPSADBW, CPU Feature: AVX512 +func (x Uint8x64) SumAbsDiff(y Uint8x64) Uint16x32 + /* Trunc */ // Trunc truncates elements towards zero. -- 2.52.0