]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: don't evaluate side effects of range over array
authorKeith Randall <khr@golang.org>
Wed, 19 Mar 2025 17:17:22 +0000 (10:17 -0700)
committerKeith Randall <khr@golang.org>
Mon, 21 Apr 2025 22:50:43 +0000 (15:50 -0700)
If the thing we're ranging over is an array or ptr to array, and
it doesn't have a function call or channel receive in it, then we
shouldn't evaluate it.

Typecheck the ranged-over value as a constant in that case.
That makes the unified exporter replace the range expression
with a constant int.

Change-Id: I0d4ea081de70d20cf6d1fa8d25ef6cb021975554
Reviewed-on: https://go-review.googlesource.com/c/go/+/659317
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Robert Griesemer <gri@google.com>
src/cmd/compile/internal/ssa/nilcheck.go
src/cmd/compile/internal/types2/range.go
src/cmd/compile/internal/walk/range.go
src/go/types/range.go
test/codegen/issue52635.go

index c69cd8c32ed39405f6562ce96542349582f0fc6a..9d43ec19919825b61ee041dd649eadc901ba632a 100644 (file)
@@ -217,12 +217,12 @@ func nilcheckelim2(f *Func) {
                                        f.Warnl(v.Pos, "removed nil check")
                                }
                                // For bug 33724, policy is that we might choose to bump an existing position
-                               // off the faulting load/store in favor of the one from the nil check.
+                               // off the faulting load in favor of the one from the nil check.
 
                                // Iteration order means that first nilcheck in the chain wins, others
                                // are bumped into the ordinary statement preservation algorithm.
                                u := b.Values[unnecessary.get(v.Args[0].ID)]
