]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/vet: in rangeloop check, inspect for loop variables too
authorAlan Donovan <adonovan@google.com>
Wed, 21 Dec 2016 15:58:12 +0000 (10:58 -0500)
committerAlan Donovan <adonovan@google.com>
Mon, 2 Oct 2017 17:45:00 +0000 (17:45 +0000)
+ Test.

Change-Id: I42eaea1c716217f7945c008ff4bde6de14df5687
Reviewed-on: https://go-review.googlesource.com/34619
Reviewed-by: Robert Griesemer <gri@golang.org>
src/cmd/vet/main.go
src/cmd/vet/rangeloop.go
src/cmd/vet/testdata/rangeloop.go

index ffe988b9fcef4cf28acf553858e20fb36201c51f..f0309cba944c9db8ce16bbf5590adee6ec5ca84d 100644 (file)
@@ -136,6 +136,7 @@ var (
        callExpr      *ast.CallExpr
        compositeLit  *ast.CompositeLit
        exprStmt      *ast.ExprStmt
+       forStmt       *ast.ForStmt
        funcDecl      *ast.FuncDecl
        funcLit       *ast.FuncLit
        genDecl       *ast.GenDecl
@@ -495,6 +496,8 @@ func (f *File) Visit(node ast.Node) ast.Visitor {
                key = compositeLit
        case *ast.ExprStmt:
                key = exprStmt
+       case *ast.ForStmt:
+               key = forStmt
        case *ast.FuncDecl:
                key = funcDecl
        case *ast.FuncLit:
index e085e21a23e060449552ed71ef665e4d0d86c861..53a41364dfe1b84e79d2d0bdfa4ac25694e7b29b 100644 (file)
@@ -25,27 +25,55 @@ import "go/ast"
 
 func init() {
        register("rangeloops",
-               "check that range loop variables are used correctly",
-               checkRangeLoop,
-               rangeStmt)
+               "check that loop variables are used correctly",
+               checkLoop,
+               rangeStmt, forStmt)
 }
 
-// checkRangeLoop walks the body of the provided range statement, checking if
+// checkLoop walks the body of the provided loop statement, checking whether
 // its index or value variables are used unsafely inside goroutines or deferred
 // function literals.
-func checkRangeLoop(f *File, node ast.Node) {
-       n := node.(*ast.RangeStmt)
-       key, _ := n.Key.(*ast.Ident)
-       val, _ := n.Value.(*ast.Ident)
-       if key == nil && val == nil {
+func checkLoop(f *File, node ast.Node) {
+       // Find the variables updated by the loop statement.
+       var vars []*ast.Ident
+       addVar := func(expr ast.Expr) {
+               if id, ok := expr.(*ast.Ident); ok {
+                       vars = append(vars, id)
+               }
+       }
+       var body *ast.BlockStmt
+       switch n := node.(type) {
+       case *ast.RangeStmt:
+               body = n.Body
+               addVar(n.Key)
+               addVar(n.Value)
+       case *ast.ForStmt:
+               body = n.Body
+               switch post := n.Post.(type) {
+               case *ast.AssignStmt:
+                       // e.g. for p = head; p != nil; p = p.next
+                       for _, lhs := range post.Lhs {
+                               addVar(lhs)
+                       }
+               case *ast.IncDecStmt:
+                       // e.g. for i := 0; i < n; i++
+                       addVar(post.X)
+               }
+       }
+       if vars == nil {
                return
        }
-       sl := n.Body.List
-       if len(sl) == 0 {
+
+       // Inspect a go or defer statement
+       // if it's the last one in the loop body.
+       // (We give up if there are following statements,
+       // because it's hard to prove go isn't followed by wait,
+       // or defer by return.)
+       if len(body.List) == 0 {
                return
        }
        var last *ast.CallExpr
-       switch s := sl[len(sl)-1].(type) {
+       switch s := body.List[len(body.List)-1].(type) {
        case *ast.GoStmt:
                last = s.Call
        case *ast.DeferStmt:
@@ -63,11 +91,14 @@ func checkRangeLoop(f *File, node ast.Node) {
                        return true
                }
                if f.pkg.types[id].Type == nil {
-                       // Not referring to a variable
+                       // Not referring to a variable (e.g. struct field name)
                        return true
                }
-               if key != nil && id.Obj == key.Obj || val != nil && id.Obj == val.Obj {
-                       f.Bad(id.Pos(), "range variable", id.Name, "captured by func literal")
+               for _, v := range vars {
+                       if v.Obj == id.Obj {
+                               f.Badf(id.Pos(), "loop variable %s captured by func literal",
+                                       id.Name)
+                       }
                }
                return true
        })
index 66223aad7170b8944e2fd1628b4c88deef509f5b..cd3b4cbc452231cda898b0a17bd9d4afdd774dc3 100644 (file)
@@ -10,24 +10,24 @@ func RangeLoopTests() {
        var s []int
        for i, v := range s {
                go func() {
-                       println(i) // ERROR "range variable i captured by func literal"
-                       println(v) // ERROR "range variable v captured by func literal"
+                       println(i) // ERROR "loop variable i captured by func literal"
+                       println(v) // ERROR "loop variable v captured by func literal"
                }()
        }
        for i, v := range s {
                defer func() {
-                       println(i) // ERROR "range variable i captured by func literal"
-                       println(v) // ERROR "range variable v captured by func literal"
+                       println(i) // ERROR "loop variable i captured by func literal"
+                       println(v) // ERROR "loop variable v captured by func literal"
                }()
        }
        for i := range s {
                go func() {
-                       println(i) // ERROR "range variable i captured by func literal"
+                       println(i) // ERROR "loop variable i captured by func literal"
                }()
        }
        for _, v := range s {
                go func() {
-                       println(v) // ERROR "range variable v captured by func literal"
+                       println(v) // ERROR "loop variable v captured by func literal"
                }()
        }
        for i, v := range s {
@@ -53,7 +53,7 @@ func RangeLoopTests() {
        var f int
        for x[0], f = range s {
                go func() {
-                       _ = f // ERROR "range variable f captured by func literal"
+                       _ = f // ERROR "loop variable f captured by func literal"
                }()
        }
        type T struct {
@@ -62,7 +62,29 @@ func RangeLoopTests() {
        for _, v := range s {
                go func() {
                        _ = T{v: 1}
-                       _ = []int{v: 1} // ERROR "range variable v captured by func literal"
+                       _ = []int{v: 1} // ERROR "loop variable v captured by func literal"
+               }()
+       }
+
+       // ordinary for-loops
+       for i := 0; i < 10; i++ {
+               go func() {
+                       print(i) // ERROR "loop variable i captured by func literal"
+               }()
+       }
+       for i, j := 0, 1; i < 100; i, j = j, i+j {
+               go func() {
+                       print(j) // ERROR "loop variable j captured by func literal"
+               }()
+       }
+       type cons struct {
+               car int
+               cdr *cons
+       }
+       var head *cons
+       for p := head; p != nil; p = p.next {
+               go func() {
+                       print(p.car) // ERROR "loop variable p captured by func literal"
                }()
        }
 }