]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: extend prove pass to handle constant comparisons
authorKeith Randall <khr@golang.org>
Wed, 23 Mar 2016 17:20:44 +0000 (10:20 -0700)
committerKeith Randall <khr@golang.org>
Thu, 31 Mar 2016 21:16:23 +0000 (21:16 +0000)
Find comparisons to constants and propagate that information
down the dominator tree.  Use it to resolve other constant
comparisons on the same variable.

So if we know x >= 7, then a x > 4 condition must return true.

This change allows us to use "_ = b[7]" hints to eliminate bounds checks.

Fixes #14900

Change-Id: Idbf230bd5b7da43de3ecb48706e21cf01bf812f7
Reviewed-on: https://go-review.googlesource.com/21008
Reviewed-by: Alexandru Moșoi <alexandru@mosoi.ro>
src/cmd/compile/internal/ssa/prove.go
src/cmd/compile/internal/ssa/value.go
src/encoding/binary/binary.go
test/prove.go

index f09a3c5e04230358e615782699c08e66ab12b6e1..6054541c3b35011881b205d0e5ae4fb381f6f1d5 100644 (file)
@@ -4,6 +4,8 @@
 
 package ssa
 
+import "math"
+
 type branch int
 
 const (
@@ -66,30 +68,109 @@ type fact struct {
        r relation
 }
 
+// a limit records known upper and lower bounds for a value.
+type limit struct {
+       min, max   int64  // min <= value <= max, signed
+       umin, umax uint64 // umin <= value <= umax, unsigned
+}
+
+var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
+
+// a limitFact is a limit known for a particular value.
+type limitFact struct {
+       vid   ID
+       limit limit
+}
+
 // factsTable keeps track of relations between pairs of values.
 type factsTable struct {
        facts map[pair]relation // current known set of relation
        stack []fact            // previous sets of relations
+
+       // known lower and upper bounds on individual values.
+       limits     map[ID]limit
+       limitStack []limitFact // previous entries
 }
 
 // checkpointFact is an invalid value used for checkpointing
 // and restoring factsTable.
 var checkpointFact = fact{}
+var checkpointBound = limitFact{}
 
 func newFactsTable() *factsTable {
        ft := &factsTable{}
        ft.facts = make(map[pair]relation)
        ft.stack = make([]fact, 4)
+       ft.limits = make(map[ID]limit)
+       ft.limitStack = make([]limitFact, 4)
        return ft
 }
 
 // get returns the known possible relations between v and w.
 // If v and w are not in the map it returns lt|eq|gt, i.e. any order.
 func (ft *factsTable) get(v, w *Value, d domain) relation {
+       if v.isGenericIntConst() || w.isGenericIntConst() {
+               reversed := false
+               if v.isGenericIntConst() {
+                       v, w = w, v
+                       reversed = true
+               }
+               r := lt | eq | gt
+               lim, ok := ft.limits[v.ID]
+               if !ok {
+                       return r
+               }
+               c := w.AuxInt
+               switch d {
+               case signed:
+                       switch {
+                       case c < lim.min:
+                               r = gt
+                       case c > lim.max:
+                               r = lt
+                       case c == lim.min && c == lim.max:
+                               r = eq
+                       case c == lim.min:
+                               r = gt | eq
+                       case c == lim.max:
+                               r = lt | eq
+                       }
+               case unsigned:
+                       // TODO: also use signed data if lim.min >= 0?
+                       var uc uint64
+                       switch w.Op {
+                       case OpConst64:
+                               uc = uint64(c)
+                       case OpConst32:
+                               uc = uint64(uint32(c))
+                       case OpConst16:
+                               uc = uint64(uint16(c))
+                       case OpConst8:
+                               uc = uint64(uint8(c))
+                       }
+                       switch {
+                       case uc < lim.umin:
+                               r = gt
+                       case uc > lim.umax:
+                               r = lt
+                       case uc == lim.umin && uc == lim.umax:
+                               r = eq
+                       case uc == lim.umin:
+                               r = gt | eq
+                       case uc == lim.umax:
+                               r = lt | eq
+                       }
+               }
+               if reversed {
+                       return reverseBits[r]
+               }
+               return r
+       }
+
        reversed := false
        if lessByID(w, v) {
                v, w = w, v
-               reversed = true
+               reversed = !reversed
        }
 
        p := pair{v, w, d}
@@ -120,12 +201,106 @@ func (ft *factsTable) update(v, w *Value, d domain, r relation) {
        oldR := ft.get(v, w, d)
        ft.stack = append(ft.stack, fact{p, oldR})
        ft.facts[p] = oldR & r
+
+       // Extract bounds when comparing against constants
+       if v.isGenericIntConst() {
+               v, w = w, v
+               r = reverseBits[r]
+       }
+       if v != nil && w.isGenericIntConst() {
+               c := w.AuxInt
+               // Note: all the +1/-1 below could overflow/underflow. Either will
+               // still generate correct results, it will just lead to imprecision.
+               // In fact if there is overflow/underflow, the corresponding
+               // code is unreachable because the known range is outside the range
+               // of the value's type.
+               old, ok := ft.limits[v.ID]
+               if !ok {
+                       old = noLimit
+               }
+               lim := old
+               // Update lim with the new information we know.
+               switch d {
+               case signed:
+                       switch r {
+                       case lt:
+                               if c-1 < lim.max {
+                                       lim.max = c - 1
+                               }
+                       case lt | eq:
+                               if c < lim.max {
+                                       lim.max = c
+                               }
+                       case gt | eq:
+                               if c > lim.min {
+                                       lim.min = c
+                               }
+                       case gt:
+                               if c+1 > lim.min {
+                                       lim.min = c + 1
+                               }
+                       case lt | gt:
+                               if c == lim.min {
+                                       lim.min++
+                               }
+                               if c == lim.max {
+                                       lim.max--
+                               }
+                       case eq:
+                               lim.min = c
+                               lim.max = c
+                       }
+               case unsigned:
+                       var uc uint64
+                       switch w.Op {
+                       case OpConst64:
+                               uc = uint64(c)
+                       case OpConst32:
+                               uc = uint64(uint32(c))
+                       case OpConst16:
+                               uc = uint64(uint16(c))
+                       case OpConst8:
+                               uc = uint64(uint8(c))
+                       }
+                       switch r {
+                       case lt:
+                               if uc-1 < lim.umax {
+                                       lim.umax = uc - 1
+                               }
+                       case lt | eq:
+                               if uc < lim.umax {
+                                       lim.umax = uc
+                               }
+                       case gt | eq:
+                               if uc > lim.umin {
+                                       lim.umin = uc
+                               }
+                       case gt:
+                               if uc+1 > lim.umin {
+                                       lim.umin = uc + 1
+                               }
+                       case lt | gt:
+                               if uc == lim.umin {
+                                       lim.umin++
+                               }
+                               if uc == lim.umax {
+                                       lim.umax--
+                               }
+                       case eq:
+                               lim.umin = uc
+                               lim.umax = uc
+                       }
+               }
+               ft.limitStack = append(ft.limitStack, limitFact{v.ID, old})
+               ft.limits[v.ID] = lim
+       }
 }
 
 // checkpoint saves the current state of known relations.
 // Called when descending on a branch.
 func (ft *factsTable) checkpoint() {
        ft.stack = append(ft.stack, checkpointFact)
+       ft.limitStack = append(ft.limitStack, checkpointBound)
 }
 
 // restore restores known relation to the state just
@@ -144,6 +319,18 @@ func (ft *factsTable) restore() {
                        ft.facts[old.p] = old.r
                }
        }
+       for {
+               old := ft.limitStack[len(ft.limitStack)-1]
+               ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
+               if old.vid == 0 { // checkpointBound
+                       break
+               }
+               if old.limit == noLimit {
+                       delete(ft.limits, old.vid)
+               } else {
+                       ft.limits[old.vid] = old.limit
+               }
+       }
 }
 
 func lessByID(v, w *Value) bool {
@@ -421,6 +608,7 @@ func simplifyBlock(ft *factsTable, b *Block) branch {
        // to the upper bound than this is proven. Most useful in cases such as:
        // if len(a) <= 1 { return }
        // do something with a[1]
+       // TODO: use constant bounds to do isNonNegative.
        if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && isNonNegative(c.Args[0]) {
                m := ft.get(a0, a1, signed)
                if m != 0 && tr.r&m == m {
index baa351169e172e83835a30e417eb5821f8c4da68..fd4eb64db17ce40e864151df0f6a7477673e7e98 100644 (file)
@@ -218,6 +218,11 @@ func (v *Value) Unimplementedf(msg string, args ...interface{}) {
        v.Block.Func.Config.Unimplementedf(v.Line, msg, args...)
 }
 
+// isGenericIntConst returns whether v is a generic integer constant.
+func (v *Value) isGenericIntConst() bool {
+       return v != nil && (v.Op == OpConst64 || v.Op == OpConst32 || v.Op == OpConst16 || v.Op == OpConst8)
+}
+
 // ExternSymbol is an aux value that encodes a variable's
 // constant offset from the static base pointer.
 type ExternSymbol struct {
index 225ecd7d7a5d4b8210e4ff6482291207262d2bbd..ada5768695681be6d46a52d0284d2601d57bdb1a 100644 (file)
@@ -49,23 +49,23 @@ var BigEndian bigEndian
 type littleEndian struct{}
 
 func (littleEndian) Uint16(b []byte) uint16 {
-       b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
+       _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
        return uint16(b[0]) | uint16(b[1])<<8
 }
 
 func (littleEndian) PutUint16(b []byte, v uint16) {
-       b = b[:2:len(b)] // early bounds check to guarantee safety of writes below
+       _ = b[1] // early bounds check to guarantee safety of writes below
        b[0] = byte(v)
        b[1] = byte(v >> 8)
 }
 
 func (littleEndian) Uint32(b []byte) uint32 {
-       b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
+       _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
        return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
 }
 
 func (littleEndian) PutUint32(b []byte, v uint32) {
-       b = b[:4:len(b)] // early bounds check to guarantee safety of writes below
+       _ = b[3] // early bounds check to guarantee safety of writes below
        b[0] = byte(v)
        b[1] = byte(v >> 8)
        b[2] = byte(v >> 16)
@@ -73,13 +73,13 @@ func (littleEndian) PutUint32(b []byte, v uint32) {
 }
 
 func (littleEndian) Uint64(b []byte) uint64 {
-       b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
+       _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
        return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
                uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
 }
 
 func (littleEndian) PutUint64(b []byte, v uint64) {
-       b = b[:8:len(b)] // early bounds check to guarantee safety of writes below
+       _ = b[7] // early bounds check to guarantee safety of writes below
        b[0] = byte(v)
        b[1] = byte(v >> 8)
        b[2] = byte(v >> 16)
@@ -97,23 +97,23 @@ func (littleEndian) GoString() string { return "binary.LittleEndian" }
 type bigEndian struct{}
 
 func (bigEndian) Uint16(b []byte) uint16 {
-       b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
+       _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
        return uint16(b[1]) | uint16(b[0])<<8
 }
 
 func (bigEndian) PutUint16(b []byte, v uint16) {
-       b = b[:2:len(b)] // early bounds check to guarantee safety of writes below
+       _ = b[1] // early bounds check to guarantee safety of writes below
        b[0] = byte(v >> 8)
        b[1] = byte(v)
 }
 
 func (bigEndian) Uint32(b []byte) uint32 {
-       b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
+       _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
        return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
 }
 
 func (bigEndian) PutUint32(b []byte, v uint32) {
-       b = b[:4:len(b)] // early bounds check to guarantee safety of writes below
+       _ = b[3] // early bounds check to guarantee safety of writes below
        b[0] = byte(v >> 24)
        b[1] = byte(v >> 16)
        b[2] = byte(v >> 8)
@@ -121,13 +121,13 @@ func (bigEndian) PutUint32(b []byte, v uint32) {
 }
 
 func (bigEndian) Uint64(b []byte) uint64 {
-       b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808
+       _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
        return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
                uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
 }
 
 func (bigEndian) PutUint64(b []byte, v uint64) {
-       b = b[:8:len(b)] // early bounds check to guarantee safety of writes below
+       _ = b[7] // early bounds check to guarantee safety of writes below
        b[0] = byte(v >> 56)
        b[1] = byte(v >> 48)
        b[2] = byte(v >> 40)
index e5e5b544cf3f4d7a74d23e18c9c0df6ca67451f2..fc2908eb033a165050a412651bdb678e7d27a792 100644 (file)
@@ -3,12 +3,14 @@
 
 package main
 
+import "math"
+
 func f0(a []int) int {
        a[0] = 1
        a[0] = 1 // ERROR "Proved boolean IsInBounds$"
        a[6] = 1
        a[6] = 1 // ERROR "Proved boolean IsInBounds$"
-       a[5] = 1
+       a[5] = 1 // ERROR "Proved IsInBounds$"
        a[5] = 1 // ERROR "Proved boolean IsInBounds$"
        return 13
 }
@@ -17,15 +19,25 @@ func f1(a []int) int {
        if len(a) <= 5 {
                return 18
        }
-       a[0] = 1
+       a[0] = 1 // ERROR "Proved non-negative bounds IsInBounds$"
        a[0] = 1 // ERROR "Proved boolean IsInBounds$"
        a[6] = 1
        a[6] = 1 // ERROR "Proved boolean IsInBounds$"
-       a[5] = 1 // ERROR "Proved non-negative bounds IsInBounds$"
+       a[5] = 1 // ERROR "Proved IsInBounds$"
        a[5] = 1 // ERROR "Proved boolean IsInBounds$"
        return 26
 }
 
+func f1b(a []int, i int, j uint) int {
+       if i >= 0 && i < len(a) { // TODO: handle this case
+               return a[i]
+       }
+       if j < uint(len(a)) {
+               return a[j] // ERROR "Proved IsInBounds"
+       }
+       return 0
+}
+
 func f2(a []int) int {
        for i := range a {
                a[i] = i
@@ -245,6 +257,164 @@ func f12(a []int, b int) {
        useSlice(a[:b])
 }
 
+func f13a(a, b, c int, x bool) int {
+       if a > 12 {
+               if x {
+                       if a < 12 { // ERROR "Disproved Less64$"
+                               return 1
+                       }
+               }
+               if x {
+                       if a <= 12 { // ERROR "Disproved Leq64$"
+                               return 2
+                       }
+               }
+               if x {
+                       if a == 12 { // ERROR "Disproved Eq64$"
+                               return 3
+                       }
+               }
+               if x {
+                       if a >= 12 { // ERROR "Proved Geq64$"
+                               return 4
+                       }
+               }
+               if x {
+                       if a > 12 { // ERROR "Proved boolean Greater64$"
+                               return 5
+                       }
+               }
+               return 6
+       }
+       return 0
+}
+
+func f13b(a int, x bool) int {
+       if a == -9 {
+               if x {
+                       if a < -9 { // ERROR "Disproved Less64$"
+                               return 7
+                       }
+               }
+               if x {
+                       if a <= -9 { // ERROR "Proved Leq64$"
+                               return 8
+                       }
+               }
+               if x {
+                       if a == -9 { // ERROR "Proved boolean Eq64$"
+                               return 9
+                       }
+               }
+               if x {
+                       if a >= -9 { // ERROR "Proved Geq64$"
+                               return 10
+                       }
+               }
+               if x {
+                       if a > -9 { // ERROR "Disproved Greater64$"
+                               return 11
+                       }
+               }
+               return 12
+       }
+       return 0
+}
+
+func f13c(a int, x bool) int {
+       if a < 90 {
+               if x {
+                       if a < 90 { // ERROR "Proved boolean Less64$"
+                               return 13
+                       }
+               }
+               if x {
+                       if a <= 90 { // ERROR "Proved Leq64$"
+                               return 14
+                       }
+               }
+               if x {
+                       if a == 90 { // ERROR "Disproved Eq64$"
+                               return 15
+                       }
+               }
+               if x {
+                       if a >= 90 { // ERROR "Disproved Geq64$"
+                               return 16
+                       }
+               }
+               if x {
+                       if a > 90 { // ERROR "Disproved Greater64$"
+                               return 17
+                       }
+               }
+               return 18
+       }
+       return 0
+}
+
+func f13d(a int) int {
+       if a < 5 {
+               if a < 9 { // ERROR "Proved Less64$"
+                       return 1
+               }
+       }
+       return 0
+}
+
+func f13e(a int) int {
+       if a > 9 {
+               if a > 5 { // ERROR "Proved Greater64$"
+                       return 1
+               }
+       }
+       return 0
+}
+
+func f13f(a int64) int64 {
+       if a > math.MaxInt64 {
+               // Unreachable, but prove doesn't know that.
+               if a == 0 {
+                       return 1
+               }
+       }
+       return 0
+}
+
+func f13g(a int) int {
+       if a < 3 {
+               return 5
+       }
+       if a > 3 {
+               return 6
+       }
+       if a == 3 { // ERROR "Proved Eq64$"
+               return 7
+       }
+       return 8
+}
+
+func f13h(a int) int {
+       if a < 3 {
+               if a > 1 {
+                       if a == 2 { // ERROR "Proved Eq64$"
+                               return 5
+                       }
+               }
+       }
+       return 0
+}
+
+func f13i(a uint) int {
+       if a == 0 {
+               return 1
+       }
+       if a > 0 { // ERROR "Proved Greater64U$"
+               return 2
+       }
+       return 3
+}
+
 //go:noinline
 func useInt(a int) {
 }