From: Alexandru Moșoi Date: Fri, 19 Feb 2016 11:14:42 +0000 (+0100) Subject: [dev.ssa] cmd/compile/internal/ssa: remove proven redundant controls. X-Git-Tag: go1.7beta1~1623^2^2~13 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=bdea1d58cfc55a5156c8df392cfc3133589389db;p=gostls13.git [dev.ssa] cmd/compile/internal/ssa: remove proven redundant controls. * It does very simple bounds checking elimination. E.g. removes the second check in for i := range a { a[i]++; a[i++]; } * Improves on the following redundant expression: return a6 || (a6 || (a6 || a4)) || (a6 || (a4 || a6 || (false || a6))) * Linear in the number of block edges. I patched in CL 12960 that does bounds, nil and constant propagation to make sure this CL is not just redundant. Size of pkg/tool/linux_amd64/* (excluding compile which is affected by this change): With IsInBounds and IsSliceInBounds -this -12960 92285080 +this -12960 91947416 -this +12960 91978976 +this +12960 91923088 Gain is ~110% of 12960. Without IsInBounds and IsSliceInBounds (older run) -this -12960 95515512 +this -12960 95492536 -this +12960 95216920 +this +12960 95204440 Shaves 22k on its own. * Can we handle IsInBounds better with this? In for i := range a { a[i]++; } the bounds checking at a[i] is not eliminated. Change-Id: I98957427399145fb33693173fd4d5a8d71c7cc20 Reviewed-on: https://go-review.googlesource.com/19710 Reviewed-by: David Chase Reviewed-by: Keith Randall Run-TryBot: Alexandru Moșoi TryBot-Result: Gobot Gobot --- diff --git a/src/cmd/compile/internal/ssa/compile.go b/src/cmd/compile/internal/ssa/compile.go index 23dab9e273..5e68ea004e 100644 --- a/src/cmd/compile/internal/ssa/compile.go +++ b/src/cmd/compile/internal/ssa/compile.go @@ -165,6 +165,7 @@ var passes = [...]pass{ {name: "opt deadcode", fn: deadcode}, // remove any blocks orphaned during opt {name: "generic cse", fn: cse}, {name: "nilcheckelim", fn: nilcheckelim}, + {name: "prove", fn: prove}, {name: "generic deadcode", fn: deadcode}, {name: "fuse", fn: fuse}, {name: "dse", fn: dse}, @@ -193,6 +194,10 @@ type constraint struct { } var passOrder = [...]constraint{ + // prove reliese on common-subexpression elimination for maximum benefits. + {"generic cse", "prove"}, + // deadcode after prove to eliminate all new dead blocks. + {"prove", "generic deadcode"}, // common-subexpression before dead-store elim, so that we recognize // when two address expressions are the same. {"generic cse", "dse"}, diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go new file mode 100644 index 0000000000..f0f4649896 --- /dev/null +++ b/src/cmd/compile/internal/ssa/prove.go @@ -0,0 +1,359 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +// rangeMask represents the possible relations between a pair of variables. +type rangeMask uint + +const ( + lt rangeMask = 1 << iota + eq + gt +) + +// typeMask represents the universe of a variable pair in which +// a set of relations is known. +// For example, information learned for unsigned pairs cannot +// be transfered to signed pairs because the same bit representation +// can mean something else. +type typeMask uint + +const ( + signed typeMask = 1 << iota + unsigned + pointer +) + +type typeRange struct { + t typeMask + r rangeMask +} + +type control struct { + tm typeMask + a0, a1 ID +} + +var ( + reverseBits = [...]rangeMask{0, 4, 2, 6, 1, 5, 3, 7} + + // maps what we learn when the positive branch is taken. + // For example: + // OpLess8: {signed, lt}, + // v1 = (OpLess8 v2 v3). + // If v1 branch is taken than we learn that the rangeMaks + // can be at most lt. + typeRangeTable = map[Op]typeRange{ + OpEq8: {signed | unsigned, eq}, + OpEq16: {signed | unsigned, eq}, + OpEq32: {signed | unsigned, eq}, + OpEq64: {signed | unsigned, eq}, + OpEqPtr: {pointer, eq}, + + OpNeq8: {signed | unsigned, lt | gt}, + OpNeq16: {signed | unsigned, lt | gt}, + OpNeq32: {signed | unsigned, lt | gt}, + OpNeq64: {signed | unsigned, lt | gt}, + OpNeqPtr: {pointer, lt | gt}, + + OpLess8: {signed, lt}, + OpLess8U: {unsigned, lt}, + OpLess16: {signed, lt}, + OpLess16U: {unsigned, lt}, + OpLess32: {signed, lt}, + OpLess32U: {unsigned, lt}, + OpLess64: {signed, lt}, + OpLess64U: {unsigned, lt}, + + OpLeq8: {signed, lt | eq}, + OpLeq8U: {unsigned, lt | eq}, + OpLeq16: {signed, lt | eq}, + OpLeq16U: {unsigned, lt | eq}, + OpLeq32: {signed, lt | eq}, + OpLeq32U: {unsigned, lt | eq}, + OpLeq64: {signed, lt | eq}, + OpLeq64U: {unsigned, lt | eq}, + + OpGeq8: {signed, eq | gt}, + OpGeq8U: {unsigned, eq | gt}, + OpGeq16: {signed, eq | gt}, + OpGeq16U: {unsigned, eq | gt}, + OpGeq32: {signed, eq | gt}, + OpGeq32U: {unsigned, eq | gt}, + OpGeq64: {signed, eq | gt}, + OpGeq64U: {unsigned, eq | gt}, + + OpGreater8: {signed, gt}, + OpGreater8U: {unsigned, gt}, + OpGreater16: {signed, gt}, + OpGreater16U: {unsigned, gt}, + OpGreater32: {signed, gt}, + OpGreater32U: {unsigned, gt}, + OpGreater64: {signed, gt}, + OpGreater64U: {unsigned, gt}, + + // TODO: OpIsInBounds actually test 0 <= a < b. This means + // that the positive branch learns signed/LT and unsigned/LT + // but the negative branch only learns unsigned/GE. + OpIsInBounds: {unsigned, lt}, + OpIsSliceInBounds: {unsigned, lt | eq}, + } +) + +// prove removes redundant BlockIf controls that can be inferred in a straight line. +// +// By far, the most common redundant control are generated by bounds checking. +// For example for the code: +// +// a[i] = 4 +// foo(a[i]) +// +// The compiler will generate the following code: +// +// if i >= len(a) { +// panic("not in bounds") +// } +// a[i] = 4 +// if i >= len(a) { +// panic("not in bounds") +// } +// foo(a[i]) +// +// The second comparison i >= len(a) is clearly redundant because if the +// else branch of the first comparison is executed, we already know that i < len(a). +// The code for the second panic can be removed. +func prove(f *Func) { + idom := dominators(f) + sdom := newSparseTree(f, idom) + domTree := make([][]*Block, f.NumBlocks()) + + // Create a block ID -> [dominees] mapping + for _, b := range f.Blocks { + if dom := idom[b.ID]; dom != nil { + domTree[dom.ID] = append(domTree[dom.ID], b) + } + } + + // current node state + type walkState int + const ( + descend walkState = iota + simplify + ) + // work maintains the DFS stack. + type bp struct { + block *Block // current handled block + state walkState // what's to do + saved []typeRange // save previous map entries modified by node + } + work := make([]bp, 0, 256) + work = append(work, bp{ + block: f.Entry, + state: descend, + }) + + // mask keep tracks of restrictions for each pair of values in + // the dominators for the current node. + // Invariant: a0.ID <= a1.ID + // For example {unsigned, a0, a1} -> eq|gt means that from + // predecessors we know that a0 must be greater or equal to + // a1. + mask := make(map[control]rangeMask) + + // DFS on the dominator tree. + for len(work) > 0 { + node := work[len(work)-1] + work = work[:len(work)-1] + + switch node.state { + case descend: + parent := idom[node.block.ID] + tr := getRestrict(sdom, parent, node.block) + saved := updateRestrictions(mask, parent, tr) + + work = append(work, bp{ + block: node.block, + state: simplify, + saved: saved, + }) + + for _, s := range domTree[node.block.ID] { + work = append(work, bp{ + block: s, + state: descend, + }) + } + + case simplify: + simplifyBlock(mask, node.block) + restoreRestrictions(mask, idom[node.block.ID], node.saved) + } + } +} + +// getRestrict returns the range restrictions added by p +// when reaching b. p is the immediate dominator or b. +func getRestrict(sdom sparseTree, p *Block, b *Block) typeRange { + if p == nil || p.Kind != BlockIf { + return typeRange{} + } + tr, has := typeRangeTable[p.Control.Op] + if !has { + return typeRange{} + } + // If p and p.Succs[0] are dominators it means that every path + // from entry to b passes through p and p.Succs[0]. We care that + // no path from entry to b passes through p.Succs[1]. If p.Succs[0] + // has one predecessor then (apart from the degenerate case), + // there is no path from entry that can reach b through p.Succs[1]. + // TODO: how about p->yes->b->yes, i.e. a loop in yes. + if sdom.isAncestorEq(p.Succs[0], b) && len(p.Succs[0].Preds) == 1 { + return tr + } else if sdom.isAncestorEq(p.Succs[1], b) && len(p.Succs[1].Preds) == 1 { + tr.r = (lt | eq | gt) ^ tr.r + return tr + } + return typeRange{} +} + +// updateRestrictions updates restrictions from the previous block (p) based on tr. +// normally tr was calculated with getRestrict. +func updateRestrictions(mask map[control]rangeMask, p *Block, tr typeRange) []typeRange { + if tr.t == 0 { + return nil + } + + // p modifies the restrictions for (a0, a1). + // save and return the previous state. + a0 := p.Control.Args[0] + a1 := p.Control.Args[1] + if a0.ID > a1.ID { + tr.r = reverseBits[tr.r] + a0, a1 = a1, a0 + } + + saved := make([]typeRange, 0, 2) + for t := typeMask(1); t <= tr.t; t <<= 1 { + if t&tr.t == 0 { + continue + } + + i := control{t, a0.ID, a1.ID} + oldRange, ok := mask[i] + if !ok { + if a1 != a0 { + oldRange = lt | eq | gt + } else { // sometimes happens after cse + oldRange = eq + } + } + // if i was not already in the map we save the full range + // so that when we restore it we properly keep track of it. + saved = append(saved, typeRange{t, oldRange}) + // mask[i] contains the possible relations between a0 and a1. + // When we branched from parent we learned that the possible + // relations cannot be more than tr.r. We compute the new set of + // relations as the intersection betwee the old and the new set. + mask[i] = oldRange & tr.r + } + return saved +} + +func restoreRestrictions(mask map[control]rangeMask, p *Block, saved []typeRange) { + if p == nil || p.Kind != BlockIf || len(saved) == 0 { + return + } + + a0 := p.Control.Args[0].ID + a1 := p.Control.Args[1].ID + if a0 > a1 { + a0, a1 = a1, a0 + } + + for _, tr := range saved { + i := control{tr.t, a0, a1} + if tr.r != lt|eq|gt { + mask[i] = tr.r + } else { + delete(mask, i) + } + } +} + +// simplifyBlock simplifies block known the restrictions in mask. +func simplifyBlock(mask map[control]rangeMask, b *Block) { + if b.Kind != BlockIf { + return + } + + tr, has := typeRangeTable[b.Control.Op] + if !has { + return + } + + succ := -1 + a0 := b.Control.Args[0].ID + a1 := b.Control.Args[1].ID + if a0 > a1 { + tr.r = reverseBits[tr.r] + a0, a1 = a1, a0 + } + + for t := typeMask(1); t <= tr.t; t <<= 1 { + if t&tr.t == 0 { + continue + } + + // tr.r represents in which case the positive branch is taken. + // m.r represents which cases are possible because of previous relations. + // If the set of possible relations m.r is included in the set of relations + // need to take the positive branch (or negative) then that branch will + // always be taken. + // For shortcut, if m.r == 0 then this block is dead code. + i := control{t, a0, a1} + m := mask[i] + if m != 0 && tr.r&m == m { + if b.Func.pass.debug > 0 { + b.Func.Config.Warnl(int(b.Line), "Proved %s", b.Control.Op) + } + b.Logf("proved positive branch of %s, block %s in %s\n", b.Control, b, b.Func.Name) + succ = 0 + break + } + if m != 0 && ((lt|eq|gt)^tr.r)&m == m { + if b.Func.pass.debug > 0 { + b.Func.Config.Warnl(int(b.Line), "Disproved %s", b.Control.Op) + } + b.Logf("proved negative branch of %s, block %s in %s\n", b.Control, b, b.Func.Name) + succ = 1 + break + } + } + + if succ == -1 { + // HACK: If the first argument of IsInBounds or IsSliceInBounds + // is a constant and we already know that constant is smaller (or equal) + // to the upper bound than this is proven. Most useful in cases such as: + // if len(a) <= 1 { return } + // do something with a[1] + c := b.Control + if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && + c.Args[0].Op == OpConst64 && c.Args[0].AuxInt >= 0 { + m := mask[control{signed, a0, a1}] + if m != 0 && tr.r&m == m { + if b.Func.pass.debug > 0 { + b.Func.Config.Warnl(int(b.Line), "Proved constant %s", c.Op) + } + succ = 0 + } + } + } + + if succ != -1 { + b.Kind = BlockFirst + b.Control = nil + b.Succs[0], b.Succs[1] = b.Succs[succ], b.Succs[1-succ] + } +} diff --git a/test/prove.go b/test/prove.go new file mode 100644 index 0000000000..0f5b8ce87f --- /dev/null +++ b/test/prove.go @@ -0,0 +1,207 @@ +// +build amd64 +// errorcheck -0 -d=ssa/prove/debug=3 + +package main + +func f0(a []int) int { + a[0] = 1 + a[0] = 1 // ERROR "Proved IsInBounds$" + a[6] = 1 + a[6] = 1 // ERROR "Proved IsInBounds$" + a[5] = 1 + a[5] = 1 // ERROR "Proved IsInBounds$" + return 13 +} + +func f1(a []int) int { + if len(a) <= 5 { + return 18 + } + a[0] = 1 + a[0] = 1 // ERROR "Proved IsInBounds$" + a[6] = 1 + a[6] = 1 // ERROR "Proved IsInBounds$" + a[5] = 1 // ERROR "Proved constant IsInBounds$" + a[5] = 1 // ERROR "Proved IsInBounds$" + return 26 +} + +func f2(a []int) int { + for i := range a { + a[i] = i + a[i] = i // ERROR "Proved IsInBounds$" + } + return 34 +} + +func f3(a []uint) int { + for i := uint(0); i < uint(len(a)); i++ { + a[i] = i // ERROR "Proved IsInBounds$" + } + return 41 +} + +func f4a(a, b, c int) int { + if a < b { + if a == b { // ERROR "Disproved Eq64$" + return 47 + } + if a > b { // ERROR "Disproved Greater64$" + return 50 + } + if a < b { // ERROR "Proved Less64$" + return 53 + } + if a == b { // ERROR "Disproved Eq64$" + return 56 + } + if a > b { + return 59 + } + return 61 + } + return 63 +} + +func f4b(a, b, c int) int { + if a <= b { + if a >= b { + if a == b { // ERROR "Proved Eq64$" + return 70 + } + return 75 + } + return 77 + } + return 79 +} + +func f4c(a, b, c int) int { + if a <= b { + if a >= b { + if a != b { // ERROR "Disproved Neq64$" + return 73 + } + return 75 + } + return 77 + } + return 79 +} + +func f4d(a, b, c int) int { + if a < b { + if a < c { + if a < b { // ERROR "Proved Less64$" + if a < c { // ERROR "Proved Less64$" + return 87 + } + return 89 + } + return 91 + } + return 93 + } + return 95 +} + +func f4e(a, b, c int) int { + if a < b { + if b > a { // ERROR "Proved Greater64$" + return 101 + } + return 103 + } + return 105 +} + +func f4f(a, b, c int) int { + if a <= b { + if b > a { + if b == a { // ERROR "Disproved Eq64$" + return 112 + } + return 114 + } + if b >= a { // ERROR "Proved Geq64$" + if b == a { // ERROR "Proved Eq64$" + return 118 + } + return 120 + } + return 122 + } + return 124 +} + +func f5(a, b uint) int { + if a == b { + if a <= b { // ERROR "Proved Leq64U$" + return 130 + } + return 132 + } + return 134 +} + +// These comparisons are compile time constants. +func f6a(a uint8) int { + if a < a { // ERROR "Disproved Less8U$" + return 140 + } + return 151 +} + +func f6b(a uint8) int { + if a < a { // ERROR "Disproved Less8U$" + return 140 + } + return 151 +} + +func f6x(a uint8) int { + if a > a { // ERROR "Disproved Greater8U$" + return 143 + } + return 151 +} + +func f6d(a uint8) int { + if a <= a { // ERROR "Proved Leq8U$" + return 146 + } + return 151 +} + +func f6e(a uint8) int { + if a >= a { // ERROR "Proved Geq8U$" + return 149 + } + return 151 +} + +func f7(a []int, b int) int { + if b < len(a) { + a[b] = 3 + if b < len(a) { // ERROR "Proved Less64$" + a[b] = 5 // ERROR "Proved IsInBounds$" + } + } + return 161 +} + +func f8(a, b uint) int { + if a == b { + return 166 + } + if a > b { + return 169 + } + if a < b { // ERROR "Proved Less64U$" + return 172 + } + return 174 +} + +func main() { +}