From: Michael Munday Date: Tue, 26 Aug 2025 20:17:36 +0000 (+0100) Subject: cmd/compile: optimize comparisons with single bit difference X-Git-Tag: go1.26rc1~287 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=0a569528ea355099af864f7612c3fa1973df30e4;p=gostls13.git cmd/compile: optimize comparisons with single bit difference Optimize comparisons with constants that only differ by 1 bit (i.e. a power of 2). For example: x == 4 || x == 6 -> x|2 == 6 x != 1 && x != 5 -> x|4 != 5 Change-Id: Ic61719e5118446d21cf15652d9da22f7d95b2a15 Reviewed-on: https://go-review.googlesource.com/c/go/+/719420 LUCI-TryBot-Result: Go LUCI Reviewed-by: Junyang Shao Auto-Submit: Keith Randall Reviewed-by: Keith Randall Reviewed-by: Keith Randall --- diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules index 7e3aba1e5e..6efead03ad 100644 --- a/src/cmd/compile/internal/ssa/_gen/generic.rules +++ b/src/cmd/compile/internal/ssa/_gen/generic.rules @@ -337,6 +337,12 @@ (OrB ((Less|Leq)16U (Const16 [c]) x) (Leq16U x (Const16 [d]))) && uint16(c) >= uint16(d+1) && uint16(d+1) > uint16(d) => ((Less|Leq)16U (Const16 [c-d-1]) (Sub16 x (Const16 [d+1]))) (OrB ((Less|Leq)8U (Const8 [c]) x) (Leq8U x (Const8 [d]))) && uint8(c) >= uint8(d+1) && uint8(d+1) > uint8(d) => ((Less|Leq)8U (Const8 [c-d-1]) (Sub8 x (Const8 [d+1]))) +// single bit difference: ( x != c && x != d ) -> ( x|(c^d) != c ) +(AndB (Neq(64|32|16|8) x cv:(Const(64|32|16|8) [c])) (Neq(64|32|16|8) x (Const(64|32|16|8) [d]))) && c|d == c && oneBit(c^d) => (Neq(64|32|16|8) (Or(64|32|16|8) x (Const(64|32|16|8) [c^d])) cv) + +// single bit difference: ( x == c || x == d ) -> ( x|(c^d) == c ) +(OrB (Eq(64|32|16|8) x cv:(Const(64|32|16|8) [c])) (Eq(64|32|16|8) x (Const(64|32|16|8) [d]))) && c|d == c && oneBit(c^d) => (Eq(64|32|16|8) (Or(64|32|16|8) x (Const(64|32|16|8) [c^d])) cv) + // NaN check: ( x != x || x (>|>=|<|<=) c ) -> ( !(c (>=|>|<=|<) x) ) (OrB (Neq64F x x) ((Less|Leq)64F x y:(Const64F [c]))) => (Not ((Leq|Less)64F y x)) (OrB (Neq64F x x) ((Less|Leq)64F y:(Const64F [c]) x)) => (Not ((Leq|Less)64F x y)) diff --git a/src/cmd/compile/internal/ssa/fuse.go b/src/cmd/compile/internal/ssa/fuse.go index 0cee91b532..e95064c1df 100644 --- a/src/cmd/compile/internal/ssa/fuse.go +++ b/src/cmd/compile/internal/ssa/fuse.go @@ -10,7 +10,9 @@ import ( ) // fuseEarly runs fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck). -func fuseEarly(f *Func) { fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck) } +func fuseEarly(f *Func) { + fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeSingleBitDifference|fuseTypeNanCheck) +} // fuseLate runs fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect). func fuseLate(f *Func) { fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect) } @@ -21,6 +23,7 @@ const ( fuseTypePlain fuseType = 1 << iota fuseTypeIf fuseTypeIntInRange + fuseTypeSingleBitDifference fuseTypeNanCheck fuseTypeBranchRedirect fuseTypeShortCircuit @@ -41,6 +44,9 @@ func fuse(f *Func, typ fuseType) { if typ&fuseTypeIntInRange != 0 { changed = fuseIntInRange(b) || changed } + if typ&fuseTypeSingleBitDifference != 0 { + changed = fuseSingleBitDifference(b) || changed + } if typ&fuseTypeNanCheck != 0 { changed = fuseNanCheck(b) || changed } diff --git a/src/cmd/compile/internal/ssa/fuse_comparisons.go b/src/cmd/compile/internal/ssa/fuse_comparisons.go index b6eb8fcb90..898c034485 100644 --- a/src/cmd/compile/internal/ssa/fuse_comparisons.go +++ b/src/cmd/compile/internal/ssa/fuse_comparisons.go @@ -19,6 +19,14 @@ func fuseNanCheck(b *Block) bool { return fuseComparisons(b, canOptNanCheck) } +// fuseSingleBitDifference replaces the short-circuit operators between equality checks with +// constants that only differ by a single bit. For example, it would convert +// `if x == 4 || x == 6 { ... }` into `if (x == 4) | (x == 6) { ... }`. Rewrite rules can +// then optimize these using a bitwise operation, in this case generating `if x|2 == 6 { ... }`. +func fuseSingleBitDifference(b *Block) bool { + return fuseComparisons(b, canOptSingleBitDifference) +} + // fuseComparisons looks for control graphs that match this pattern: // // p - predecessor @@ -229,3 +237,40 @@ func canOptNanCheck(x, y *Value, op Op) bool { } return false } + +// canOptSingleBitDifference returns true if x op y matches either: +// +// v == c || v == d +// v != c && v != d +// +// Where c and d are constant values that differ by a single bit. +func canOptSingleBitDifference(x, y *Value, op Op) bool { + if x.Op != y.Op { + return false + } + switch x.Op { + case OpEq64, OpEq32, OpEq16, OpEq8: + if op != OpOrB { + return false + } + case OpNeq64, OpNeq32, OpNeq16, OpNeq8: + if op != OpAndB { + return false + } + default: + return false + } + + xi := getConstIntArgIndex(x) + if xi < 0 { + return false + } + yi := getConstIntArgIndex(y) + if yi < 0 { + return false + } + if x.Args[xi^1] != y.Args[yi^1] { + return false + } + return oneBit(x.Args[xi].AuxInt ^ y.Args[yi].AuxInt) +} diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index fd5139c0bb..2428f17947 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -5332,6 +5332,182 @@ func rewriteValuegeneric_OpAndB(v *Value) bool { } break } + // match: (AndB (Neq64 x cv:(Const64 [c])) (Neq64 x (Const64 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Neq64 (Or64 x (Const64 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst64 { + continue + } + c := auxIntToInt64(cv.AuxInt) + if v_1.Op != OpNeq64 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst64 { + continue + } + d := auxIntToInt64(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpNeq64) + v0 := b.NewValue0(v.Pos, OpOr64, x.Type) + v1 := b.NewValue0(v.Pos, OpConst64, x.Type) + v1.AuxInt = int64ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } + // match: (AndB (Neq32 x cv:(Const32 [c])) (Neq32 x (Const32 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Neq32 (Or32 x (Const32 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst32 { + continue + } + c := auxIntToInt32(cv.AuxInt) + if v_1.Op != OpNeq32 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst32 { + continue + } + d := auxIntToInt32(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpNeq32) + v0 := b.NewValue0(v.Pos, OpOr32, x.Type) + v1 := b.NewValue0(v.Pos, OpConst32, x.Type) + v1.AuxInt = int32ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } + // match: (AndB (Neq16 x cv:(Const16 [c])) (Neq16 x (Const16 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Neq16 (Or16 x (Const16 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq16 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst16 { + continue + } + c := auxIntToInt16(cv.AuxInt) + if v_1.Op != OpNeq16 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst16 { + continue + } + d := auxIntToInt16(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpNeq16) + v0 := b.NewValue0(v.Pos, OpOr16, x.Type) + v1 := b.NewValue0(v.Pos, OpConst16, x.Type) + v1.AuxInt = int16ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } + // match: (AndB (Neq8 x cv:(Const8 [c])) (Neq8 x (Const8 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Neq8 (Or8 x (Const8 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq8 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst8 { + continue + } + c := auxIntToInt8(cv.AuxInt) + if v_1.Op != OpNeq8 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst8 { + continue + } + d := auxIntToInt8(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpNeq8) + v0 := b.NewValue0(v.Pos, OpOr8, x.Type) + v1 := b.NewValue0(v.Pos, OpConst8, x.Type) + v1.AuxInt = int8ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } return false } func rewriteValuegeneric_OpArraySelect(v *Value) bool { @@ -23242,6 +23418,182 @@ func rewriteValuegeneric_OpOrB(v *Value) bool { } break } + // match: (OrB (Eq64 x cv:(Const64 [c])) (Eq64 x (Const64 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Eq64 (Or64 x (Const64 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpEq64 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst64 { + continue + } + c := auxIntToInt64(cv.AuxInt) + if v_1.Op != OpEq64 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst64 { + continue + } + d := auxIntToInt64(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpEq64) + v0 := b.NewValue0(v.Pos, OpOr64, x.Type) + v1 := b.NewValue0(v.Pos, OpConst64, x.Type) + v1.AuxInt = int64ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } + // match: (OrB (Eq32 x cv:(Const32 [c])) (Eq32 x (Const32 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Eq32 (Or32 x (Const32 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpEq32 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst32 { + continue + } + c := auxIntToInt32(cv.AuxInt) + if v_1.Op != OpEq32 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst32 { + continue + } + d := auxIntToInt32(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpEq32) + v0 := b.NewValue0(v.Pos, OpOr32, x.Type) + v1 := b.NewValue0(v.Pos, OpConst32, x.Type) + v1.AuxInt = int32ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } + // match: (OrB (Eq16 x cv:(Const16 [c])) (Eq16 x (Const16 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Eq16 (Or16 x (Const16 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpEq16 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst16 { + continue + } + c := auxIntToInt16(cv.AuxInt) + if v_1.Op != OpEq16 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst16 { + continue + } + d := auxIntToInt16(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpEq16) + v0 := b.NewValue0(v.Pos, OpOr16, x.Type) + v1 := b.NewValue0(v.Pos, OpConst16, x.Type) + v1.AuxInt = int16ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } + // match: (OrB (Eq8 x cv:(Const8 [c])) (Eq8 x (Const8 [d]))) + // cond: c|d == c && oneBit(c^d) + // result: (Eq8 (Or8 x (Const8 [c^d])) cv) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpEq8 { + continue + } + _ = v_0.Args[1] + v_0_0 := v_0.Args[0] + v_0_1 := v_0.Args[1] + for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 { + x := v_0_0 + cv := v_0_1 + if cv.Op != OpConst8 { + continue + } + c := auxIntToInt8(cv.AuxInt) + if v_1.Op != OpEq8 { + continue + } + _ = v_1.Args[1] + v_1_0 := v_1.Args[0] + v_1_1 := v_1.Args[1] + for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 { + if x != v_1_0 || v_1_1.Op != OpConst8 { + continue + } + d := auxIntToInt8(v_1_1.AuxInt) + if !(c|d == c && oneBit(c^d)) { + continue + } + v.reset(OpEq8) + v0 := b.NewValue0(v.Pos, OpOr8, x.Type) + v1 := b.NewValue0(v.Pos, OpConst8, x.Type) + v1.AuxInt = int8ToAuxInt(c ^ d) + v0.AddArg2(x, v1) + v.AddArg2(v0, cv) + return true + } + } + } + break + } // match: (OrB (Neq64F x x) (Less64F x y:(Const64F [c]))) // result: (Not (Leq64F y x)) for { diff --git a/test/codegen/fuse.go b/test/codegen/fuse.go index 4fbb03bef8..e5a28549dc 100644 --- a/test/codegen/fuse.go +++ b/test/codegen/fuse.go @@ -198,6 +198,126 @@ func ui4d(c <-chan uint8) { } } +// ------------------------------------ // +// single bit difference (conjunction) // +// ------------------------------------ // + +func sisbc64(c <-chan int64) { + // amd64: "ORQ [$]2," + // riscv64: "ORI [$]2," + for x := <-c; x != 4 && x != 6; x = <-c { + } +} + +func sisbc32(c <-chan int32) { + // amd64: "ORL [$]4," + // riscv64: "ORI [$]4," + for x := <-c; x != -1 && x != -5; x = <-c { + } +} + +func sisbc16(c <-chan int16) { + // amd64: "ORL [$]32," + // riscv64: "ORI [$]32," + for x := <-c; x != 16 && x != 48; x = <-c { + } +} + +func sisbc8(c <-chan int8) { + // amd64: "ORL [$]16," + // riscv64: "ORI [$]16," + for x := <-c; x != -15 && x != -31; x = <-c { + } +} + +func uisbc64(c <-chan uint64) { + // amd64: "ORQ [$]4," + // riscv64: "ORI [$]4," + for x := <-c; x != 1 && x != 5; x = <-c { + } +} + +func uisbc32(c <-chan uint32) { + // amd64: "ORL [$]4," + // riscv64: "ORI [$]4," + for x := <-c; x != 2 && x != 6; x = <-c { + } +} + +func uisbc16(c <-chan uint16) { + // amd64: "ORL [$]32," + // riscv64: "ORI [$]32," + for x := <-c; x != 16 && x != 48; x = <-c { + } +} + +func uisbc8(c <-chan uint8) { + // amd64: "ORL [$]64," + // riscv64: "ORI [$]64," + for x := <-c; x != 64 && x != 0; x = <-c { + } +} + +// ------------------------------------ // +// single bit difference (disjunction) // +// ------------------------------------ // + +func sisbd64(c <-chan int64) { + // amd64: "ORQ [$]2," + // riscv64: "ORI [$]2," + for x := <-c; x == 4 || x == 6; x = <-c { + } +} + +func sisbd32(c <-chan int32) { + // amd64: "ORL [$]4," + // riscv64: "ORI [$]4," + for x := <-c; x == -1 || x == -5; x = <-c { + } +} + +func sisbd16(c <-chan int16) { + // amd64: "ORL [$]32," + // riscv64: "ORI [$]32," + for x := <-c; x == 16 || x == 48; x = <-c { + } +} + +func sisbd8(c <-chan int8) { + // amd64: "ORL [$]16," + // riscv64: "ORI [$]16," + for x := <-c; x == -15 || x == -31; x = <-c { + } +} + +func uisbd64(c <-chan uint64) { + // amd64: "ORQ [$]4," + // riscv64: "ORI [$]4," + for x := <-c; x == 1 || x == 5; x = <-c { + } +} + +func uisbd32(c <-chan uint32) { + // amd64: "ORL [$]4," + // riscv64: "ORI [$]4," + for x := <-c; x == 2 || x == 6; x = <-c { + } +} + +func uisbd16(c <-chan uint16) { + // amd64: "ORL [$]32," + // riscv64: "ORI [$]32," + for x := <-c; x == 16 || x == 48; x = <-c { + } +} + +func uisbd8(c <-chan uint8) { + // amd64: "ORL [$]64," + // riscv64: "ORI [$]64," + for x := <-c; x == 64 || x == 0; x = <-c { + } +} + // -------------------------------------// // merge NaN checks // // ------------------------------------ //