]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile/internal/ssa: model right shift more precisely
authorRuss Cox <rsc@golang.org>
Wed, 29 Oct 2025 11:27:38 +0000 (07:27 -0400)
committerGopher Robot <gobot@golang.org>
Thu, 30 Oct 2025 16:17:59 +0000 (09:17 -0700)
Prove currently checks for 0 sign bit extraction (x>>63) at the
end of the pass, but it is more general and more useful
(and not really more work) to model right shift during
value range tracking. This handles sign bit extraction (both 0 and -1)
but also makes the value ranges available for proving bounds checks.

'go build -a -gcflags=-d=ssa/prove/debug=1 std'
finds 105 new things to prove.
https://gist.github.com/rsc/8ac41176e53ed9c2f1a664fc668e8336

For example, the compiler now recognizes that this code in
strconv does not need to check the second shift for being ≥ 64.

msb := xHi >> 63
retMantissa := xHi >> (msb + 38)

nor does this code in regexp:

return b < utf8.RuneSelf && specialBytes[b%16]&(1<<(b/16)) != 0

This code in math no longer has a bounds check on the first index:

if 0 <= n && n <= 308 {
return pow10postab32[uint(n)/32] * pow10tab[uint(n)%32]
}

The diff shows one "lost" proof in ycbcr.go but it's not really lost:
the expression was folded to a constant instead, and that only shows
up with debug=2. A diff of that output is at
https://gist.github.com/rsc/9139ed46c6019ae007f5a1ba4bb3250f

Change-Id: I84087311e0a303f00e2820d957a6f8b29ee22519
Reviewed-on: https://go-review.googlesource.com/c/go/+/716140
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Russ Cox <rsc@golang.org>
Reviewed-by: David Chase <drchase@google.com>
src/cmd/compile/internal/ssa/prove.go
test/prove.go
test/prove_constant_folding.go

index 086e5b3a8f5fc9144ce5f2fae0d210a2ba29fc95..4919d6ad3702866cea0d33391091f1382a9ae9ed 100644 (file)
@@ -12,6 +12,7 @@ import (
        "math"
        "math/bits"
        "slices"
+       "strings"
 )
 
 type branch int
