]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: handle boolean and pointer relations
authorkhr@golang.org <khr@golang.org>
Sun, 7 Jul 2024 21:58:47 +0000 (14:58 -0700)
committerKeith Randall <khr@golang.org>
Wed, 7 Aug 2024 16:07:55 +0000 (16:07 +0000)
The constant lattice for these types is pretty simple.
We no longer need the old-style facts table, as the ordering
table now has all that information.

Change-Id: If0e118c27a4de8e9bfd727b78942185c2eb50c4b
Reviewed-on: https://go-review.googlesource.com/c/go/+/599097
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/cmd/compile/internal/ssa/prove.go
test/fuse.go
test/prove.go

index 6091950be8bf339c6b19f6261b7998e7443c6249..c8d2ab7a6f6893db16b6ea02c4aac6a3d0c4977c 100644 (file)
@@ -128,9 +128,18 @@ type fact struct {
 }
 
 // a limit records known upper and lower bounds for a value.
+//
+// If we have min>max or umin>umax, then this limit is
+// called "unsatisfiable". When we encounter such a limit, we
+// know that any code for which that limit applies is unreachable.
+// We don't particularly care how unsatisfiable limits propagate,
+// including becoming satisfiable, because any optimization
+// decisions based on those limits only apply to unreachable code.
 type limit struct {
        min, max   int64  // min <= value <= max, signed
        umin, umax uint64 // umin <= value <= umax, unsigned
+       // For booleans, we use 0==false, 1==true for both ranges
+       // For pointers, we use 0,0,0,0 for nil and minInt64,maxInt64,1,maxUint64 for nonnil
 }
 
 func (l limit) String() string {
@@ -359,8 +368,9 @@ type ordering struct {
        next *ordering // linked list of all known orderings for v.
        // Note: v is implicit here, determined by which linked list it is in.
        w *Value
-       d domain   // one of signed or unsigned
+       d domain
        r relation // one of ==,!=,<,<=,>,>=
+       // if d is boolean or pointer, r can only be ==, !=
 }
 
 // factsTable keeps track of relations between pairs of values.
@@ -379,9 +389,6 @@ type factsTable struct {
        unsat      bool // true if facts contains a contradiction
        unsatDepth int  // number of unsat checkpoints
 
-       facts map[pair]relation // current known set of relation
-       stack []fact            // previous sets of relations
-
        // order* is a couple of partial order sets that record information
        // about relations between SSA values in the signed and unsigned
        // domain.
@@ -423,8 +430,6 @@ func newFactsTable(f *Func) *factsTable {
        ft.orderS.SetUnsigned(false)
        ft.orderU.SetUnsigned(true)
        ft.orderings = make(map[ID]*ordering)
-       ft.facts = make(map[pair]relation)
-       ft.stack = make([]fact, 4)
        ft.limits = f.Cache.allocLimitSlice(f.NumValues())
        for _, b := range f.Blocks {
                for _, v := range b.Values {
@@ -471,6 +476,21 @@ func (ft *factsTable) unsignedMinMax(v *Value, min, max uint64) bool {
        return ft.newLimit(v, limit{min: math.MinInt64, max: math.MaxInt64, umin: min, umax: max})
 }
 
+func (ft *factsTable) booleanFalse(v *Value) bool {
+       return ft.newLimit(v, limit{min: 0, max: 0, umin: 0, umax: 0})
+}
+func (ft *factsTable) booleanTrue(v *Value) bool {
+       return ft.newLimit(v, limit{min: 1, max: 1, umin: 1, umax: 1})
+}
+func (ft *factsTable) pointerNil(v *Value) bool {
+       return ft.newLimit(v, limit{min: 0, max: 0, umin: 0, umax: 0})
+}
+func (ft *factsTable) pointerNonNil(v *Value) bool {
+       l := noLimit
+       l.umin = 1
+       return ft.newLimit(v, l)
+}
+
 // newLimit adds new limiting information for v.
 // Returns true if the new limit added any new information.
 func (ft *factsTable) newLimit(v *Value, newLim limit) bool {
@@ -574,6 +594,38 @@ func (ft *factsTable) newLimit(v *Value, newLim limit) bool {
                                        }
                                }
                        }
+               case boolean:
+                       switch o.r {
+                       case eq:
+                               if lim.min == 0 && lim.max == 0 { // constant false
+                                       ft.booleanFalse(o.w)
+                               }
+                               if lim.min == 1 && lim.max == 1 { // constant true
+                                       ft.booleanTrue(o.w)
+                               }
+                       case lt | gt:
+                               if lim.min == 0 && lim.max == 0 { // constant false
+                                       ft.booleanTrue(o.w)
+                               }
+                               if lim.min == 1 && lim.max == 1 { // constant true
+                                       ft.booleanFalse(o.w)
+                               }
+                       }
+               case pointer:
+                       switch o.r {
+                       case eq:
+                               if lim.umax == 0 { // nil
+                                       ft.pointerNil(o.w)
+                               }
+                               if lim.umin > 0 { // non-nil
+                                       ft.pointerNonNil(o.w)
+                               }
+                       case lt | gt:
+                               if lim.umax == 0 { // nil
+                                       ft.pointerNonNil(o.w)
+                               }
+                               // note: not equal to non-nil doesn't tell us anything.
+                       }
                }
        }
 
@@ -647,122 +699,163 @@ func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) {
                        ft.unsat = true
                        return
                }
-       } else {
-               if lessByID(w, v) {
-                       v, w = w, v
-                       r = reverseBits[r]
-               }
-
-               p := pair{v, w, d}
-               oldR, ok := ft.facts[p]
-               if !ok {
-                       if v == w {
-                               oldR = eq
-                       } else {
-                               oldR = lt | eq | gt
-                       }
-               }
-               // No changes compared to information already in facts table.
-               if oldR == r {
-                       return
-               }
-               ft.stack = append(ft.stack, fact{p, oldR})
-               ft.facts[p] = oldR & r
-               // If this relation is not satisfiable, mark it and exit right away
-               if oldR&r == 0 {
-                       if parent.Func.pass.debug > 2 {
-                               parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r)
+       }
+       if d == boolean || d == pointer {
+               for o := ft.orderings[v.ID]; o != nil; o = o.next {
+                       if o.d == d && o.w == w {
+                               // We already know a relationship between v and w.
+                               // Either it is a duplicate, or it is a contradiction,
+                               // as we only allow eq and lt|gt for these domains,
+                               if o.r != r {
+                                       ft.unsat = true
+                               }
+                               return
                        }
-                       ft.unsat = true
-                       return
                }
+               // TODO: this does not do transitive equality.
+               // We could use a poset like above, but somewhat degenerate (==,!= only).
+               ft.addOrdering(v, w, d, r)
+               ft.addOrdering(w, v, d, r) // note: reverseBits unnecessary for eq and lt|gt.
        }
 
        // Extract new constant limits based on the comparison.
-       if d == signed || d == unsigned {
-               vLimit := ft.limits[v.ID]
-               wLimit := ft.limits[w.ID]
-               // Note: all the +1/-1 below could overflow/underflow. Either will
-               // still generate correct results, it will just lead to imprecision.
-               // In fact if there is overflow/underflow, the corresponding
-               // code is unreachable because the known range is outside the range
-               // of the value's type.
-               switch d {
-               case signed:
-                       switch r {
-                       case eq: // v == w
-                               ft.signedMinMax(v, wLimit.min, wLimit.max)
-                               ft.signedMinMax(w, vLimit.min, vLimit.max)
-                       case lt: // v < w
-                               ft.signedMax(v, wLimit.max-1)
-                               ft.signedMin(w, vLimit.min+1)
-                       case lt | eq: // v <= w
-                               ft.signedMax(v, wLimit.max)
-                               ft.signedMin(w, vLimit.min)
-                       case gt: // v > w
-                               ft.signedMin(v, wLimit.min+1)
-                               ft.signedMax(w, vLimit.max-1)
-                       case gt | eq: // v >= w
-                               ft.signedMin(v, wLimit.min)
-                               ft.signedMax(w, vLimit.max)
-                       case lt | gt: // v != w
-                               if vLimit.min == vLimit.max { // v is a constant
-                                       c := vLimit.min
-                                       if wLimit.min == c {
-                                               ft.signedMin(w, c+1)
-                                       }
-                                       if wLimit.max == c {
-                                               ft.signedMax(w, c-1)
-                                       }
+       vLimit := ft.limits[v.ID]
+       wLimit := ft.limits[w.ID]
+       // Note: all the +1/-1 below could overflow/underflow. Either will
+       // still generate correct results, it will just lead to imprecision.
+       // In fact if there is overflow/underflow, the corresponding
+       // code is unreachable because the known range is outside the range
+       // of the value's type.
+       switch d {
+       case signed:
+               switch r {
+               case eq: // v == w
+                       ft.signedMinMax(v, wLimit.min, wLimit.max)
+                       ft.signedMinMax(w, vLimit.min, vLimit.max)
+               case lt: // v < w
+                       ft.signedMax(v, wLimit.max-1)
+                       ft.signedMin(w, vLimit.min+1)
+               case lt | eq: // v <= w
+                       ft.signedMax(v, wLimit.max)
+                       ft.signedMin(w, vLimit.min)
+               case gt: // v > w
+                       ft.signedMin(v, wLimit.min+1)
+                       ft.signedMax(w, vLimit.max-1)
+               case gt | eq: // v >= w
+                       ft.signedMin(v, wLimit.min)
+                       ft.signedMax(w, vLimit.max)
+               case lt | gt: // v != w
+                       if vLimit.min == vLimit.max { // v is a constant
+                               c := vLimit.min
+                               if wLimit.min == c {
+                                       ft.signedMin(w, c+1)
                                }
-                               if wLimit.min == wLimit.max { // w is a constant
-                                       c := wLimit.min
-                                       if vLimit.min == c {
-                                               ft.signedMin(v, c+1)
-                                       }
-                                       if vLimit.max == c {
-                                               ft.signedMax(v, c-1)
-                                       }
+                               if wLimit.max == c {
+                                       ft.signedMax(w, c-1)
                                }
                        }
-               case unsigned:
-                       switch r {
-                       case eq: // v == w
-                               ft.unsignedMinMax(v, wLimit.umin, wLimit.umax)
-                               ft.unsignedMinMax(w, vLimit.umin, vLimit.umax)
-                       case lt: // v < w
-                               ft.unsignedMax(v, wLimit.umax-1)
-                               ft.unsignedMin(w, vLimit.umin+1)
-                       case lt | eq: // v <= w
-                               ft.unsignedMax(v, wLimit.umax)
-                               ft.unsignedMin(w, vLimit.umin)
-                       case gt: // v > w
-                               ft.unsignedMin(v, wLimit.umin+1)
-                               ft.unsignedMax(w, vLimit.umax-1)
-                       case gt | eq: // v >= w
-                               ft.unsignedMin(v, wLimit.umin)
-                               ft.unsignedMax(w, vLimit.umax)
-                       case lt | gt: // v != w
-                               if vLimit.umin == vLimit.umax { // v is a constant
-                                       c := vLimit.umin
-                                       if wLimit.umin == c {
-                                               ft.unsignedMin(w, c+1)
-                                       }
-                                       if wLimit.umax == c {
-                                               ft.unsignedMax(w, c-1)
-                                       }
+                       if wLimit.min == wLimit.max { // w is a constant
+                               c := wLimit.min
+                               if vLimit.min == c {
+                                       ft.signedMin(v, c+1)
                                }
-                               if wLimit.umin == wLimit.umax { // w is a constant
-                                       c := wLimit.umin
-                                       if vLimit.umin == c {
-                                               ft.unsignedMin(v, c+1)
-                                       }
-                                       if vLimit.umax == c {
-                                               ft.unsignedMax(v, c-1)
-                                       }
+                               if vLimit.max == c {
+                                       ft.signedMax(v, c-1)
+                               }
+                       }
+               }
+       case unsigned:
+               switch r {
+               case eq: // v == w
+                       ft.unsignedMinMax(v, wLimit.umin, wLimit.umax)
+                       ft.unsignedMinMax(w, vLimit.umin, vLimit.umax)
+               case lt: // v < w
+                       ft.unsignedMax(v, wLimit.umax-1)
+                       ft.unsignedMin(w, vLimit.umin+1)
+               case lt | eq: // v <= w
+                       ft.unsignedMax(v, wLimit.umax)
+                       ft.unsignedMin(w, vLimit.umin)
+               case gt: // v > w
+                       ft.unsignedMin(v, wLimit.umin+1)
+                       ft.unsignedMax(w, vLimit.umax-1)
+               case gt | eq: // v >= w
+                       ft.unsignedMin(v, wLimit.umin)
+                       ft.unsignedMax(w, vLimit.umax)
+               case lt | gt: // v != w
+                       if vLimit.umin == vLimit.umax { // v is a constant
+                               c := vLimit.umin
+                               if wLimit.umin == c {
+                                       ft.unsignedMin(w, c+1)
+                               }
+                               if wLimit.umax == c {
+                                       ft.unsignedMax(w, c-1)
+                               }
+                       }
+                       if wLimit.umin == wLimit.umax { // w is a constant
+                               c := wLimit.umin
+                               if vLimit.umin == c {
+                                       ft.unsignedMin(v, c+1)
+                               }
+                               if vLimit.umax == c {
+                                       ft.unsignedMax(v, c-1)
                                }
                        }
                }
+       case boolean:
+               switch r {
+               case eq: // v == w
+                       if vLimit.min == 1 { // v is true
+                               ft.booleanTrue(w)
+                       }
+                       if vLimit.max == 0 { // v is false
+                               ft.booleanFalse(w)
+                       }
+                       if wLimit.min == 1 { // w is true
+                               ft.booleanTrue(v)
+                       }
+                       if wLimit.max == 0 { // w is false
+                               ft.booleanFalse(v)
+                       }
+               case lt | gt: // v != w
+                       if vLimit.min == 1 { // v is true
+                               ft.booleanFalse(w)
+                       }
+                       if vLimit.max == 0 { // v is false
+                               ft.booleanTrue(w)
+                       }
+                       if wLimit.min == 1 { // w is true
+                               ft.booleanFalse(v)
+                       }
+                       if wLimit.max == 0 { // w is false
+                               ft.booleanTrue(v)
+                       }
+               }
+       case pointer:
+               switch r {
+               case eq: // v == w
+                       if vLimit.umax == 0 { // v is nil
+                               ft.pointerNil(w)
+                       }
+                       if vLimit.umin > 0 { // v is non-nil
+                               ft.pointerNonNil(w)
+                       }
+                       if wLimit.umax == 0 { // w is nil
+                               ft.pointerNil(v)
+                       }
+                       if wLimit.umin > 0 { // w is non-nil
+                               ft.pointerNonNil(v)
+                       }
+               case lt | gt: // v != w
+                       if vLimit.umax == 0 { // v is nil
+                               ft.pointerNonNil(w)
+                       }
+                       if wLimit.umax == 0 { // w is nil
+                               ft.pointerNonNil(v)
+                       }
+                       // Note: the other direction doesn't work.
+                       // Being not equal to a non-nil pointer doesn't
+                       // make you (necessarily) a nil pointer.
+               }
        }
 
        // Derived facts below here are only about numbers.
@@ -970,7 +1063,6 @@ func (ft *factsTable) checkpoint() {
        if ft.unsat {
                ft.unsatDepth++
        }
-       ft.stack = append(ft.stack, checkpointFact)
        ft.limitStack = append(ft.limitStack, checkpointBound)
        ft.orderS.Checkpoint()
        ft.orderU.Checkpoint()
@@ -986,18 +1078,6 @@ func (ft *factsTable) restore() {
        } else {
                ft.unsat = false
        }
-       for {
-               old := ft.stack[len(ft.stack)-1]
-               ft.stack = ft.stack[:len(ft.stack)-1]
-               if old == checkpointFact {
-                       break
-               }
-               if old.r == lt|eq|gt {
-                       delete(ft.facts, old.p)
-               } else {
-                       ft.facts[old.p] = old.r
-               }
-       }
        for {
                old := ft.limitStack[len(ft.limitStack)-1]
                ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
@@ -1050,12 +1130,14 @@ var (
                OpEq32:  {signed | unsigned, eq},
                OpEq64:  {signed | unsigned, eq},
                OpEqPtr: {pointer, eq},
+               OpEqB:   {boolean, eq},
 
                OpNeq8:   {signed | unsigned, lt | gt},
                OpNeq16:  {signed | unsigned, lt | gt},
                OpNeq32:  {signed | unsigned, lt | gt},
                OpNeq64:  {signed | unsigned, lt | gt},
                OpNeqPtr: {pointer, lt | gt},
+               OpNeqB:   {boolean, lt | gt},
 
                OpLess8:   {signed, lt},
                OpLess8U:  {unsigned, lt},
@@ -1407,8 +1489,28 @@ func prove(f *Func) {
 // flowLimit, below, which computes additional constraints based on
 // ranges of opcode arguments).
 func initLimit(v *Value) limit {
+       if v.Type.IsBoolean() {
+               switch v.Op {
+               case OpConstBool:
+                       b := v.AuxInt
+                       return limit{min: b, max: b, umin: uint64(b), umax: uint64(b)}
+               default:
+                       return limit{min: 0, max: 1, umin: 0, umax: 1}
+               }
+       }
+       if v.Type.IsPtrShaped() { // These are the types that EqPtr/NeqPtr operate on, except uintptr.
+               switch v.Op {
+               case OpConstNil:
+                       return limit{min: 0, max: 0, umin: 0, umax: 0}
+               case OpAddr, OpLocalAddr: // TODO: others?
+                       l := noLimit
+                       l.umin = 1
+                       return l
+               default:
+                       return noLimit
+               }
+       }
        if !v.Type.IsInteger() {
-               // TODO: boolean?
                return noLimit
        }
 
@@ -1700,9 +1802,9 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
        c := b.Controls[0]
        switch {
        case br == negative:
-               addRestrictions(b, ft, boolean, nil, c, eq)
+               ft.booleanFalse(c)
        case br == positive:
-               addRestrictions(b, ft, boolean, nil, c, lt|gt)
+               ft.booleanTrue(c)
        case br >= jumpTable0:
                idx := br - jumpTable0
                val := int64(idx)
@@ -1769,7 +1871,14 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
                                addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r)
                        }
                }
-
+       }
+       if c.Op == OpIsNonNil {
+               switch br {
+               case positive:
+                       ft.pointerNonNil(c.Args[0])
+               case negative:
+                       ft.pointerNil(c.Args[0])
+               }
        }
 }
 
@@ -1984,7 +2093,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
                // Helps in cases where we reuse a value after branching on its equality.
                for i, arg := range v.Args {
                        switch arg.Op {
-                       case OpConst64, OpConst32, OpConst16, OpConst8:
+                       case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool, OpConstNil:
                                continue
                        }
                        lim := ft.limits[arg.ID]
index e9205dcc2355fb8f6dbb064c7547dfde4dc4c91e..9366b218587f76b90056bc7ec9cc621f23844a88 100644 (file)
@@ -148,11 +148,11 @@ func fEqInterEqInter(a interface{}, f float64) bool {
 }
 
 func fEqInterNeqInter(a interface{}, f float64) bool {
-       return a == nil && f > Cf2 || a != nil && f < -Cf2
+       return a == nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
 }
 
 func fNeqInterEqInter(a interface{}, f float64) bool {
-       return a != nil && f > Cf2 || a == nil && f < -Cf2
+       return a != nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
 }
 
 func fNeqInterNeqInter(a interface{}, f float64) bool {
@@ -164,11 +164,11 @@ func fEqSliceEqSlice(a []int, f float64) bool {
 }
 
 func fEqSliceNeqSlice(a []int, f float64) bool {
-       return a == nil && f > Cf2 || a != nil && f < -Cf2
+       return a == nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
 }
 
 func fNeqSliceEqSlice(a []int, f float64) bool {
-       return a != nil && f > Cf2 || a == nil && f < -Cf2
+       return a != nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
 }
 
 func fNeqSliceNeqSlice(a []int, f float64) bool {
index 6cb30c6ce16ceed8285f31dc0b81ab344e3cac00..b85ee5fe0d4d79f71e327bdf9c5ebf4e8e508ae3 100644 (file)
@@ -1159,6 +1159,28 @@ func issue66826b(a [31]byte, i int) {
        _ = a[3*i] // ERROR "Proved IsInBounds"
 }
 
+func f20(a, b bool) int {
+       if a == b {
+               if a {
+                       if b { // ERROR "Proved Arg"
+                               return 1
+                       }
+               }
+       }
+       return 0
+}
+
+func f21(a, b *int) int {
+       if a == b {
+               if a != nil {
+                       if b != nil { // ERROR "Proved IsNonNil"
+                               return 1
+                       }
+               }
+       }
+       return 0
+}
+
 //go:noinline
 func useInt(a int) {
 }