]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: add flag to FOR/RANGE to preserve loop semantics across inlines
authorDavid Chase <drchase@google.com>
Wed, 25 Jan 2023 22:08:16 +0000 (17:08 -0500)
committerDavid Chase <drchase@google.com>
Mon, 6 Mar 2023 18:34:53 +0000 (18:34 +0000)
This modifies the loopvar change to be tied to the
package if it is specified that way, and preserves
the change across inlining.

Down the road, this will be triggered (and flow correctly)
if the changed semantics are tied to Go version specified
in go.mod (or rather, for the compiler, by the specified
version for compilation).

Includes tests.

Change-Id: If54e8b6dd23273b86be5ba47838c90d38af9bd1a
Reviewed-on: https://go-review.googlesource.com/c/go/+/463595
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: David Chase <drchase@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/cmd/compile/internal/escape/stmt.go
src/cmd/compile/internal/ir/fmt.go
src/cmd/compile/internal/ir/stmt.go
src/cmd/compile/internal/loopvar/loopvar.go
src/cmd/compile/internal/loopvar/loopvar_test.go
src/cmd/compile/internal/noder/reader.go
src/cmd/compile/internal/noder/writer.go
src/cmd/compile/internal/reflectdata/alg.go
src/cmd/compile/internal/ssagen/ssa.go
src/cmd/compile/internal/walk/complit.go
src/cmd/compile/internal/walk/range.go

