]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: add opt branchelim to rewrite some CondSelect into math
authorJorropo <jorropo.pgm@gmail.com>
Wed, 2 Jul 2025 23:35:51 +0000 (01:35 +0200)
committerJorropo <jorropo.pgm@gmail.com>
Thu, 24 Jul 2025 21:42:10 +0000 (14:42 -0700)
This allows something like:
  if y { x++ }

To be compiled to:
  MOVBLZX BX, CX
  ADDQ CX, AX

Instead of:
  LEAQ    1(AX), CX
  MOVBLZX BL, DX
  TESTQ   DX, DX
  CMOVQNE CX, AX

While ./make.bash uniqued per LOC, there is 100 additions and 75 substractions.

See benchmark here: https://go.dev/play/p/DJf5COjwhd_s

Either it's a performance no-op or it is faster:

  goos: linux
  goarch: amd64
  cpu: AMD Ryzen 5 3600 6-Core Processor
                                          │ /tmp/old.logs │            /tmp/new.logs             │
                                          │    sec/op     │    sec/op     vs base                │
  CmovInlineConditionAddLatency-12           0.5443n ± 5%   0.5339n ± 3%   -1.90% (p=0.004 n=10)
  CmovInlineConditionAddThroughputBy6-12      1.492n ± 1%    1.494n ± 1%        ~ (p=0.955 n=10)
  CmovInlineConditionSubLatency-12           0.5419n ± 3%   0.5282n ± 3%   -2.52% (p=0.019 n=10)
  CmovInlineConditionSubThroughputBy6-12      1.587n ± 1%    1.584n ± 2%        ~ (p=0.492 n=10)
  CmovOutlineConditionAddLatency-12          0.5223n ± 1%   0.2639n ± 4%  -49.47% (p=0.000 n=10)
  CmovOutlineConditionAddThroughputBy6-12     1.159n ± 1%    1.097n ± 2%   -5.35% (p=0.000 n=10)
  CmovOutlineConditionSubLatency-12          0.5271n ± 3%   0.2654n ± 2%  -49.66% (p=0.000 n=10)
  CmovOutlineConditionSubThroughputBy6-12     1.053n ± 1%    1.050n ± 1%        ~ (p=1.000 n=10)
  geomean

There are other benefits not tested by this benchmark:
- the math form is usually a couple bytes shorter (ICACHE)
- the math form is usually 0~2 uops shorter (UCACHE)
- the math form has usually less register pressure*
- the math form can sometimes be optimized further

*regalloc rarely find how it can use less registers

As far as pass ordering goes there are many possible options,
I've decided to reorder branchelim before late opt since:
- unlike running exclusively the CondSelect rules after branchelim,
  some extra optimizations might trigger on the adds or subs.
- I don't want to maintain a second generic.rules file of only the stuff,
  that can trigger after branchelim.
- rerunning all of opt a third time increase compilation time for little gains.

By elimination moving branchelim seems fine.

Change-Id: I869adf57e4d109948ee157cfc47144445146bafd
Reviewed-on: https://go-review.googlesource.com/c/go/+/685676
Reviewed-by: Keith Randall <khr@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/cmd/compile/internal/ssa/_gen/generic.rules
src/cmd/compile/internal/ssa/compile.go
src/cmd/compile/internal/ssa/rewritegeneric.go
test/codegen/condmove.go

index b178a1add6de29eb2181106a311dd7ac0fd0a099..89657bdabb732a0ce0ef1be87a4e38fcebaf7f85 100644 (file)
 (Neq16 (Const16 <t> [c]) (Add16 (Const16 <t> [d]) x)) => (Neq16 (Const16 <t> [c-d]) x)
 (Neq8  (Const8  <t> [c]) (Add8  (Const8  <t> [d]) x)) => (Neq8  (Const8  <t> [c-d]) x)
 
+(CondSelect x _ (ConstBool [true ])) => x
+(CondSelect _ y (ConstBool [false])) => y
+
 // signed integer range: ( c <= x && x (<|<=) d ) -> ( unsigned(x-c) (<|<=) unsigned(d-c) )
 (AndB (Leq64 (Const64 [c]) x) ((Less|Leq)64 x (Const64 [d]))) && d >= c => ((Less|Leq)64U (Sub64 <x.Type> x (Const64 <x.Type> [c])) (Const64 <x.Type> [d-c]))
 (AndB (Leq32 (Const32 [c]) x) ((Less|Leq)32 x (Const32 [d]))) && d >= c => ((Less|Leq)32U (Sub32 <x.Type> x (Const32 <x.Type> [c])) (Const32 <x.Type> [d-c]))
   && clobber(sbts)
   && clobber(key)
 => (StaticLECall {f} [argsize] dict_ (StringMake <typ.String> ptr len) mem)
