]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: add additional arm64 bit field rules
authorkhr@golang.org <khr@golang.org>
Sun, 4 Aug 2024 12:41:38 +0000 (05:41 -0700)
committerKeith Randall <khr@golang.org>
Mon, 12 Aug 2024 21:03:55 +0000 (21:03 +0000)
Get rid of TODO in prove pass.
We currently avoid marking shifts of constants as bounded, where
bounded means we don't have to worry about <0 or >=bitwidth shifts.
We do this because it causes different rule applications during lowering
which cause some codegen tests to fail.

Add some new rules which ensure that we get the right final instruction
sequence regardless of the ordering. Then we can remove this special case.

Change-Id: I4e962d4f09992b42ab47e123de5ded3b8b8fb205
Reviewed-on: https://go-review.googlesource.com/c/go/+/602935
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/cmd/compile/internal/ssa/_gen/ARM64.rules
src/cmd/compile/internal/ssa/prove.go
src/cmd/compile/internal/ssa/rewriteARM64.go
test/codegen/comparisons.go

index 12badbdcb65bbcbdfcc38073d6f6149726ff0102..6c71e231b694c4124014ca33c96ecf4f31ea4ed9 100644 (file)
 (SRAconst [sc] (SBFIZ [bfc] x)) && sc >= bfc.getARM64BFlsb()
        && sc < bfc.getARM64BFlsb()+bfc.getARM64BFwidth()
        => (SBFX [armBFAuxInt(sc-bfc.getARM64BFlsb(), bfc.getARM64BFlsb()+bfc.getARM64BFwidth()-sc)] x)
+(SBFX [bfc] s:(SLLconst [sc] x))
+       && s.Uses == 1
+       && sc <= bfc.getARM64BFlsb()
+       => (SBFX [armBFAuxInt(bfc.getARM64BFlsb() - sc, bfc.getARM64BFwidth())] x)
+(SBFX [bfc] s:(SLLconst [sc] x))
+       && s.Uses == 1
+       && sc > bfc.getARM64BFlsb()
+       => (SBFIZ [armBFAuxInt(sc - bfc.getARM64BFlsb(), bfc.getARM64BFwidth() - (sc-bfc.getARM64BFlsb()))] x)
 
 // ubfiz
 // (x << lc) >> rc
        (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), arm64BFWidth(c, 0)))] x)
 (UBFX [bfc] (ANDconst [c] x)) && isARM64BFMask(0, c, 0) && bfc.getARM64BFlsb() + bfc.getARM64BFwidth() <= arm64BFWidth(c, 0) =>
        (UBFX [bfc] x)
-// merge ubfx and zerso-extension into ubfx
+// merge ubfx and zero-extension into ubfx
 (MOVWUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 32 => (UBFX [bfc] x)
 (MOVHUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 16 => (UBFX [bfc] x)
 (MOVBUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <=  8 => (UBFX [bfc] x)
 
+// Extracting bits from across a zero-extension boundary.
+(UBFX [bfc] e:(MOVWUreg x))
+       && e.Uses == 1
+       && bfc.getARM64BFlsb() < 32
+       => (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))] x)
+(UBFX [bfc] e:(MOVHUreg x))
+       && e.Uses == 1
+       && bfc.getARM64BFlsb() < 16
+       => (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))] x)
+(UBFX [bfc] e:(MOVBUreg x))
+       && e.Uses == 1
+       && bfc.getARM64BFlsb() < 8
+       => (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))] x)
+
 // ubfiz/ubfx combinations: merge shifts into bitfield ops
 (SRLconst [sc] (UBFX [bfc] x)) && sc < bfc.getARM64BFwidth()
        => (UBFX [armBFAuxInt(bfc.getARM64BFlsb()+sc, bfc.getARM64BFwidth()-sc)] x)
 (OR (UBFIZ [bfc] x) (ANDconst [ac] y))
        && ac == ^((1<<uint(bfc.getARM64BFwidth())-1) << uint(bfc.getARM64BFlsb()))
        => (BFI [bfc] y x)
