]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: optimize SIMD IsNaN.Or(IsNaN)
authorCherry Mui <cherryyz@google.com>
Fri, 2 Jan 2026 19:02:07 +0000 (14:02 -0500)
committerCherry Mui <cherryyz@google.com>
Fri, 2 Jan 2026 20:16:34 +0000 (12:16 -0800)
IsNaN's underlying instruction, VCMPPS (or VCMPPD), takes two
inputs, and computes either of them is NaN. Optimize the Or
pattern to generate two-operand form.

This implements the optimization mentioned in CL 733660.

Change-Id: I13943b377ee384864c913eed320763f333a03e41
Reviewed-on: https://go-review.googlesource.com/c/go/+/733680
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/cmd/compile/internal/ssa/_gen/AMD64.rules
src/cmd/compile/internal/ssa/rewriteAMD64.go
src/simd/archsimd/internal/simd_test/compare_test.go
test/codegen/simd.go

index 9c541868548d40443e8ad33074f9fe1dce43bb00..9cd23c6286308b75482c2c5821ec3f5e380c27cb 100644 (file)
 (EQ (VPTEST x:(VPAND(D|Q)512 j k) y) yes no) && x == y && x.Uses == 2 => (EQ (VPTEST j k) yes no)
 (EQ (VPTEST x:(VPANDN(128|256) j k) y) yes no) && x == y && x.Uses == 2 => (ULT (VPTEST k j) yes no) // AndNot has swapped its operand order
 (EQ (VPTEST x:(VPANDN(D|Q)512 j k) y) yes no) && x == y && x.Uses == 2 => (ULT (VPTEST k j) yes no) // AndNot has swapped its operand order
+
+// optimize x.IsNaN().Or(y.IsNaN())
+(VPOR128 (VCMPP(S|D)128 [3] x x) (VCMPP(S|D)128 [3] y y)) => (VCMPP(S|D)128 [3] x y)
+(VPOR256 (VCMPP(S|D)256 [3] x x) (VCMPP(S|D)256 [3] y y)) => (VCMPP(S|D)256 [3] x y)
+(VPORD512 (VPMOVMToVec32x16 (VCMPPS512 [3] x x)) (VPMOVMToVec32x16 (VCMPPS512 [3] y y))) =>
+       (VPMOVMToVec32x16 (VCMPPS512 [3] x y))
+(VPORD512 (VPMOVMToVec64x8  (VCMPPD512 [3] x x)) (VPMOVMToVec64x8  (VCMPPD512 [3] y y))) =>
+       (VPMOVMToVec64x8  (VCMPPD512 [3] x y))
index 0b2bb74ce489c91306ee1183a732cdb8bd91a553..3eb2a6278b4149f1bef99f53b3e96ba921d487c6 100644 (file)
@@ -1382,6 +1382,10 @@ func rewriteValueAMD64(v *Value) bool {
                return rewriteValueAMD64_OpAMD64VPOPCNTQMasked256(v)
        case OpAMD64VPOPCNTQMasked512:
                return rewriteValueAMD64_OpAMD64VPOPCNTQMasked512(v)
+       case OpAMD64VPOR128:
+               return rewriteValueAMD64_OpAMD64VPOR128(v)
+       case OpAMD64VPOR256:
+               return rewriteValueAMD64_OpAMD64VPOR256(v)
        case OpAMD64VPORD512:
                return rewriteValueAMD64_OpAMD64VPORD512(v)
        case OpAMD64VPORDMasked128:
@@ -56768,9 +56772,173 @@ func rewriteValueAMD64_OpAMD64VPOPCNTQMasked512(v *Value) bool {
        }
        return false
 }
