]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: derive bounds on signed %N for N a power of 2
authorKeith Randall <khr@golang.org>
Fri, 25 Apr 2025 01:37:49 +0000 (18:37 -0700)
committerKeith Randall <khr@golang.org>
Mon, 19 May 2025 22:21:54 +0000 (15:21 -0700)
-N+1 <= x % N <= N-1

This is useful for cases like:

func setBit(b []byte, i int) {
    b[i/8] |= 1<<(i%8)
}

The shift does not need protection against larger-than-7 cases.
(It does still need protection against <0 cases.)

Change-Id: Idf83101386af538548bfeb6e2928cea855610ce2
Reviewed-on: https://go-review.googlesource.com/c/go/+/672995
Reviewed-by: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/cmd/compile/internal/ssa/prove.go
test/codegen/shift.go

index 94f23a84aa0dbdddf0fda8b992e38b99d1e0e2b1..5617edb21f15ed101a913e26ba4651eb5d5b263f 100644 (file)
@@ -1757,7 +1757,9 @@ func (ft *factsTable) flowLimit(v *Value) bool {
        case OpSub64, OpSub32, OpSub16, OpSub8:
                a := ft.limits[v.Args[0].ID]
                b := ft.limits[v.Args[1].ID]
-               return ft.newLimit(v, a.sub(b, uint(v.Type.Size())*8))
+               sub := ft.newLimit(v, a.sub(b, uint(v.Type.Size())*8))
+               mod := ft.detectSignedMod(v)
+               return sub || mod
        case OpNeg64, OpNeg32, OpNeg16, OpNeg8:
                a := ft.limits[v.Args[0].ID]
                bitsize := uint(v.Type.Size()) * 8
@@ -1913,6 +1915,122 @@ func (ft *factsTable) flowLimit(v *Value) bool {
        return false
 }
 
+// See if we can get any facts because v is the result of signed mod by a constant.
+// The mod operation has already been rewritten, so we have to try and reconstruct it.
+//   x % d
+// is rewritten as
+//   x - (x / d) * d
+// furthermore, the divide itself gets rewritten. If d is a power of 2 (d == 1<<k), we do
+//   (x / d) * d = ((x + adj) >> k) << k
+//               = (x + adj) & (-1<<k)
+// with adj being an adjustment in case x is negative (see below).
+// if d is not a power of 2, we do
+//   x / d = ... TODO ...
+func (ft *factsTable) detectSignedMod(v *Value) bool {
+       if ft.detectSignedModByPowerOfTwo(v) {
+               return true
+       }
+       // TODO: non-powers-of-2
+       return false
+}
+func (ft *factsTable) detectSignedModByPowerOfTwo(v *Value) bool {
+       // We're looking for:
+       //
+       //   x % d ==
+       //   x - (x / d) * d
+       //
+       // which for d a power of 2, d == 1<<k, is done as
+       //
+       //   x - ((x + (x>>(w-1))>>>(w-k)) & (-1<<k))
+       //
+       // w = bit width of x.
+       // (>> = signed shift, >>> = unsigned shift).
+       // See ./_gen/generic.rules, search for "Signed divide by power of 2".
+
+       var w int64
+       var addOp, andOp, constOp, sshiftOp, ushiftOp Op
+       switch v.Op {
+       case OpSub64:
+               w = 64
+               addOp = OpAdd64
+               andOp = OpAnd64
+               constOp = OpConst64
+               sshiftOp = OpRsh64x64
+               ushiftOp = OpRsh64Ux64
+       case OpSub32:
+               w = 32
+               addOp = OpAdd32
+               andOp = OpAnd32
+               constOp = OpConst32
+               sshiftOp = OpRsh32x64
+               ushiftOp = OpRsh32Ux64
+       case OpSub16:
+               w = 16
+               addOp = OpAdd16
+               andOp = OpAnd16
+               constOp = OpConst16
+               sshiftOp = OpRsh16x64
+               ushiftOp = OpRsh16Ux64
+       case OpSub8:
+               w = 8
+               addOp = OpAdd8
+               andOp = OpAnd8
+               constOp = OpConst8
+               sshiftOp = OpRsh8x64
+               ushiftOp = OpRsh8Ux64
+       default:
+               return false
+       }
+
+       x := v.Args[0]
+       and := v.Args[1]
+       if and.Op != andOp {
+               return false
+       }
+       var add, mask *Value
+       if and.Args[0].Op == addOp && and.Args[1].Op == constOp {
+               add = and.Args[0]
+               mask = and.Args[1]
+       } else if and.Args[1].Op == addOp && and.Args[0].Op == constOp {
+               add = and.Args[1]
+               mask = and.Args[0]
+       } else {
+               return false
+       }
+       var ushift *Value
+       if add.Args[0] == x {
+               ushift = add.Args[1]
+       } else if add.Args[1] == x {
+               ushift = add.Args[0]
+       } else {
+               return false
+       }
+       if ushift.Op != ushiftOp {
+               return false
+       }
+       if ushift.Args[1].Op != OpConst64 {
+               return false
+       }
+       k := w - ushift.Args[1].AuxInt // Now we know k!
+       d := int64(1) << k             // divisor
+       sshift := ushift.Args[0]
+       if sshift.Op != sshiftOp {
+               return false
+       }
+       if sshift.Args[0] != x {
+               return false
+       }
+       if sshift.Args[1].Op != OpConst64 || sshift.Args[1].AuxInt != w-1 {
+               return false
+       }
+       if mask.AuxInt != -d {
+               return false
+       }
+
+       // All looks ok. x % d is at most +/- d-1.
+       return ft.signedMinMax(v, -d+1, d-1)
+}
+
 // getBranch returns the range restrictions added by p
 // when reaching b. p is the immediate dominator of b.
 func getBranch(sdom SparseTree, p *Block, b *Block) branch {
index 98d621d35204c62f6e3a54cd22b4fe64f0e34b34..56e8d354e6ba1b73d549ad72dcdb2f6141970818 100644 (file)
@@ -656,3 +656,12 @@ func rsh64to8(v int64) int8 {
        }
        return x
 }
+
+// We don't need to worry about shifting
+// more than the type size.
+// (There is still a negative shift test, but
+// no shift-too-big test.)
+func signedModShift(i int) int64 {
+       // arm64:-"CMP",-"CSEL"
+       return 1 << (i % 64)
+}