]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.ssa] cmd/compile/internal/ssa: remove proven redundant controls.
authorAlexandru Moșoi <mosoi@google.com>
Fri, 19 Feb 2016 11:14:42 +0000 (12:14 +0100)
committerAlexandru Moșoi <alexandru@mosoi.ro>
Sun, 28 Feb 2016 19:48:20 +0000 (19:48 +0000)
* It does very simple bounds checking elimination. E.g.
removes the second check in for i := range a { a[i]++; a[i++]; }
* Improves on the following redundant expression:
return a6 || (a6 || (a6 || a4)) || (a6 || (a4 || a6 || (false || a6)))
* Linear in the number of block edges.

I patched in CL 12960 that does bounds, nil and constant propagation
to make sure this CL is not just redundant. Size of pkg/tool/linux_amd64/*
(excluding compile which is affected by this change):

With IsInBounds and IsSliceInBounds
-this -12960 92285080
+this -12960 91947416
-this +12960 91978976
+this +12960 91923088

Gain is ~110% of 12960.

Without IsInBounds and IsSliceInBounds (older run)
-this -12960 95515512
+this -12960 95492536
-this +12960 95216920
+this +12960 95204440

Shaves 22k on its own.

* Can we handle IsInBounds better with this? In
for i := range a { a[i]++; } the bounds checking at a[i]
is not eliminated.

Change-Id: I98957427399145fb33693173fd4d5a8d71c7cc20
Reviewed-on: https://go-review.googlesource.com/19710
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Run-TryBot: Alexandru Moșoi <alexandru@mosoi.ro>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/cmd/compile/internal/ssa/compile.go
src/cmd/compile/internal/ssa/prove.go [new file with mode: 0644]
test/prove.go [new file with mode: 0644]

index 23dab9e2735589bb96021d594b4819de604617fb..5e68ea004e84914d3331e14eb3e6f6ae2902f3e9 100644 (file)
@@ -165,6 +165,7 @@ var passes = [...]pass{
        {name: "opt deadcode", fn: deadcode},             // remove any blocks orphaned during opt
        {name: "generic cse", fn: cse},
        {name: "nilcheckelim", fn: nilcheckelim},
+       {name: "prove", fn: prove},
        {name: "generic deadcode", fn: deadcode},
        {name: "fuse", fn: fuse},
        {name: "dse", fn: dse},
@@ -193,6 +194,10 @@ type constraint struct {
 }
 
 var passOrder = [...]constraint{
+       // prove reliese on common-subexpression elimination for maximum benefits.
+       {"generic cse", "prove"},
+       // deadcode after prove to eliminate all new dead blocks.
+       {"prove", "generic deadcode"},
        // common-subexpression before dead-store elim, so that we recognize
        // when two address expressions are the same.
        {"generic cse", "dse"},
diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go
new file mode 100644 (file)
index 0000000..f0f4649
--- /dev/null
@@ -0,0 +1,359 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssa
+
+// rangeMask represents the possible relations between a pair of variables.
+type rangeMask uint
+
+const (
+       lt rangeMask = 1 << iota
+       eq
+       gt
+)
+
+// typeMask represents the universe of a variable pair in which
+// a set of relations is known.
+// For example, information learned for unsigned pairs cannot
+// be transfered to signed pairs because the same bit representation
+// can mean something else.
+type typeMask uint
+
+const (
+       signed typeMask = 1 << iota
+       unsigned
+       pointer
+)
+
+type typeRange struct {
+       t typeMask
+       r rangeMask
+}
+
+type control struct {
+       tm     typeMask
+       a0, a1 ID
+}
+
+var (
+       reverseBits = [...]rangeMask{0, 4, 2, 6, 1, 5, 3, 7}
+
+       // maps what we learn when the positive branch is taken.
+       // For example:
+       //      OpLess8:   {signed, lt},
+       //      v1 = (OpLess8 v2 v3).
+       // If v1 branch is taken than we learn that the rangeMaks
+       // can be at most lt.
+       typeRangeTable = map[Op]typeRange{
+               OpEq8:   {signed | unsigned, eq},
+               OpEq16:  {signed | unsigned, eq},
+               OpEq32:  {signed | unsigned, eq},
+               OpEq64:  {signed | unsigned, eq},
+               OpEqPtr: {pointer, eq},
+
+               OpNeq8:   {signed | unsigned, lt | gt},
+               OpNeq16:  {signed | unsigned, lt | gt},
+               OpNeq32:  {signed | unsigned, lt | gt},
+               OpNeq64:  {signed | unsigned, lt | gt},
+               OpNeqPtr: {pointer, lt | gt},
+
+               OpLess8:   {signed, lt},
+               OpLess8U:  {unsigned, lt},
+               OpLess16:  {signed, lt},
+               OpLess16U: {unsigned, lt},
+               OpLess32:  {signed, lt},
+               OpLess32U: {unsigned, lt},
+               OpLess64:  {signed, lt},
+               OpLess64U: {unsigned, lt},
+
+               OpLeq8:   {signed, lt | eq},
+               OpLeq8U:  {unsigned, lt | eq},
+               OpLeq16:  {signed, lt | eq},
+               OpLeq16U: {unsigned, lt | eq},
+               OpLeq32:  {signed, lt | eq},
+               OpLeq32U: {unsigned, lt | eq},
+               OpLeq64:  {signed, lt | eq},
+               OpLeq64U: {unsigned, lt | eq},
+
+               OpGeq8:   {signed, eq | gt},
+               OpGeq8U:  {unsigned, eq | gt},
+               OpGeq16:  {signed, eq | gt},
+               OpGeq16U: {unsigned, eq | gt},
+               OpGeq32:  {signed, eq | gt},
+               OpGeq32U: {unsigned, eq | gt},
+               OpGeq64:  {signed, eq | gt},
+               OpGeq64U: {unsigned, eq | gt},
+
+               OpGreater8:   {signed, gt},
+               OpGreater8U:  {unsigned, gt},
+               OpGreater16:  {signed, gt},
+               OpGreater16U: {unsigned, gt},
+               OpGreater32:  {signed, gt},
+               OpGreater32U: {unsigned, gt},
+               OpGreater64:  {signed, gt},
+               OpGreater64U: {unsigned, gt},
+
+               // TODO: OpIsInBounds actually test 0 <= a < b. This means
+               // that the positive branch learns signed/LT and unsigned/LT
+               // but the negative branch only learns unsigned/GE.
+               OpIsInBounds:      {unsigned, lt},
+               OpIsSliceInBounds: {unsigned, lt | eq},
+       }
+)
+
+// prove removes redundant BlockIf controls that can be inferred in a straight line.
+//
+// By far, the most common redundant control are generated by bounds checking.
+// For example for the code:
+//
+//    a[i] = 4
+//    foo(a[i])
+//
+// The compiler will generate the following code:
+//
+//    if i >= len(a) {
+//        panic("not in bounds")
+//    }
+//    a[i] = 4
+//    if i >= len(a) {
+//        panic("not in bounds")
+//    }
+//    foo(a[i])
+//
+// The second comparison i >= len(a) is clearly redundant because if the
+// else branch of the first comparison is executed, we already know that i < len(a).
+// The code for the second panic can be removed.
+func prove(f *Func) {
+       idom := dominators(f)
+       sdom := newSparseTree(f, idom)
+       domTree := make([][]*Block, f.NumBlocks())
+
+       // Create a block ID -> [dominees] mapping
+       for _, b := range f.Blocks {
+               if dom := idom[b.ID]; dom != nil {
+                       domTree[dom.ID] = append(domTree[dom.ID], b)
+               }
+       }
+
+       // current node state
+       type walkState int
+       const (
+               descend walkState = iota
+               simplify
+       )
+       // work maintains the DFS stack.
+       type bp struct {
+               block *Block      // current handled block
+               state walkState   // what's to do
+               saved []typeRange // save previous map entries modified by node
+       }
+       work := make([]bp, 0, 256)
+       work = append(work, bp{
+               block: f.Entry,
+               state: descend,
+       })
+
+       // mask keep tracks of restrictions for each pair of values in
+       // the dominators for the current node.
+       // Invariant: a0.ID <= a1.ID
+       // For example {unsigned, a0, a1} -> eq|gt means that from
+       // predecessors we know that a0 must be greater or equal to
+       // a1.
+       mask := make(map[control]rangeMask)
+
+       // DFS on the dominator tree.
+       for len(work) > 0 {
+               node := work[len(work)-1]
+               work = work[:len(work)-1]
+
+               switch node.state {
+               case descend:
+                       parent := idom[node.block.ID]
+                       tr := getRestrict(sdom, parent, node.block)
+                       saved := updateRestrictions(mask, parent, tr)
+
+                       work = append(work, bp{
+                               block: node.block,
+                               state: simplify,
+                               saved: saved,
+                       })
+
+                       for _, s := range domTree[node.block.ID] {
+                               work = append(work, bp{
+                                       block: s,
+                                       state: descend,
+                               })
+                       }
+
+               case simplify:
+                       simplifyBlock(mask, node.block)
+                       restoreRestrictions(mask, idom[node.block.ID], node.saved)
+               }
+       }
+}
+
+// getRestrict returns the range restrictions added by p
+// when reaching b. p is the immediate dominator or b.
+func getRestrict(sdom sparseTree, p *Block, b *Block) typeRange {
+       if p == nil || p.Kind != BlockIf {
+               return typeRange{}
+       }
+       tr, has := typeRangeTable[p.Control.Op]
+       if !has {
+               return typeRange{}
+       }
+       // If p and p.Succs[0] are dominators it means that every path
+       // from entry to b passes through p and p.Succs[0]. We care that
+       // no path from entry to b passes through p.Succs[1]. If p.Succs[0]
+       // has one predecessor then (apart from the degenerate case),
+       // there is no path from entry that can reach b through p.Succs[1].
+       // TODO: how about p->yes->b->yes, i.e. a loop in yes.
+       if sdom.isAncestorEq(p.Succs[0], b) && len(p.Succs[0].Preds) == 1 {
+               return tr
+       } else if sdom.isAncestorEq(p.Succs[1], b) && len(p.Succs[1].Preds) == 1 {
+               tr.r = (lt | eq | gt) ^ tr.r
+               return tr
+       }
+       return typeRange{}
+}
+
+// updateRestrictions updates restrictions from the previous block (p) based on tr.
+// normally tr was calculated with getRestrict.
+func updateRestrictions(mask map[control]rangeMask, p *Block, tr typeRange) []typeRange {
+       if tr.t == 0 {
+               return nil
+       }
+
+       // p modifies the restrictions for (a0, a1).
+       // save and return the previous state.
+       a0 := p.Control.Args[0]
+       a1 := p.Control.Args[1]
+       if a0.ID > a1.ID {
+               tr.r = reverseBits[tr.r]
+               a0, a1 = a1, a0
+       }
+
+       saved := make([]typeRange, 0, 2)
+       for t := typeMask(1); t <= tr.t; t <<= 1 {
+               if t&tr.t == 0 {
+                       continue
+               }
+
+               i := control{t, a0.ID, a1.ID}
+               oldRange, ok := mask[i]
+               if !ok {
+                       if a1 != a0 {
+                               oldRange = lt | eq | gt
+                       } else { // sometimes happens after cse
+                               oldRange = eq
+                       }
+               }
+               // if i was not already in the map we save the full range
+               // so that when we restore it we properly keep track of it.
+               saved = append(saved, typeRange{t, oldRange})
+               // mask[i] contains the possible relations between a0 and a1.
+               // When we branched from parent we learned that the possible
+               // relations cannot be more than tr.r. We compute the new set of
+               // relations as the intersection betwee the old and the new set.
+               mask[i] = oldRange & tr.r
+       }
+       return saved
+}
+
+func restoreRestrictions(mask map[control]rangeMask, p *Block, saved []typeRange) {
+       if p == nil || p.Kind != BlockIf || len(saved) == 0 {
+               return
+       }
+
+       a0 := p.Control.Args[0].ID
+       a1 := p.Control.Args[1].ID
+       if a0 > a1 {
+               a0, a1 = a1, a0
+       }
+
+       for _, tr := range saved {
+               i := control{tr.t, a0, a1}
+               if tr.r != lt|eq|gt {
+                       mask[i] = tr.r
+               } else {
+                       delete(mask, i)
+               }
+       }
+}
+
+// simplifyBlock simplifies block known the restrictions in mask.
+func simplifyBlock(mask map[control]rangeMask, b *Block) {
+       if b.Kind != BlockIf {
+               return
+       }
+
+       tr, has := typeRangeTable[b.Control.Op]
+       if !has {
+               return
+       }
+
+       succ := -1
+       a0 := b.Control.Args[0].ID
+       a1 := b.Control.Args[1].ID
+       if a0 > a1 {
+               tr.r = reverseBits[tr.r]
+               a0, a1 = a1, a0
+       }
+
+       for t := typeMask(1); t <= tr.t; t <<= 1 {
+               if t&tr.t == 0 {
+                       continue
+               }
+
+               // tr.r represents in which case the positive branch is taken.
+               // m.r represents which cases are possible because of previous relations.
+               // If the set of possible relations m.r is included in the set of relations
+               // need to take the positive branch (or negative) then that branch will
+               // always be taken.
+               // For shortcut, if m.r == 0 then this block is dead code.
+               i := control{t, a0, a1}
+               m := mask[i]
+               if m != 0 && tr.r&m == m {
+                       if b.Func.pass.debug > 0 {
+                               b.Func.Config.Warnl(int(b.Line), "Proved %s", b.Control.Op)
+                       }
+                       b.Logf("proved positive branch of %s, block %s in %s\n", b.Control, b, b.Func.Name)
+                       succ = 0
+                       break
+               }
+               if m != 0 && ((lt|eq|gt)^tr.r)&m == m {
+                       if b.Func.pass.debug > 0 {
+                               b.Func.Config.Warnl(int(b.Line), "Disproved %s", b.Control.Op)
+                       }
+                       b.Logf("proved negative branch of %s, block %s in %s\n", b.Control, b, b.Func.Name)
+                       succ = 1
+                       break
+               }
+       }
+
+       if succ == -1 {
+               // HACK: If the first argument of IsInBounds or IsSliceInBounds
+               // is a constant and we already know that constant is smaller (or equal)
+               // to the upper bound than this is proven. Most useful in cases such as:
+               // if len(a) <= 1 { return }
+               // do something with a[1]
+               c := b.Control
+               if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) &&
+                       c.Args[0].Op == OpConst64 && c.Args[0].AuxInt >= 0 {
+                       m := mask[control{signed, a0, a1}]
+                       if m != 0 && tr.r&m == m {
+                               if b.Func.pass.debug > 0 {
+                                       b.Func.Config.Warnl(int(b.Line), "Proved constant %s", c.Op)
+                               }
+                               succ = 0
+                       }
+               }
+       }
+
+       if succ != -1 {
+               b.Kind = BlockFirst
+               b.Control = nil
+               b.Succs[0], b.Succs[1] = b.Succs[succ], b.Succs[1-succ]
+       }
+}
diff --git a/test/prove.go b/test/prove.go
new file mode 100644 (file)
index 0000000..0f5b8ce
--- /dev/null
@@ -0,0 +1,207 @@
+// +build amd64
+// errorcheck -0 -d=ssa/prove/debug=3
+
+package main
+
+func f0(a []int) int {
+       a[0] = 1
+       a[0] = 1 // ERROR "Proved IsInBounds$"
+       a[6] = 1
+       a[6] = 1 // ERROR "Proved IsInBounds$"
+       a[5] = 1
+       a[5] = 1 // ERROR "Proved IsInBounds$"
+       return 13
+}
+
+func f1(a []int) int {
+       if len(a) <= 5 {
+               return 18
+       }
+       a[0] = 1
+       a[0] = 1 // ERROR "Proved IsInBounds$"
+       a[6] = 1
+       a[6] = 1 // ERROR "Proved IsInBounds$"
+       a[5] = 1 // ERROR "Proved constant IsInBounds$"
+       a[5] = 1 // ERROR "Proved IsInBounds$"
+       return 26
+}
+
+func f2(a []int) int {
+       for i := range a {
+               a[i] = i
+               a[i] = i // ERROR "Proved IsInBounds$"
+       }
+       return 34
+}
+
+func f3(a []uint) int {
+       for i := uint(0); i < uint(len(a)); i++ {
+               a[i] = i // ERROR "Proved IsInBounds$"
+       }
+       return 41
+}
+
+func f4a(a, b, c int) int {
+       if a < b {
+               if a == b { // ERROR "Disproved Eq64$"
+                       return 47
+               }
+               if a > b { // ERROR "Disproved Greater64$"
+                       return 50
+               }
+               if a < b { // ERROR "Proved Less64$"
+                       return 53
+               }
+               if a == b { // ERROR "Disproved Eq64$"
+                       return 56
+               }
+               if a > b {
+                       return 59
+               }
+               return 61
+       }
+       return 63
+}
+
+func f4b(a, b, c int) int {
+       if a <= b {
+               if a >= b {
+                       if a == b { // ERROR "Proved Eq64$"
+                               return 70
+                       }
+                       return 75
+               }
+               return 77
+       }
+       return 79
+}
+
+func f4c(a, b, c int) int {
+       if a <= b {
+               if a >= b {
+                       if a != b { // ERROR "Disproved Neq64$"
+                               return 73
+                       }
+                       return 75
+               }
+               return 77
+       }
+       return 79
+}
+
+func f4d(a, b, c int) int {
+       if a < b {
+               if a < c {
+                       if a < b { // ERROR "Proved Less64$"
+                               if a < c { // ERROR "Proved Less64$"
+                                       return 87
+                               }
+                               return 89
+                       }
+                       return 91
+               }
+               return 93
+       }
+       return 95
+}
+
+func f4e(a, b, c int) int {
+       if a < b {
+               if b > a { // ERROR "Proved Greater64$"
+                       return 101
+               }
+               return 103
+       }
+       return 105
+}
+
+func f4f(a, b, c int) int {
+       if a <= b {
+               if b > a {
+                       if b == a { // ERROR "Disproved Eq64$"
+                               return 112
+                       }
+                       return 114
+               }
+               if b >= a { // ERROR "Proved Geq64$"
+                       if b == a { // ERROR "Proved Eq64$"
+                               return 118
+                       }
+                       return 120
+               }
+               return 122
+       }
+       return 124
+}
+
+func f5(a, b uint) int {
+       if a == b {
+               if a <= b { // ERROR "Proved Leq64U$"
+                       return 130
+               }
+               return 132
+       }
+       return 134
+}
+
+// These comparisons are compile time constants.
+func f6a(a uint8) int {
+       if a < a { // ERROR "Disproved Less8U$"
+               return 140
+       }
+       return 151
+}
+
+func f6b(a uint8) int {
+       if a < a { // ERROR "Disproved Less8U$"
+               return 140
+       }
+       return 151
+}
+
+func f6x(a uint8) int {
+       if a > a { // ERROR "Disproved Greater8U$"
+               return 143
+       }
+       return 151
+}
+
+func f6d(a uint8) int {
+       if a <= a { // ERROR "Proved Leq8U$"
+               return 146
+       }
+       return 151
+}
+
+func f6e(a uint8) int {
+       if a >= a { // ERROR "Proved Geq8U$"
+               return 149
+       }
+       return 151
+}
+
+func f7(a []int, b int) int {
+       if b < len(a) {
+               a[b] = 3
+               if b < len(a) { // ERROR "Proved Less64$"
+                       a[b] = 5 // ERROR "Proved IsInBounds$"
+               }
+       }
+       return 161
+}
+
+func f8(a, b uint) int {
+       if a == b {
+               return 166
+       }
+       if a > b {
+               return 169
+       }
+       if a < b { // ERROR "Proved Less64U$"
+               return 172
+       }
+       return 174
+}
+
+func main() {
+}