+func rewriteValueAMD64_OpAMD64VPOR128(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (VPOR128 (VCMPPS128 [3] x x) (VCMPPS128 [3] y y))
+       // result: (VCMPPS128 [3] x y)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64VCMPPS128 || auxIntToUint8(v_0.AuxInt) != 3 {
+                               continue
+                       }
+                       x := v_0.Args[1]
+                       if x != v_0.Args[0] || v_1.Op != OpAMD64VCMPPS128 || auxIntToUint8(v_1.AuxInt) != 3 {
+                               continue
+                       }
+                       y := v_1.Args[1]
+                       if y != v_1.Args[0] {
+                               continue
+                       }
+                       v.reset(OpAMD64VCMPPS128)
+                       v.AuxInt = uint8ToAuxInt(3)
+                       v.AddArg2(x, y)
+                       return true
+               }
+               break
+       }
+       // match: (VPOR128 (VCMPPD128 [3] x x) (VCMPPD128 [3] y y))
+       // result: (VCMPPD128 [3] x y)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64VCMPPD128 || auxIntToUint8(v_0.AuxInt) != 3 {
+                               continue
+                       }
+                       x := v_0.Args[1]
+                       if x != v_0.Args[0] || v_1.Op != OpAMD64VCMPPD128 || auxIntToUint8(v_1.AuxInt) != 3 {
+                               continue
+                       }
+                       y := v_1.Args[1]
+                       if y != v_1.Args[0] {
+                               continue
+                       }
+                       v.reset(OpAMD64VCMPPD128)
+                       v.AuxInt = uint8ToAuxInt(3)
+                       v.AddArg2(x, y)
+                       return true
+               }
+               break
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPOR256(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (VPOR256 (VCMPPS256 [3] x x) (VCMPPS256 [3] y y))
+       // result: (VCMPPS256 [3] x y)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64VCMPPS256 || auxIntToUint8(v_0.AuxInt) != 3 {
+                               continue
+                       }
+                       x := v_0.Args[1]
+                       if x != v_0.Args[0] || v_1.Op != OpAMD64VCMPPS256 || auxIntToUint8(v_1.AuxInt) != 3 {
+                               continue
+                       }
+                       y := v_1.Args[1]
+                       if y != v_1.Args[0] {
+                               continue
+                       }
+                       v.reset(OpAMD64VCMPPS256)
+                       v.AuxInt = uint8ToAuxInt(3)
+                       v.AddArg2(x, y)
+                       return true
+               }
+               break
+       }
+       // match: (VPOR256 (VCMPPD256 [3] x x) (VCMPPD256 [3] y y))
+       // result: (VCMPPD256 [3] x y)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64VCMPPD256 || auxIntToUint8(v_0.AuxInt) != 3 {
+                               continue
+                       }
+                       x := v_0.Args[1]
+                       if x != v_0.Args[0] || v_1.Op != OpAMD64VCMPPD256 || auxIntToUint8(v_1.AuxInt) != 3 {
+                               continue
+                       }
+                       y := v_1.Args[1]
+                       if y != v_1.Args[0] {
+                               continue
+                       }
+                       v.reset(OpAMD64VCMPPD256)
+                       v.AuxInt = uint8ToAuxInt(3)
+                       v.AddArg2(x, y)
+                       return true
+               }
+               break
+       }
+       return false
+}
 func rewriteValueAMD64_OpAMD64VPORD512(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (VPORD512 (VPMOVMToVec32x16 (VCMPPS512 [3] x x)) (VPMOVMToVec32x16 (VCMPPS512 [3] y y)))
+       // result: (VPMOVMToVec32x16 (VCMPPS512 [3] x y))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64VPMOVMToVec32x16 {
+                               continue
+                       }
+                       v_0_0 := v_0.Args[0]
+                       if v_0_0.Op != OpAMD64VCMPPS512 || auxIntToUint8(v_0_0.AuxInt) != 3 {
+                               continue
+                       }
+                       x := v_0_0.Args[1]
+                       if x != v_0_0.Args[0] || v_1.Op != OpAMD64VPMOVMToVec32x16 {
+                               continue
+                       }
+                       v_1_0 := v_1.Args[0]
+                       if v_1_0.Op != OpAMD64VCMPPS512 || auxIntToUint8(v_1_0.AuxInt) != 3 {
+                               continue
+                       }
+                       y := v_1_0.Args[1]
+                       if y != v_1_0.Args[0] {
+                               continue
+                       }
+                       v.reset(OpAMD64VPMOVMToVec32x16)
+                       v0 := b.NewValue0(v.Pos, OpAMD64VCMPPS512, typ.Mask)
+                       v0.AuxInt = uint8ToAuxInt(3)
+                       v0.AddArg2(x, y)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (VPORD512 (VPMOVMToVec64x8 (VCMPPD512 [3] x x)) (VPMOVMToVec64x8 (VCMPPD512 [3] y y)))
+       // result: (VPMOVMToVec64x8 (VCMPPD512 [3] x y))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64VPMOVMToVec64x8 {
+                               continue
+                       }
+                       v_0_0 := v_0.Args[0]
+                       if v_0_0.Op != OpAMD64VCMPPD512 || auxIntToUint8(v_0_0.AuxInt) != 3 {
+                               continue
+                       }
+                       x := v_0_0.Args[1]
+                       if x != v_0_0.Args[0] || v_1.Op != OpAMD64VPMOVMToVec64x8 {
+                               continue
+                       }
+                       v_1_0 := v_1.Args[0]
+                       if v_1_0.Op != OpAMD64VCMPPD512 || auxIntToUint8(v_1_0.AuxInt) != 3 {
+                               continue
+                       }
+                       y := v_1_0.Args[1]
+                       if y != v_1_0.Args[0] {
+                               continue
+                       }
+                       v.reset(OpAMD64VPMOVMToVec64x8)
+                       v0 := b.NewValue0(v.Pos, OpAMD64VCMPPD512, typ.Mask)
+                       v0.AuxInt = uint8ToAuxInt(3)
+                       v0.AddArg2(x, y)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
        // match: (VPORD512 x l:(VMOVDQUload512 {sym} [off] ptr mem))
        // cond: canMergeLoad(v, l) && clobber(l)
        // result: (VPORD512load {sym} [off] x ptr mem)
