]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: change testing.B.Loop keep alive semantic
authorJunyang Shao <shaojunyang@google.com>
Thu, 30 Oct 2025 19:14:57 +0000 (19:14 +0000)
committerJunyang Shao <shaojunyang@google.com>
Fri, 21 Nov 2025 20:49:20 +0000 (12:49 -0800)
This CL implements this initial design of testing.B.Loop's keep variable
alive semantic:
https://github.com/golang/go/issues/61515#issuecomment-2407963248.

Fixes #73137.

Change-Id: I8060470dbcb0dda0819334f3615cc391ff0f6501
Reviewed-on: https://go-review.googlesource.com/c/go/+/716660
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
src/cmd/compile/internal/bloop/bloop.go [new file with mode: 0644]
src/cmd/compile/internal/escape/call.go
src/cmd/compile/internal/gc/main.go
src/cmd/compile/internal/inline/interleaved/interleaved.go
src/cmd/compile/internal/ir/expr.go
src/cmd/compile/internal/typecheck/_builtin/runtime.go
src/cmd/compile/internal/typecheck/builtin.go
src/testing/benchmark.go
test/bloop.go [new file with mode: 0644]
test/inline_testingbloop.go [deleted file]

diff --git a/src/cmd/compile/internal/bloop/bloop.go b/src/cmd/compile/internal/bloop/bloop.go
new file mode 100644 (file)
index 0000000..1e7f915
--- /dev/null
@@ -0,0 +1,313 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bloop
+
+// This file contains support routines for keeping
+// statements alive
+// in such loops (example):
+//
+//     for b.Loop() {
+//             var a, b int
+//             a = 5
+//             b = 6
+//             f(a, b)
+//     }
+//
+// The results of a, b and f(a, b) will be kept alive.
+//
+// Formally, the lhs (if they are [ir.Name]-s) of
+// [ir.AssignStmt], [ir.AssignListStmt],
+// [ir.AssignOpStmt], and the results of [ir.CallExpr]
+// or its args if it doesn't return a value will be kept
+// alive.
+//
+// The keep alive logic is implemented with as wrapping a
+// runtime.KeepAlive around the Name.
+//
+// TODO: currently this is implemented with KeepAlive
+// because it will prevent DSE and DCE which is probably
+// what we want right now. And KeepAlive takes an ssa
+// value instead of a symbol, which is easier to manage.
+// But since KeepAlive's context was mainly in the runtime
+// and GC, should we implement a new intrinsic that lowers
+// to OpVarLive? Peeling out the symbols is a bit tricky
+// and also VarLive seems to assume that there exists a
+// VarDef on the same symbol that dominates it.
+
+import (
+       "cmd/compile/internal/base"
+       "cmd/compile/internal/ir"
+       "cmd/compile/internal/reflectdata"
+       "cmd/compile/internal/typecheck"
+       "cmd/compile/internal/types"
+       "fmt"
+)
+
+// getNameFromNode tries to iteratively peel down the node to
+// get the name.
+func getNameFromNode(n ir.Node) *ir.Name {
+       var ret *ir.Name
+       if n.Op() == ir.ONAME {
+               ret = n.(*ir.Name)
+       } else {
+               // avoid infinite recursion on circular referencing nodes.
+               seen := map[ir.Node]bool{n: true}
+               var findName func(ir.Node) bool
+               findName = func(a ir.Node) bool {
+                       if a.Op() == ir.ONAME {
+                               ret = a.(*ir.Name)
+                               return true
+                       }
+                       if !seen[a] {
+                               seen[a] = true
+                               return ir.DoChildren(a, findName)
+                       }
+                       return false
+               }
+               ir.DoChildren(n, findName)
+       }
+       return ret
+}
+
+// keepAliveAt returns a statement that is either curNode, or a
+// block containing curNode followed by a call to runtime.keepAlive for each
+// ONAME in ns. These calls ensure that names in ns will be live until
+// after curNode's execution.
+func keepAliveAt(ns []*ir.Name, curNode ir.Node) ir.Node {
+       if len(ns) == 0 {
+               return curNode
+       }
+
+       pos := curNode.Pos()
+       calls := []ir.Node{curNode}
+       for _, n := range ns {
+               if n == nil {
+                       continue
+               }
+               if n.Sym() == nil {
+                       continue
+               }
+               if n.Sym().IsBlank() {
+                       continue
+               }
+               arg := ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], n)
+               if !n.Type().IsInterface() {
+                       srcRType0 := reflectdata.TypePtrAt(pos, n.Type())
+                       arg.TypeWord = srcRType0
+                       arg.SrcRType = srcRType0
+               }
+               callExpr := typecheck.Call(pos,
+                       typecheck.LookupRuntime("KeepAlive"),
+                       []ir.Node{arg}, false).(*ir.CallExpr)
+               callExpr.IsCompilerVarLive = true
+               callExpr.NoInline = true
+               calls = append(calls, callExpr)
+       }
+
+       return ir.NewBlockStmt(pos, calls)
+}
+
+func debugName(name *ir.Name, line string) {
+       if base.Flag.LowerM > 0 {
+               if name.Linksym() != nil {
+                       fmt.Printf("%v: %s will be kept alive\n", line, name.Linksym().Name)
+               } else {
+                       fmt.Printf("%v: expr will be kept alive\n", line)
+               }
+       }
+}
+
+// preserveStmt transforms stmt so that any names defined/assigned within it
+// are used after stmt's execution, preventing their dead code elimination
+// and dead store elimination. The return value is the transformed statement.
+func preserveStmt(curFn *ir.Func, stmt ir.Node) (ret ir.Node) {
+       ret = stmt
+       switch n := stmt.(type) {
+       case *ir.AssignStmt:
+               // Peel down struct and slice indexing to get the names
+               name := getNameFromNode(n.X)
+               if name != nil {
+                       debugName(name, ir.Line(stmt))
+                       ret = keepAliveAt([]*ir.Name{name}, n)
+               }
+       case *ir.AssignListStmt:
+               names := []*ir.Name{}
+               for _, lhs := range n.Lhs {
+                       name := getNameFromNode(lhs)
+                       if name != nil {
+                               debugName(name, ir.Line(stmt))
+                               names = append(names, name)
+                       }
+               }
+               ret = keepAliveAt(names, n)
+       case *ir.AssignOpStmt:
+               name := getNameFromNode(n.X)
+               if name != nil {
+                       debugName(name, ir.Line(stmt))
+                       ret = keepAliveAt([]*ir.Name{name}, n)
+               }
+       case *ir.CallExpr:
+               names := []*ir.Name{}
+               curNode := stmt
+               if n.Fun != nil && n.Fun.Type() != nil && n.Fun.Type().NumResults() != 0 {
+                       // This function's results are not assigned, assign them to
+                       // auto tmps and then keepAliveAt these autos.
+                       // Note: markStmt assumes the context that it's called - this CallExpr is
+                       // not within another OAS2, which is guaranteed by the case above.
+                       results := n.Fun.Type().Results()
+                       lhs := make([]ir.Node, len(results))
+                       for i, res := range results {
+                               tmp := typecheck.TempAt(n.Pos(), curFn, res.Type)
+                               lhs[i] = tmp
+                               names = append(names, tmp)
+                       }
+
+                       // Create an assignment statement.
+                       assign := typecheck.AssignExpr(
+                               ir.NewAssignListStmt(n.Pos(), ir.OAS2, lhs,
+                                       []ir.Node{n})).(*ir.AssignListStmt)
+                       assign.Def = true
+                       curNode = assign
+                       plural := ""
+                       if len(results) > 1 {
+                               plural = "s"
+                       }
+                       if base.Flag.LowerM > 0 {
+                               fmt.Printf("%v: function result%s will be kept alive\n", ir.Line(stmt), plural)
+                       }
+               } else {
+                       // This function probably doesn't return anything, keep its args alive.
+                       argTmps := []ir.Node{}
+                       for i, a := range n.Args {
+                               if name := getNameFromNode(a); name != nil {
+                                       // If they are name, keep them alive directly.
+                                       debugName(name, ir.Line(stmt))
+                                       names = append(names, name)
+                               } else if a.Op() == ir.OSLICELIT {
+                                       // variadic args are encoded as slice literal.
+                                       s := a.(*ir.CompLitExpr)
+                                       ns := []*ir.Name{}
+                                       for i, n := range s.List {
+                                               if name := getNameFromNode(n); name != nil {
+                                                       debugName(name, ir.Line(a))
+                                                       ns = append(ns, name)
+                                               } else {
+                                                       // We need a temporary to save this arg.
+                                                       tmp := typecheck.TempAt(n.Pos(), curFn, n.Type())
+                                                       argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, n)))
+                                                       names = append(names, tmp)
+                                                       s.List[i] = tmp
+                                                       if base.Flag.LowerM > 0 {
+                                                               fmt.Printf("%v: function arg will be kept alive\n", ir.Line(n))
+                                                       }
+                                               }
+                                       }
+                                       names = append(names, ns...)
+                               } else {
+                                       // expressions, we need to assign them to temps and change the original arg to reference
+                                       // them.
+                                       tmp := typecheck.TempAt(n.Pos(), curFn, a.Type())
+                                       argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, a)))
+                                       names = append(names, tmp)
+                                       n.Args[i] = tmp
+                                       if base.Flag.LowerM > 0 {
+                                               fmt.Printf("%v: function arg will be kept alive\n", ir.Line(stmt))
+                                       }
+                               }
+                       }
+                       if len(argTmps) > 0 {
+                               argTmps = append(argTmps, n)
+                               curNode = ir.NewBlockStmt(n.Pos(), argTmps)
+                       }
+               }
+               ret = keepAliveAt(names, curNode)
+       }
+       return
+}
+
+func preserveStmts(curFn *ir.Func, list ir.Nodes) {
+       for i := range list {
+               list[i] = preserveStmt(curFn, list[i])
+       }
+}
+
+// isTestingBLoop returns true if it matches the node as a
+// testing.(*B).Loop. See issue #61515.
+func isTestingBLoop(t ir.Node) bool {
+       if t.Op() != ir.OFOR {
+               return false
+       }
+       nFor, ok := t.(*ir.ForStmt)
+       if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
+               return false
+       }
+       n, ok := nFor.Cond.(*ir.CallExpr)
+       if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
+               return false
+       }
+       name := ir.MethodExprName(n.Fun)
+       if name == nil {
+               return false
+       }
+       if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
+               fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
+               // Attempting to match a function call to testing.(*B).Loop
+               return true
+       }
+       return false
+}
+
+type editor struct {
+       inBloop bool
+       curFn   *ir.Func
+}
+
+func (e editor) edit(n ir.Node) ir.Node {
+       e.inBloop = isTestingBLoop(n) || e.inBloop
+       // It's in bloop, mark the stmts with bodies.
+       ir.EditChildren(n, e.edit)
+       if e.inBloop {
+               switch n := n.(type) {
+               case *ir.ForStmt:
+                       preserveStmts(e.curFn, n.Body)
+               case *ir.IfStmt:
+                       preserveStmts(e.curFn, n.Body)
+                       preserveStmts(e.curFn, n.Else)
+               case *ir.BlockStmt:
+                       preserveStmts(e.curFn, n.List)
+               case *ir.CaseClause:
+                       preserveStmts(e.curFn, n.List)
+                       preserveStmts(e.curFn, n.Body)
+               case *ir.CommClause:
+                       preserveStmts(e.curFn, n.Body)
+               }
+       }
+       return n
+}
+
+// BloopWalk performs a walk on all functions in the package
+// if it imports testing and wrap the results of all qualified
+// statements in a runtime.KeepAlive intrinsic call. See package
+// doc for more details.
+//
+//     for b.Loop() {...}
+//
+// loop's body.
+func BloopWalk(pkg *ir.Package) {
+       hasTesting := false
+       for _, i := range pkg.Imports {
+               if i.Path == "testing" {
+                       hasTesting = true
+                       break
+               }
+       }
+       if !hasTesting {
+               return
+       }
+       for _, fn := range pkg.Funcs {
+               e := editor{false, fn}
+               ir.EditChildren(fn, e.edit)
+       }
+}
index f9351de975fdabdca240c515dce38d9ad73b705a..f9d34630346024ec1b89b03adffd9e7128198a21 100644 (file)
@@ -45,6 +45,20 @@ func (e *escape) call(ks []hole, call ir.Node) {
                        fn = ir.StaticCalleeName(v)
                }
 
