]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/vet: detect misuse of atomic.Add*
authorRodrigo Rafael Monti Kochenburger <divoxx@gmail.com>
Wed, 30 Jan 2013 15:57:11 +0000 (07:57 -0800)
committerRuss Cox <rsc@golang.org>
Wed, 30 Jan 2013 15:57:11 +0000 (07:57 -0800)
Re-assigning the return value of an atomic operation to the same variable being operated is a common mistake:

x = atomic.AddUint64(&x, 1)

Add this check to go vet.

Fixes #4065.

R=dvyukov, golang-dev, remyoudompheng, rsc
CC=golang-dev
https://golang.org/cl/7097048

src/cmd/vet/Makefile
src/cmd/vet/atomic.go [new file with mode: 0644]
src/cmd/vet/main.go

index d90b5f9d54b2afb1705be42a97cbe29f01ab7857..2be9f66426f000092ab1658a8d6b6a8cd97bf4c3 100644 (file)
@@ -4,4 +4,4 @@
 
 test testshort:
        go build
-       ../../../test/errchk ./vet -printfuncs='Warn:1,Warnf:1' print.go rangeloop.go
+       ../../../test/errchk ./vet -printfuncs='Warn:1,Warnf:1' print.go rangeloop.go atomic.go
diff --git a/src/cmd/vet/atomic.go b/src/cmd/vet/atomic.go
new file mode 100644 (file)
index 0000000..7a76e9b
--- /dev/null
@@ -0,0 +1,90 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "go/ast"
+       "go/token"
+       "sync/atomic"
+)
+
+// checkAtomicAssignment walks the assignment statement checking for comomon
+// mistaken usage of atomic package, such as: x = atomic.AddUint64(&x, 1)
+func (f *File) checkAtomicAssignment(n *ast.AssignStmt) {
+       if !*vetAtomic && !*vetAll {
+               return
+       }
+
+       if len(n.Lhs) != len(n.Rhs) {
+               return
+       }
+
+       for i, right := range n.Rhs {
+               call, ok := right.(*ast.CallExpr)
+               if !ok {
+                       continue
+               }
+               sel, ok := call.Fun.(*ast.SelectorExpr)
+               if !ok {
+                       continue
+               }
+               pkg, ok := sel.X.(*ast.Ident)
+               if !ok || pkg.Name != "atomic" {
+                       continue
+               }
+
+               switch sel.Sel.Name {
+               case "AddInt32", "AddInt64", "AddUint32", "AddUint64", "AddUintptr":
+                       f.checkAtomicAddAssignment(n.Lhs[i], call)
+               }
+       }
+}
+
+// checkAtomicAddAssignment walks the atomic.Add* method calls checking for assigning the return value
+// to the same variable being used in the operation
+func (f *File) checkAtomicAddAssignment(left ast.Expr, call *ast.CallExpr) {
+       arg := call.Args[0]
+       broken := false
+
+       if uarg, ok := arg.(*ast.UnaryExpr); ok && uarg.Op == token.AND {
+               broken = f.gofmt(left) == f.gofmt(uarg.X)
+       } else if star, ok := left.(*ast.StarExpr); ok {
+               broken = f.gofmt(star.X) == f.gofmt(arg)
+       }
+
+       if broken {
+               f.Warn(left.Pos(), "direct assignment to atomic value")
+       }
+}
+
+type Counter uint64
+
+func BadAtomicAssignmentUsedInTests() {
+       x := uint64(1)
+       x = atomic.AddUint64(&x, 1)        // ERROR "direct assignment to atomic value"
+       _, x = 10, atomic.AddUint64(&x, 1) // ERROR "direct assignment to atomic value"
+       x, _ = atomic.AddUint64(&x, 1), 10 // ERROR "direct assignment to atomic value"
+
+       y := &x
+       *y = atomic.AddUint64(y, 1) // ERROR "direct assignment to atomic value"
+
+       var su struct{ Counter uint64 }
+       su.Counter = atomic.AddUint64(&su.Counter, 1) // ERROR "direct assignment to atomic value"
+       z1 := atomic.AddUint64(&su.Counter, 1)
+       _ = z1 // Avoid err "z declared and not used"
+
+       var sp struct{ Counter *uint64 }
+       *sp.Counter = atomic.AddUint64(sp.Counter, 1) // ERROR "direct assignment to atomic value"
+       z2 := atomic.AddUint64(sp.Counter, 1)
+       _ = z2 // Avoid err "z declared and not used"
+
+       au := []uint64{10, 20}
+       au[0] = atomic.AddUint64(&au[0], 1) // ERROR "direct assignment to atomic value"
+       au[1] = atomic.AddUint64(&au[0], 1)
+
+       ap := []*uint64{&au[0], &au[1]}
+       *ap[0] = atomic.AddUint64(ap[0], 1) // ERROR "direct assignment to atomic value"
+       *ap[1] = atomic.AddUint64(ap[0], 1)
+}
index ec751972cfac68ce8de4f22d23cbc01eec10d0f4..bfab52626813cb788e916c370217444b13f2f2e2 100644 (file)
@@ -12,6 +12,7 @@ import (
        "fmt"
        "go/ast"
        "go/parser"
+       "go/printer"
        "go/token"
        "io"
        "os"
@@ -31,6 +32,7 @@ var (
        vetStructTags      = flag.Bool("structtags", false, "check that struct field tags have canonical format")
        vetUntaggedLiteral = flag.Bool("composites", false, "check that composite literals used type-tagged elements")
        vetRangeLoops      = flag.Bool("rangeloops", false, "check that range loop variables are used correctly")
+       vetAtomic          = flag.Bool("atomic", false, "check for common mistaken usages of the sync/atomic package")
 )
 
 // setExit sets the value for os.Exit when it is called, later.  It
@@ -188,6 +190,8 @@ func (f *File) walkFile(name string, file *ast.File) {
 // Visit implements the ast.Visitor interface.
 func (f *File) Visit(node ast.Node) ast.Visitor {
        switch n := node.(type) {
+       case *ast.AssignStmt:
+               f.walkAssignStmt(n)
        case *ast.CallExpr:
                f.walkCallExpr(n)
        case *ast.CompositeLit:
@@ -204,6 +208,11 @@ func (f *File) Visit(node ast.Node) ast.Visitor {
        return f
 }
 
+// walkCall walks an assignment statement
+func (f *File) walkAssignStmt(stmt *ast.AssignStmt) {
+       f.checkAtomicAssignment(stmt)
+}
+
 // walkCall walks a call expression.
 func (f *File) walkCall(call *ast.CallExpr, name string) {
        f.checkFmtPrintfCall(call, name)
@@ -259,3 +268,10 @@ func (f *File) walkInterfaceType(t *ast.InterfaceType) {
 func (f *File) walkRangeStmt(n *ast.RangeStmt) {
        checkRangeLoop(f, n)
 }
+
+// goFmt returns a string representation of the expression
+func (f *File) gofmt(x ast.Expr) string {
+       f.b.Reset()
+       printer.Fprint(&f.b, f.fset, x)
+       return f.b.String()
+}