]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: restore return-in-loop loopvar optimization
authorDavid Chase <drchase@google.com>
Tue, 28 Feb 2023 22:36:38 +0000 (17:36 -0500)
committerDavid Chase <drchase@google.com>
Thu, 16 Mar 2023 18:46:51 +0000 (18:46 +0000)
but this time, correctly.
children of Returns can have For/Range loops in them,
and those must be visited.

Includes test to verify that the optimization occurs,
and also that the problematic case that broke the original
optimization is now correctly handled.

Change-Id: If5a94fd51c862d4bfb318fec78456b7b202f3fcd
Reviewed-on: https://go-review.googlesource.com/c/go/+/472355
Run-TryBot: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/cmd/compile/internal/loopvar/loopvar.go
src/cmd/compile/internal/loopvar/loopvar_test.go
src/cmd/compile/internal/loopvar/testdata/opt.go [new file with mode: 0644]

index 0ecb70570f7476144e9b560c035a47653fb94e42..c92b9d61ea360b72b1c6eded321b2a4eb950a4c6 100644 (file)
@@ -50,6 +50,10 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                // will be transformed.
                possiblyLeaked := make(map[*ir.Name]bool)
 
+               // these enable an optimization of "escape" under return statements
+               loopDepth := 0
+               returnInLoopDepth := 0
+
                // noteMayLeak is called for candidate variables in for range/3-clause, and
                // adds them (mapped to false) to possiblyLeaked.
                noteMayLeak := func(x ir.Node) {
@@ -95,6 +99,11 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                scanChildrenThenTransform = func(n ir.Node) bool {
                        switch x := n.(type) {
                        case *ir.ClosureExpr:
+                               if returnInLoopDepth >= loopDepth {
+                                       // This expression is a child of a return, which escapes all loops above
+                                       // the return, but not those between this expression and the return.
+                                       break
+                               }
                                for _, cv := range x.Func.ClosureVars {
                                        v := cv.Canonical()
                                        if _, ok := possiblyLeaked[v]; ok {
@@ -103,6 +112,11 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                }
 
                        case *ir.AddrExpr:
+                               if returnInLoopDepth >= loopDepth {
+                                       // This expression is a child of a return, which escapes all loops above
+                                       // the return, but not those between this expression and the return.
+                                       break
+                               }
                                // Explicitly note address-taken so that return-statements can be excluded
                                y := ir.OuterValue(x.X)
                                if y.Op() != ir.ONAME {
@@ -119,6 +133,11 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                        }
                                }
 
+                       case *ir.ReturnStmt:
+                               savedRILD := returnInLoopDepth
+                               returnInLoopDepth = loopDepth
+                               defer func() { returnInLoopDepth = savedRILD }()
+
                        case *ir.RangeStmt:
                                if !(x.Def && x.DistinctVars) {
                                        // range loop must define its iteration variables AND have distinctVars.
@@ -127,7 +146,9 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                }
                                noteMayLeak(x.Key)
                                noteMayLeak(x.Value)
+                               loopDepth++
                                ir.DoChildren(n, scanChildrenThenTransform)
+                               loopDepth--
                                x.Key = maybeReplaceVar(x.Key, x)
                                x.Value = maybeReplaceVar(x.Value, x)
                                x.DistinctVars = false
@@ -138,7 +159,9 @@ func ForCapture(fn *ir.Func) []*ir.Name {
                                        break
                                }
                                forAllDefInInit(x, noteMayLeak)
+                               loopDepth++
                                ir.DoChildren(n, scanChildrenThenTransform)
+                               loopDepth--
                                var leaked []*ir.Name
                                // Collect the leaking variables for the much-more-complex transformation.
                                forAllDefInInit(x, func(z ir.Node) {
index 6f4e73bb271deed3d38225bae580959187a548cf..729c240ef549e49cfed002d525bf5567af52a73e 100644 (file)
@@ -206,3 +206,41 @@ func TestLoopVarHashes(t *testing.T) {
                t.Errorf("Did not see expected value of m run")
        }
 }
+
+func TestLoopVarOpt(t *testing.T) {
+       switch runtime.GOOS {
+       case "linux", "darwin":
+       default:
+               t.Skipf("Slow test, usually avoid it, os=%s not linux or darwin", runtime.GOOS)
+       }
+       switch runtime.GOARCH {
+       case "amd64", "arm64":
+       default:
+               t.Skipf("Slow test, usually avoid it, arch=%s not amd64 or arm64", runtime.GOARCH)
+       }
+
+       testenv.MustHaveGoBuild(t)
+       gocmd := testenv.GoToolPath(t)
+
+       cmd := testenv.Command(t, gocmd, "run", "-gcflags=-d=loopvar=2", "opt.go")
+       cmd.Dir = filepath.Join("testdata")
+
+       b, err := cmd.CombinedOutput()
+       m := string(b)
+
+       t.Logf(m)
+
+       yCount := strings.Count(m, "opt.go:16:6: transformed loop variable private escapes (loop inlined into ./opt.go:30)")
+       nCount := strings.Count(m, "shared")
+
+       if yCount != 1 {
+               t.Errorf("yCount=%d != 1", yCount)
+       }
+       if nCount > 0 {
+               t.Errorf("nCount=%d > 0", nCount)
+       }
+       if err != nil {
+               t.Errorf("err=%v != nil", err)
+       }
+
+}
diff --git a/src/cmd/compile/internal/loopvar/testdata/opt.go b/src/cmd/compile/internal/loopvar/testdata/opt.go
new file mode 100644 (file)
index 0000000..1bcd736
--- /dev/null
@@ -0,0 +1,42 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "fmt"
+       "os"
+)
+
+var is []func() int
+
+func inline(j, k int) []*int {
+       var a []*int
+       for private := j; private < k; private++ {
+               a = append(a, &private)
+       }
+       return a
+
+}
+
+//go:noinline
+func notinline(j, k int) ([]*int, *int) {
+       for shared := j; shared < k; shared++ {
+               if shared == k/2 {
+                       // want the call inlined, want "private" in that inline to be transformed,
+                       // (believe it ends up on init node of the return).
+                       // but do not want "shared" transformed,
+                       return inline(j, k), &shared
+               }
+       }
+       return nil, &j
+}
+
+func main() {
+       a, p := notinline(2, 9)
+       fmt.Printf("a[0]=%d,*p=%d\n", *a[0], *p)
+       if *a[0] != 2 {
+               os.Exit(1)
+       }
+}