From 93f1aac246ac93dbd929b0e5e49bcc65f1dd3b57 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 21 Dec 2016 10:58:12 -0500 Subject: [PATCH] cmd/vet: in rangeloop check, inspect for loop variables too + Test. Change-Id: I42eaea1c716217f7945c008ff4bde6de14df5687 Reviewed-on: https://go-review.googlesource.com/34619 Reviewed-by: Robert Griesemer --- src/cmd/vet/main.go | 3 ++ src/cmd/vet/rangeloop.go | 61 +++++++++++++++++++++++-------- src/cmd/vet/testdata/rangeloop.go | 38 +++++++++++++++---- 3 files changed, 79 insertions(+), 23 deletions(-) diff --git a/src/cmd/vet/main.go b/src/cmd/vet/main.go index ffe988b9fc..f0309cba94 100644 --- a/src/cmd/vet/main.go +++ b/src/cmd/vet/main.go @@ -136,6 +136,7 @@ var ( callExpr *ast.CallExpr compositeLit *ast.CompositeLit exprStmt *ast.ExprStmt + forStmt *ast.ForStmt funcDecl *ast.FuncDecl funcLit *ast.FuncLit genDecl *ast.GenDecl @@ -495,6 +496,8 @@ func (f *File) Visit(node ast.Node) ast.Visitor { key = compositeLit case *ast.ExprStmt: key = exprStmt + case *ast.ForStmt: + key = forStmt case *ast.FuncDecl: key = funcDecl case *ast.FuncLit: diff --git a/src/cmd/vet/rangeloop.go b/src/cmd/vet/rangeloop.go index e085e21a23..53a41364df 100644 --- a/src/cmd/vet/rangeloop.go +++ b/src/cmd/vet/rangeloop.go @@ -25,27 +25,55 @@ import "go/ast" func init() { register("rangeloops", - "check that range loop variables are used correctly", - checkRangeLoop, - rangeStmt) + "check that loop variables are used correctly", + checkLoop, + rangeStmt, forStmt) } -// checkRangeLoop walks the body of the provided range statement, checking if +// checkLoop walks the body of the provided loop statement, checking whether // its index or value variables are used unsafely inside goroutines or deferred // function literals. -func checkRangeLoop(f *File, node ast.Node) { - n := node.(*ast.RangeStmt) - key, _ := n.Key.(*ast.Ident) - val, _ := n.Value.(*ast.Ident) - if key == nil && val == nil { +func checkLoop(f *File, node ast.Node) { + // Find the variables updated by the loop statement. + var vars []*ast.Ident + addVar := func(expr ast.Expr) { + if id, ok := expr.(*ast.Ident); ok { + vars = append(vars, id) + } + } + var body *ast.BlockStmt + switch n := node.(type) { + case *ast.RangeStmt: + body = n.Body + addVar(n.Key) + addVar(n.Value) + case *ast.ForStmt: + body = n.Body + switch post := n.Post.(type) { + case *ast.AssignStmt: + // e.g. for p = head; p != nil; p = p.next + for _, lhs := range post.Lhs { + addVar(lhs) + } + case *ast.IncDecStmt: + // e.g. for i := 0; i < n; i++ + addVar(post.X) + } + } + if vars == nil { return } - sl := n.Body.List - if len(sl) == 0 { + + // Inspect a go or defer statement + // if it's the last one in the loop body. + // (We give up if there are following statements, + // because it's hard to prove go isn't followed by wait, + // or defer by return.) + if len(body.List) == 0 { return } var last *ast.CallExpr - switch s := sl[len(sl)-1].(type) { + switch s := body.List[len(body.List)-1].(type) { case *ast.GoStmt: last = s.Call case *ast.DeferStmt: @@ -63,11 +91,14 @@ func checkRangeLoop(f *File, node ast.Node) { return true } if f.pkg.types[id].Type == nil { - // Not referring to a variable + // Not referring to a variable (e.g. struct field name) return true } - if key != nil && id.Obj == key.Obj || val != nil && id.Obj == val.Obj { - f.Bad(id.Pos(), "range variable", id.Name, "captured by func literal") + for _, v := range vars { + if v.Obj == id.Obj { + f.Badf(id.Pos(), "loop variable %s captured by func literal", + id.Name) + } } return true }) diff --git a/src/cmd/vet/testdata/rangeloop.go b/src/cmd/vet/testdata/rangeloop.go index 66223aad71..cd3b4cbc45 100644 --- a/src/cmd/vet/testdata/rangeloop.go +++ b/src/cmd/vet/testdata/rangeloop.go @@ -10,24 +10,24 @@ func RangeLoopTests() { var s []int for i, v := range s { go func() { - println(i) // ERROR "range variable i captured by func literal" - println(v) // ERROR "range variable v captured by func literal" + println(i) // ERROR "loop variable i captured by func literal" + println(v) // ERROR "loop variable v captured by func literal" }() } for i, v := range s { defer func() { - println(i) // ERROR "range variable i captured by func literal" - println(v) // ERROR "range variable v captured by func literal" + println(i) // ERROR "loop variable i captured by func literal" + println(v) // ERROR "loop variable v captured by func literal" }() } for i := range s { go func() { - println(i) // ERROR "range variable i captured by func literal" + println(i) // ERROR "loop variable i captured by func literal" }() } for _, v := range s { go func() { - println(v) // ERROR "range variable v captured by func literal" + println(v) // ERROR "loop variable v captured by func literal" }() } for i, v := range s { @@ -53,7 +53,7 @@ func RangeLoopTests() { var f int for x[0], f = range s { go func() { - _ = f // ERROR "range variable f captured by func literal" + _ = f // ERROR "loop variable f captured by func literal" }() } type T struct { @@ -62,7 +62,29 @@ func RangeLoopTests() { for _, v := range s { go func() { _ = T{v: 1} - _ = []int{v: 1} // ERROR "range variable v captured by func literal" + _ = []int{v: 1} // ERROR "loop variable v captured by func literal" + }() + } + + // ordinary for-loops + for i := 0; i < 10; i++ { + go func() { + print(i) // ERROR "loop variable i captured by func literal" + }() + } + for i, j := 0, 1; i < 100; i, j = j, i+j { + go func() { + print(j) // ERROR "loop variable j captured by func literal" + }() + } + type cons struct { + car int + cdr *cons + } + var head *cons + for p := head; p != nil; p = p.next { + go func() { + print(p.car) // ERROR "loop variable p captured by func literal" }() } } -- 2.48.1