From 72735094660a475a69050b7368c56b25346f5406 Mon Sep 17 00:00:00 2001 From: "khr@golang.org" Date: Sun, 4 Aug 2024 05:41:38 -0700 Subject: [PATCH] cmd/compile: add additional arm64 bit field rules Get rid of TODO in prove pass. We currently avoid marking shifts of constants as bounded, where bounded means we don't have to worry about <0 or >=bitwidth shifts. We do this because it causes different rule applications during lowering which cause some codegen tests to fail. Add some new rules which ensure that we get the right final instruction sequence regardless of the ordering. Then we can remove this special case. Change-Id: I4e962d4f09992b42ab47e123de5ded3b8b8fb205 Reviewed-on: https://go-review.googlesource.com/c/go/+/602935 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Reviewed-by: Michael Knyszek --- src/cmd/compile/internal/ssa/_gen/ARM64.rules | 30 ++++- src/cmd/compile/internal/ssa/prove.go | 9 -- src/cmd/compile/internal/ssa/rewriteARM64.go | 121 ++++++++++++++++++ test/codegen/comparisons.go | 4 +- 4 files changed, 152 insertions(+), 12 deletions(-) diff --git a/src/cmd/compile/internal/ssa/_gen/ARM64.rules b/src/cmd/compile/internal/ssa/_gen/ARM64.rules index 12badbdcb6..6c71e231b6 100644 --- a/src/cmd/compile/internal/ssa/_gen/ARM64.rules +++ b/src/cmd/compile/internal/ssa/_gen/ARM64.rules @@ -1789,6 +1789,14 @@ (SRAconst [sc] (SBFIZ [bfc] x)) && sc >= bfc.getARM64BFlsb() && sc < bfc.getARM64BFlsb()+bfc.getARM64BFwidth() => (SBFX [armBFAuxInt(sc-bfc.getARM64BFlsb(), bfc.getARM64BFlsb()+bfc.getARM64BFwidth()-sc)] x) +(SBFX [bfc] s:(SLLconst [sc] x)) + && s.Uses == 1 + && sc <= bfc.getARM64BFlsb() + => (SBFX [armBFAuxInt(bfc.getARM64BFlsb() - sc, bfc.getARM64BFwidth())] x) +(SBFX [bfc] s:(SLLconst [sc] x)) + && s.Uses == 1 + && sc > bfc.getARM64BFlsb() + => (SBFIZ [armBFAuxInt(sc - bfc.getARM64BFlsb(), bfc.getARM64BFwidth() - (sc-bfc.getARM64BFlsb()))] x) // ubfiz // (x << lc) >> rc @@ -1833,11 +1841,25 @@ (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), arm64BFWidth(c, 0)))] x) (UBFX [bfc] (ANDconst [c] x)) && isARM64BFMask(0, c, 0) && bfc.getARM64BFlsb() + bfc.getARM64BFwidth() <= arm64BFWidth(c, 0) => (UBFX [bfc] x) -// merge ubfx and zerso-extension into ubfx +// merge ubfx and zero-extension into ubfx (MOVWUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 32 => (UBFX [bfc] x) (MOVHUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 16 => (UBFX [bfc] x) (MOVBUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 8 => (UBFX [bfc] x) +// Extracting bits from across a zero-extension boundary. +(UBFX [bfc] e:(MOVWUreg x)) + && e.Uses == 1 + && bfc.getARM64BFlsb() < 32 + => (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))] x) +(UBFX [bfc] e:(MOVHUreg x)) + && e.Uses == 1 + && bfc.getARM64BFlsb() < 16 + => (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))] x) +(UBFX [bfc] e:(MOVBUreg x)) + && e.Uses == 1 + && bfc.getARM64BFlsb() < 8 + => (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))] x) + // ubfiz/ubfx combinations: merge shifts into bitfield ops (SRLconst [sc] (UBFX [bfc] x)) && sc < bfc.getARM64BFwidth() => (UBFX [armBFAuxInt(bfc.getARM64BFlsb()+sc, bfc.getARM64BFwidth()-sc)] x) @@ -1868,6 +1890,12 @@ (OR (UBFIZ [bfc] x) (ANDconst [ac] y)) && ac == ^((1< (BFI [bfc] y x) +(ORshiftLL [s] (ANDconst [xc] x) (ANDconst [yc] y)) + && xc == ^(yc << s) // opposite masks + && yc & (yc+1) == 0 // power of 2 minus 1 + && yc > 0 // not 0, not all 64 bits (there are better rewrites in that case) + && s+log64(yc+1) <= 64 // shifted mask doesn't overflow + => (BFI [armBFAuxInt(s, log64(yc+1))] x y) (ORshiftRL [rc] (ANDconst [ac] x) (SLLconst [lc] y)) && lc > rc && ac == ^((1< (BFI [armBFAuxInt(lc-rc, 64-lc)] x y) diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 9bc2fdc90c..8acd38aa69 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -2008,15 +2008,6 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { lim := ft.limits[by.ID] bits := 8 * v.Args[0].Type.Size() if lim.umax < uint64(bits) || (lim.max < bits && ft.isNonNegative(by)) { - if by.isGenericIntConst() { - // TODO: get rid of this block. - // Currently this causes lowering to happen - // in different orders, which causes some - // problems for codegen tests for arm64 - // where rule application order results - // in different final instructions. - break - } v.AuxInt = 1 // see shiftIsBounded if b.Func.pass.debug > 0 && !by.isGenericIntConst() { b.Func.Warnl(v.Pos, "Proved %v bounded", v.Op) diff --git a/src/cmd/compile/internal/ssa/rewriteARM64.go b/src/cmd/compile/internal/ssa/rewriteARM64.go index 44b171d605..ff13ae3da8 100644 --- a/src/cmd/compile/internal/ssa/rewriteARM64.go +++ b/src/cmd/compile/internal/ssa/rewriteARM64.go @@ -352,6 +352,8 @@ func rewriteValueARM64(v *Value) bool { return rewriteValueARM64_OpARM64RORW(v) case OpARM64SBCSflags: return rewriteValueARM64_OpARM64SBCSflags(v) + case OpARM64SBFX: + return rewriteValueARM64_OpARM64SBFX(v) case OpARM64SLL: return rewriteValueARM64_OpARM64SLL(v) case OpARM64SLLconst: @@ -15203,6 +15205,29 @@ func rewriteValueARM64_OpARM64ORshiftLL(v *Value) bool { v.AddArg2(x2, x) return true } + // match: (ORshiftLL [s] (ANDconst [xc] x) (ANDconst [yc] y)) + // cond: xc == ^(yc << s) && yc & (yc+1) == 0 && yc > 0 && s+log64(yc+1) <= 64 + // result: (BFI [armBFAuxInt(s, log64(yc+1))] x y) + for { + s := auxIntToInt64(v.AuxInt) + if v_0.Op != OpARM64ANDconst { + break + } + xc := auxIntToInt64(v_0.AuxInt) + x := v_0.Args[0] + if v_1.Op != OpARM64ANDconst { + break + } + yc := auxIntToInt64(v_1.AuxInt) + y := v_1.Args[0] + if !(xc == ^(yc< 0 && s+log64(yc+1) <= 64) { + break + } + v.reset(OpARM64BFI) + v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(s, log64(yc+1))) + v.AddArg2(x, y) + return true + } // match: (ORshiftLL [sc] (UBFX [bfc] x) (SRLconst [sc] y)) // cond: sc == bfc.getARM64BFwidth() // result: (BFXIL [bfc] y x) @@ -15546,6 +15571,48 @@ func rewriteValueARM64_OpARM64SBCSflags(v *Value) bool { } return false } +func rewriteValueARM64_OpARM64SBFX(v *Value) bool { + v_0 := v.Args[0] + // match: (SBFX [bfc] s:(SLLconst [sc] x)) + // cond: s.Uses == 1 && sc <= bfc.getARM64BFlsb() + // result: (SBFX [armBFAuxInt(bfc.getARM64BFlsb() - sc, bfc.getARM64BFwidth())] x) + for { + bfc := auxIntToArm64BitField(v.AuxInt) + s := v_0 + if s.Op != OpARM64SLLconst { + break + } + sc := auxIntToInt64(s.AuxInt) + x := s.Args[0] + if !(s.Uses == 1 && sc <= bfc.getARM64BFlsb()) { + break + } + v.reset(OpARM64SBFX) + v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb()-sc, bfc.getARM64BFwidth())) + v.AddArg(x) + return true + } + // match: (SBFX [bfc] s:(SLLconst [sc] x)) + // cond: s.Uses == 1 && sc > bfc.getARM64BFlsb() + // result: (SBFIZ [armBFAuxInt(sc - bfc.getARM64BFlsb(), bfc.getARM64BFwidth() - (sc-bfc.getARM64BFlsb()))] x) + for { + bfc := auxIntToArm64BitField(v.AuxInt) + s := v_0 + if s.Op != OpARM64SLLconst { + break + } + sc := auxIntToInt64(s.AuxInt) + x := s.Args[0] + if !(s.Uses == 1 && sc > bfc.getARM64BFlsb()) { + break + } + v.reset(OpARM64SBFIZ) + v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(sc-bfc.getARM64BFlsb(), bfc.getARM64BFwidth()-(sc-bfc.getARM64BFlsb()))) + v.AddArg(x) + return true + } + return false +} func rewriteValueARM64_OpARM64SLL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -16946,6 +17013,60 @@ func rewriteValueARM64_OpARM64UBFX(v *Value) bool { v.AddArg(x) return true } + // match: (UBFX [bfc] e:(MOVWUreg x)) + // cond: e.Uses == 1 && bfc.getARM64BFlsb() < 32 + // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))] x) + for { + bfc := auxIntToArm64BitField(v.AuxInt) + e := v_0 + if e.Op != OpARM64MOVWUreg { + break + } + x := e.Args[0] + if !(e.Uses == 1 && bfc.getARM64BFlsb() < 32) { + break + } + v.reset(OpARM64UBFX) + v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))) + v.AddArg(x) + return true + } + // match: (UBFX [bfc] e:(MOVHUreg x)) + // cond: e.Uses == 1 && bfc.getARM64BFlsb() < 16 + // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))] x) + for { + bfc := auxIntToArm64BitField(v.AuxInt) + e := v_0 + if e.Op != OpARM64MOVHUreg { + break + } + x := e.Args[0] + if !(e.Uses == 1 && bfc.getARM64BFlsb() < 16) { + break + } + v.reset(OpARM64UBFX) + v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))) + v.AddArg(x) + return true + } + // match: (UBFX [bfc] e:(MOVBUreg x)) + // cond: e.Uses == 1 && bfc.getARM64BFlsb() < 8 + // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))] x) + for { + bfc := auxIntToArm64BitField(v.AuxInt) + e := v_0 + if e.Op != OpARM64MOVBUreg { + break + } + x := e.Args[0] + if !(e.Uses == 1 && bfc.getARM64BFlsb() < 8) { + break + } + v.reset(OpARM64UBFX) + v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))) + v.AddArg(x) + return true + } // match: (UBFX [bfc] (SRLconst [sc] x)) // cond: sc+bfc.getARM64BFwidth()+bfc.getARM64BFlsb() < 64 // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb()+sc, bfc.getARM64BFwidth())] x) diff --git a/test/codegen/comparisons.go b/test/codegen/comparisons.go index 909cf954b1..5fbb31c00c 100644 --- a/test/codegen/comparisons.go +++ b/test/codegen/comparisons.go @@ -233,9 +233,9 @@ func CmpToZero(a, b, d int32, e, f int64, deOptC0, deOptC1 bool) int32 { // arm:`AND`,-`TST` // 386:`ANDL` c6 := a&d >= 0 - // arm64:`TST\sR[0-9]+<<3,\sR[0-9]+` + // For arm64, could be TST+BGE or AND+TBZ c7 := e&(f<<3) < 0 - // arm64:`CMN\sR[0-9]+<<3,\sR[0-9]+` + // For arm64, could be CMN+BPL or ADD+TBZ c8 := e+(f<<3) < 0 // arm64:`TST\sR[0-9],\sR[0-9]+` c9 := e&(-19) < 0 -- 2.48.1