From: Jorropo Date: Thu, 3 Jul 2025 00:57:25 +0000 (+0200) Subject: cmd/compile: rewrite condselects into doublings and halvings X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=ce05ad448fe6ea3b9b33c0eab1143dcb40e3bbc3;p=gostls13.git cmd/compile: rewrite condselects into doublings and halvings For performance see CL 685676. This allows something like: if y { x *= 2 } To be compiled to: SHLXQ BX, AX, AX Instead of: MOVQ AX, CX SHLQ $1, CX MOVBLZX BL, DX TESTQ DX, DX CMOVQNE CX, AX While ./make.bash uniqued per LOC, there is 2 doublings and 4 halvings. Change-Id: Ic0727cbf429528a2dbf17cbfc3b0121db8387444 Reviewed-on: https://go-review.googlesource.com/c/go/+/685695 LUCI-TryBot-Result: Go LUCI Reviewed-by: Keith Randall Reviewed-by: Michael Knyszek Reviewed-by: Keith Randall --- diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules index 89657bdabb..b98dfae2d5 100644 --- a/src/cmd/compile/internal/ssa/_gen/generic.rules +++ b/src/cmd/compile/internal/ssa/_gen/generic.rules @@ -2855,3 +2855,10 @@ // if b { x-- } => x -= b (CondSelect (Add8 x (Const8 [-1])) x bool) => (Sub8 x (CvtBoolToUint8 bool)) (CondSelect (Add(64|32|16) x (Const(64|32|16) [-1])) x bool) => (Sub(64|32|16) x (ZeroExt8to(64|32|16) (CvtBoolToUint8 bool))) + +// if b { x <<= 1 } => x <<= b +(CondSelect (Lsh(64|32|16|8)x64 x (Const64 [1])) x bool) => (Lsh(64|32|16|8)x8 [true] x (CvtBoolToUint8 bool)) + +// if b { x >>= 1 } => x >>= b +(CondSelect (Rsh(64|32|16|8)x64 x (Const64 [1])) x bool) => (Rsh(64|32|16|8)x8 [true] x (CvtBoolToUint8 bool)) +(CondSelect (Rsh(64|32|16|8)Ux64 x (Const64 [1])) x bool) => (Rsh(64|32|16|8)Ux8 [true] x (CvtBoolToUint8 bool)) diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index a8c3373e40..0b9f9c09f9 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -5942,6 +5942,246 @@ func rewriteValuegeneric_OpCondSelect(v *Value) bool { } break } + // match: (CondSelect (Lsh64x64 x (Const64 [1])) x bool) + // result: (Lsh64x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpLsh64x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpLsh64x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Lsh32x64 x (Const64 [1])) x bool) + // result: (Lsh32x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpLsh32x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpLsh32x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Lsh16x64 x (Const64 [1])) x bool) + // result: (Lsh16x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpLsh16x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpLsh16x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Lsh8x64 x (Const64 [1])) x bool) + // result: (Lsh8x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpLsh8x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpLsh8x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh64x64 x (Const64 [1])) x bool) + // result: (Rsh64x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh64x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh64x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh32x64 x (Const64 [1])) x bool) + // result: (Rsh32x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh32x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh32x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh16x64 x (Const64 [1])) x bool) + // result: (Rsh16x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh16x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh16x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh8x64 x (Const64 [1])) x bool) + // result: (Rsh8x8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh8x64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh8x8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh64Ux64 x (Const64 [1])) x bool) + // result: (Rsh64Ux8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh64Ux64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh64Ux8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh32Ux64 x (Const64 [1])) x bool) + // result: (Rsh32Ux8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh32Ux64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh32Ux8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh16Ux64 x (Const64 [1])) x bool) + // result: (Rsh16Ux8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh16Ux64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh16Ux8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } + // match: (CondSelect (Rsh8Ux64 x (Const64 [1])) x bool) + // result: (Rsh8Ux8 [true] x (CvtBoolToUint8 bool)) + for { + if v_0.Op != OpRsh8Ux64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + v_0_1 := v_0.Args[1] + if v_0_1.Op != OpConst64 || auxIntToInt64(v_0_1.AuxInt) != 1 || x != v_1 { + break + } + bool := v_2 + v.reset(OpRsh8Ux8) + v.AuxInt = boolToAuxInt(true) + v0 := b.NewValue0(v.Pos, OpCvtBoolToUint8, types.Types[types.TUINT8]) + v0.AddArg(bool) + v.AddArg2(x, v0) + return true + } return false } func rewriteValuegeneric_OpConstInterface(v *Value) bool { diff --git a/test/codegen/condmove.go b/test/codegen/condmove.go index 95a9d2cd23..5659972eed 100644 --- a/test/codegen/condmove.go +++ b/test/codegen/condmove.go @@ -473,3 +473,36 @@ func cmovmathsub(a uint, b bool) uint { // wasm:"Sub", "-Select" return a } + +func cmovmathdouble(a uint, b bool) uint { + if b { + a *= 2 + } + // amd64:"SHL", -"CMOV" + // amd64/v3:"SHL", -"CMOV", -"MOV" + // arm64:"LSL", -"CSEL" + // wasm:"Shl", "-Select" + return a +} + +func cmovmathhalvei(a int, b bool) int { + if b { + // For some reason on arm64 it attributes the ASR to inside this block rather than where the Phi node is. + // arm64:"ASR", -"CSEL" + a /= 2 + } + // arm64:-"CSEL" + // wasm:"Shr", "-Select" + return a +} + +func cmovmathhalveu(a uint, b bool) uint { + if b { + a /= 2 + } + // amd64:"SHR", -"CMOV" + // amd64/v3:"SHR", -"CMOV", -"MOV" + // arm64:"LSR", -"CSEL" + // wasm:"Shr", "-Select" + return a +}