]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: learn transitive proofs for safe unsigned adds
authorJorropo <jorropo.pgm@gmail.com>
Fri, 4 Jul 2025 07:21:03 +0000 (09:21 +0200)
committerGopher Robot <gobot@golang.org>
Thu, 24 Jul 2025 20:48:55 +0000 (13:48 -0700)
I've split this into it's own CL to make git bisect more effective.

Change-Id: Iaab5f0bd2ad51e86ced8c6b8fbd371eb75eeef14
Reviewed-on: https://go-review.googlesource.com/c/go/+/685815
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Mark Freeman <mark@golang.org>
src/cmd/compile/internal/ssa/prove.go
test/prove.go

index 93bd525c387b96fa72b56f18e5978748f1d15cd7..8fe1bb7050b1f5b15273fe0372229ec8bdb00d4e 100644 (file)
@@ -5,6 +5,7 @@
 package ssa
 
 import (
+       "cmd/compile/internal/types"
        "cmd/internal/src"
        "fmt"
        "math"
@@ -2132,6 +2133,21 @@ func addRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r rel
        }
 }
 
+func unsignedAddOverflows(a, b uint64, t *types.Type) bool {
+       switch t.Size() {
+       case 8:
+               return a+b < a
+       case 4:
+               return a+b > math.MaxUint32
+       case 2:
+               return a+b > math.MaxUint16
+       case 1:
+               return a+b > math.MaxUint8
+       default:
+               panic("unreachable")
+       }
+}
+
 func addLocalFacts(ft *factsTable, b *Block) {
        // Propagate constant ranges among values in this block.
        // We do this before the second loop so that we have the
@@ -2151,6 +2167,21 @@ func addLocalFacts(ft *factsTable, b *Block) {
                // FIXME(go.dev/issue/68857): this loop only set up limits properly when b.Values is in topological order.
                // flowLimit can also depend on limits given by this loop which right now is not handled.
                switch v.Op {
+               case OpAdd64, OpAdd32, OpAdd16, OpAdd8:
+                       x := ft.limits[v.Args[0].ID]
+                       y := ft.limits[v.Args[1].ID]
+                       if !unsignedAddOverflows(x.umax, y.umax, v.Type) {
+                               r := gt
+                               if !x.nonzero() {
+                                       r |= eq
+                               }
+                               ft.update(b, v, v.Args[1], unsigned, r)
+                               r = gt
+                               if !y.nonzero() {
+                                       r |= eq
+                               }
+                               ft.update(b, v, v.Args[0], unsigned, r)
+                       }
                case OpAnd64, OpAnd32, OpAnd16, OpAnd8:
                        ft.update(b, v, v.Args[0], unsigned, lt|eq)
                        ft.update(b, v, v.Args[1], unsigned, lt|eq)
index faf0b79237e091dc842ccbb05f8167754fb9087e..e843edcbf0a29fc1d1a3b3bf77f8f4c929d861ec 100644 (file)
@@ -2041,6 +2041,69 @@ func cvtBoolToUint8BCE(b bool, a [2]int64) int64 {
        return a[c] // ERROR "Proved IsInBounds$"
 }
 
+func transitiveProofsThroughNonOverflowingUnsignedAdd(x, y, z uint64) {
+       x &= 1<<63 - 1
+       y &= 1<<63 - 1
+
+       a := x + y
+       if a > z {
+               return
+       }
+
+       if x > z { // ERROR "Disproved Less64U$"
+               return
+       }
+       if y > z { // ERROR "Disproved Less64U$"
+               return
+       }
+       if a == x {
+               return
+       }
+       if a == y {
+               return
+       }
+
+       x |= 1
+       y |= 1
+       a = x + y
+       if a == x { // ERROR "Disproved Eq64$"
+               return
+       }
+       if a == y { // ERROR "Disproved Eq64$"
+               return
+       }
+}
+
+func transitiveProofsThroughOverflowingUnsignedAdd(x, y, z uint64) {
+       a := x + y
+       if a > z {
+               return
+       }
+
+       if x > z {
+               return
+       }
+       if y > z {
+               return
+       }
+       if a == x {
+               return
+       }
+       if a == y {
+               return
+       }
+
+       x |= 1
+       y |= 1
+       a = x + y
+       if a == x {
+               return
+       }
+       if a == y {
+               return
+       }
+}
+
 //go:noinline
 func useInt(a int) {
 }