]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.simd] cmd/compile: sample peephole optimization for SIMD broadcast
authorDavid Chase <drchase@google.com>
Thu, 14 Aug 2025 21:31:09 +0000 (17:31 -0400)
committerDavid Chase <drchase@google.com>
Sat, 23 Aug 2025 04:19:38 +0000 (21:19 -0700)
After tinkering and rewrite, this also optimizes some instances
of SetElem(0).

Change-Id: Ibba2d50a56b68ccf9de517ef24ca52b64c6c5b2c
Reviewed-on: https://go-review.googlesource.com/c/go/+/696376
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/cmd/compile/internal/amd64/ssa.go
src/cmd/compile/internal/ssa/_gen/AMD64.rules
src/cmd/compile/internal/ssa/_gen/AMD64Ops.go
src/cmd/compile/internal/ssa/_gen/rulegen.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteAMD64.go
src/simd/internal/simd_test/simd_test.go

index ec4eaaed03cbd1f533df27e656ca02ec57cbec5c..58a0f9cc8140b72b3d73130d353d4dffc06088dd 100644 (file)
@@ -1711,8 +1711,26 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
        // SIMD ops
        case ssa.OpAMD64VZEROUPPER, ssa.OpAMD64VZEROALL:
                s.Prog(v.Op.Asm())
-       case ssa.OpAMD64Zero128, ssa.OpAMD64Zero256, ssa.OpAMD64Zero512:
-               // zero-width, no instruction generated
+
+       case ssa.OpAMD64Zero128, ssa.OpAMD64Zero256, ssa.OpAMD64Zero512: // no code emitted
+
+       case ssa.OpAMD64VMOVSSf2v, ssa.OpAMD64VMOVSDf2v:
+               // These are for initializing the least 32/64 bits of a SIMD register from a "float".
+               p := s.Prog(v.Op.Asm())
+               p.From.Type = obj.TYPE_REG
+               p.From.Reg = v.Args[0].Reg()
+               p.AddRestSourceReg(x86.REG_X15)
+               p.To.Type = obj.TYPE_REG
+               p.To.Reg = simdReg(v)
+
+       case ssa.OpAMD64VMOVD, ssa.OpAMD64VMOVQ:
+               // These are for initializing the least 32/64 bits of a SIMD register from an "int".
+               p := s.Prog(v.Op.Asm())
+               p.From.Type = obj.TYPE_REG
+               p.From.Reg = v.Args[0].Reg()
+               p.To.Type = obj.TYPE_REG
+               p.To.Reg = simdReg(v)
+
        case ssa.OpAMD64VMOVDQUload128, ssa.OpAMD64VMOVDQUload256, ssa.OpAMD64VMOVDQUload512, ssa.OpAMD64KMOVQload:
                p := s.Prog(v.Op.Asm())
                p.From.Type = obj.TYPE_MEM
index 913ddbf5596b94244cbb7858130178f23f002fd7..0c7c7ced4375f1d1b82bf68af43fcd613ce0d0e6 100644 (file)
 (VPANDQ512 x (VPMOVMToVec32x16 k)) => (VMOVDQU32Masked512 x k)
 (VPANDQ512 x (VPMOVMToVec16x32 k)) => (VMOVDQU16Masked512 x k)
 (VPANDQ512 x (VPMOVMToVec8x64 k)) => (VMOVDQU8Masked512 x k)