index e678676be0103c62f551b89a5da65d891a5b37da..ea8514ac939756d35fdafbb3696a1582282876e1 100644 (file)
@@ -309,4 +309,38 @@ func TestIsNaN(t *testing.T) {
                testFloat32x16UnaryCompare(t, archsimd.Float32x16.IsNaN, isNaNSlice[float32])
                testFloat64x8UnaryCompare(t, archsimd.Float64x8.IsNaN, isNaNSlice[float64])
        }
+
+       // Test x.IsNaN().Or(y.IsNaN()), which is optimized to VCMPP(S|D) $3, x, y.
+       want32 := mapCompare(func(x, y float32) bool { return x != x || y != y })
+       want64 := mapCompare(func(x, y float64) bool { return x != x || y != y })
+       testFloat32x4Compare(t,
+               func(x, y archsimd.Float32x4) archsimd.Mask32x4 {
+                       return x.IsNaN().Or(y.IsNaN())
+               }, want32)
+       testFloat64x2Compare(t,
+               func(x, y archsimd.Float64x2) archsimd.Mask64x2 {
+                       return x.IsNaN().Or(y.IsNaN())
+               }, want64)
+
+       if archsimd.X86.AVX2() {
+               testFloat32x8Compare(t,
+                       func(x, y archsimd.Float32x8) archsimd.Mask32x8 {
+                               return x.IsNaN().Or(y.IsNaN())
+                       }, want32)
+               testFloat64x4Compare(t,
+                       func(x, y archsimd.Float64x4) archsimd.Mask64x4 {
+                               return x.IsNaN().Or(y.IsNaN())
+                       }, want64)
+       }
+
+       if archsimd.X86.AVX512() {
+               testFloat32x16Compare(t,
+                       func(x, y archsimd.Float32x16) archsimd.Mask32x16 {
+                               return x.IsNaN().Or(y.IsNaN())
+                       }, want32)
+               testFloat64x8Compare(t,
+                       func(x, y archsimd.Float64x8) archsimd.Mask64x8 {
+                               return x.IsNaN().Or(y.IsNaN())
+                       }, want64)
+       }
 }
index 8f3a1a9f46e7b5b928d940d777089ca7df9ef2f6..04e01944def1fa18748fa3d70cc80c4a6085843a 100644 (file)
@@ -6,11 +6,14 @@
 
 // These tests check code generation of simd peephole optimizations.
 
-//go:build goexperiment.simd
+//go:build goexperiment.simd && amd64
 
 package codegen
 
-import "simd/archsimd"
+import (
+       "math"
+       "simd/archsimd"
+)
 
 func vptest1() bool {
        v1 := archsimd.LoadUint64x2Slice([]uint64{0, 1})
@@ -77,3 +80,27 @@ func simdMaskedMerge() archsimd.Int16x16 {
        mask := archsimd.Mask16x16FromBits(5)
        return x.Add(y).Merge(x, mask) // amd64:`VPBLENDVB\s.*$`
 }
+
+var nan = math.NaN()
+var floats64s = []float64{0, 1, 2, nan, 4, nan, 6, 7, 8, 9, 10, 11, nan, 13, 14, 15}
+var sinkInt64s = make([]int64, 100)
+
+func simdIsNaN() {
+       x := archsimd.LoadFloat64x4Slice(floats64s)
+       y := archsimd.LoadFloat64x4Slice(floats64s[4:])
+       a := x.IsNaN()
+       b := y.IsNaN()
+       // amd64:"VCMPPD [$]3," -"VPOR"
+       c := a.Or(b)
+       c.ToInt64x4().StoreSlice(sinkInt64s)
+}
+
+func simdIsNaN512() {
+       x := archsimd.LoadFloat64x8Slice(floats64s)
+       y := archsimd.LoadFloat64x8Slice(floats64s[8:])
+       a := x.IsNaN()
+       b := y.IsNaN()
+       // amd64:"VCMPPD [$]3," -"VPOR"
+       c := a.Or(b)
+       c.ToInt64x8().StoreSlice(sinkInt64s)
+}