+(ORshiftLL [s] (ANDconst [xc] x) (ANDconst [yc] y))
+       && xc == ^(yc << s)    // opposite masks
+       && yc & (yc+1) == 0    // power of 2 minus 1
+       && yc > 0              // not 0, not all 64 bits (there are better rewrites in that case)
+       && s+log64(yc+1) <= 64 // shifted mask doesn't overflow
+       => (BFI [armBFAuxInt(s, log64(yc+1))] x y)
 (ORshiftRL [rc] (ANDconst [ac] x) (SLLconst [lc] y))
        && lc > rc && ac == ^((1<<uint(64-lc)-1) << uint64(lc-rc))
        => (BFI [armBFAuxInt(lc-rc, 64-lc)] x y)
index 9bc2fdc90c1566960b121e6a17e998110194b291..8acd38aa69f14c084ed9b0d8464b373e13f00def 100644 (file)
@@ -2008,15 +2008,6 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
                        lim := ft.limits[by.ID]
                        bits := 8 * v.Args[0].Type.Size()
                        if lim.umax < uint64(bits) || (lim.max < bits && ft.isNonNegative(by)) {
-                               if by.isGenericIntConst() {
-                                       // TODO: get rid of this block.
-                                       // Currently this causes lowering to happen
-                                       // in different orders, which causes some
-                                       // problems for codegen tests for arm64
-                                       // where rule application order results
-                                       // in different final instructions.
-                                       break
-                               }
                                v.AuxInt = 1 // see shiftIsBounded
                                if b.Func.pass.debug > 0 && !by.isGenericIntConst() {
                                        b.Func.Warnl(v.Pos, "Proved %v bounded", v.Op)
index 44b171d6055bbd2c1bd4e62b6e2328593c1991c3..ff13ae3da85a35d5ab948d6495f7967ac3cbdcf9 100644 (file)
@@ -352,6 +352,8 @@ func rewriteValueARM64(v *Value) bool {
                return rewriteValueARM64_OpARM64RORW(v)
        case OpARM64SBCSflags:
                return rewriteValueARM64_OpARM64SBCSflags(v)
+       case OpARM64SBFX:
+               return rewriteValueARM64_OpARM64SBFX(v)
        case OpARM64SLL:
                return rewriteValueARM64_OpARM64SLL(v)
        case OpARM64SLLconst:
@@ -15203,6 +15205,29 @@ func rewriteValueARM64_OpARM64ORshiftLL(v *Value) bool {
                v.AddArg2(x2, x)
                return true
        }
+       // match: (ORshiftLL [s] (ANDconst [xc] x) (ANDconst [yc] y))
+       // cond: xc == ^(yc << s) && yc & (yc+1) == 0 && yc > 0 && s+log64(yc+1) <= 64
+       // result: (BFI [armBFAuxInt(s, log64(yc+1))] x y)
+       for {
+               s := auxIntToInt64(v.AuxInt)
+               if v_0.Op != OpARM64ANDconst {
+                       break
+               }
+               xc := auxIntToInt64(v_0.AuxInt)
+               x := v_0.Args[0]
+               if v_1.Op != OpARM64ANDconst {
+                       break
+               }
+               yc := auxIntToInt64(v_1.AuxInt)
+               y := v_1.Args[0]
+               if !(xc == ^(yc<<s) && yc&(yc+1) == 0 && yc > 0 && s+log64(yc+1) <= 64) {
+                       break
+               }
+               v.reset(OpARM64BFI)
+               v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(s, log64(yc+1)))
+               v.AddArg2(x, y)
+               return true
+       }
        // match: (ORshiftLL [sc] (UBFX [bfc] x) (SRLconst [sc] y))
        // cond: sc == bfc.getARM64BFwidth()
        // result: (BFXIL [bfc] y x)
@@ -15546,6 +15571,48 @@ func rewriteValueARM64_OpARM64SBCSflags(v *Value) bool {
        }
        return false
 }
+func rewriteValueARM64_OpARM64SBFX(v *Value) bool {
+       v_0 := v.Args[0]
+       // match: (SBFX [bfc] s:(SLLconst [sc] x))
+       // cond: s.Uses == 1 && sc <= bfc.getARM64BFlsb()
+       // result: (SBFX [armBFAuxInt(bfc.getARM64BFlsb() - sc, bfc.getARM64BFwidth())] x)
+       for {
+               bfc := auxIntToArm64BitField(v.AuxInt)
+               s := v_0
+               if s.Op != OpARM64SLLconst {
+                       break
+               }
+               sc := auxIntToInt64(s.AuxInt)
+               x := s.Args[0]
+               if !(s.Uses == 1 && sc <= bfc.getARM64BFlsb()) {
+                       break
+               }
+               v.reset(OpARM64SBFX)
+               v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb()-sc, bfc.getARM64BFwidth()))
+               v.AddArg(x)
+               return true
+       }
+       // match: (SBFX [bfc] s:(SLLconst [sc] x))
+       // cond: s.Uses == 1 && sc > bfc.getARM64BFlsb()
+       // result: (SBFIZ [armBFAuxInt(sc - bfc.getARM64BFlsb(), bfc.getARM64BFwidth() - (sc-bfc.getARM64BFlsb()))] x)
+       for {
+               bfc := auxIntToArm64BitField(v.AuxInt)
+               s := v_0
+               if s.Op != OpARM64SLLconst {
+                       break
+               }
+               sc := auxIntToInt64(s.AuxInt)
+               x := s.Args[0]
+               if !(s.Uses == 1 && sc > bfc.getARM64BFlsb()) {
+                       break
+               }
+               v.reset(OpARM64SBFIZ)
+               v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(sc-bfc.getARM64BFlsb(), bfc.getARM64BFwidth()-(sc-bfc.getARM64BFlsb())))
+               v.AddArg(x)
+               return true
+       }
+       return false
+}
 func rewriteValueARM64_OpARM64SLL(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
@@ -16946,6 +17013,60 @@ func rewriteValueARM64_OpARM64UBFX(v *Value) bool {
                v.AddArg(x)
                return true
        }
+       // match: (UBFX [bfc] e:(MOVWUreg x))
+       // cond: e.Uses == 1 && bfc.getARM64BFlsb() < 32
+       // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))] x)
+       for {
+               bfc := auxIntToArm64BitField(v.AuxInt)
+               e := v_0
+               if e.Op != OpARM64MOVWUreg {
+                       break
+               }
+               x := e.Args[0]
+               if !(e.Uses == 1 && bfc.getARM64BFlsb() < 32) {
+                       break
+               }
+               v.reset(OpARM64UBFX)
+               v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb())))
+               v.AddArg(x)
+               return true
+       }
+       // match: (UBFX [bfc] e:(MOVHUreg x))
+       // cond: e.Uses == 1 && bfc.getARM64BFlsb() < 16
+       // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))] x)
+       for {
+               bfc := auxIntToArm64BitField(v.AuxInt)
+               e := v_0
+               if e.Op != OpARM64MOVHUreg {
+                       break
+               }
+               x := e.Args[0]
+               if !(e.Uses == 1 && bfc.getARM64BFlsb() < 16) {
+                       break
+               }
+               v.reset(OpARM64UBFX)
+               v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb())))
+               v.AddArg(x)
+               return true
+       }
+       // match: (UBFX [bfc] e:(MOVBUreg x))
+       // cond: e.Uses == 1 && bfc.getARM64BFlsb() < 8
+       // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))] x)
+       for {
+               bfc := auxIntToArm64BitField(v.AuxInt)
+               e := v_0
+               if e.Op != OpARM64MOVBUreg {
+                       break
+               }
+               x := e.Args[0]
+               if !(e.Uses == 1 && bfc.getARM64BFlsb() < 8) {
+                       break
+               }
+               v.reset(OpARM64UBFX)
+               v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb())))
+               v.AddArg(x)
+               return true
+       }
        // match: (UBFX [bfc] (SRLconst [sc] x))
        // cond: sc+bfc.getARM64BFwidth()+bfc.getARM64BFlsb() < 64
        // result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb()+sc, bfc.getARM64BFwidth())] x)
index 909cf954b14d3833de331022bb62b2bbeae14509..5fbb31c00c8cd87b15bff18bbfe76ac44ea02dbe 100644 (file)
@@ -233,9 +233,9 @@ func CmpToZero(a, b, d int32, e, f int64, deOptC0, deOptC1 bool) int32 {
        // arm:`AND`,-`TST`
        // 386:`ANDL`
        c6 := a&d >= 0
-       // arm64:`TST\sR[0-9]+<<3,\sR[0-9]+`
+       // For arm64, could be TST+BGE or AND+TBZ
        c7 := e&(f<<3) < 0
-       // arm64:`CMN\sR[0-9]+<<3,\sR[0-9]+`
+       // For arm64, could be CMN+BPL or ADD+TBZ
        c8 := e+(f<<3) < 0
        // arm64:`TST\sR[0-9],\sR[0-9]+`
        c9 := e&(-19) < 0