--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssa
+
+// rangeMask represents the possible relations between a pair of variables.
+type rangeMask uint
+
+const (
+ lt rangeMask = 1 << iota
+ eq
+ gt
+)
+
+// typeMask represents the universe of a variable pair in which
+// a set of relations is known.
+// For example, information learned for unsigned pairs cannot
+// be transfered to signed pairs because the same bit representation
+// can mean something else.
+type typeMask uint
+
+const (
+ signed typeMask = 1 << iota
+ unsigned
+ pointer
+)
+
+type typeRange struct {
+ t typeMask
+ r rangeMask
+}
+
+type control struct {
+ tm typeMask
+ a0, a1 ID
+}
+
+var (
+ reverseBits = [...]rangeMask{0, 4, 2, 6, 1, 5, 3, 7}
+
+ // maps what we learn when the positive branch is taken.
+ // For example:
+ // OpLess8: {signed, lt},
+ // v1 = (OpLess8 v2 v3).
+ // If v1 branch is taken than we learn that the rangeMaks
+ // can be at most lt.
+ typeRangeTable = map[Op]typeRange{
+ OpEq8: {signed | unsigned, eq},
+ OpEq16: {signed | unsigned, eq},
+ OpEq32: {signed | unsigned, eq},
+ OpEq64: {signed | unsigned, eq},
+ OpEqPtr: {pointer, eq},
+
+ OpNeq8: {signed | unsigned, lt | gt},
+ OpNeq16: {signed | unsigned, lt | gt},
+ OpNeq32: {signed | unsigned, lt | gt},
+ OpNeq64: {signed | unsigned, lt | gt},
+ OpNeqPtr: {pointer, lt | gt},
+
+ OpLess8: {signed, lt},
+ OpLess8U: {unsigned, lt},
+ OpLess16: {signed, lt},
+ OpLess16U: {unsigned, lt},
+ OpLess32: {signed, lt},
+ OpLess32U: {unsigned, lt},
+ OpLess64: {signed, lt},
+ OpLess64U: {unsigned, lt},
+
+ OpLeq8: {signed, lt | eq},
+ OpLeq8U: {unsigned, lt | eq},
+ OpLeq16: {signed, lt | eq},
+ OpLeq16U: {unsigned, lt | eq},
+ OpLeq32: {signed, lt | eq},
+ OpLeq32U: {unsigned, lt | eq},
+ OpLeq64: {signed, lt | eq},
+ OpLeq64U: {unsigned, lt | eq},
+
+ OpGeq8: {signed, eq | gt},
+ OpGeq8U: {unsigned, eq | gt},
+ OpGeq16: {signed, eq | gt},
+ OpGeq16U: {unsigned, eq | gt},
+ OpGeq32: {signed, eq | gt},
+ OpGeq32U: {unsigned, eq | gt},
+ OpGeq64: {signed, eq | gt},
+ OpGeq64U: {unsigned, eq | gt},
+
+ OpGreater8: {signed, gt},
+ OpGreater8U: {unsigned, gt},
+ OpGreater16: {signed, gt},
+ OpGreater16U: {unsigned, gt},
+ OpGreater32: {signed, gt},
+ OpGreater32U: {unsigned, gt},
+ OpGreater64: {signed, gt},
+ OpGreater64U: {unsigned, gt},
+
+ // TODO: OpIsInBounds actually test 0 <= a < b. This means
+ // that the positive branch learns signed/LT and unsigned/LT
+ // but the negative branch only learns unsigned/GE.
+ OpIsInBounds: {unsigned, lt},
+ OpIsSliceInBounds: {unsigned, lt | eq},
+ }
+)
+
+// prove removes redundant BlockIf controls that can be inferred in a straight line.
+//
+// By far, the most common redundant control are generated by bounds checking.
+// For example for the code:
+//
+// a[i] = 4
+// foo(a[i])
+//
+// The compiler will generate the following code:
+//
+// if i >= len(a) {
+// panic("not in bounds")
+// }
+// a[i] = 4
+// if i >= len(a) {
+// panic("not in bounds")
+// }
+// foo(a[i])
+//
+// The second comparison i >= len(a) is clearly redundant because if the
+// else branch of the first comparison is executed, we already know that i < len(a).
+// The code for the second panic can be removed.
+func prove(f *Func) {
+ idom := dominators(f)
+ sdom := newSparseTree(f, idom)
+ domTree := make([][]*Block, f.NumBlocks())
+
+ // Create a block ID -> [dominees] mapping
+ for _, b := range f.Blocks {
+ if dom := idom[b.ID]; dom != nil {
+ domTree[dom.ID] = append(domTree[dom.ID], b)
+ }
+ }
+
+ // current node state
+ type walkState int
+ const (
+ descend walkState = iota
+ simplify
+ )
+ // work maintains the DFS stack.
+ type bp struct {
+ block *Block // current handled block
+ state walkState // what's to do
+ saved []typeRange // save previous map entries modified by node
+ }
+ work := make([]bp, 0, 256)
+ work = append(work, bp{
+ block: f.Entry,
+ state: descend,
+ })
+
+ // mask keep tracks of restrictions for each pair of values in
+ // the dominators for the current node.
+ // Invariant: a0.ID <= a1.ID
+ // For example {unsigned, a0, a1} -> eq|gt means that from
+ // predecessors we know that a0 must be greater or equal to
+ // a1.
+ mask := make(map[control]rangeMask)
+
+ // DFS on the dominator tree.
+ for len(work) > 0 {
+ node := work[len(work)-1]
+ work = work[:len(work)-1]
+
+ switch node.state {
+ case descend:
+ parent := idom[node.block.ID]
+ tr := getRestrict(sdom, parent, node.block)
+ saved := updateRestrictions(mask, parent, tr)
+
+ work = append(work, bp{
+ block: node.block,
+ state: simplify,
+ saved: saved,
+ })
+
+ for _, s := range domTree[node.block.ID] {
+ work = append(work, bp{
+ block: s,
+ state: descend,
+ })
+ }
+
+ case simplify:
+ simplifyBlock(mask, node.block)
+ restoreRestrictions(mask, idom[node.block.ID], node.saved)
+ }
+ }
+}
+
+// getRestrict returns the range restrictions added by p
+// when reaching b. p is the immediate dominator or b.
+func getRestrict(sdom sparseTree, p *Block, b *Block) typeRange {
+ if p == nil || p.Kind != BlockIf {
+ return typeRange{}
+ }
+ tr, has := typeRangeTable[p.Control.Op]
+ if !has {
+ return typeRange{}
+ }
+ // If p and p.Succs[0] are dominators it means that every path
+ // from entry to b passes through p and p.Succs[0]. We care that
+ // no path from entry to b passes through p.Succs[1]. If p.Succs[0]
+ // has one predecessor then (apart from the degenerate case),
+ // there is no path from entry that can reach b through p.Succs[1].
+ // TODO: how about p->yes->b->yes, i.e. a loop in yes.
+ if sdom.isAncestorEq(p.Succs[0], b) && len(p.Succs[0].Preds) == 1 {
+ return tr
+ } else if sdom.isAncestorEq(p.Succs[1], b) && len(p.Succs[1].Preds) == 1 {
+ tr.r = (lt | eq | gt) ^ tr.r
+ return tr
+ }
+ return typeRange{}
+}
+
+// updateRestrictions updates restrictions from the previous block (p) based on tr.
+// normally tr was calculated with getRestrict.
+func updateRestrictions(mask map[control]rangeMask, p *Block, tr typeRange) []typeRange {
+ if tr.t == 0 {
+ return nil
+ }
+
+ // p modifies the restrictions for (a0, a1).
+ // save and return the previous state.
+ a0 := p.Control.Args[0]
+ a1 := p.Control.Args[1]
+ if a0.ID > a1.ID {
+ tr.r = reverseBits[tr.r]
+ a0, a1 = a1, a0
+ }
+
+ saved := make([]typeRange, 0, 2)
+ for t := typeMask(1); t <= tr.t; t <<= 1 {
+ if t&tr.t == 0 {
+ continue
+ }
+
+ i := control{t, a0.ID, a1.ID}
+ oldRange, ok := mask[i]
+ if !ok {
+ if a1 != a0 {
+ oldRange = lt | eq | gt
+ } else { // sometimes happens after cse
+ oldRange = eq
+ }
+ }
+ // if i was not already in the map we save the full range
+ // so that when we restore it we properly keep track of it.
+ saved = append(saved, typeRange{t, oldRange})
+ // mask[i] contains the possible relations between a0 and a1.
+ // When we branched from parent we learned that the possible
+ // relations cannot be more than tr.r. We compute the new set of
+ // relations as the intersection betwee the old and the new set.
+ mask[i] = oldRange & tr.r
+ }
+ return saved
+}
+
+func restoreRestrictions(mask map[control]rangeMask, p *Block, saved []typeRange) {
+ if p == nil || p.Kind != BlockIf || len(saved) == 0 {
+ return
+ }
+
+ a0 := p.Control.Args[0].ID
+ a1 := p.Control.Args[1].ID
+ if a0 > a1 {
+ a0, a1 = a1, a0
+ }
+
+ for _, tr := range saved {
+ i := control{tr.t, a0, a1}
+ if tr.r != lt|eq|gt {
+ mask[i] = tr.r
+ } else {
+ delete(mask, i)
+ }
+ }
+}
+
+// simplifyBlock simplifies block known the restrictions in mask.
+func simplifyBlock(mask map[control]rangeMask, b *Block) {
+ if b.Kind != BlockIf {
+ return
+ }
+
+ tr, has := typeRangeTable[b.Control.Op]
+ if !has {
+ return
+ }
+
+ succ := -1
+ a0 := b.Control.Args[0].ID
+ a1 := b.Control.Args[1].ID
+ if a0 > a1 {
+ tr.r = reverseBits[tr.r]
+ a0, a1 = a1, a0
+ }
+
+ for t := typeMask(1); t <= tr.t; t <<= 1 {
+ if t&tr.t == 0 {
+ continue
+ }
+
+ // tr.r represents in which case the positive branch is taken.
+ // m.r represents which cases are possible because of previous relations.
+ // If the set of possible relations m.r is included in the set of relations
+ // need to take the positive branch (or negative) then that branch will
+ // always be taken.
+ // For shortcut, if m.r == 0 then this block is dead code.
+ i := control{t, a0, a1}
+ m := mask[i]
+ if m != 0 && tr.r&m == m {
+ if b.Func.pass.debug > 0 {
+ b.Func.Config.Warnl(int(b.Line), "Proved %s", b.Control.Op)
+ }
+ b.Logf("proved positive branch of %s, block %s in %s\n", b.Control, b, b.Func.Name)
+ succ = 0
+ break
+ }
+ if m != 0 && ((lt|eq|gt)^tr.r)&m == m {
+ if b.Func.pass.debug > 0 {
+ b.Func.Config.Warnl(int(b.Line), "Disproved %s", b.Control.Op)
+ }
+ b.Logf("proved negative branch of %s, block %s in %s\n", b.Control, b, b.Func.Name)
+ succ = 1
+ break
+ }
+ }
+
+ if succ == -1 {
+ // HACK: If the first argument of IsInBounds or IsSliceInBounds
+ // is a constant and we already know that constant is smaller (or equal)
+ // to the upper bound than this is proven. Most useful in cases such as:
+ // if len(a) <= 1 { return }
+ // do something with a[1]
+ c := b.Control
+ if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) &&
+ c.Args[0].Op == OpConst64 && c.Args[0].AuxInt >= 0 {
+ m := mask[control{signed, a0, a1}]
+ if m != 0 && tr.r&m == m {
+ if b.Func.pass.debug > 0 {
+ b.Func.Config.Warnl(int(b.Line), "Proved constant %s", c.Op)
+ }
+ succ = 0
+ }
+ }
+ }
+
+ if succ != -1 {
+ b.Kind = BlockFirst
+ b.Control = nil
+ b.Succs[0], b.Succs[1] = b.Succs[succ], b.Succs[1-succ]
+ }
+}
--- /dev/null
+// +build amd64
+// errorcheck -0 -d=ssa/prove/debug=3
+
+package main
+
+func f0(a []int) int {
+ a[0] = 1
+ a[0] = 1 // ERROR "Proved IsInBounds$"
+ a[6] = 1
+ a[6] = 1 // ERROR "Proved IsInBounds$"
+ a[5] = 1
+ a[5] = 1 // ERROR "Proved IsInBounds$"
+ return 13
+}
+
+func f1(a []int) int {
+ if len(a) <= 5 {
+ return 18
+ }
+ a[0] = 1
+ a[0] = 1 // ERROR "Proved IsInBounds$"
+ a[6] = 1
+ a[6] = 1 // ERROR "Proved IsInBounds$"
+ a[5] = 1 // ERROR "Proved constant IsInBounds$"
+ a[5] = 1 // ERROR "Proved IsInBounds$"
+ return 26
+}
+
+func f2(a []int) int {
+ for i := range a {
+ a[i] = i
+ a[i] = i // ERROR "Proved IsInBounds$"
+ }
+ return 34
+}
+
+func f3(a []uint) int {
+ for i := uint(0); i < uint(len(a)); i++ {
+ a[i] = i // ERROR "Proved IsInBounds$"
+ }
+ return 41
+}
+
+func f4a(a, b, c int) int {
+ if a < b {
+ if a == b { // ERROR "Disproved Eq64$"
+ return 47
+ }
+ if a > b { // ERROR "Disproved Greater64$"
+ return 50
+ }
+ if a < b { // ERROR "Proved Less64$"
+ return 53
+ }
+ if a == b { // ERROR "Disproved Eq64$"
+ return 56
+ }
+ if a > b {
+ return 59
+ }
+ return 61
+ }
+ return 63
+}
+
+func f4b(a, b, c int) int {
+ if a <= b {
+ if a >= b {
+ if a == b { // ERROR "Proved Eq64$"
+ return 70
+ }
+ return 75
+ }
+ return 77
+ }
+ return 79
+}
+
+func f4c(a, b, c int) int {
+ if a <= b {
+ if a >= b {
+ if a != b { // ERROR "Disproved Neq64$"
+ return 73
+ }
+ return 75
+ }
+ return 77
+ }
+ return 79
+}
+
+func f4d(a, b, c int) int {
+ if a < b {
+ if a < c {
+ if a < b { // ERROR "Proved Less64$"
+ if a < c { // ERROR "Proved Less64$"
+ return 87
+ }
+ return 89
+ }
+ return 91
+ }
+ return 93
+ }
+ return 95
+}
+
+func f4e(a, b, c int) int {
+ if a < b {
+ if b > a { // ERROR "Proved Greater64$"
+ return 101
+ }
+ return 103
+ }
+ return 105
+}
+
+func f4f(a, b, c int) int {
+ if a <= b {
+ if b > a {
+ if b == a { // ERROR "Disproved Eq64$"
+ return 112
+ }
+ return 114
+ }
+ if b >= a { // ERROR "Proved Geq64$"
+ if b == a { // ERROR "Proved Eq64$"
+ return 118
+ }
+ return 120
+ }
+ return 122
+ }
+ return 124
+}
+
+func f5(a, b uint) int {
+ if a == b {
+ if a <= b { // ERROR "Proved Leq64U$"
+ return 130
+ }
+ return 132
+ }
+ return 134
+}
+
+// These comparisons are compile time constants.
+func f6a(a uint8) int {
+ if a < a { // ERROR "Disproved Less8U$"
+ return 140
+ }
+ return 151
+}
+
+func f6b(a uint8) int {
+ if a < a { // ERROR "Disproved Less8U$"
+ return 140
+ }
+ return 151
+}
+
+func f6x(a uint8) int {
+ if a > a { // ERROR "Disproved Greater8U$"
+ return 143
+ }
+ return 151
+}
+
+func f6d(a uint8) int {
+ if a <= a { // ERROR "Proved Leq8U$"
+ return 146
+ }
+ return 151
+}
+
+func f6e(a uint8) int {
+ if a >= a { // ERROR "Proved Geq8U$"
+ return 149
+ }
+ return 151
+}
+
+func f7(a []int, b int) int {
+ if b < len(a) {
+ a[b] = 3
+ if b < len(a) { // ERROR "Proved Less64$"
+ a[b] = 5 // ERROR "Proved IsInBounds$"
+ }
+ }
+ return 161
+}
+
+func f8(a, b uint) int {
+ if a == b {
+ return 166
+ }
+ if a > b {
+ return 169
+ }
+ if a < b { // ERROR "Proved Less64U$"
+ return 172
+ }
+ return 174
+}
+
+func main() {
+}