]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: reorganize prove pass domain relation table
authorkhr@golang.org <khr@golang.org>
Thu, 11 Jul 2024 19:10:10 +0000 (12:10 -0700)
committerKeith Randall <khr@golang.org>
Wed, 7 Aug 2024 16:08:03 +0000 (16:08 +0000)
Move some code from when we learn that we take a branch, to when
we learn that a boolean is true or false. It is more consistent
this way (and may lead to a few more cases where we can derive
useful relations).

Change-Id: Iea7b2d6740e10c9c71c4b1546881f501da81cd21
Reviewed-on: https://go-review.googlesource.com/c/go/+/599098
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/cmd/compile/internal/ssa/prove.go
test/prove.go

index c8d2ab7a6f6893db16b6ea02c4aac6a3d0c4977c..0bf4e32729dca71529039a97f7e77260a80b745b 100644 (file)
@@ -629,6 +629,67 @@ func (ft *factsTable) newLimit(v *Value, newLim limit) bool {
                }
        }
 
+       // If this is new known constant for a boolean value,
+       // extract relation between its args. For example, if
+       // We learn v is false, and v is defined as a<b, then we learn a>=b.
+       if v.Type.IsBoolean() {
+               // If we reach here, is is because we have a more restrictive
+               // value for v than the default. The only two such values
+               // are constant true or constant false.
+               if lim.min != lim.max {
+                       v.Block.Func.Fatalf("boolean not constant %v", v)
+               }
+               isTrue := lim.min == 1
+               if dr, ok := domainRelationTable[v.Op]; ok && v.Op != OpIsInBounds && v.Op != OpIsSliceInBounds {
+                       d := dr.d
+                       r := dr.r
+                       if d == signed && ft.isNonNegative(v.Args[0]) && ft.isNonNegative(v.Args[1]) {
+                               d |= unsigned
+                       }
+                       if !isTrue {
+                               r ^= (lt | gt | eq)
+                       }
+                       // TODO: v.Block is wrong?
+                       addRestrictions(v.Block, ft, d, v.Args[0], v.Args[1], r)
+               }
+               switch v.Op {
+               case OpIsNonNil:
+                       if isTrue {
+                               ft.pointerNonNil(v.Args[0])
+                       } else {
+                               ft.pointerNil(v.Args[0])
+                       }
+               case OpIsInBounds, OpIsSliceInBounds:
+                       // 0 <= a0 < a1 (or 0 <= a0 <= a1)
+                       r := lt
+                       if v.Op == OpIsSliceInBounds {
+                               r |= eq
+                       }
+                       if isTrue {
+                               // On the positive branch, we learn:
+                               //   signed: 0 <= a0 < a1 (or 0 <= a0 <= a1)
+                               //   unsigned:    a0 < a1 (or a0 <= a1)
+                               ft.setNonNegative(v.Args[0])
+                               ft.update(v.Block, v.Args[0], v.Args[1], signed, r)
+                               ft.update(v.Block, v.Args[0], v.Args[1], unsigned, r)
+                       } else {
+                               // On the negative branch, we learn (0 > a0 ||
+                               // a0 >= a1). In the unsigned domain, this is
+                               // simply a0 >= a1 (which is the reverse of the
+                               // positive branch, so nothing surprising).
+                               // But in the signed domain, we can't express the ||
+                               // condition, so check if a0 is non-negative instead,
+                               // to be able to learn something.
+                               r ^= (lt | gt | eq) // >= (index) or > (slice)
+                               if ft.isNonNegative(v.Args[0]) {
+                                       ft.update(v.Block, v.Args[0], v.Args[1], signed, r)
+                               }
+                               ft.update(v.Block, v.Args[0], v.Args[1], unsigned, r)
+                               // TODO: v.Block is wrong here
+                       }
+               }
+       }
+
        return true
 }
 
@@ -1119,8 +1180,8 @@ var (
        // For example:
        //      OpLess8:   {signed, lt},
        //      v1 = (OpLess8 v2 v3).
-       // If v1 branch is taken then we learn that the rangeMask
-       // can be at most lt.
+       // If we learn that v1 is true, then we can deduce that v2<v3
+       // in the signed domain.
        domainRelationTable = map[Op]struct {
                d domain
                r relation
@@ -1156,12 +1217,6 @@ var (
                OpLeq32U: {unsigned, lt | eq},
                OpLeq64:  {signed, lt | eq},
                OpLeq64U: {unsigned, lt | eq},
-
-               // For these ops, the negative branch is different: we can only
-               // prove signed/GE (signed/GT) if we can prove that arg0 is non-negative.
-               // See the special case in addBranchRestrictions.
-               OpIsInBounds:      {signed | unsigned, lt},      // 0 <= arg0 < arg1
-               OpIsSliceInBounds: {signed | unsigned, lt | eq}, // 0 <= arg0 <= arg1
        }
 )
 
@@ -1830,56 +1885,6 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
        default:
                panic("unknown branch")
        }
-       if tr, has := domainRelationTable[c.Op]; has {
-               // When we branched from parent we learned a new set of
-               // restrictions. Update the factsTable accordingly.
-               d := tr.d
-               if d == signed && ft.isNonNegative(c.Args[0]) && ft.isNonNegative(c.Args[1]) {
-                       d |= unsigned
-               }
-               switch c.Op {
-               case OpIsInBounds, OpIsSliceInBounds:
-                       // 0 <= a0 < a1 (or 0 <= a0 <= a1)
-                       //
-                       // On the positive branch, we learn:
-                       //   signed: 0 <= a0 < a1 (or 0 <= a0 <= a1)
-                       //   unsigned:    a0 < a1 (or a0 <= a1)
-                       //
-                       // On the negative branch, we learn (0 > a0 ||
-                       // a0 >= a1). In the unsigned domain, this is
-                       // simply a0 >= a1 (which is the reverse of the
-                       // positive branch, so nothing surprising).
-                       // But in the signed domain, we can't express the ||
-                       // condition, so check if a0 is non-negative instead,
-                       // to be able to learn something.
-                       switch br {
-                       case negative:
-                               d = unsigned
-                               if ft.isNonNegative(c.Args[0]) {
-                                       d |= signed
-                               }
-                               addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r^(lt|gt|eq))
-                       case positive:
-                               ft.setNonNegative(c.Args[0])
-                               addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r)
-                       }
-               default:
-                       switch br {
-                       case negative:
-                               addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r^(lt|gt|eq))
-                       case positive:
-                               addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r)
-                       }
-               }
-       }
-       if c.Op == OpIsNonNil {
-               switch br {
-               case positive:
-                       ft.pointerNonNil(c.Args[0])
-               case negative:
-                       ft.pointerNil(c.Args[0])
-               }
-       }
 }
 
 // addRestrictions updates restrictions from the immediate
index b85ee5fe0d4d79f71e327bdf9c5ebf4e8e508ae3..32096eafff7ca565187806f30ae248cc725d6fe9 100644 (file)
@@ -1181,6 +1181,18 @@ func f21(a, b *int) int {
        return 0
 }
 
+func f22(b bool, x, y int) int {
+       b2 := x < y
+       if b == b2 {
+               if b {
+                       if x >= y { // ERROR "Disproved Leq64$"
+                               return 1
+                       }
+               }
+       }
+       return 0
+}
+
 //go:noinline
 func useInt(a int) {
 }