index 98cd2d53a65da87d3a01b4abed7ff9b277664a00..5ae78e35fc3b4de037edfa25d4b4e8a84ff82797 100644 (file)
@@ -80,6 +80,7 @@ func (e *escape) stmt(n ir.Node) {
 
        case ir.OFOR:
                n := n.(*ir.ForStmt)
+               base.Assert(!n.DistinctVars) // Should all be rewritten before escape analysis
                e.loopDepth++
                e.discard(n.Cond)
                e.stmt(n.Post)
@@ -89,6 +90,7 @@ func (e *escape) stmt(n ir.Node) {
        case ir.ORANGE:
                // for Key, Value = range X { Body }
                n := n.(*ir.RangeStmt)
+               base.Assert(!n.DistinctVars) // Should all be rewritten before escape analysis
 
                // X is evaluated outside the loop.
                tmp := e.newLoc(nil, false)
index ccd295d7e1b9cc66ad8acf2c0eeb4261c59e7c1f..83f45234254e3a95cd6f89c011b9ef900e75bf59 100644 (file)
@@ -417,6 +417,9 @@ func stmtFmt(n Node, s fmt.State) {
                }
 
                fmt.Fprint(s, "for")
+               if n.DistinctVars {
+                       fmt.Fprint(s, " /* distinct */")
+               }
                if simpleinit {
                        fmt.Fprintf(s, " %v;", n.Init()[0])
                } else if n.Post != nil {
@@ -451,6 +454,9 @@ func stmtFmt(n Node, s fmt.State) {
                        fmt.Fprint(s, " =")
                }
                fmt.Fprintf(s, " range %v { %v }", n.X, n.Body)
+               if n.DistinctVars {
+                       fmt.Fprint(s, " /* distinct vars */")
+               }
 
        case OSELECT:
                n := n.(*SelectStmt)
index dd3908e6653c91f7030c1a3911f6e5c6d5aed4ca..e6f0757ba2c1ef00f3c3c314695f8c60aa1f981e 100644 (file)
@@ -216,14 +216,15 @@ func NewCommStmt(pos src.XPos, comm Node, body []Node) *CommClause {
 // A ForStmt is a non-range for loop: for Init; Cond; Post { Body }
 type ForStmt struct {
        miniStmt
-       Label    *types.Sym
-       Cond     Node
-       Post     Node
-       Body     Nodes
-       HasBreak bool
+       Label        *types.Sym
+       Cond         Node
+       Post         Node
+       Body         Nodes
+       HasBreak     bool
+       DistinctVars bool
 }
 
-func NewForStmt(pos src.XPos, init Node, cond, post Node, body []Node) *ForStmt {
+func NewForStmt(pos src.XPos, init Node, cond, post Node, body []Node, distinctVars bool) *ForStmt {
        n := &ForStmt{Cond: cond, Post: post}
        n.pos = pos
        n.op = OFOR
@@ -231,6 +232,7 @@ func NewForStmt(pos src.XPos, init Node, cond, post Node, body []Node) *ForStmt
                n.init = []Node{init}
        }
        n.Body = body
+       n.DistinctVars = distinctVars
        return n
 }
 
@@ -341,15 +343,16 @@ func (n *LabelStmt) Sym() *types.Sym { return n.Label }
 // A RangeStmt is a range loop: for Key, Value = range X { Body }
 type RangeStmt struct {
        miniStmt
-       Label    *types.Sym
-       Def      bool
-       X        Node
-       RType    Node `mknode:"-"` // see reflectdata/helpers.go
-       Key      Node
-       Value    Node
-       Body     Nodes
-       HasBreak bool
-       Prealloc *Name
+       Label        *types.Sym
+       Def          bool
+       X            Node
+       RType        Node `mknode:"-"` // see reflectdata/helpers.go
+       Key          Node
+       Value        Node
+       Body         Nodes
+       HasBreak     bool
+       DistinctVars bool
+       Prealloc     *Name
 
        // When desugaring the RangeStmt during walk, the assignments to Key
        // and Value may require OCONVIFACE operations. If so, these fields
@@ -360,11 +363,12 @@ type RangeStmt struct {
        ValueSrcRType Node `mknode:"-"`
 }
 
-func NewRangeStmt(pos src.XPos, key, value, x Node, body []Node) *RangeStmt {
+func NewRangeStmt(pos src.XPos, key, value, x Node, body []Node, distinctVars bool) *RangeStmt {
        n := &RangeStmt{X: x, Key: key, Value: value}
        n.pos = pos
        n.op = ORANGE
        n.Body = body
+       n.DistinctVars = distinctVars
        return n
 }
 
index bc288657abc1ff7225561ab8ea47189db444c738..0ecb70570f7476144e9b560c035a47653fb94e42 100644 (file)
@@ -21,10 +21,14 @@ import (
 // subject to this change, that may (once transformed) be heap allocated in the
 // process. (This allows checking after escape analysis to call out any such
 // variables, in case it causes allocation/performance problems).
-
-// For this code, the meaningful debug and hash flag settings
 //
-// base.Debug.LoopVar <= 0 => do not transform
+// The decision to transform loops is normally encoded in the For/Range loop node
+// field DistinctVars but is also dependent on base.LoopVarHash, and some values
+// of base.Debug.LoopVar (which is set per-package).  Decisions encoded in DistinctVars
+// are preserved across inlining, so if package a calls b.F and loops in b.F are
+// transformed, then they are always transformed, whether b.F is inlined or not.
+//
+// Per-package, the debug flag settings that affect this transformer:
 //
 // base.LoopVarHash != nil => use hash setting to govern transformation.
 // note that LoopVarHash != nil sets base.Debug.LoopVar to 1 (unless it is >= 11, for testing/debugging).
@@ -32,13 +36,7 @@ import (
 // base.Debug.LoopVar == 11 => transform ALL loops ignoring syntactic/potential escape. Do not log, can be in addition to GOEXPERIMENT.
 //
 // The effect of GOEXPERIMENT=loopvar is to change the default value (0) of base.Debug.LoopVar to 1 for all packages.
-
 func ForCapture(fn *ir.Func) []*ir.Name {
-       if base.Debug.LoopVar <= 0 { // code in base:flags.go ensures >= 1 if loopvarhash != ""
-               // TODO remove this when the transformation is made sensitive to inlining; this is least-risk for 1.21
-               return nil
-       }
-
        // if a loop variable is transformed it is appended to this slice for later logging
        var transformed []*ir.Name
 
@@ -86,7 +84,7 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                }
 
                // scanChildrenThenTransform processes node x to:
-               //  1. if x is a for/range, note declared iteration variables possiblyLeaked (PL)
+               //  1. if x is a for/range w/ DistinctVars, note declared iteration variables possiblyLeaked (PL)
                //  2. search all of x's children for syntactically escaping references to v in PL,
                //     meaning either address-of-v or v-captured-by-a-closure
                //  3. for all v in PL that had a syntactically escaping reference, transform the declaration
@@ -122,7 +120,9 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                }
 
                        case *ir.RangeStmt:
-                               if !x.Def {
+                               if !(x.Def && x.DistinctVars) {
+                                       // range loop must define its iteration variables AND have distinctVars.
+                                       x.DistinctVars = false
                                        break
                                }
                                noteMayLeak(x.Key)
@@ -130,9 +130,13 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                ir.DoChildren(n, scanChildrenThenTransform)
                                x.Key = maybeReplaceVar(x.Key, x)
                                x.Value = maybeReplaceVar(x.Value, x)
+                               x.DistinctVars = false
                                return false
 
                        case *ir.ForStmt:
+                               if !x.DistinctVars {
+                                       break
+                               }
                                forAllDefInInit(x, noteMayLeak)
                                ir.DoChildren(n, scanChildrenThenTransform)
                                var leaked []*ir.Name
@@ -335,6 +339,7 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                        // (11) post' = {}
                                        x.Post = nil
                                }
+                               x.DistinctVars = false
 
                                return false
                        }
index 6ff4fc9e22f859e3bca8a17fe93927d4ebc76093..6f4e73bb271deed3d38225bae580959187a548cf 100644 (file)
@@ -102,6 +102,63 @@ func TestLoopVar(t *testing.T) {
        }
 }
 
+func TestLoopVarInlines(t *testing.T) {
+       switch runtime.GOOS {
+       case "linux", "darwin":
+       default:
+               t.Skipf("Slow test, usually avoid it, os=%s not linux or darwin", runtime.GOOS)
+       }
+       switch runtime.GOARCH {
+       case "amd64", "arm64":
+       default:
+               t.Skipf("Slow test, usually avoid it, arch=%s not amd64 or arm64", runtime.GOARCH)
+       }
+
+       testenv.MustHaveGoBuild(t)
+       gocmd := testenv.GoToolPath(t)
+       tmpdir := t.TempDir()
+
+       root := "cmd/compile/internal/loopvar/testdata/inlines"
+
+       f := func(pkg string) string {
+               // This disables the loopvar change, except for the specified package.
+               // The effect should follow the package, even though everything (except "c")
+               // is inlined.
+               cmd := testenv.Command(t, gocmd, "run", "-gcflags="+pkg+"=-d=loopvar=1", root)
+               cmd.Env = append(cmd.Env, "GOEXPERIMENT=noloopvar", "HOME="+tmpdir)
+               cmd.Dir = filepath.Join("testdata", "inlines")
+
+               b, e := cmd.CombinedOutput()
+               if e != nil {
+                       t.Error(e)
+               }
+               return string(b)
+       }
+
+       a := f(root + "/a")
+       b := f(root + "/b")
+       c := f(root + "/c")
+       m := f(root)
+
+       t.Logf(a)
+       t.Logf(b)
+       t.Logf(c)
+       t.Logf(m)
+
+       if !strings.Contains(a, "f, af, bf, abf, cf sums = 100, 45, 100, 100, 100") {
+               t.Errorf("Did not see expected value of a")
+       }
+       if !strings.Contains(b, "f, af, bf, abf, cf sums = 100, 100, 45, 45, 100") {
+               t.Errorf("Did not see expected value of b")
+       }
+       if !strings.Contains(c, "f, af, bf, abf, cf sums = 100, 100, 100, 100, 45") {
+               t.Errorf("Did not see expected value of c")
+       }
+       if !strings.Contains(m, "f, af, bf, abf, cf sums = 45, 100, 100, 100, 100") {
+               t.Errorf("Did not see expected value of m")
+       }
+}
+
 func TestLoopVarHashes(t *testing.T) {
        switch runtime.GOOS {
        case "linux", "darwin":
@@ -148,5 +205,4 @@ func TestLoopVarHashes(t *testing.T) {
        if !strings.Contains(m, ", 100, 100, 100, 100") {
                t.Errorf("Did not see expected value of m run")
        }
-
 }
index 64173312acfe7b6e2f04bcdfab03246af58685c8..6098c92ac9cf3c12127929a67a76d3cf5f93ea98 100644 (file)
@@ -1857,7 +1857,7 @@ func (r *reader) forStmt(label *types.Sym) ir.Node {
 
        if r.Bool() {
                pos := r.pos()
-               rang := ir.NewRangeStmt(pos, nil, nil, nil, nil)
+               rang := ir.NewRangeStmt(pos, nil, nil, nil, nil, false)
                rang.Label = label
 
                names, lhs := r.assignList()
@@ -1881,6 +1881,7 @@ func (r *reader) forStmt(label *types.Sym) ir.Node {
                }
 
                rang.Body = r.blockStmt()
+               rang.DistinctVars = r.Bool()
                r.closeAnotherScope()
 
                return rang
@@ -1891,9 +1892,10 @@ func (r *reader) forStmt(label *types.Sym) ir.Node {
        cond := r.optExpr()
        post := r.stmt()
        body := r.blockStmt()
+       dv := r.Bool()
        r.closeAnotherScope()
 
-       stmt := ir.NewForStmt(pos, init, cond, post, body)
+       stmt := ir.NewForStmt(pos, init, cond, post, body, dv)
        stmt.Label = label
        return stmt
 }
index 5dd8d1de2df1b510490493943899e25fdec40d50..97862938eea2379c5ee1833186a384be102d0de9 100644 (file)
@@ -1445,6 +1445,7 @@ func (w *writer) forStmt(stmt *syntax.ForStmt) {
        }
 
        w.blockStmt(stmt.Body)
+       w.Bool(base.Debug.LoopVar > 0)
        w.closeAnotherScope()
 }
 
index a2ba1a2bbec3b3648df7398906a55c9e777967cf..10240b2f1fb120833fd183c2c392a41403b35c71 100644 (file)
@@ -166,7 +166,7 @@ func hashFunc(t *types.Type) *ir.Func {
                init := ir.NewAssignStmt(base.Pos, ni, ir.NewInt(base.Pos, 0))
                cond := ir.NewBinaryExpr(base.Pos, ir.OLT, ni, ir.NewInt(base.Pos, t.NumElem()))
                post := ir.NewAssignStmt(base.Pos, ni, ir.NewBinaryExpr(base.Pos, ir.OADD, ni, ir.NewInt(base.Pos, 1)))
-               loop := ir.NewForStmt(base.Pos, nil, cond, post, nil)
+               loop := ir.NewForStmt(base.Pos, nil, cond, post, nil, false)
                loop.PtrInit().Append(init)
 
                // h = hashel(&p[i], h)
@@ -442,7 +442,7 @@ func eqFunc(t *types.Type) *ir.Func {
                                i := typecheck.Temp(types.Types[types.TINT])
                                init := ir.NewAssignStmt(base.Pos, i, ir.NewInt(base.Pos, 0))
                                cond := ir.NewBinaryExpr(base.Pos, ir.OLT, i, ir.NewInt(base.Pos, iterateTo))
-                               loop := ir.NewForStmt(base.Pos, nil, cond, nil, nil)
+                               loop := ir.NewForStmt(base.Pos, nil, cond, nil, nil, false)
                                loop.PtrInit().Append(init)
 
                                // if eq(p[i+0], q[i+0]) && eq(p[i+1], q[i+1]) && ... && eq(p[i+unroll-1], q[i+unroll-1]) {
index 6831da69081cd731fd00b3308fd291fc75f56844..48d5e34f46bb86554275dabe4d3a5327df337388 100644 (file)
@@ -1796,6 +1796,7 @@ func (s *state) stmt(n ir.Node) {
                // OFOR: for Ninit; Left; Right { Nbody }
                // cond (Left); body (Nbody); incr (Right)
                n := n.(*ir.ForStmt)
+               base.Assert(!n.DistinctVars) // Should all be rewritten before escape analysis
                bCond := s.f.NewBlock(ssa.BlockPlain)
                bBody := s.f.NewBlock(ssa.BlockPlain)
                bIncr := s.f.NewBlock(ssa.BlockPlain)
index 0a8ce65a1690ccb82b53d8383ecdfe170073eb16..6330530aa4ba42162d0a1621eb6c638a30dcba36 100644 (file)
@@ -484,7 +484,7 @@ func maplit(n *ir.CompLitExpr, m ir.Node, init *ir.Nodes) {
                body = typecheck.Stmt(body)
                body = orderStmtInPlace(body, map[string][]*ir.Name{})
 
-               loop := ir.NewForStmt(base.Pos, nil, cond, incr, nil)
+               loop := ir.NewForStmt(base.Pos, nil, cond, incr, nil, false)
                loop.Body = []ir.Node{body}
                loop.SetInit([]ir.Node{zero})
 
index ae2d9c250bb9f4a2823e5ea8754c283f71a9d9ff..e20ffc2a619b2b5d1a360a5c052f9c3d05148e99 100644 (file)
@@ -38,11 +38,12 @@ func cheapComputableIndex(width int64) bool {
 // Node n may also be modified in place, and may also be
 // the returned node.
 func walkRange(nrange *ir.RangeStmt) ir.Node {
+       base.Assert(!nrange.DistinctVars) // Should all be rewritten before escape analysis
        if isMapClear(nrange) {
                return mapRangeClear(nrange)
        }
 
-       nfor := ir.NewForStmt(nrange.Pos(), nil, nil, nil, nil)
+       nfor := ir.NewForStmt(nrange.Pos(), nil, nil, nil, nil, nrange.DistinctVars)
        nfor.SetInit(nrange.Init())
        nfor.Label = nrange.Label