]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: propagate constant ranges through multiplies and shifts
authorkhr@golang.org <khr@golang.org>
Thu, 27 Jun 2024 22:53:24 +0000 (15:53 -0700)
committerKeith Randall <khr@golang.org>
Wed, 7 Aug 2024 16:07:42 +0000 (16:07 +0000)
Fixes #40704
Fixes #66826

Change-Id: Ia9c356e29b2ed6f2e3bc6e5eb9304cd4dccb4263
Reviewed-on: https://go-review.googlesource.com/c/go/+/599256
Reviewed-by: Michael Knyszek <mknyszek@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
src/cmd/compile/internal/ssa/prove.go
test/prove.go

index df3566985fa25256a41228bc068ef4588067860b..6091950be8bf339c6b19f6261b7998e7443c6249 100644 (file)
@@ -8,6 +8,7 @@ import (
        "cmd/internal/src"
        "fmt"
        "math"
+       "math/bits"
 )
 
 type branch int
@@ -311,6 +312,40 @@ func (l limit) sub(l2 limit, b uint) limit {
        return r
 }
 
+// same as add but for multiplication.
+func (l limit) mul(l2 limit, b uint) limit {
+       r := noLimit
+       umaxhi, umaxlo := bits.Mul64(l.umax, l2.umax)
+       if umaxhi == 0 && fitsInBitsU(umaxlo, b) {
+               r.umax = umaxlo
+               r.umin = l.umin * l2.umin
+               // Note: if the code containing this multiply is
+               // unreachable, then we may have umin>umax, and this
+               // multiply may overflow.  But that's ok for
+               // unreachable code. If this code is reachable, we
+               // know umin<=umax, so this multiply will not overflow
+               // because the max multiply didn't.
+       }
+       // Signed is harder, so don't bother. The only useful
+       // case is when we know both multiplicands are nonnegative,
+       // but that case is handled above because we would have then
+       // previously propagated signed info to the unsigned domain,
+       // and will propagate it back after the multiply.
+       return r
+}
+
+// Similar to add, but compute 1 << l if it fits without overflow in b bits.
+func (l limit) exp2(b uint) limit {
+       r := noLimit
+       if l.umax < uint64(b) {
+               r.umin = 1 << l.umin
+               r.umax = 1 << l.umax
+               // Same as above in mul, signed<->unsigned propagation
+               // will handle the signed case for us.
+       }
+       return r
+}
+
 var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
 
 // a limitFact is a limit known for a particular value.
@@ -1548,6 +1583,39 @@ func (ft *factsTable) flowLimit(v *Value) bool {
                a := ft.limits[v.Args[0].ID]
                b := ft.limits[v.Args[1].ID]
                return ft.newLimit(v, a.sub(b, 8))
+       case OpMul64:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b, 64))
+       case OpMul32:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b, 32))
+       case OpMul16:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b, 16))
+       case OpMul8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b, 8))
+       case OpLsh64x64, OpLsh64x32, OpLsh64x16, OpLsh64x8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b.exp2(64), 64))
+       case OpLsh32x64, OpLsh32x32, OpLsh32x16, OpLsh32x8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b.exp2(32), 32))
+       case OpLsh16x64, OpLsh16x32, OpLsh16x16, OpLsh16x8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b.exp2(16), 16))
+       case OpLsh8x64, OpLsh8x32, OpLsh8x16, OpLsh8x8:
+               a := ft.limits[v.Args[0].ID]
+               b := ft.limits[v.Args[1].ID]
+               return ft.newLimit(v, a.mul(b.exp2(8), 8))
+
        case OpPhi:
                // Compute the union of all the input phis.
                // Often this will convey no information, because the block
@@ -1834,7 +1902,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
                        OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64:
                        // Check whether, for a >> b, we know that a is non-negative
                        // and b is all of a's bits except the MSB. If so, a is shifted to zero.
-                       bits := 8 * v.Type.Size()
+                       bits := 8 * v.Args[0].Type.Size()
                        if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) {
                                if b.Func.pass.debug > 0 {
                                        b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op)
index 70466ce2c5132c3bf6cac60bc6c3f6aca767762b..6cb30c6ce16ceed8285f31dc0b81ab344e3cac00 100644 (file)
@@ -1147,6 +1147,18 @@ func inequalityPropagation(a [1]int, i, j uint) int {
        return 0
 }
 
+func issue66826a(a [21]byte) {
+       for i := 0; i <= 10; i++ { // ERROR "Induction variable: limits \[0,10\], increment 1$"
+               _ = a[2*i] // ERROR "Proved IsInBounds"
+       }
+}
+func issue66826b(a [31]byte, i int) {
+       if i < 0 || i > 10 {
+               return
+       }
+       _ = a[3*i] // ERROR "Proved IsInBounds"
+}
+
 //go:noinline
 func useInt(a int) {
 }