From 47c9e139aed3e6d24ee980828835fb03229272ff Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Wed, 23 Mar 2016 10:20:44 -0700 Subject: [PATCH] cmd/compile: extend prove pass to handle constant comparisons MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Find comparisons to constants and propagate that information down the dominator tree. Use it to resolve other constant comparisons on the same variable. So if we know x >= 7, then a x > 4 condition must return true. This change allows us to use "_ = b[7]" hints to eliminate bounds checks. Fixes #14900 Change-Id: Idbf230bd5b7da43de3ecb48706e21cf01bf812f7 Reviewed-on: https://go-review.googlesource.com/21008 Reviewed-by: Alexandru Moșoi --- src/cmd/compile/internal/ssa/prove.go | 190 +++++++++++++++++++++++++- src/cmd/compile/internal/ssa/value.go | 5 + src/encoding/binary/binary.go | 24 ++-- test/prove.go | 176 +++++++++++++++++++++++- 4 files changed, 379 insertions(+), 16 deletions(-) diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index f09a3c5e04..6054541c3b 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -4,6 +4,8 @@ package ssa +import "math" + type branch int const ( @@ -66,30 +68,109 @@ type fact struct { r relation } +// a limit records known upper and lower bounds for a value. +type limit struct { + min, max int64 // min <= value <= max, signed + umin, umax uint64 // umin <= value <= umax, unsigned +} + +var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64} + +// a limitFact is a limit known for a particular value. +type limitFact struct { + vid ID + limit limit +} + // factsTable keeps track of relations between pairs of values. type factsTable struct { facts map[pair]relation // current known set of relation stack []fact // previous sets of relations + + // known lower and upper bounds on individual values. + limits map[ID]limit + limitStack []limitFact // previous entries } // checkpointFact is an invalid value used for checkpointing // and restoring factsTable. var checkpointFact = fact{} +var checkpointBound = limitFact{} func newFactsTable() *factsTable { ft := &factsTable{} ft.facts = make(map[pair]relation) ft.stack = make([]fact, 4) + ft.limits = make(map[ID]limit) + ft.limitStack = make([]limitFact, 4) return ft } // get returns the known possible relations between v and w. // If v and w are not in the map it returns lt|eq|gt, i.e. any order. func (ft *factsTable) get(v, w *Value, d domain) relation { + if v.isGenericIntConst() || w.isGenericIntConst() { + reversed := false + if v.isGenericIntConst() { + v, w = w, v + reversed = true + } + r := lt | eq | gt + lim, ok := ft.limits[v.ID] + if !ok { + return r + } + c := w.AuxInt + switch d { + case signed: + switch { + case c < lim.min: + r = gt + case c > lim.max: + r = lt + case c == lim.min && c == lim.max: + r = eq + case c == lim.min: + r = gt | eq + case c == lim.max: + r = lt | eq + } + case unsigned: + // TODO: also use signed data if lim.min >= 0? + var uc uint64 + switch w.Op { + case OpConst64: + uc = uint64(c) + case OpConst32: + uc = uint64(uint32(c)) + case OpConst16: + uc = uint64(uint16(c)) + case OpConst8: + uc = uint64(uint8(c)) + } + switch { + case uc < lim.umin: + r = gt + case uc > lim.umax: + r = lt + case uc == lim.umin && uc == lim.umax: + r = eq + case uc == lim.umin: + r = gt | eq + case uc == lim.umax: + r = lt | eq + } + } + if reversed { + return reverseBits[r] + } + return r + } + reversed := false if lessByID(w, v) { v, w = w, v - reversed = true + reversed = !reversed } p := pair{v, w, d} @@ -120,12 +201,106 @@ func (ft *factsTable) update(v, w *Value, d domain, r relation) { oldR := ft.get(v, w, d) ft.stack = append(ft.stack, fact{p, oldR}) ft.facts[p] = oldR & r + + // Extract bounds when comparing against constants + if v.isGenericIntConst() { + v, w = w, v + r = reverseBits[r] + } + if v != nil && w.isGenericIntConst() { + c := w.AuxInt + // Note: all the +1/-1 below could overflow/underflow. Either will + // still generate correct results, it will just lead to imprecision. + // In fact if there is overflow/underflow, the corresponding + // code is unreachable because the known range is outside the range + // of the value's type. + old, ok := ft.limits[v.ID] + if !ok { + old = noLimit + } + lim := old + // Update lim with the new information we know. + switch d { + case signed: + switch r { + case lt: + if c-1 < lim.max { + lim.max = c - 1 + } + case lt | eq: + if c < lim.max { + lim.max = c + } + case gt | eq: + if c > lim.min { + lim.min = c + } + case gt: + if c+1 > lim.min { + lim.min = c + 1 + } + case lt | gt: + if c == lim.min { + lim.min++ + } + if c == lim.max { + lim.max-- + } + case eq: + lim.min = c + lim.max = c + } + case unsigned: + var uc uint64 + switch w.Op { + case OpConst64: + uc = uint64(c) + case OpConst32: + uc = uint64(uint32(c)) + case OpConst16: + uc = uint64(uint16(c)) + case OpConst8: + uc = uint64(uint8(c)) + } + switch r { + case lt: + if uc-1 < lim.umax { + lim.umax = uc - 1 + } + case lt | eq: + if uc < lim.umax { + lim.umax = uc + } + case gt | eq: + if uc > lim.umin { + lim.umin = uc + } + case gt: + if uc+1 > lim.umin { + lim.umin = uc + 1 + } + case lt | gt: + if uc == lim.umin { + lim.umin++ + } + if uc == lim.umax { + lim.umax-- + } + case eq: + lim.umin = uc + lim.umax = uc + } + } + ft.limitStack = append(ft.limitStack, limitFact{v.ID, old}) + ft.limits[v.ID] = lim + } } // checkpoint saves the current state of known relations. // Called when descending on a branch. func (ft *factsTable) checkpoint() { ft.stack = append(ft.stack, checkpointFact) + ft.limitStack = append(ft.limitStack, checkpointBound) } // restore restores known relation to the state just @@ -144,6 +319,18 @@ func (ft *factsTable) restore() { ft.facts[old.p] = old.r } } + for { + old := ft.limitStack[len(ft.limitStack)-1] + ft.limitStack = ft.limitStack[:len(ft.limitStack)-1] + if old.vid == 0 { // checkpointBound + break + } + if old.limit == noLimit { + delete(ft.limits, old.vid) + } else { + ft.limits[old.vid] = old.limit + } + } } func lessByID(v, w *Value) bool { @@ -421,6 +608,7 @@ func simplifyBlock(ft *factsTable, b *Block) branch { // to the upper bound than this is proven. Most useful in cases such as: // if len(a) <= 1 { return } // do something with a[1] + // TODO: use constant bounds to do isNonNegative. if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && isNonNegative(c.Args[0]) { m := ft.get(a0, a1, signed) if m != 0 && tr.r&m == m { diff --git a/src/cmd/compile/internal/ssa/value.go b/src/cmd/compile/internal/ssa/value.go index baa351169e..fd4eb64db1 100644 --- a/src/cmd/compile/internal/ssa/value.go +++ b/src/cmd/compile/internal/ssa/value.go @@ -218,6 +218,11 @@ func (v *Value) Unimplementedf(msg string, args ...interface{}) { v.Block.Func.Config.Unimplementedf(v.Line, msg, args...) } +// isGenericIntConst returns whether v is a generic integer constant. +func (v *Value) isGenericIntConst() bool { + return v != nil && (v.Op == OpConst64 || v.Op == OpConst32 || v.Op == OpConst16 || v.Op == OpConst8) +} + // ExternSymbol is an aux value that encodes a variable's // constant offset from the static base pointer. type ExternSymbol struct { diff --git a/src/encoding/binary/binary.go b/src/encoding/binary/binary.go index 225ecd7d7a..ada5768695 100644 --- a/src/encoding/binary/binary.go +++ b/src/encoding/binary/binary.go @@ -49,23 +49,23 @@ var BigEndian bigEndian type littleEndian struct{} func (littleEndian) Uint16(b []byte) uint16 { - b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 return uint16(b[0]) | uint16(b[1])<<8 } func (littleEndian) PutUint16(b []byte, v uint16) { - b = b[:2:len(b)] // early bounds check to guarantee safety of writes below + _ = b[1] // early bounds check to guarantee safety of writes below b[0] = byte(v) b[1] = byte(v >> 8) } func (littleEndian) Uint32(b []byte) uint32 { - b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808 return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 } func (littleEndian) PutUint32(b []byte, v uint32) { - b = b[:4:len(b)] // early bounds check to guarantee safety of writes below + _ = b[3] // early bounds check to guarantee safety of writes below b[0] = byte(v) b[1] = byte(v >> 8) b[2] = byte(v >> 16) @@ -73,13 +73,13 @@ func (littleEndian) PutUint32(b []byte, v uint32) { } func (littleEndian) Uint64(b []byte) uint64 { - b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 } func (littleEndian) PutUint64(b []byte, v uint64) { - b = b[:8:len(b)] // early bounds check to guarantee safety of writes below + _ = b[7] // early bounds check to guarantee safety of writes below b[0] = byte(v) b[1] = byte(v >> 8) b[2] = byte(v >> 16) @@ -97,23 +97,23 @@ func (littleEndian) GoString() string { return "binary.LittleEndian" } type bigEndian struct{} func (bigEndian) Uint16(b []byte) uint16 { - b = b[:2:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 return uint16(b[1]) | uint16(b[0])<<8 } func (bigEndian) PutUint16(b []byte, v uint16) { - b = b[:2:len(b)] // early bounds check to guarantee safety of writes below + _ = b[1] // early bounds check to guarantee safety of writes below b[0] = byte(v >> 8) b[1] = byte(v) } func (bigEndian) Uint32(b []byte) uint32 { - b = b[:4:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808 return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 } func (bigEndian) PutUint32(b []byte, v uint32) { - b = b[:4:len(b)] // early bounds check to guarantee safety of writes below + _ = b[3] // early bounds check to guarantee safety of writes below b[0] = byte(v >> 24) b[1] = byte(v >> 16) b[2] = byte(v >> 8) @@ -121,13 +121,13 @@ func (bigEndian) PutUint32(b []byte, v uint32) { } func (bigEndian) Uint64(b []byte) uint64 { - b = b[:8:len(b)] // bounds check hint to compiler; see golang.org/issue/14808 + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 } func (bigEndian) PutUint64(b []byte, v uint64) { - b = b[:8:len(b)] // early bounds check to guarantee safety of writes below + _ = b[7] // early bounds check to guarantee safety of writes below b[0] = byte(v >> 56) b[1] = byte(v >> 48) b[2] = byte(v >> 40) diff --git a/test/prove.go b/test/prove.go index e5e5b544cf..fc2908eb03 100644 --- a/test/prove.go +++ b/test/prove.go @@ -3,12 +3,14 @@ package main +import "math" + func f0(a []int) int { a[0] = 1 a[0] = 1 // ERROR "Proved boolean IsInBounds$" a[6] = 1 a[6] = 1 // ERROR "Proved boolean IsInBounds$" - a[5] = 1 + a[5] = 1 // ERROR "Proved IsInBounds$" a[5] = 1 // ERROR "Proved boolean IsInBounds$" return 13 } @@ -17,15 +19,25 @@ func f1(a []int) int { if len(a) <= 5 { return 18 } - a[0] = 1 + a[0] = 1 // ERROR "Proved non-negative bounds IsInBounds$" a[0] = 1 // ERROR "Proved boolean IsInBounds$" a[6] = 1 a[6] = 1 // ERROR "Proved boolean IsInBounds$" - a[5] = 1 // ERROR "Proved non-negative bounds IsInBounds$" + a[5] = 1 // ERROR "Proved IsInBounds$" a[5] = 1 // ERROR "Proved boolean IsInBounds$" return 26 } +func f1b(a []int, i int, j uint) int { + if i >= 0 && i < len(a) { // TODO: handle this case + return a[i] + } + if j < uint(len(a)) { + return a[j] // ERROR "Proved IsInBounds" + } + return 0 +} + func f2(a []int) int { for i := range a { a[i] = i @@ -245,6 +257,164 @@ func f12(a []int, b int) { useSlice(a[:b]) } +func f13a(a, b, c int, x bool) int { + if a > 12 { + if x { + if a < 12 { // ERROR "Disproved Less64$" + return 1 + } + } + if x { + if a <= 12 { // ERROR "Disproved Leq64$" + return 2 + } + } + if x { + if a == 12 { // ERROR "Disproved Eq64$" + return 3 + } + } + if x { + if a >= 12 { // ERROR "Proved Geq64$" + return 4 + } + } + if x { + if a > 12 { // ERROR "Proved boolean Greater64$" + return 5 + } + } + return 6 + } + return 0 +} + +func f13b(a int, x bool) int { + if a == -9 { + if x { + if a < -9 { // ERROR "Disproved Less64$" + return 7 + } + } + if x { + if a <= -9 { // ERROR "Proved Leq64$" + return 8 + } + } + if x { + if a == -9 { // ERROR "Proved boolean Eq64$" + return 9 + } + } + if x { + if a >= -9 { // ERROR "Proved Geq64$" + return 10 + } + } + if x { + if a > -9 { // ERROR "Disproved Greater64$" + return 11 + } + } + return 12 + } + return 0 +} + +func f13c(a int, x bool) int { + if a < 90 { + if x { + if a < 90 { // ERROR "Proved boolean Less64$" + return 13 + } + } + if x { + if a <= 90 { // ERROR "Proved Leq64$" + return 14 + } + } + if x { + if a == 90 { // ERROR "Disproved Eq64$" + return 15 + } + } + if x { + if a >= 90 { // ERROR "Disproved Geq64$" + return 16 + } + } + if x { + if a > 90 { // ERROR "Disproved Greater64$" + return 17 + } + } + return 18 + } + return 0 +} + +func f13d(a int) int { + if a < 5 { + if a < 9 { // ERROR "Proved Less64$" + return 1 + } + } + return 0 +} + +func f13e(a int) int { + if a > 9 { + if a > 5 { // ERROR "Proved Greater64$" + return 1 + } + } + return 0 +} + +func f13f(a int64) int64 { + if a > math.MaxInt64 { + // Unreachable, but prove doesn't know that. + if a == 0 { + return 1 + } + } + return 0 +} + +func f13g(a int) int { + if a < 3 { + return 5 + } + if a > 3 { + return 6 + } + if a == 3 { // ERROR "Proved Eq64$" + return 7 + } + return 8 +} + +func f13h(a int) int { + if a < 3 { + if a > 1 { + if a == 2 { // ERROR "Proved Eq64$" + return 5 + } + } + } + return 0 +} + +func f13i(a uint) int { + if a == 0 { + return 1 + } + if a > 0 { // ERROR "Proved Greater64U$" + return 2 + } + return 3 +} + //go:noinline func useInt(a int) { } -- 2.48.1