+
+// Insert to zero of 32/64 bit floats and ints to a zero is just MOVS[SD]
+(VPINSRQ128 [0] (Zero128 <t>) y) && y.Type.IsFloat() => (VMOVSDf2v <types.TypeVec128> y)
+(VPINSRD128 [0] (Zero128 <t>) y) && y.Type.IsFloat() => (VMOVSSf2v <types.TypeVec128> y)
+(VPINSRQ128 [0] (Zero128 <t>) y) && !y.Type.IsFloat() => (VMOVQ <types.TypeVec128> y)
+(VPINSRD128 [0] (Zero128 <t>) y) && !y.Type.IsFloat() => (VMOVD <types.TypeVec128> y)
+
+// These rewrites can skip zero-extending the 8/16-bit inputs because they are
+// only used as the input to a broadcast; the potentially "bad" bits are ignored
+(VPBROADCASTB(128|256|512) x:(VPINSRB128 [0] (Zero128    <t>) y)) && x.Uses == 1 =>
+       (VPBROADCASTB(128|256|512) (VMOVQ <types.TypeVec128> y))
+(VPBROADCASTW(128|256|512) x:(VPINSRW128 [0] (Zero128    <t>) y)) && x.Uses == 1 =>
+       (VPBROADCASTW(128|256|512)   (VMOVQ <types.TypeVec128> y))
+
index 12be7cae4185d9d5b32a90db149256fb7bf9902c..03f38db640074b5485d939c421616d1ec74f1cd5 100644 (file)
@@ -226,6 +226,8 @@ func init() {
                vgp   = regInfo{inputs: vonly, outputs: gponly}
                vfpv  = regInfo{inputs: []regMask{vz, fp}, outputs: vonly}
                vfpkv = regInfo{inputs: []regMask{vz, fp, mask}, outputs: vonly}
+               fpv   = regInfo{inputs: []regMask{fp}, outputs: vonly}
+               gpv   = regInfo{inputs: []regMask{gp}, outputs: vonly}
 
                w11   = regInfo{inputs: wzonly, outputs: wonly}
                w21   = regInfo{inputs: []regMask{wz, wz}, outputs: wonly}
@@ -1382,6 +1384,11 @@ func init() {
                {name: "Zero256", argLength: 0, reg: x15only, zeroWidth: true, fixedReg: true},
                {name: "Zero512", argLength: 0, reg: x15only, zeroWidth: true, fixedReg: true},
 
+               {name: "VMOVSDf2v", argLength: 1, reg: fpv, asm: "VMOVSD"},
+               {name: "VMOVSSf2v", argLength: 1, reg: fpv, asm: "VMOVSS"},
+               {name: "VMOVQ", argLength: 1, reg: gpv, asm: "VMOVQ"},
+               {name: "VMOVD", argLength: 1, reg: gpv, asm: "VMOVD"},
+
                {name: "VZEROUPPER", argLength: 0, asm: "VZEROUPPER"},
                {name: "VZEROALL", argLength: 0, asm: "VZEROALL"},
 
index d4ca1aef22279f4cecd74bb761ed795e55179007..b16f9567bac87d87cec1e1584eeeb50268d3a997 100644 (file)
@@ -875,7 +875,7 @@ func declReserved(name, value string) *Declare {
        if !reservedNames[name] {
                panic(fmt.Sprintf("declReserved call does not use a reserved name: %q", name))
        }
-       return &Declare{name, exprf(value)}
+       return &Declare{name, exprf("%s", value)}
 }
 
 // breakf constructs a simple "if cond { break }" statement, using exprf for its
@@ -902,7 +902,7 @@ func genBlockRewrite(rule Rule, arch arch, data blockData) *RuleRewrite {
                        if vname == "" {
                                vname = fmt.Sprintf("v_%v", i)
                        }
-                       rr.add(declf(rr.Loc, vname, cname))
+                       rr.add(declf(rr.Loc, vname, "%s", cname))
                        p, op := genMatch0(rr, arch, expr, vname, nil, false) // TODO: pass non-nil cnt?
                        if op != "" {
                                check := fmt.Sprintf("%s.Op == %s", cname, op)
@@ -917,7 +917,7 @@ func genBlockRewrite(rule Rule, arch arch, data blockData) *RuleRewrite {
                        }
                        pos[i] = p
                } else {
-                       rr.add(declf(rr.Loc, arg, cname))
+                       rr.add(declf(rr.Loc, arg, "%s", cname))
                        pos[i] = arg + ".Pos"
                }
        }
index 76b0f84f3576affa44cfaabe69bff9f59dbd4233..7f6e9a0282cb28c83240e81a49a5d5cdabad5c62 100644 (file)
@@ -1214,6 +1214,10 @@ const (
        OpAMD64Zero128
        OpAMD64Zero256
        OpAMD64Zero512
+       OpAMD64VMOVSDf2v
+       OpAMD64VMOVSSf2v
+       OpAMD64VMOVQ
+       OpAMD64VMOVD
        OpAMD64VZEROUPPER
        OpAMD64VZEROALL
        OpAMD64KMOVQload
@@ -18869,6 +18873,58 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:   "VMOVSDf2v",
+               argLen: 1,
+               asm:    x86.AVMOVSD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
+       {
+               name:   "VMOVSSf2v",
+               argLen: 1,
+               asm:    x86.AVMOVSS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
+       {
+               name:   "VMOVQ",
+               argLen: 1,
+               asm:    x86.AVMOVQ,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 49135}, // AX CX DX BX BP SI DI R8 R9 R10 R11 R12 R13 R15
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
+       {
+               name:   "VMOVD",
+               argLen: 1,
+               asm:    x86.AVMOVD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 49135}, // AX CX DX BX BP SI DI R8 R9 R10 R11 R12 R13 R15
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
        {
                name:   "VZEROUPPER",
                argLen: 0,
index 77ae32519a45072f3e113ae68bec3e38271f653c..469417536fddefe4045e24acfcb14b19a8f2a3f8 100644 (file)
@@ -517,6 +517,22 @@ func rewriteValueAMD64(v *Value) bool {
                return rewriteValueAMD64_OpAMD64VMOVDQU8Masked512(v)
        case OpAMD64VPANDQ512:
                return rewriteValueAMD64_OpAMD64VPANDQ512(v)
+       case OpAMD64VPBROADCASTB128:
+               return rewriteValueAMD64_OpAMD64VPBROADCASTB128(v)
+       case OpAMD64VPBROADCASTB256:
+               return rewriteValueAMD64_OpAMD64VPBROADCASTB256(v)
+       case OpAMD64VPBROADCASTB512:
+               return rewriteValueAMD64_OpAMD64VPBROADCASTB512(v)
+       case OpAMD64VPBROADCASTW128:
+               return rewriteValueAMD64_OpAMD64VPBROADCASTW128(v)
+       case OpAMD64VPBROADCASTW256:
+               return rewriteValueAMD64_OpAMD64VPBROADCASTW256(v)
+       case OpAMD64VPBROADCASTW512:
+               return rewriteValueAMD64_OpAMD64VPBROADCASTW512(v)
+       case OpAMD64VPINSRD128:
+               return rewriteValueAMD64_OpAMD64VPINSRD128(v)
+       case OpAMD64VPINSRQ128:
+               return rewriteValueAMD64_OpAMD64VPINSRQ128(v)
        case OpAMD64VPMOVVec16x16ToM:
                return rewriteValueAMD64_OpAMD64VPMOVVec16x16ToM(v)
        case OpAMD64VPMOVVec16x32ToM:
@@ -28848,6 +28864,242 @@ func rewriteValueAMD64_OpAMD64VPANDQ512(v *Value) bool {
        }
        return false
 }
+func rewriteValueAMD64_OpAMD64VPBROADCASTB128(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       // match: (VPBROADCASTB128 x:(VPINSRB128 [0] (Zero128 <t>) y))
+       // cond: x.Uses == 1
+       // result: (VPBROADCASTB128 (VMOVQ <types.TypeVec128> y))
+       for {
+               x := v_0
+               if x.Op != OpAMD64VPINSRB128 || auxIntToUint8(x.AuxInt) != 0 {
+                       break
+               }
+               y := x.Args[1]
+               x_0 := x.Args[0]
+               if x_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               if !(x.Uses == 1) {
+                       break
+               }
+               v.reset(OpAMD64VPBROADCASTB128)
+               v0 := b.NewValue0(v.Pos, OpAMD64VMOVQ, types.TypeVec128)
+               v0.AddArg(y)
+               v.AddArg(v0)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPBROADCASTB256(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       // match: (VPBROADCASTB256 x:(VPINSRB128 [0] (Zero128 <t>) y))
+       // cond: x.Uses == 1
+       // result: (VPBROADCASTB256 (VMOVQ <types.TypeVec128> y))
+       for {
+               x := v_0
+               if x.Op != OpAMD64VPINSRB128 || auxIntToUint8(x.AuxInt) != 0 {
+                       break
+               }
+               y := x.Args[1]
+               x_0 := x.Args[0]
+               if x_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               if !(x.Uses == 1) {
+                       break
+               }
+               v.reset(OpAMD64VPBROADCASTB256)
+               v0 := b.NewValue0(v.Pos, OpAMD64VMOVQ, types.TypeVec128)
+               v0.AddArg(y)
+               v.AddArg(v0)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPBROADCASTB512(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       // match: (VPBROADCASTB512 x:(VPINSRB128 [0] (Zero128 <t>) y))
+       // cond: x.Uses == 1
+       // result: (VPBROADCASTB512 (VMOVQ <types.TypeVec128> y))
+       for {
+               x := v_0
+               if x.Op != OpAMD64VPINSRB128 || auxIntToUint8(x.AuxInt) != 0 {
+                       break
+               }
+               y := x.Args[1]
+               x_0 := x.Args[0]
+               if x_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               if !(x.Uses == 1) {
+                       break
+               }
+               v.reset(OpAMD64VPBROADCASTB512)
+               v0 := b.NewValue0(v.Pos, OpAMD64VMOVQ, types.TypeVec128)
+               v0.AddArg(y)
+               v.AddArg(v0)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPBROADCASTW128(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       // match: (VPBROADCASTW128 x:(VPINSRW128 [0] (Zero128 <t>) y))
+       // cond: x.Uses == 1
+       // result: (VPBROADCASTW128 (VMOVQ <types.TypeVec128> y))
+       for {
+               x := v_0
+               if x.Op != OpAMD64VPINSRW128 || auxIntToUint8(x.AuxInt) != 0 {
+                       break
+               }
+               y := x.Args[1]
+               x_0 := x.Args[0]
+               if x_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               if !(x.Uses == 1) {
+                       break
+               }
+               v.reset(OpAMD64VPBROADCASTW128)
+               v0 := b.NewValue0(v.Pos, OpAMD64VMOVQ, types.TypeVec128)
+               v0.AddArg(y)
+               v.AddArg(v0)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPBROADCASTW256(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       // match: (VPBROADCASTW256 x:(VPINSRW128 [0] (Zero128 <t>) y))
+       // cond: x.Uses == 1
+       // result: (VPBROADCASTW256 (VMOVQ <types.TypeVec128> y))
+       for {
+               x := v_0
+               if x.Op != OpAMD64VPINSRW128 || auxIntToUint8(x.AuxInt) != 0 {
+                       break
+               }
+               y := x.Args[1]
+               x_0 := x.Args[0]
+               if x_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               if !(x.Uses == 1) {
+                       break
+               }
+               v.reset(OpAMD64VPBROADCASTW256)
+               v0 := b.NewValue0(v.Pos, OpAMD64VMOVQ, types.TypeVec128)
+               v0.AddArg(y)
+               v.AddArg(v0)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPBROADCASTW512(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       // match: (VPBROADCASTW512 x:(VPINSRW128 [0] (Zero128 <t>) y))
+       // cond: x.Uses == 1
+       // result: (VPBROADCASTW512 (VMOVQ <types.TypeVec128> y))
+       for {
+               x := v_0
+               if x.Op != OpAMD64VPINSRW128 || auxIntToUint8(x.AuxInt) != 0 {
+                       break
+               }
+               y := x.Args[1]
+               x_0 := x.Args[0]
+               if x_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               if !(x.Uses == 1) {
+                       break
+               }
+               v.reset(OpAMD64VPBROADCASTW512)
+               v0 := b.NewValue0(v.Pos, OpAMD64VMOVQ, types.TypeVec128)
+               v0.AddArg(y)
+               v.AddArg(v0)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPINSRD128(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (VPINSRD128 [0] (Zero128 <t>) y)
+       // cond: y.Type.IsFloat()
+       // result: (VMOVSSf2v <types.TypeVec128> y)
+       for {
+               if auxIntToUint8(v.AuxInt) != 0 || v_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               y := v_1
+               if !(y.Type.IsFloat()) {
+                       break
+               }
+               v.reset(OpAMD64VMOVSSf2v)
+               v.Type = types.TypeVec128
+               v.AddArg(y)
+               return true
+       }
+       // match: (VPINSRD128 [0] (Zero128 <t>) y)
+       // cond: !y.Type.IsFloat()
+       // result: (VMOVD <types.TypeVec128> y)
+       for {
+               if auxIntToUint8(v.AuxInt) != 0 || v_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               y := v_1
+               if !(!y.Type.IsFloat()) {
+                       break
+               }
+               v.reset(OpAMD64VMOVD)
+               v.Type = types.TypeVec128
+               v.AddArg(y)
+               return true
+       }
+       return false
+}
+func rewriteValueAMD64_OpAMD64VPINSRQ128(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (VPINSRQ128 [0] (Zero128 <t>) y)
+       // cond: y.Type.IsFloat()
+       // result: (VMOVSDf2v <types.TypeVec128> y)
+       for {
+               if auxIntToUint8(v.AuxInt) != 0 || v_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               y := v_1
+               if !(y.Type.IsFloat()) {
+                       break
+               }
+               v.reset(OpAMD64VMOVSDf2v)
+               v.Type = types.TypeVec128
+               v.AddArg(y)
+               return true
+       }
+       // match: (VPINSRQ128 [0] (Zero128 <t>) y)
+       // cond: !y.Type.IsFloat()
+       // result: (VMOVQ <types.TypeVec128> y)
+       for {
+               if auxIntToUint8(v.AuxInt) != 0 || v_0.Op != OpAMD64Zero128 {
+                       break
+               }
+               y := v_1
+               if !(!y.Type.IsFloat()) {
+                       break
+               }
+               v.reset(OpAMD64VMOVQ)
+               v.Type = types.TypeVec128
+               v.AddArg(y)
+               return true
+       }
+       return false
+}
 func rewriteValueAMD64_OpAMD64VPMOVVec16x16ToM(v *Value) bool {
        v_0 := v.Args[0]
        // match: (VPMOVVec16x16ToM (VPMOVMToVec16x16 x))
index 38065cb841b151388717e7176d73da098e07e3ba..3dcb5c6a2746a724a7ac097c2ae9ed046bf23c6f 100644 (file)
@@ -458,6 +458,22 @@ func TestBroadcastUint64x2(t *testing.T) {
        checkSlices(t, s, []uint64{123456789, 123456789})
 }
 
+func TestBroadcastUint16x8(t *testing.T) {
+       s := make([]uint16, 8, 8)
+       simd.BroadcastUint16x8(12345).StoreSlice(s)
+       checkSlices(t, s, []uint16{12345, 12345, 12345, 12345})
+}
+
+func TestBroadcastInt8x32(t *testing.T) {
+       s := make([]int8, 32, 32)
+       simd.BroadcastInt8x32(-123).StoreSlice(s)
+       checkSlices(t, s, []int8{-123, -123, -123, -123, -123, -123, -123, -123,
+               -123, -123, -123, -123, -123, -123, -123, -123,
+               -123, -123, -123, -123, -123, -123, -123, -123,
+               -123, -123, -123, -123, -123, -123, -123, -123,
+       })
+}
+
 func TestMaskOpt512(t *testing.T) {
        if !simd.HasAVX512() {
                t.Skip("Test requires HasAVX512, not available on this hardware")