@@ -132,7 +133,7 @@ type limit struct {
 }
 
 func (l limit) String() string {
-       return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax)
+       return fmt.Sprintf("sm,SM=%d,%d um,UM=%d,%d", l.min, l.max, l.umin, l.umax)
 }
 
 func (l limit) intersect(l2 limit) limit {
@@ -1965,6 +1966,30 @@ func (ft *factsTable) flowLimit(v *Value) bool {
                b := ft.limits[v.Args[1].ID]
                bitsize := uint(v.Type.Size()) * 8
                return ft.newLimit(v, a.mul(b.exp2(bitsize), bitsize))
+       case OpRsh64x64, OpRsh64x32, OpRsh64x16, OpRsh64x8,
+               OpRsh32x64, OpRsh32x32, OpRsh32x16, OpRsh32x8,
+               OpRsh16x64, OpRsh16x32, OpRsh16x16, OpRsh16x8,
+               OpRsh8x64, OpRsh8x32, OpRsh8x16, OpRsh8x8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               if b.min >= 0 {
+                       // Shift of negative makes a value closer to 0 (greater),
+                       // so if a.min is negative, v.min is a.min>>b.min instead of a.min>>b.max,
+                       // and similarly if a.max is negative, v.max is a.max>>b.max.
+                       // Easier to compute min and max of both than to write sign logic.
+                       vmin := min(a.min>>b.min, a.min>>b.max)
+                       vmax := max(a.max>>b.min, a.max>>b.max)
+                       return ft.signedMinMax(v, vmin, vmax)
+               }
+       case OpRsh64Ux64, OpRsh64Ux32, OpRsh64Ux16, OpRsh64Ux8,
+               OpRsh32Ux64, OpRsh32Ux32, OpRsh32Ux16, OpRsh32Ux8,
+               OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8,
+               OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               if b.min >= 0 {
+                       return ft.unsignedMinMax(v, a.umin>>b.max, a.umax>>b.min)
+               }
        case OpDiv64, OpDiv32, OpDiv16, OpDiv8:
                a := ft.limits[v.Args[0].ID]
                b := ft.limits[v.Args[1].ID]
@@ -2621,6 +2646,17 @@ var bytesizeToAnd = [...]Op{
 func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
        for iv, v := range b.Values {
                switch v.Op {
+               case OpStaticLECall:
+                       if b.Func.pass.debug > 0 && len(v.Args) == 2 {
+                               fn := auxToCall(v.Aux).Fn
+                               if fn != nil && strings.Contains(fn.String(), "prove") {
+                                       // Print bounds of any argument to single-arg function with "prove" in name,
+                                       // for debugging and especially for test/prove.go.
+                                       // (v.Args[1] is mem).
+                                       x := v.Args[0]
+                                       b.Func.Warnl(v.Pos, "Proved %v (%v)", ft.limits[x.ID], x)
+                               }
+                       }
                case OpSlicemask:
                        // Replace OpSlicemask operations in b with constants where possible.
                        cap := v.Args[0]
@@ -2670,21 +2706,8 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
                case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64,
                        OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64,
                        OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64,
-                       OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64:
-                       // Check whether, for a >> b, we know that a is non-negative
-                       // and b is all of a's bits except the MSB. If so, a is shifted to zero.
-                       bits := 8 * v.Args[0].Type.Size()
-                       if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) {
-                               if b.Func.pass.debug > 0 {
-                                       b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op)
-                               }
-                               v.reset(bytesizeToConst[bits/8])
-                               v.AuxInt = 0
-                               break // Be sure not to fallthrough - this is no longer OpRsh.
-                       }
-                       // If the Rsh hasn't been replaced with 0, still check if it is bounded.
-                       fallthrough
-               case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64,
+                       OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64,
+                       OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64,
                        OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64,
                        OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64,
                        OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64,
index db32d1beb0d48fe2e14b5821dc9722bf1e01ed7e..365e8ba006ee4af184d67ad76502b051accf592a 100644 (file)
@@ -971,40 +971,6 @@ func negIndex2(n int) {
        useSlice(c)
 }
 
-// Check that prove is zeroing these right shifts of positive ints by bit-width - 1.
-// e.g (Rsh64x64 <t> n (Const64 <typ.UInt64> [63])) && ft.isNonNegative(n) -> 0
-func sh64(n int64) int64 {
-       if n < 0 {
-               return n
-       }
-       return n >> 63 // ERROR "Proved Rsh64x64 shifts to zero"
-}
-
-func sh32(n int32) int32 {
-       if n < 0 {
-               return n
-       }
-       return n >> 31 // ERROR "Proved Rsh32x64 shifts to zero"
-}
-
-func sh32x64(n int32) int32 {
-       if n < 0 {
-               return n
-       }
-       return n >> uint64(31) // ERROR "Proved Rsh32x64 shifts to zero"
-}
-
-func sh16(n int16) int16 {
-       if n < 0 {
-               return n
-       }
-       return n >> 15 // ERROR "Proved Rsh16x64 shifts to zero"
-}
-
-func sh64noopt(n int64) int64 {
-       return n >> 63 // not optimized; n could be negative
-}
-
 // These cases are division of a positive signed integer by a power of 2.
 // The opt pass doesnt have sufficient information to see that n is positive.
 // So, instead, opt rewrites the division with a less-than-optimal replacement.
@@ -2584,6 +2550,103 @@ func swapbound(v []int) {
        }
 }
 
+func rightshift(v *[256]int) int {
+       for i := range 1024 { // ERROR "Induction"
+               if v[i/32] == 0 { // ERROR "Proved Div64 is unsigned" "Proved IsInBounds"
+                       return i
+               }
+       }
+       for i := range 1024 { // ERROR "Induction"
+               if v[i>>2] == 0 { // ERROR "Proved IsInBounds"
+                       return i
+               }
+       }
+       return -1
+}
+
+func rightShiftBounds(v, s int) {
+       // The ignored "Proved" messages on the shift itself are about whether s >= 0 or s < 32 or 64.
+       // We care about the bounds for x printed on the prove(x) lines.
+
+       if -8 <= v && v <= -2 && 1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-4,-1 "
+       }
+       if -80 <= v && v <= -20 && 1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-40,-3 "
+       }
+       if -8 <= v && v <= 10 && 1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-4,5 "
+       }
+       if 2 <= v && v <= 10 && 1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=0,5 "
+       }
+
+       if -8 <= v && v <= -2 && 0 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-8,-1 "
+       }
+       if -80 <= v && v <= -20 && 0 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-80,-3 "
+       }
+       if -8 <= v && v <= 10 && 0 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-8,10 "
+       }
+       if 2 <= v && v <= 10 && 0 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=0,10 "
+       }
+
+       if -8 <= v && v <= -2 && -1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-8,-1 "
+       }
+       if -80 <= v && v <= -20 && -1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-80,-3 "
+       }
+       if -8 <= v && v <= 10 && -1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=-8,10 "
+       }
+       if 2 <= v && v <= 10 && -1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               prove(x) // ERROR "Proved sm,SM=0,10 "
+       }
+}
+
+func unsignedRightShiftBounds(v uint, s int) {
+       if 2 <= v && v <= 10 && -1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               proveu(x) // ERROR "Proved sm,SM=0,10 "
+       }
+       if 2 <= v && v <= 10 && 0 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               proveu(x) // ERROR "Proved sm,SM=0,10 "
+       }
+       if 2 <= v && v <= 10 && 1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               proveu(x) // ERROR "Proved sm,SM=0,5 "
+       }
+       if 20 <= v && v <= 100 && 1 <= s && s <= 3 {
+               x := v>>s // ERROR "Proved"
+               proveu(x) // ERROR "Proved sm,SM=2,50 "
+       }
+}
+
+//go:noinline
+func prove(x int) {
+}
+
+//go:noinline
+func proveu(x uint) {
+}
+
 //go:noinline
 func useInt(a int) {
 }