+
+// Transform some CondSelect into math operations.
+// if b { x++ } => x += b // but not on arm64 because it has CSINC
+(CondSelect (Add8 <t> x (Const8 [1])) x bool) && config.arch != "arm64" => (Add8 x (CvtBoolToUint8 <t> bool))
+(CondSelect (Add(64|32|16) <t> x (Const(64|32|16) [1])) x bool) && config.arch != "arm64" => (Add(64|32|16) x (ZeroExt8to(64|32|16) <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+
+// if b { x-- } => x -= b
+(CondSelect (Add8 <t> x (Const8 [-1])) x bool) => (Sub8 x (CvtBoolToUint8 <t> bool))
+(CondSelect (Add(64|32|16) <t> x (Const(64|32|16) [-1])) x bool) => (Sub(64|32|16) x (ZeroExt8to(64|32|16) <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
index e9500a24ed8a608e1f44e79401f6e968f5b4d3f1..1f47362583353e69fa95e82aa98f91fcdb0cd7ce 100644 (file)
@@ -473,11 +473,11 @@ var passes = [...]pass{
        {name: "expand calls", fn: expandCalls, required: true},
        {name: "decompose builtin", fn: postExpandCallsDecompose, required: true},
        {name: "softfloat", fn: softfloat, required: true},
+       {name: "branchelim", fn: branchelim},
        {name: "late opt", fn: opt, required: true}, // TODO: split required rules and optimizing rules
        {name: "dead auto elim", fn: elimDeadAutosGeneric},
        {name: "sccp", fn: sccp},
        {name: "generic deadcode", fn: deadcode, required: true}, // remove dead stores, which otherwise mess up store chain
-       {name: "branchelim", fn: branchelim},
        {name: "late fuse", fn: fuseLate},
        {name: "check bce", fn: checkbce},
        {name: "dse", fn: dse},
@@ -583,6 +583,10 @@ var passOrder = [...]constraint{
        {"late fuse", "memcombine"},
        // memcombine is a arch-independent pass.
        {"memcombine", "lower"},
+       // late opt transform some CondSelects into math.
+       {"branchelim", "late opt"},
+       // ranchelim is an arch-independent pass.
+       {"branchelim", "lower"},
 }
 
 func init() {
index bfbd3c8522ed241f91b4d2489546bc89c486b905..a8c3373e409fe5d4c7b514534bd6d300663295b0 100644 (file)
@@ -56,6 +56,8 @@ func rewriteValuegeneric(v *Value) bool {
                return rewriteValuegeneric_OpCom64(v)
        case OpCom8:
                return rewriteValuegeneric_OpCom8(v)
+       case OpCondSelect:
+               return rewriteValuegeneric_OpCondSelect(v)
        case OpConstInterface:
                return rewriteValuegeneric_OpConstInterface(v)
        case OpConstSlice:
@@ -5694,6 +5696,254 @@ func rewriteValuegeneric_OpCom8(v *Value) bool {
        }
        return false
 }
+func rewriteValuegeneric_OpCondSelect(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       b := v.Block
+       config := b.Func.Config
+       // match: (CondSelect x _ (ConstBool [true ]))
+       // result: x
+       for {
+               x := v_0
+               if v_2.Op != OpConstBool || auxIntToBool(v_2.AuxInt) != true {
+                       break
+               }
+               v.copyOf(x)
+               return true
+       }
+       // match: (CondSelect _ y (ConstBool [false]))
+       // result: y
+       for {
+               y := v_1
+               if v_2.Op != OpConstBool || auxIntToBool(v_2.AuxInt) != false {
+                       break
+               }
+               v.copyOf(y)
+               return true
+       }
+       // match: (CondSelect (Add8 <t> x (Const8 [1])) x bool)
+       // cond: config.arch != "arm64"
+       // result: (Add8 x (CvtBoolToUint8 <t> bool))
+       for {
+               if v_0.Op != OpAdd8 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst8 || auxIntToInt8(v_0_1.AuxInt) != 1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       if !(config.arch != "arm64") {
+                               continue
+                       }
+                       v.reset(OpAdd8)
+                       v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, t)
+                       v0.AddArg(bool)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add64 <t> x (Const64 [1])) x bool)
+       // cond: config.arch != "arm64"
+       // result: (Add64 x (ZeroExt8to64 <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+       for {
+               if v_0.Op != OpAdd64 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       if !(config.arch != "arm64") {
+                               continue
+                       }
+                       v.reset(OpAdd64)
+                       v0 := b.NewValue0(v.Pos, OpZeroExt8to64, t)
+                       v1 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8])
+                       v1.AddArg(bool)
+                       v0.AddArg(v1)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add32 <t> x (Const32 [1])) x bool)
+       // cond: config.arch != "arm64"
+       // result: (Add32 x (ZeroExt8to32 <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+       for {
+               if v_0.Op != OpAdd32 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst32 || auxIntToInt32(v_0_1.AuxInt) != 1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       if !(config.arch != "arm64") {
+                               continue
+                       }
+                       v.reset(OpAdd32)
+                       v0 := b.NewValue0(v.Pos, OpZeroExt8to32, t)
+                       v1 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8])
+                       v1.AddArg(bool)
+                       v0.AddArg(v1)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add16 <t> x (Const16 [1])) x bool)
+       // cond: config.arch != "arm64"
+       // result: (Add16 x (ZeroExt8to16 <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+       for {
+               if v_0.Op != OpAdd16 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst16 || auxIntToInt16(v_0_1.AuxInt) != 1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       if !(config.arch != "arm64") {
+                               continue
+                       }
+                       v.reset(OpAdd16)
+                       v0 := b.NewValue0(v.Pos, OpZeroExt8to16, t)
+                       v1 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8])
+                       v1.AddArg(bool)
+                       v0.AddArg(v1)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add8 <t> x (Const8 [-1])) x bool)
+       // result: (Sub8 x (CvtBoolToUint8 <t> bool))
+       for {
+               if v_0.Op != OpAdd8 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst8 || auxIntToInt8(v_0_1.AuxInt) != -1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       v.reset(OpSub8)
+                       v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, t)
+                       v0.AddArg(bool)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add64 <t> x (Const64 [-1])) x bool)
+       // result: (Sub64 x (ZeroExt8to64 <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+       for {
+               if v_0.Op != OpAdd64 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != -1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       v.reset(OpSub64)
+                       v0 := b.NewValue0(v.Pos, OpZeroExt8to64, t)
+                       v1 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8])
+                       v1.AddArg(bool)
+                       v0.AddArg(v1)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add32 <t> x (Const32 [-1])) x bool)
+       // result: (Sub32 x (ZeroExt8to32 <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+       for {
+               if v_0.Op != OpAdd32 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst32 || auxIntToInt32(v_0_1.AuxInt) != -1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       v.reset(OpSub32)
+                       v0 := b.NewValue0(v.Pos, OpZeroExt8to32, t)
+                       v1 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8])
+                       v1.AddArg(bool)
+                       v0.AddArg(v1)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       // match: (CondSelect (Add16 <t> x (Const16 [-1])) x bool)
+       // result: (Sub16 x (ZeroExt8to16 <t> (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)))
+       for {
+               if v_0.Op != OpAdd16 {
+                       break
+               }
+               t := v_0.Type
+               _ = v_0.Args[1]
+               v_0_0 := v_0.Args[0]
+               v_0_1 := v_0.Args[1]
+               for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
+                       x := v_0_0
+                       if v_0_1.Op != OpConst16 || auxIntToInt16(v_0_1.AuxInt) != -1 || x != v_1 {
+                               continue
+                       }
+                       bool := v_2
+                       v.reset(OpSub16)
+                       v0 := b.NewValue0(v.Pos, OpZeroExt8to16, t)
+                       v1 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8])
+                       v1.AddArg(bool)
+                       v0.AddArg(v1)
+                       v.AddArg2(x, v0)
+                       return true
+               }
+               break
+       }
+       return false
+}
 func rewriteValuegeneric_OpConstInterface(v *Value) bool {
        b := v.Block
        typ := &b.Func.Config.Types
index 1058910307ce4b44b61d09600ed725c30f436322..95a9d2cd2325861c7c8e0ae6b067e405349f969a 100644 (file)
@@ -106,7 +106,7 @@ func cmovfloatint2(x, y float64) float64 {
        for r >= y {
                rfr, rexp := frexp(r)
                if rfr < yfr {
-                       rexp = rexp - 1
+                       rexp = rexp - 42
                }
                // amd64:"CMOVQHI"
                // arm64:"CSEL\tMI"
@@ -205,7 +205,7 @@ func cmovinvert6(x, y uint64) uint64 {
 
 func cmovload(a []int, i int, b bool) int {
        if b {
-               i++
+               i += 42
        }
        // See issue 26306
        // amd64:-"CMOVQNE"
@@ -214,7 +214,7 @@ func cmovload(a []int, i int, b bool) int {
 
 func cmovstore(a []int, i int, b bool) {
        if b {
-               i++
+               i += 42
        }
        // amd64:"CMOVQNE"
        a[i] = 7
@@ -451,3 +451,25 @@ func cmovzeroreg1(a, b int) int {
        // ppc64x:"ISEL\t[$]2, R0, R[0-9]+, R[0-9]+"
        return x
 }
+
+func cmovmathadd(a uint, b bool) uint {
+       if b {
+               a++
+       }
+       // amd64:"ADDQ", -"CMOV"
+       // arm64:"CSINC", -"CSEL"
+       // ppc64x:"ADD", -"ISEL"
+       // wasm:"Add", "-Select"
+       return a
+}
+
+func cmovmathsub(a uint, b bool) uint {
+       if b {
+               a--
+       }
+       // amd64:"SUBQ", -"CMOV"
+       // arm64:"SUB", -"CSEL"
+       // ppc64x:"SUB", -"ISEL"
+       // wasm:"Sub", "-Select"
+       return a
+}