"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/types/typeutil"
+ "golang.org/x/tools/internal/analysisinternal"
)
const Doc = `check references to loop variables from within nested functions
-This analyzer checks for references to loop variables from within a
-function literal inside the loop body. It checks only instances where
-the function literal is called in a defer or go statement that is the
-last statement in the loop body, as otherwise we would need whole
-program analysis.
+This analyzer checks for references to loop variables from within a function
+literal inside the loop body. It checks for patterns where access to a loop
+variable is known to escape the current loop iteration:
+ 1. a call to go or defer at the end of the loop body
+ 2. a call to golang.org/x/sync/errgroup.Group.Go at the end of the loop body
+
+The analyzer only considers references in the last statement of the loop body
+as it is not deep enough to understand the effects of subsequent statements
+which might render the reference benign.
For example:
See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
+// TODO(rfindley): enable support for checking parallel subtests, pending
+// investigation, adding:
+// 3. a call testing.T.Run where the subtest body invokes t.Parallel()
+
var Analyzer = &analysis.Analyzer{
Name: "loopclosure",
Doc: Doc,
}
inspect.Preorder(nodeFilter, func(n ast.Node) {
// Find the variables updated by the loop statement.
- var vars []*ast.Ident
+ var vars []types.Object
addVar := func(expr ast.Expr) {
- if id, ok := expr.(*ast.Ident); ok {
- vars = append(vars, id)
+ if id, _ := expr.(*ast.Ident); id != nil {
+ if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
+ vars = append(vars, obj)
+ }
}
}
var body *ast.BlockStmt
return
}
- // Inspect a go or defer statement
- // if it's the last one in the loop body.
- // (We give up if there are following statements,
- // because it's hard to prove go isn't followed by wait,
- // or defer by return.)
- if len(body.List) == 0 {
- return
- }
- // The function invoked in the last return statement.
- var fun ast.Expr
- switch s := body.List[len(body.List)-1].(type) {
- case *ast.GoStmt:
- fun = s.Call.Fun
- case *ast.DeferStmt:
- fun = s.Call.Fun
- case *ast.ExprStmt: // check for errgroup.Group.Go()
- if call, ok := s.X.(*ast.CallExpr); ok {
- fun = goInvokes(pass.TypesInfo, call)
- }
- }
- lit, ok := fun.(*ast.FuncLit)
- if !ok {
- return
- }
- ast.Inspect(lit.Body, func(n ast.Node) bool {
- id, ok := n.(*ast.Ident)
- if !ok || id.Obj == nil {
- return true
+ // Inspect statements to find function literals that may be run outside of
+ // the current loop iteration.
+ //
+ // For go, defer, and errgroup.Group.Go, we ignore all but the last
+ // statement, because it's hard to prove go isn't followed by wait, or
+ // defer by return.
+ //
+ // We consider every t.Run statement in the loop body, because there is
+ // no such commonly used mechanism for synchronizing parallel subtests.
+ // It is of course theoretically possible to synchronize parallel subtests,
+ // though such a pattern is likely to be exceedingly rare as it would be
+ // fighting against the test runner.
+ lastStmt := len(body.List) - 1
+ for i, s := range body.List {
+ var fun ast.Expr // if non-nil, a function that escapes the loop iteration
+ switch s := s.(type) {
+ case *ast.GoStmt:
+ if i == lastStmt {
+ fun = s.Call.Fun
+ }
+
+ case *ast.DeferStmt:
+ if i == lastStmt {
+ fun = s.Call.Fun
+ }
+
+ case *ast.ExprStmt: // check for errgroup.Group.Go and testing.T.Run (with T.Parallel)
+ if call, ok := s.X.(*ast.CallExpr); ok {
+ if i == lastStmt {
+ fun = goInvoke(pass.TypesInfo, call)
+ }
+ if fun == nil && analysisinternal.LoopclosureParallelSubtests {
+ fun = parallelSubtest(pass.TypesInfo, call)
+ }
+ }
}
- if pass.TypesInfo.Types[id].Type == nil {
- // Not referring to a variable (e.g. struct field name)
- return true
+
+ lit, ok := fun.(*ast.FuncLit)
+ if !ok {
+ continue
}
- for _, v := range vars {
- if v.Obj == id.Obj {
- pass.ReportRangef(id, "loop variable %s captured by func literal",
- id.Name)
+
+ ast.Inspect(lit.Body, func(n ast.Node) bool {
+ id, ok := n.(*ast.Ident)
+ if !ok {
+ return true
}
- }
- return true
- })
+ obj := pass.TypesInfo.Uses[id]
+ if obj == nil {
+ return true
+ }
+ for _, v := range vars {
+ if v == obj {
+ pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
+ }
+ }
+ return true
+ })
+ }
})
return nil, nil
}
-// goInvokes returns a function expression that would be called asynchronously
+// goInvoke returns a function expression that would be called asynchronously
// (but not awaited) in another goroutine as a consequence of the call.
// For example, given the g.Go call below, it returns the function literal expression.
//
// g.Go(func() error { ... })
//
// Currently only "golang.org/x/sync/errgroup.Group()" is considered.
-func goInvokes(info *types.Info, call *ast.CallExpr) ast.Expr {
- f := typeutil.StaticCallee(info, call)
- // Note: Currently only supports: golang.org/x/sync/errgroup.Go.
- if f == nil || f.Name() != "Go" {
+func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
+ if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
return nil
}
- recv := f.Type().(*types.Signature).Recv()
- if recv == nil {
+ return call.Args[0]
+}
+
+// parallelSubtest returns a function expression that would be called
+// asynchronously via the go test runner, as t.Run has been invoked with a
+// function literal that calls t.Parallel.
+//
+// import "testing"
+//
+// func TestFoo(t *testing.T) {
+// tests := []int{0, 1, 2}
+// for i, t := range tests {
+// t.Run("subtest", func(t *testing.T) {
+// t.Parallel()
+// println(i, t)
+// })
+// }
+// }
+func parallelSubtest(info *types.Info, call *ast.CallExpr) ast.Expr {
+ if !isMethodCall(info, call, "testing", "T", "Run") {
return nil
}
- rtype, ok := recv.Type().(*types.Pointer)
+
+ lit, ok := call.Args[1].(*ast.FuncLit)
if !ok {
return nil
}
- named, ok := rtype.Elem().(*types.Named)
+
+ for _, stmt := range lit.Body.List {
+ exprStmt, ok := stmt.(*ast.ExprStmt)
+ if !ok {
+ continue
+ }
+ if isMethodCall(info, exprStmt.X, "testing", "T", "Parallel") {
+ return lit
+ }
+ }
+
+ return nil
+}
+
+// isMethodCall reports whether expr is a method call of
+// <pkgPath>.<typeName>.<method>.
+func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
+ call, ok := expr.(*ast.CallExpr)
if !ok {
- return nil
+ return false
}
- if named.Obj().Name() != "Group" {
- return nil
+
+ // Check that we are calling a method <method>
+ f := typeutil.StaticCallee(info, call)
+ if f == nil || f.Name() != method {
+ return false
+ }
+ recv := f.Type().(*types.Signature).Recv()
+ if recv == nil {
+ return false
+ }
+
+ // Check that the receiver is a <pkgPath>.<typeName> or
+ // *<pkgPath>.<typeName>.
+ rtype := recv.Type()
+ if ptr, ok := recv.Type().(*types.Pointer); ok {
+ rtype = ptr.Elem()
+ }
+ named, ok := rtype.(*types.Named)
+ if !ok {
+ return false
+ }
+ if named.Obj().Name() != typeName {
+ return false
}
pkg := f.Pkg()
if pkg == nil {
- return nil
+ return false
}
- if pkg.Path() != "golang.org/x/sync/errgroup" {
- return nil
+ if pkg.Path() != pkgPath {
+ return false
}
- return call.Args[0]
+
+ return true
}
argNum := firstArg
maxArgNum := firstArg
anyIndex := false
- anyW := false
for i, w := 0, 0; i < len(format); i += w {
w = 1
if format[i] != '%' {
pass.Reportf(call.Pos(), "%s does not support error-wrapping directive %%w", state.name)
return
}
- if anyW {
- pass.Reportf(call.Pos(), "%s call has more than one error-wrapping directive %%w", state.name)
- return
- }
- anyW = true
}
if len(state.argNums) > 0 {
// Continue with the next sequential argument.
s.scanNum()
ok := true
if s.nbytes == len(s.format) || s.nbytes == start || s.format[s.nbytes] != ']' {
- ok = false
- s.nbytes = strings.Index(s.format, "]")
+ ok = false // syntax error is either missing "]" or invalid index.
+ s.nbytes = strings.Index(s.format[start:], "]")
if s.nbytes < 0 {
s.pass.ReportRangef(s.call, "%s format %s is missing closing ]", s.name, s.format)
return false
}
+ s.nbytes = s.nbytes + start
}
arg32, err := strconv.ParseInt(s.format[start:s.nbytes], 10, 32)
if err != nil || !ok || arg32 <= 0 || arg32 > int64(len(s.call.Args)-s.firstArg) {
if reason != "" {
details = " (" + reason + ")"
}
- pass.ReportRangef(call, "%s format %s has arg %s of wrong type %s%s", state.name, state.format, analysisutil.Format(pass.Fset, arg), typeString, details)
+ pass.ReportRangef(call, "%s format %s has arg %s of wrong type %s%s, see also https://pkg.go.dev/fmt#hdr-Printing", state.name, state.format, analysisutil.Format(pass.Fset, arg), typeString, details)
return false
}
if v.typ&argString != 0 && v.verb != 'T' && !bytes.Contains(state.flags, []byte{'#'}) {