index 1029c8e2d3a0d0cfc5a2cbf242677085104eece8..46764f9b9d94a13bf0ac8b4d35d924d6a5482b26 100644 (file)
@@ -20,14 +20,62 @@ func f0i(x int) int {
        return x + 1
 }
 
-func f0u(x uint) uint {
+func f0u(x uint) int {
        if x == 20 {
-               return x // ERROR "Proved.+is constant 20$"
+               return int(x) // ERROR "Proved.+is constant 20$"
        }
 
        if (x + 20) == 20 {
-               return x + 5 // ERROR "Proved.+is constant 0$" "Proved.+is constant 5$" "x\+d >=? w"
+               return int(x + 5) // ERROR "Proved.+is constant 0$" "Proved.+is constant 5$" "x\+d >=? w"
        }
 
-       return x + 1
+       if x < 1000 {
+               return int(x)>>31 // ERROR "Proved.+is constant 0$"
+       }
+       if x := int32(x); x < -1000 {
+               return int(x>>31) // ERROR "Proved.+is constant -1$"
+       }
+
+       return int(x) + 1
+}
+
+// Check that prove is zeroing these right shifts of positive ints by bit-width - 1.
+// e.g (Rsh64x64 <t> n (Const64 <typ.UInt64> [63])) && ft.isNonNegative(n) -> 0
+func sh64(n int64) int64 {
+       if n < 0 {
+               return n
+       }
+       return n >> 63 // ERROR "Proved .+ is constant 0$"
+}
+
+func sh32(n int32) int32 {
+       if n < 0 {
+               return n
+       }
+       return n >> 31 // ERROR "Proved .+ is constant 0$"
+}
+
+func sh32x64(n int32) int32 {
+       if n < 0 {
+               return n
+       }
+       return n >> uint64(31) // ERROR "Proved .+ is constant 0$"
+}
+
+func sh32x64n(n int32) int32 {
+       if n >= 0 {
+               return 0
+       }
+       return n >> 31// ERROR "Proved .+ is constant -1$"
+}
+
+func sh16(n int16) int16 {
+       if n < 0 {
+               return n
+       }
+       return n >> 15 // ERROR "Proved .+ is constant 0$"
+}
+
+func sh64noopt(n int64) int64 {
+       return n >> 63 // not optimized; n could be negative
 }