-                               if !u.Pos.SameFileAndLine(v.Pos) {
+                               if !u.Type.IsMemory() && !u.Pos.SameFileAndLine(v.Pos) {
                                        if u.Pos.IsStmt() == src.PosIsStmt {
                                                pendingLines.add(u.Pos)
                                        }
index 86626ceaa8b38e6c079e7d2bbf659300b052cf19..ecda53d14bad45f768d5de715fc3fc05db6e63e1 100644 (file)
@@ -8,6 +8,7 @@ package types2
 
 import (
        "cmd/compile/internal/syntax"
+       "go/constant"
        "internal/buildcfg"
        . "internal/types/errors"
 )
@@ -23,8 +24,32 @@ import (
 func (check *Checker) rangeStmt(inner stmtContext, rangeStmt *syntax.ForStmt, noNewVarPos poser, sKey, sValue, sExtra, rangeVar syntax.Expr, isDef bool) {
        // check expression to iterate over
        var x operand
+
+       // From the spec:
+       //   The range expression x is evaluated before beginning the loop,
+       //   with one exception: if at most one iteration variable is present
+       //   and x or len(x) is constant, the range expression is not evaluated.
+       // So we have to be careful not to evaluate the arg in the
+       // described situation.
+
+       check.hasCallOrRecv = false
        check.expr(nil, &x, rangeVar)
 
+       if isTypes2 && x.mode != invalid && sValue == nil && !check.hasCallOrRecv {
+               if t, ok := arrayPtrDeref(under(x.typ)).(*Array); ok {
+                       // Override type of rangeVar to be a constant
+                       // (and thus side-effects will not be computed
+                       // by the backend).
+                       check.record(&operand{
+                               mode: constant_,
+                               expr: rangeVar,
+                               typ:  Typ[Int],
+                               val:  constant.MakeInt64(t.len),
+                               id:   x.id,
+                       })
+               }
+       }
+
        // determine key/value types
        var key, val Type
        if x.mode != invalid {
index ede9f2182d494fbdf7f531f27294f22939246318..a1e5442a69eb336fe8ab02907f24968931551518 100644 (file)
@@ -5,6 +5,7 @@
 package walk
 
 import (
+       "go/constant"
        "internal/buildcfg"
        "unicode/utf8"
 
@@ -80,6 +81,10 @@ func walkRange(nrange *ir.RangeStmt) ir.Node {
                base.Fatalf("walkRange")
 
        case types.IsInt[k]:
+               if nn := arrayRangeClear(nrange, v1, v2, a); nn != nil {
+                       base.Pos = lno
+                       return nn
+               }
                hv1 := typecheck.TempAt(base.Pos, ir.CurFunc, t)
                hn := typecheck.TempAt(base.Pos, ir.CurFunc, t)
 
@@ -519,13 +524,33 @@ func arrayRangeClear(loop *ir.RangeStmt, v1, v2, a ir.Node) ir.Node {
        }
        lhs := stmt.X.(*ir.IndexExpr)
        x := lhs.X
-       if a.Type().IsPtr() && a.Type().Elem().IsArray() {
-               if s, ok := x.(*ir.StarExpr); ok && s.Op() == ir.ODEREF {
-                       x = s.X
+
+       // Get constant number of iterations for int and array cases.
+       n := int64(-1)
+       if ir.IsConst(a, constant.Int) {
+               n = ir.Int64Val(a)
+       } else if a.Type().IsArray() {
+               n = a.Type().NumElem()
+       } else if a.Type().IsPtr() && a.Type().Elem().IsArray() {
+               n = a.Type().Elem().NumElem()
+       }
+
+       if n >= 0 {
+               // Int/Array case.
+               if !x.Type().IsArray() {
+                       return nil
+               }
+               if x.Type().NumElem() != n {
+                       return nil
+               }
+       } else {
+               // Slice case.
+               if !ir.SameSafeExpr(x, a) {
+                       return nil
                }
        }
 
-       if !ir.SameSafeExpr(x, a) || !ir.SameSafeExpr(lhs.Index, v1) {
+       if !ir.SameSafeExpr(lhs.Index, v1) {
                return nil
        }
 
@@ -533,7 +558,7 @@ func arrayRangeClear(loop *ir.RangeStmt, v1, v2, a ir.Node) ir.Node {
                return nil
        }
 
-       return arrayClear(stmt.Pos(), a, loop)
+       return arrayClear(stmt.Pos(), x, loop)
 }
 
 // arrayClear constructs a call to runtime.memclr for fast zeroing of slices and arrays.
index 5c80463aba0db3b25d275aa53b46ebfe304bd93d..91149c1426fe8c59663131f1290b2553664a8dac 100644 (file)
@@ -11,6 +11,7 @@ package types
 
 import (
        "go/ast"
+       "go/constant"
        "internal/buildcfg"
        . "internal/types/errors"
 )
@@ -26,8 +27,32 @@ import (
 func (check *Checker) rangeStmt(inner stmtContext, rangeStmt *ast.RangeStmt, noNewVarPos positioner, sKey, sValue, sExtra, rangeVar ast.Expr, isDef bool) {
        // check expression to iterate over
        var x operand
+
+       // From the spec:
+       //   The range expression x is evaluated before beginning the loop,
+       //   with one exception: if at most one iteration variable is present
+       //   and x or len(x) is constant, the range expression is not evaluated.
+       // So we have to be careful not to evaluate the arg in the
+       // described situation.
+
+       check.hasCallOrRecv = false
        check.expr(nil, &x, rangeVar)
 
+       if isTypes2 && x.mode != invalid && sValue == nil && !check.hasCallOrRecv {
+               if t, ok := arrayPtrDeref(under(x.typ)).(*Array); ok {
+                       // Override type of rangeVar to be a constant
+                       // (and thus side-effects will not be computed
+                       // by the backend).
+                       check.record(&operand{
+                               mode: constant_,
+                               expr: rangeVar,
+                               typ:  Typ[Int],
+                               val:  constant.MakeInt64(t.len),
+                               id:   x.id,
+                       })
+               }
+       }
+
        // determine key/value types
        var key, val Type
        if x.mode != invalid {
index 9b08cade36f98fc0bafb7832924a2094b8cc225e..9ee63f0fbeccdd6f070dd5d1c822cb8342ac52d7 100644 (file)
@@ -12,6 +12,7 @@ package codegen
 type T struct {
        a *[10]int
        b [10]int
+       s []int
 }
 
 func (t *T) f() {
@@ -38,4 +39,15 @@ func (t *T) f() {
        for i := range *t.a {
                (*t.a)[i] = 0
        }
+
+       // amd64:-".*runtime.memclrNoHeapPointers"
+       // amd64:"DUFFZERO"
+       for i := range t.b {
+               t.b[i] = 0
+       }
+
+       // amd64:".*runtime.memclrNoHeapPointers"
+       for i := range t.s {
+               t.s[i] = 0
+       }
 }