+               // argumentParam handles escape analysis of assigning a call
+               // argument to its corresponding parameter.
+               argumentParam := func(param *types.Field, arg ir.Node) {
+                       e.rewriteArgument(arg, call, fn)
+                       argument(e.tagHole(ks, fn, param), arg)
+               }
+
+               if call.IsCompilerVarLive {
+                       // Don't escape compiler-inserted KeepAlive.
+                       argumentParam = func(param *types.Field, arg ir.Node) {
+                               argument(e.discardHole(), arg)
+                       }
+               }
+
                fntype := call.Fun.Type()
                if fn != nil {
                        fntype = fn.Type()
@@ -77,13 +91,6 @@ func (e *escape) call(ks []hole, call ir.Node) {
                        recvArg = call.Fun.(*ir.SelectorExpr).X
                }
 
-               // argumentParam handles escape analysis of assigning a call
-               // argument to its corresponding parameter.
-               argumentParam := func(param *types.Field, arg ir.Node) {
-                       e.rewriteArgument(arg, call, fn)
-                       argument(e.tagHole(ks, fn, param), arg)
-               }
-
                // internal/abi.EscapeNonString forces its argument to be on
                // the heap, if it contains a non-string pointer.
                // This is used in hash/maphash.Comparable, where we cannot
index ef6a5d6017c9ec081a0e928ea1057019734b105b..780b08a8722f3d6c32d22d42df8e81a62e2ec01c 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bufio"
        "bytes"
        "cmd/compile/internal/base"
+       "cmd/compile/internal/bloop"
        "cmd/compile/internal/coverage"
        "cmd/compile/internal/deadlocals"
        "cmd/compile/internal/dwarfgen"
@@ -234,6 +235,9 @@ func Main(archInit func(*ssagen.ArchInfo)) {
                }
        }
 
