]> Cypherpunks repositories - gostls13.git/commitdiff
vet: add range variable misuse detection
authorAndrew Gerrand <adg@golang.org>
Tue, 18 Sep 2012 21:19:31 +0000 (14:19 -0700)
committerAndrew Gerrand <adg@golang.org>
Tue, 18 Sep 2012 21:19:31 +0000 (14:19 -0700)
R=fullung, r, remyoudompheng, minux.ma, gri, rsc
CC=golang-dev
https://golang.org/cl/6494075

src/cmd/vet/Makefile
src/cmd/vet/main.go
src/cmd/vet/rangeloop.go [new file with mode: 0644]

index 2a35d1ae37fc6d9c2b356699d94fe08108034157..d90b5f9d54b2afb1705be42a97cbe29f01ab7857 100644 (file)
@@ -4,4 +4,4 @@
 
 test testshort:
        go build
-       ../../../test/errchk ./vet -printfuncs='Warn:1,Warnf:1' print.go
+       ../../../test/errchk ./vet -printfuncs='Warn:1,Warnf:1' print.go rangeloop.go
index d2a7c6e55b10083e9bb820370196ba4f8ec6d746..76a4896bfa0ed96902d71dad108977cd264c65fb 100644 (file)
@@ -30,6 +30,7 @@ var (
        vetPrintf          = flag.Bool("printf", false, "check printf-like invocations")
        vetStructTags      = flag.Bool("structtags", false, "check that struct field tags have canonical format")
        vetUntaggedLiteral = flag.Bool("composites", false, "check that composite literals used type-tagged elements")
+       vetRangeLoops      = flag.Bool("rangeloops", false, "check that range loop variables are used correctly")
 )
 
 // setExit sets the value for os.Exit when it is called, later.  It
@@ -60,7 +61,7 @@ func main() {
        flag.Parse()
 
        // If a check is named explicitly, turn off the 'all' flag.
-       if *vetMethods || *vetPrintf || *vetStructTags || *vetUntaggedLiteral {
+       if *vetMethods || *vetPrintf || *vetStructTags || *vetUntaggedLiteral || *vetRangeLoops {
                *vetAll = false
        }
 
@@ -197,6 +198,8 @@ func (f *File) Visit(node ast.Node) ast.Visitor {
                f.walkMethodDecl(n)
        case *ast.InterfaceType:
                f.walkInterfaceType(n)
+       case *ast.RangeStmt:
+               f.walkRangeStmt(n)
        }
        return f
 }
@@ -206,6 +209,16 @@ func (f *File) walkCall(call *ast.CallExpr, name string) {
        f.checkFmtPrintfCall(call, name)
 }
 
+// walkCallExpr walks a call expression.
+func (f *File) walkCallExpr(call *ast.CallExpr) {
+       switch x := call.Fun.(type) {
+       case *ast.Ident:
+               f.walkCall(call, x.Name)
+       case *ast.SelectorExpr:
+               f.walkCall(call, x.Sel.Name)
+       }
+}
+
 // walkCompositeLit walks a composite literal.
 func (f *File) walkCompositeLit(c *ast.CompositeLit) {
        f.checkUntaggedLiteral(c)
@@ -242,12 +255,7 @@ func (f *File) walkInterfaceType(t *ast.InterfaceType) {
        }
 }
 
-// walkCallExpr walks a call expression.
-func (f *File) walkCallExpr(call *ast.CallExpr) {
-       switch x := call.Fun.(type) {
-       case *ast.Ident:
-               f.walkCall(call, x.Name)
-       case *ast.SelectorExpr:
-               f.walkCall(call, x.Sel.Name)
-       }
+// walkRangeStmt walks a range statment.
+func (f *File) walkRangeStmt(n *ast.RangeStmt) {
+       checkRangeLoop(f, n)
 }
diff --git a/src/cmd/vet/rangeloop.go b/src/cmd/vet/rangeloop.go
new file mode 100644 (file)
index 0000000..2fdb0b6
--- /dev/null
@@ -0,0 +1,104 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+This file contains the code to check range loop variables bound inside function
+literals that are deferred or launched in new goroutines. We only check
+instances where the defer or go statement is the last statement in the loop
+body, as otherwise we would need whole program analysis.
+
+For example:
+
+       for i, v := range s {
+               go func() {
+                       println(i, v) // not what you might expect
+               }()
+       }
+
+See: http://golang.org/doc/go_faq.html#closures_and_goroutines
+*/
+
+package main
+
+import "go/ast"
+
+// checkRangeLoop walks the body of the provided range statement, checking if
+// its index or value variables are used unsafely inside goroutines or deferred
+// function literals.
+func checkRangeLoop(f *File, n *ast.RangeStmt) {
+       if !*vetRangeLoops && !*vetAll {
+               return
+       }
+       key, _ := n.Key.(*ast.Ident)
+       val, _ := n.Value.(*ast.Ident)
+       if key == nil && val == nil {
+               return
+       }
+       sl := n.Body.List
+       if len(sl) == 0 {
+               return
+       }
+       var last *ast.CallExpr
+       switch s := sl[len(sl)-1].(type) {
+       case *ast.GoStmt:
+               last = s.Call
+       case *ast.DeferStmt:
+               last = s.Call
+       default:
+               return
+       }
+       lit, ok := last.Fun.(*ast.FuncLit)
+       if !ok {
+               return
+       }
+       ast.Inspect(lit.Body, func(n ast.Node) bool {
+               if n, ok := n.(*ast.Ident); ok && n.Obj != nil && (n.Obj == key.Obj || n.Obj == val.Obj) {
+                       f.Warn(n.Pos(), "range variable", n.Name, "enclosed by function")
+               }
+               return true
+       })
+}
+
+func BadRangeLoopsUsedInTests() {
+       var s []int
+       for i, v := range s {
+               go func() {
+                       println(i) // ERROR "range variable i enclosed by function"
+                       println(v) // ERROR "range variable v enclosed by function"
+               }()
+       }
+       for i, v := range s {
+               defer func() {
+                       println(i) // ERROR "range variable i enclosed by function"
+                       println(v) // ERROR "range variable v enclosed by function"
+               }()
+       }
+       for i := range s {
+               go func() {
+                       println(i) // ERROR "range variable i enclosed by function"
+               }()
+       }
+       for _, v := range s {
+               go func() {
+                       println(v) // ERROR "range variable v enclosed by function"
+               }()
+       }
+       for i, v := range s {
+               go func() {
+                       println(i, v)
+               }()
+               println("unfortunately, we don't catch the error above because of this statement")
+       }
+       for i, v := range s {
+               go func(i, v int) {
+                       println(i, v)
+               }(i, v)
+       }
+       for i, v := range s {
+               i, v := i, v
+               go func() {
+                       println(i, v)
+               }()
+       }
+}