From: Keith Randall Date: Wed, 23 Mar 2016 17:20:44 +0000 (-0700) Subject: cmd/compile: extend prove pass to handle constant comparisons X-Git-Tag: go1.7beta1~976 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=47c9e139aed3e6d24ee980828835fb03229272ff;p=gostls13.git cmd/compile: extend prove pass to handle constant comparisons Find comparisons to constants and propagate that information down the dominator tree. Use it to resolve other constant comparisons on the same variable. So if we know x >= 7, then a x > 4 condition must return true. This change allows us to use "_ = b[7]" hints to eliminate bounds checks. Fixes #14900 Change-Id: Idbf230bd5b7da43de3ecb48706e21cf01bf812f7 Reviewed-on: https://go-review.googlesource.com/21008 Reviewed-by: Alexandru Moșoi --- diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index f09a3c5e04..6054541c3b 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -4,6 +4,8 @@ package ssa +import "math" + type branch int const ( @@ -66,30 +68,109 @@ type fact struct { r relation } +// a limit records known upper and lower bounds for a value. +type limit struct { + min, max int64 // min <= value <= max, signed + umin, umax uint64 // umin <= value <= umax, unsigned +} + +var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64} + +// a limitFact is a limit known for a particular value. +type limitFact struct { + vid ID + limit limit +} + // factsTable keeps track of relations between pairs of values. type factsTable struct { facts map[pair]relation // current known set of relation stack []fact // previous sets of relations + + // known lower and upper bounds on individual values. + limits map[ID]limit + limitStack []limitFact // previous entries } // checkpointFact is an invalid value used for checkpointing // and restoring factsTable. var checkpointFact = fact{} +var checkpointBound = limitFact{} func newFactsTable() *factsTable { ft := &factsTable{} ft.facts = make(map[pair]relation) ft.stack = make([]fact, 4) + ft.limits = make(map[ID]limit) + ft.limitStack = make([]limitFact, 4) return ft } // get returns the known possible relations between v and w. // If v and w are not in the map it returns lt|eq|gt, i.e. any order. func (ft *factsTable) get(v, w *Value, d domain) relation { + if v.isGenericIntConst() || w.isGenericIntConst() { + reversed := false + if v.isGenericIntConst() { + v, w = w, v + reversed = true + } + r := lt | eq | gt + lim, ok := ft.limits[v.ID] + if !ok { + return r + } + c := w.AuxInt + switch d { + case signed: + switch { + case c < lim.min: + r = gt + case c > lim.max: + r = lt + case c == lim.min && c == lim.max: + r = eq + case c == lim.min: + r = gt | eq + case c == lim.max: + r = lt | eq + } + case unsigned: + // TODO: also use signed data if lim.min >= 0? + var uc uint64 + switch w.Op { + case OpConst64: + uc = uint64(c) + case OpConst32: + uc = uint64(uint32(c)) + case OpConst16: + uc = uint64(uint16(c)) + case OpConst8: + uc = uint64(uint8(c)) + } + switch { + case uc < lim.umin: + r = gt + case uc > lim.umax: + r = lt + case uc == lim.umin && uc == lim.umax: + r = eq + case uc == lim.umin: + r = gt | eq + case uc == lim.umax: + r = lt | eq + } + } + if reversed { + return reverseBits[r] + } + return r + } + reversed := false if lessByID(w, v) { v, w = w, v - reversed = true + reversed = !reversed } p := pair{v, w, d} @@ -120,12 +201,106 @@ func (ft *factsTable) update(v, w *Value, d domain, r relation) { oldR := ft.get(v, w, d) ft.stack = append(ft.stack, fact{p, oldR}) ft.facts[p] = oldR & r + + // Extract bounds when comparing against constants + if v.isGenericIntConst() { + v, w = w, v + r = reverseBits[r] + } + if v != nil && w.isGenericIntConst() { + c := w.AuxInt + // Note: all the +1/-1 below could overflow/underflow. Either will + // still generate correct results, it will just lead to imprecision. + // In fact if there is overflow/underflow, the corresponding + // code is unreachable because the known range is outside the range + // of the value's type. + old, ok := ft.limits[v.ID] + if !ok { + old = noLimit + } + lim := old + // Update lim with the new information we know. + switch d { + case signed: + switch r { + case lt: + if c-1 < lim.max { + lim.max = c - 1 + } + case lt | eq: + if c < lim.max { + lim.max = c + } + case gt | eq: + if c > lim.min { + lim.min = c + } + case gt: + if c+1 > lim.min { + lim.min = c + 1 + } + case lt | gt: + if c == lim.min { + lim.min++ + } + if c == lim.max { + lim.max-- + } + case eq: + lim.min = c + lim.max = c + } + case unsigned: + var uc uint64 + switch w.Op { + case OpConst64: + uc = uint64(c) + case OpConst32: + uc = uint64(uint32(c)) + case OpConst16: + uc = uint64(uint16(c)) + case OpConst8: + uc = uint64(uint8(c)) + } + switch r { + case lt: + if uc-1 < lim.umax { + lim.umax = uc - 1 + } + case lt | eq: + if uc < lim.umax { + lim.umax = uc + } + case gt | eq: + if uc > lim.umin { + lim.umin = uc + } + case gt: + if uc+1 > lim.umin { + lim.umin = uc + 1 + } + case lt | gt: + if uc == lim.umin { + lim.umin++ + } + if uc == lim.umax { + lim.umax-- + } + case eq: + lim.umin = uc + lim.umax = uc + } + } + ft.limitStack = append(ft.limitStack, limitFact{v.ID, old}) + ft.limits[v.ID] = lim + } } // checkpoint saves the current state of known relations. // Called when descending on a branch. func (ft *factsTable) checkpoint() { ft.stack = append(ft.stack, checkpointFact) + ft.limitStack = append(ft.limitStack, checkpointBound) } // restore restores known relation to the state just @@ -144,6 +319,18 @@ func (ft *factsTable) restore() { ft.facts[old.p] = old.r } } + for { + old := ft.limitStack[len(ft.limitStack)-1] + ft.limitStack = ft.limitStack[:len(ft.limitStack)-1] + if old.vid == 0 { // checkpointBound + break + } + if old.limit == noLimit { + delete(ft.limits, old.vid) + } else { + ft.limits[old.vid] = old.limit + } + } } func lessByID(v, w *Value) bool { @@ -421,6 +608,7 @@ func simplifyBlock(ft *factsTable, b *Block) branch { // to the upper bound than this is proven. Most useful in cases such as: // if len(a) <= 1 { return } // do something with a[1] + // TODO: use constant bounds to do isNonNegative. if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && isNonNegative(c.Args[0]) { m := ft.get(a0, a1, signed) if m != 0 && tr.r&m == m { diff --git a/src/cmd/compile/internal/ssa/value.go b/src/cmd/compile/internal/ssa/value.go index baa351169e..fd4eb64db1 100644 --- a/src/cmd/compile/internal/ssa/value.go +++ b/src/cmd/compile/internal/ssa/value.go @@ -218,6 +218,11 @@ func (v *Value) Unimplementedf(msg string, args ...interface{}) { v.Block.Func.Config.Unimplementedf(v.Line, msg, args...) } +// isGenericIntConst returns whether v is a generic integer constant. +func (v *Value) isGenericIntConst() bool { + return v != nil && (v.Op == OpConst64 || v.Op == OpConst32 || v.Op == OpConst16 || v.Op == OpConst8) +} + // ExternSymbol is an aux value that encodes a variable's // constant offset from the static base pointer. type ExternSymbol struct { diff --git a/src/encoding/binary/binary.go b/src/encoding/binary/binary.go index 225ecd7d7a..ada5768695 100644 --- a/src/encoding/binary/binary.go +++ b/src/encoding/binary/binary.go @@ -49,23 +49,23 @@ var BigEndian bigEndian type littleEndian struct{} func (littleEndian) Uint16(b []byte) uint16 { - b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 return uint16(b[0]) | uint16(b[1])<<8 } func (littleEndian) PutUint16(b []byte, v uint16) { - b = b[:2:len(b)] // early bounds check to guarantee safety of writes below + _ = b[1] // early bounds check to guarantee safety of writes below b[0] = byte(v) b[1] = byte(v >> 8) } func (littleEndian) Uint32(b []byte) uint32 { - b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808 return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 } func (littleEndian) PutUint32(b []byte, v uint32) { - b = b[:4:len(b)] // early bounds check to guarantee safety of writes below + _ = b[3] // early bounds check to guarantee safety of writes below b[0] = byte(v) b[1] = byte(v >> 8) b[2] = byte(v >> 16) @@ -73,13 +73,13 @@ func (littleEndian) PutUint32(b []byte, v uint32) { } func (littleEndian) Uint64(b []byte) uint64 { - b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 } func (littleEndian) PutUint64(b []byte, v uint64) { - b = b[:8:len(b)] // early bounds check to guarantee safety of writes below + _ = b[7] // early bounds check to guarantee safety of writes below b[0] = byte(v) b[1] = byte(v >> 8) b[2] = byte(v >> 16) @@ -97,23 +97,23 @@ func (littleEndian) GoString() string { return "binary.LittleEndian" } type bigEndian struct{} func (bigEndian) Uint16(b []byte) uint16 { - b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 return uint16(b[1]) | uint16(b[0])<<8 } func (bigEndian) PutUint16(b []byte, v uint16) { - b = b[:2:len(b)] // early bounds check to guarantee safety of writes below + _ = b[1] // early bounds check to guarantee safety of writes below b[0] = byte(v >> 8) b[1] = byte(v) } func (bigEndian) Uint32(b []byte) uint32 { - b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808 return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 } func (bigEndian) PutUint32(b []byte, v uint32) { - b = b[:4:len(b)] // early bounds check to guarantee safety of writes below + _ = b[3] // early bounds check to guarantee safety of writes below b[0] = byte(v >> 24) b[1] = byte(v >> 16) b[2] = byte(v >> 8) @@ -121,13 +121,13 @@ func (bigEndian) PutUint32(b []byte, v uint32) { } func (bigEndian) Uint64(b []byte) uint64 { - b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 } func (bigEndian) PutUint64(b []byte, v uint64) { - b = b[:8:len(b)] // early bounds check to guarantee safety of writes below + _ = b[7] // early bounds check to guarantee safety of writes below b[0] = byte(v >> 56) b[1] = byte(v >> 48) b[2] = byte(v >> 40) diff --git a/test/prove.go b/test/prove.go index e5e5b544cf..fc2908eb03 100644 --- a/test/prove.go +++ b/test/prove.go @@ -3,12 +3,14 @@ package main +import "math" + func f0(a []int) int { a[0] = 1 a[0] = 1 // ERROR "Proved boolean IsInBounds$" a[6] = 1 a[6] = 1 // ERROR "Proved boolean IsInBounds$" - a[5] = 1 + a[5] = 1 // ERROR "Proved IsInBounds$" a[5] = 1 // ERROR "Proved boolean IsInBounds$" return 13 } @@ -17,15 +19,25 @@ func f1(a []int) int { if len(a) <= 5 { return 18 } - a[0] = 1 + a[0] = 1 // ERROR "Proved non-negative bounds IsInBounds$" a[0] = 1 // ERROR "Proved boolean IsInBounds$" a[6] = 1 a[6] = 1 // ERROR "Proved boolean IsInBounds$" - a[5] = 1 // ERROR "Proved non-negative bounds IsInBounds$" + a[5] = 1 // ERROR "Proved IsInBounds$" a[5] = 1 // ERROR "Proved boolean IsInBounds$" return 26 } +func f1b(a []int, i int, j uint) int { + if i >= 0 && i < len(a) { // TODO: handle this case + return a[i] + } + if j < uint(len(a)) { + return a[j] // ERROR "Proved IsInBounds" + } + return 0 +} + func f2(a []int) int { for i := range a { a[i] = i @@ -245,6 +257,164 @@ func f12(a []int, b int) { useSlice(a[:b]) } +func f13a(a, b, c int, x bool) int { + if a > 12 { + if x { + if a < 12 { // ERROR "Disproved Less64$" + return 1 + } + } + if x { + if a <= 12 { // ERROR "Disproved Leq64$" + return 2 + } + } + if x { + if a == 12 { // ERROR "Disproved Eq64$" + return 3 + } + } + if x { + if a >= 12 { // ERROR "Proved Geq64$" + return 4 + } + } + if x { + if a > 12 { // ERROR "Proved boolean Greater64$" + return 5 + } + } + return 6 + } + return 0 +} + +func f13b(a int, x bool) int { + if a == -9 { + if x { + if a < -9 { // ERROR "Disproved Less64$" + return 7 + } + } + if x { + if a <= -9 { // ERROR "Proved Leq64$" + return 8 + } + } + if x { + if a == -9 { // ERROR "Proved boolean Eq64$" + return 9 + } + } + if x { + if a >= -9 { // ERROR "Proved Geq64$" + return 10 + } + } + if x { + if a > -9 { // ERROR "Disproved Greater64$" + return 11 + } + } + return 12 + } + return 0 +} + +func f13c(a int, x bool) int { + if a < 90 { + if x { + if a < 90 { // ERROR "Proved boolean Less64$" + return 13 + } + } + if x { + if a <= 90 { // ERROR "Proved Leq64$" + return 14 + } + } + if x { + if a == 90 { // ERROR "Disproved Eq64$" + return 15 + } + } + if x { + if a >= 90 { // ERROR "Disproved Geq64$" + return 16 + } + } + if x { + if a > 90 { // ERROR "Disproved Greater64$" + return 17 + } + } + return 18 + } + return 0 +} + +func f13d(a int) int { + if a < 5 { + if a < 9 { // ERROR "Proved Less64$" + return 1 + } + } + return 0 +} + +func f13e(a int) int { + if a > 9 { + if a > 5 { // ERROR "Proved Greater64$" + return 1 + } + } + return 0 +} + +func f13f(a int64) int64 { + if a > math.MaxInt64 { + // Unreachable, but prove doesn't know that. + if a == 0 { + return 1 + } + } + return 0 +} + +func f13g(a int) int { + if a < 3 { + return 5 + } + if a > 3 { + return 6 + } + if a == 3 { // ERROR "Proved Eq64$" + return 7 + } + return 8 +} + +func f13h(a int) int { + if a < 3 { + if a > 1 { + if a == 2 { // ERROR "Proved Eq64$" + return 5 + } + } + } + return 0 +} + +func f13i(a uint) int { + if a == 0 { + return 1 + } + if a > 0 { // ERROR "Proved Greater64U$" + return 2 + } + return 3 +} + //go:noinline func useInt(a int) { }