+       // Apply bloop markings.
+       bloop.BloopWalk(typecheck.Target)
+
        // Interleaved devirtualization and inlining.
        base.Timer.Start("fe", "devirtualize-and-inline")
        interleaved.DevirtualizeAndInlinePackage(typecheck.Target, profile)
index c83bbdb718df5633793a71b501ee4f711788da46..80a0cb97df16965aa446b5df417cbfdd2d703bd3 100644 (file)
@@ -254,28 +254,6 @@ func (s *inlClosureState) mark(n ir.Node) ir.Node {
                return n // already visited n.X before wrapping
        }
 
-       if isTestingBLoop(n) {
-               // No inlining nor devirtualization performed on b.Loop body
-               if base.Flag.LowerM > 0 {
-                       fmt.Printf("%v: skip inlining within testing.B.loop for %v\n", ir.Line(n), n)
-               }
-               // We still want to explore inlining opportunities in other parts of ForStmt.
-               nFor, _ := n.(*ir.ForStmt)
-               nForInit := nFor.Init()
-               for i, x := range nForInit {
-                       if x != nil {
-                               nForInit[i] = s.mark(x)
-                       }
-               }
-               if nFor.Cond != nil {
-                       nFor.Cond = s.mark(nFor.Cond)
-               }
-               if nFor.Post != nil {
-                       nFor.Post = s.mark(nFor.Post)
-               }
-               return n
-       }
-
        if p != nil {
                n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node.
        }
@@ -371,29 +349,3 @@ func match(n ir.Node) bool {
        }
        return false
 }
-
-// isTestingBLoop returns true if it matches the node as a
-// testing.(*B).Loop. See issue #61515.
-func isTestingBLoop(t ir.Node) bool {
-       if t.Op() != ir.OFOR {
-               return false
-       }
-       nFor, ok := t.(*ir.ForStmt)
-       if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
-               return false
-       }
-       n, ok := nFor.Cond.(*ir.CallExpr)
-       if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
-               return false
-       }
-       name := ir.MethodExprName(n.Fun)
-       if name == nil {
-               return false
-       }
-       if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
-               fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
-               // Attempting to match a function call to testing.(*B).Loop
-               return true
-       }
-       return false
-}
index 6f198f0021695d539904e84cdcf47641553085b0..ca2838d1e31f6f87a45ce60e43a4765c6a6a3c9a 100644 (file)
@@ -193,6 +193,9 @@ type CallExpr struct {
        GoDefer   bool // whether this call is part of a go or defer statement
        NoInline  bool // whether this call must not be inlined
        UseBuf    bool // use stack buffer for backing store (OAPPEND only)
+       // whether it's a runtime.KeepAlive call the compiler generates to
+       // keep a variable alive. See #73137.
+       IsCompilerVarLive bool
 }
 
 func NewCallExpr(pos src.XPos, op Op, fun Node, args []Node) *CallExpr {
index 7988ebf5b93d1385bdb9a52d9587930eac91d8c7..784e7950fd26889f8b30654b03033d4be1d9ba56 100644 (file)
@@ -303,3 +303,6 @@ var loong64HasLSX bool
 var riscv64HasZbb bool
 
 func asanregisterglobals(unsafe.Pointer, uintptr)
+
+// used by testing.B.Loop
+func KeepAlive(interface{})
index ee892856dd93df06c8bc8f216a3f1824449f94e4..9cf55422907f2c94801a14ecb95941a0d0158621 100644 (file)
@@ -249,6 +249,7 @@ var runtimeDecls = [...]struct {
        {"loong64HasLSX", varTag, 6},
        {"riscv64HasZbb", varTag, 6},
        {"asanregisterglobals", funcTag, 136},
+       {"KeepAlive", funcTag, 11},
 }
 
 func runtimeTypes() []*types.Type {
index 437c2ec741b18245b3d4645cbc8747711dbc01ab..fbd82beb8474b4f55a4563f1da603547dea2b542 100644 (file)
@@ -483,12 +483,14 @@ func (b *B) loopSlowPath() bool {
 // the timer so cleanup code is not measured.
 //
 // Within the body of a "for b.Loop() { ... }" loop, arguments to and
-// results from function calls within the loop are kept alive, preventing
-// the compiler from fully optimizing away the loop body. Currently, this is
-// implemented by disabling inlining of functions called in a b.Loop loop.
-// This applies only to calls syntactically between the curly braces of the loop,
-// and the loop condition must be written exactly as "b.Loop()". Optimizations
-// are performed as usual in any functions called by the loop.
+// results from function calls and assignment receivers within the loop are kept
+// alive, preventing the compiler from fully optimizing away the loop body.
+// Currently, this is implemented as a compiler transformation that wraps such
+// variables with a runtime.KeepAlive intrinsic call. The compiler can recursively
+// walk the body of for, if statments, the cases of switch, select statments
+// and bracket-braced blocks. This applies only to statements syntactically between
+// the curly braces of the loop, and the loop condition must be written exactly
+// as "b.Loop()".
 //
 // After Loop returns false, b.N contains the total number of iterations that
 // ran, so the benchmark may use b.N to compute other average metrics.
diff --git a/test/bloop.go b/test/bloop.go
new file mode 100644 (file)
index 0000000..0d2dcba
--- /dev/null
@@ -0,0 +1,51 @@
+// errorcheck -0 -m
+
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Test keeping statements results in testing.B.Loop alive.
+// See issue #61515, #73137.
+
+package foo
+
+import "testing"
+
+func caninline(x int) int { // ERROR "can inline caninline"
+       return x
+}
+
+var something int
+
+func caninlineNoRet(x int) { // ERROR "can inline caninlineNoRet"
+       something = x
+}
+
+func caninlineVariadic(x ...int) { // ERROR "can inline caninlineVariadic" "x does not escape"
+       something = x[0]
+}
+
+func test(b *testing.B, localsink, cond int) { // ERROR "leaking param: b"
+       for i := 0; i < b.N; i++ {
+               caninline(1) // ERROR "inlining call to caninline"
+       }
+       for b.Loop() { // ERROR "inlining call to testing\.\(\*B\)\.Loop"
+               caninline(1)                 // ERROR "inlining call to caninline" "function result will be kept alive" ".* does not escape"
+               caninlineNoRet(1)            // ERROR "inlining call to caninlineNoRet" "function arg will be kept alive" ".* does not escape"
+               caninlineVariadic(1)         // ERROR "inlining call to caninlineVariadic" "function arg will be kept alive" ".* does not escape"
+               caninlineVariadic(localsink) // ERROR "inlining call to caninlineVariadic" "localsink will be kept alive" ".* does not escape"
+               localsink = caninline(1)     // ERROR "inlining call to caninline" "localsink will be kept alive" ".* does not escape"
+               localsink += 5               // ERROR "localsink will be kept alive" ".* does not escape"
+               localsink, cond = 1, 2       // ERROR "localsink will be kept alive" "cond will be kept alive" ".* does not escape"
+               if cond > 0 {
+                       caninline(1) // ERROR "inlining call to caninline" "function result will be kept alive" ".* does not escape"
+               }
+               switch cond {
+               case 2:
+                       caninline(1) // ERROR "inlining call to caninline" "function result will be kept alive" ".* does not escape"
+               }
+               {
+                       caninline(1) // ERROR "inlining call to caninline" "function result will be kept alive" ".* does not escape"
+               }
+       }
+}
diff --git a/test/inline_testingbloop.go b/test/inline_testingbloop.go
deleted file mode 100644 (file)
index 702a652..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-// errorcheck -0 -m
-
-// Copyright 2024 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Test no inlining of function calls in testing.B.Loop.
-// See issue #61515.
-
-package foo
-
-import "testing"
-
-func caninline(x int) int { // ERROR "can inline caninline"
-       return x
-}
-
-func test(b *testing.B) { // ERROR "leaking param: b"
-       for i := 0; i < b.N; i++ {
-               caninline(1) // ERROR "inlining call to caninline"
-       }
-       for b.Loop() { // ERROR "skip inlining within testing.B.loop" "inlining call to testing\.\(\*B\)\.Loop"
-               caninline(1)
-       }
-       for i := 0; i < b.N; i++ {
-               caninline(1) // ERROR "inlining call to caninline"
-       }
-       for b.Loop() { // ERROR "skip inlining within testing.B.loop" "inlining call to testing\.\(\*B\)\.Loop"
-               caninline(1)
-       }
-       for i := 0; i < b.N; i++ {
-               caninline(1) // ERROR "inlining call to caninline"
-       }
-       for b.Loop() { // ERROR "skip inlining within testing.B.loop" "inlining call to testing\.\(\*B\)\.Loop"
-               caninline(1)
-       }
-}