]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/vet: -lostcancel: check for discarded result of context.WithCancel
authorAlan Donovan <adonovan@google.com>
Wed, 15 Jun 2016 03:50:39 +0000 (23:50 -0400)
committerAlan Donovan <adonovan@google.com>
Tue, 21 Jun 2016 14:58:33 +0000 (14:58 +0000)
The cfg subpackage builds a control flow graph of ast.Nodes.
The lostcancel module checks this graph to find paths, from a call to
WithCancel to a return statement, on which the cancel variable is
not used.  (A trivial case is simply assigning the cancel result to
the blank identifier.)

In a sample of 50,000 source files, the new check found 2068 blank
assignments and 118 return-before-cancel errors.  I manually inspected
20 of the latter and didn't find a single false positive among them.

Change-Id: I84cd49445f9f8d04908b04881eb1496a96611205
Reviewed-on: https://go-review.googlesource.com/24150
Reviewed-by: Robert Griesemer <gri@golang.org>
Reviewed-by: Rob Pike <r@golang.org>
src/cmd/vet/deadcode.go
src/cmd/vet/doc.go
src/cmd/vet/internal/cfg/builder.go [new file with mode: 0644]
src/cmd/vet/internal/cfg/cfg.go [new file with mode: 0644]
src/cmd/vet/internal/cfg/cfg_test.go [new file with mode: 0644]
src/cmd/vet/lostcancel.go [new file with mode: 0644]
src/cmd/vet/testdata/lostcancel.go [new file with mode: 0644]

index abede47a458b554ae36e53d22652fcfb285423e7..b1077aef38fdc4737a4dcec9586e4afd5db5c91f 100644 (file)
@@ -29,6 +29,8 @@ type deadState struct {
 }
 
 // checkUnreachable checks a function body for dead code.
+//
+// TODO(adonovan): use the new cfg package, which is more precise.
 func checkUnreachable(f *File, node ast.Node) {
        var body *ast.BlockStmt
        switch n := node.(type) {
index c697f3bc36269647a695ff44f78bb3be3f5d3cb9..bb8dcf171f136771398d5e0d006d2f9c9bb0d7ef 100644 (file)
@@ -91,6 +91,15 @@ Flag: -tests
 Mistakes involving tests including functions with incorrect names or signatures
 and example tests that document identifiers not in the package.
 
+Failure to call the cancellation function returned by context.WithCancel.
+
+Flag: -lostcancel
+
+The cancellation function returned by context.WithCancel, WithTimeout,
+and WithDeadline must be called or the new context will remain live
+until its parent context is cancelled.
+(The background context is never cancelled.)
+
 Methods
 
 Flag: -methods
diff --git a/src/cmd/vet/internal/cfg/builder.go b/src/cmd/vet/internal/cfg/builder.go
new file mode 100644 (file)
index 0000000..79c906b
--- /dev/null
@@ -0,0 +1,510 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cfg
+
+// This file implements the CFG construction pass.
+
+import (
+       "fmt"
+       "go/ast"
+       "go/token"
+)
+
+type builder struct {
+       cfg       *CFG
+       mayReturn func(*ast.CallExpr) bool
+       current   *Block
+       lblocks   map[*ast.Object]*lblock // labeled blocks
+       targets   *targets                // linked stack of branch targets
+}
+
+func (b *builder) stmt(_s ast.Stmt) {
+       // The label of the current statement.  If non-nil, its _goto
+       // target is always set; its _break and _continue are set only
+       // within the body of switch/typeswitch/select/for/range.
+       // It is effectively an additional default-nil parameter of stmt().
+       var label *lblock
+start:
+       switch s := _s.(type) {
+       case *ast.BadStmt,
+               *ast.SendStmt,
+               *ast.IncDecStmt,
+               *ast.GoStmt,
+               *ast.DeferStmt,
+               *ast.EmptyStmt,
+               *ast.AssignStmt:
+               // No effect on control flow.
+               b.add(s)
+
+       case *ast.ExprStmt:
+               b.add(s)
+               if call, ok := s.X.(*ast.CallExpr); ok && !b.mayReturn(call) {
+                       // Calls to panic, os.Exit, etc, never return.
+                       b.current = b.newUnreachableBlock("unreachable.call")
+               }
+
+       case *ast.DeclStmt:
+               // Treat each var ValueSpec as a separate statement.
+               d := s.Decl.(*ast.GenDecl)
+               if d.Tok == token.VAR {
+                       for _, spec := range d.Specs {
+                               if spec, ok := spec.(*ast.ValueSpec); ok {
+                                       b.add(spec)
+                               }
+                       }
+               }
+
+       case *ast.LabeledStmt:
+               label = b.labeledBlock(s.Label)
+               b.jump(label._goto)
+               b.current = label._goto
+               _s = s.Stmt
+               goto start // effectively: tailcall stmt(g, s.Stmt, label)
+
+       case *ast.ReturnStmt:
+               b.add(s)
+               b.current = b.newUnreachableBlock("unreachable.return")
+
+       case *ast.BranchStmt:
+               var block *Block
+               switch s.Tok {
+               case token.BREAK:
+                       if s.Label != nil {
+                               if lb := b.labeledBlock(s.Label); lb != nil {
+                                       block = lb._break
+                               }
+                       } else {
+                               for t := b.targets; t != nil && block == nil; t = t.tail {
+                                       block = t._break
+                               }
+                       }
+
+               case token.CONTINUE:
+                       if s.Label != nil {
+                               if lb := b.labeledBlock(s.Label); lb != nil {
+                                       block = lb._continue
+                               }
+                       } else {
+                               for t := b.targets; t != nil && block == nil; t = t.tail {
+                                       block = t._continue
+                               }
+                       }
+
+               case token.FALLTHROUGH:
+                       for t := b.targets; t != nil; t = t.tail {
+                               block = t._fallthrough
+                       }
+
+               case token.GOTO:
+                       block = b.labeledBlock(s.Label)._goto
+               }
+               if block == nil {
+                       block = b.newBlock("undefined.branch")
+               }
+               b.jump(block)
+               b.current = b.newUnreachableBlock("unreachable.branch")
+
+       case *ast.BlockStmt:
+               b.stmtList(s.List)
+
+       case *ast.IfStmt:
+               if s.Init != nil {
+                       b.stmt(s.Init)
+               }
+               then := b.newBlock("if.then")
+               done := b.newBlock("if.done")
+               _else := done
+               if s.Else != nil {
+                       _else = b.newBlock("if.else")
+               }
+               b.add(s.Cond)
+               b.ifelse(then, _else)
+               b.current = then
+               b.stmt(s.Body)
+               b.jump(done)
+
+               if s.Else != nil {
+                       b.current = _else
+                       b.stmt(s.Else)
+                       b.jump(done)
+               }
+
+               b.current = done
+
+       case *ast.SwitchStmt:
+               b.switchStmt(s, label)
+
+       case *ast.TypeSwitchStmt:
+               b.typeSwitchStmt(s, label)
+
+       case *ast.SelectStmt:
+               b.selectStmt(s, label)
+
+       case *ast.ForStmt:
+               b.forStmt(s, label)
+
+       case *ast.RangeStmt:
+               b.rangeStmt(s, label)
+
+       default:
+               panic(fmt.Sprintf("unexpected statement kind: %T", s))
+       }
+}
+
+func (b *builder) stmtList(list []ast.Stmt) {
+       for _, s := range list {
+               b.stmt(s)
+       }
+}
+
+func (b *builder) switchStmt(s *ast.SwitchStmt, label *lblock) {
+       if s.Init != nil {
+               b.stmt(s.Init)
+       }
+       if s.Tag != nil {
+               b.add(s.Tag)
+       }
+       done := b.newBlock("switch.done")
+       if label != nil {
+               label._break = done
+       }
+       // We pull the default case (if present) down to the end.
+       // But each fallthrough label must point to the next
+       // body block in source order, so we preallocate a
+       // body block (fallthru) for the next case.
+       // Unfortunately this makes for a confusing block order.
+       var defaultBody *[]ast.Stmt
+       var defaultFallthrough *Block
+       var fallthru, defaultBlock *Block
+       ncases := len(s.Body.List)
+       for i, clause := range s.Body.List {
+               body := fallthru
+               if body == nil {
+                       body = b.newBlock("switch.body") // first case only
+               }
+
+               // Preallocate body block for the next case.
+               fallthru = done
+               if i+1 < ncases {
+                       fallthru = b.newBlock("switch.body")
+               }
+
+               cc := clause.(*ast.CaseClause)
+               if cc.List == nil {
+                       // Default case.
+                       defaultBody = &cc.Body
+                       defaultFallthrough = fallthru
+                       defaultBlock = body
+                       continue
+               }
+
+               var nextCond *Block
+               for _, cond := range cc.List {
+                       nextCond = b.newBlock("switch.next")
+                       b.add(cond) // one half of the tag==cond condition
+                       b.ifelse(body, nextCond)
+                       b.current = nextCond
+               }
+               b.current = body
+               b.targets = &targets{
+                       tail:         b.targets,
+                       _break:       done,
+                       _fallthrough: fallthru,
+               }
+               b.stmtList(cc.Body)
+               b.targets = b.targets.tail
+               b.jump(done)
+               b.current = nextCond
+       }
+       if defaultBlock != nil {
+               b.jump(defaultBlock)
+               b.current = defaultBlock
+               b.targets = &targets{
+                       tail:         b.targets,
+                       _break:       done,
+                       _fallthrough: defaultFallthrough,
+               }
+               b.stmtList(*defaultBody)
+               b.targets = b.targets.tail
+       }
+       b.jump(done)
+       b.current = done
+}
+
+func (b *builder) typeSwitchStmt(s *ast.TypeSwitchStmt, label *lblock) {
+       if s.Init != nil {
+               b.stmt(s.Init)
+       }
+       if s.Assign != nil {
+               b.add(s.Assign)
+       }
+
+       done := b.newBlock("typeswitch.done")
+       if label != nil {
+               label._break = done
+       }
+       var default_ *ast.CaseClause
+       for _, clause := range s.Body.List {
+               cc := clause.(*ast.CaseClause)
+               if cc.List == nil {
+                       default_ = cc
+                       continue
+               }
+               body := b.newBlock("typeswitch.body")
+               var next *Block
+               for _, casetype := range cc.List {
+                       next = b.newBlock("typeswitch.next")
+                       // casetype is a type, so don't call b.add(casetype).
+                       // This block logically contains a type assertion,
+                       // x.(casetype), but it's unclear how to represent x.
+                       _ = casetype
+                       b.ifelse(body, next)
+                       b.current = next
+               }
+               b.current = body
+               b.typeCaseBody(cc, done)
+               b.current = next
+       }
+       if default_ != nil {
+               b.typeCaseBody(default_, done)
+       } else {
+               b.jump(done)
+       }
+       b.current = done
+}
+
+func (b *builder) typeCaseBody(cc *ast.CaseClause, done *Block) {
+       b.targets = &targets{
+               tail:   b.targets,
+               _break: done,
+       }
+       b.stmtList(cc.Body)
+       b.targets = b.targets.tail
+       b.jump(done)
+}
+
+func (b *builder) selectStmt(s *ast.SelectStmt, label *lblock) {
+       // First evaluate channel expressions.
+       // TODO(adonovan): fix: evaluate only channel exprs here.
+       for _, clause := range s.Body.List {
+               if comm := clause.(*ast.CommClause).Comm; comm != nil {
+                       b.stmt(comm)
+               }
+       }
+
+       done := b.newBlock("select.done")
+       if label != nil {
+               label._break = done
+       }
+
+       var defaultBody *[]ast.Stmt
+       for _, cc := range s.Body.List {
+               clause := cc.(*ast.CommClause)
+               if clause.Comm == nil {
+                       defaultBody = &clause.Body
+                       continue
+               }
+               body := b.newBlock("select.body")
+               next := b.newBlock("select.next")
+               b.ifelse(body, next)
+               b.current = body
+               b.targets = &targets{
+                       tail:   b.targets,
+                       _break: done,
+               }
+               switch comm := clause.Comm.(type) {
+               case *ast.ExprStmt: // <-ch
+                       // nop
+               case *ast.AssignStmt: // x := <-states[state].Chan
+                       b.add(comm.Lhs[0])
+               }
+               b.stmtList(clause.Body)
+               b.targets = b.targets.tail
+               b.jump(done)
+               b.current = next
+       }
+       if defaultBody != nil {
+               b.targets = &targets{
+                       tail:   b.targets,
+                       _break: done,
+               }
+               b.stmtList(*defaultBody)
+               b.targets = b.targets.tail
+               b.jump(done)
+       }
+       b.current = done
+}
+
+func (b *builder) forStmt(s *ast.ForStmt, label *lblock) {
+       //      ...init...
+       //      jump loop
+       // loop:
+       //      if cond goto body else done
+       // body:
+       //      ...body...
+       //      jump post
+       // post:                                 (target of continue)
+       //      ...post...
+       //      jump loop
+       // done:                                 (target of break)
+       if s.Init != nil {
+               b.stmt(s.Init)
+       }
+       body := b.newBlock("for.body")
+       done := b.newBlock("for.done") // target of 'break'
+       loop := body                   // target of back-edge
+       if s.Cond != nil {
+               loop = b.newBlock("for.loop")
+       }
+       cont := loop // target of 'continue'
+       if s.Post != nil {
+               cont = b.newBlock("for.post")
+       }
+       if label != nil {
+               label._break = done
+               label._continue = cont
+       }
+       b.jump(loop)
+       b.current = loop
+       if loop != body {
+               b.add(s.Cond)
+               b.ifelse(body, done)
+               b.current = body
+       }
+       b.targets = &targets{
+               tail:      b.targets,
+               _break:    done,
+               _continue: cont,
+       }
+       b.stmt(s.Body)
+       b.targets = b.targets.tail
+       b.jump(cont)
+
+       if s.Post != nil {
+               b.current = cont
+               b.stmt(s.Post)
+               b.jump(loop) // back-edge
+       }
+       b.current = done
+}
+
+func (b *builder) rangeStmt(s *ast.RangeStmt, label *lblock) {
+       b.add(s.X)
+
+       if s.Key != nil {
+               b.add(s.Key)
+       }
+       if s.Value != nil {
+               b.add(s.Value)
+       }
+
+       //      ...
+       // loop:                                   (target of continue)
+       //      if ... goto body else done
+       // body:
+       //      ...
+       //      jump loop
+       // done:                                   (target of break)
+
+       loop := b.newBlock("range.loop")
+       b.jump(loop)
+       b.current = loop
+
+       body := b.newBlock("range.body")
+       done := b.newBlock("range.done")
+       b.ifelse(body, done)
+       b.current = body
+
+       if label != nil {
+               label._break = done
+               label._continue = loop
+       }
+       b.targets = &targets{
+               tail:      b.targets,
+               _break:    done,
+               _continue: loop,
+       }
+       b.stmt(s.Body)
+       b.targets = b.targets.tail
+       b.jump(loop) // back-edge
+       b.current = done
+}
+
+// -------- helpers --------
+
+// Destinations associated with unlabeled for/switch/select stmts.
+// We push/pop one of these as we enter/leave each construct and for
+// each BranchStmt we scan for the innermost target of the right type.
+//
+type targets struct {
+       tail         *targets // rest of stack
+       _break       *Block
+       _continue    *Block
+       _fallthrough *Block
+}
+
+// Destinations associated with a labeled block.
+// We populate these as labels are encountered in forward gotos or
+// labeled statements.
+//
+type lblock struct {
+       _goto     *Block
+       _break    *Block
+       _continue *Block
+}
+
+// labeledBlock returns the branch target associated with the
+// specified label, creating it if needed.
+//
+func (b *builder) labeledBlock(label *ast.Ident) *lblock {
+       lb := b.lblocks[label.Obj]
+       if lb == nil {
+               lb = &lblock{_goto: b.newBlock(label.Name)}
+               if b.lblocks == nil {
+                       b.lblocks = make(map[*ast.Object]*lblock)
+               }
+               b.lblocks[label.Obj] = lb
+       }
+       return lb
+}
+
+// newBlock appends a new unconnected basic block to b.cfg's block
+// slice and returns it.
+// It does not automatically become the current block.
+// comment is an optional string for more readable debugging output.
+func (b *builder) newBlock(comment string) *Block {
+       g := b.cfg
+       block := &Block{
+               index:   int32(len(g.Blocks)),
+               comment: comment,
+       }
+       block.Succs = block.succs2[:0]
+       g.Blocks = append(g.Blocks, block)
+       return block
+}
+
+func (b *builder) newUnreachableBlock(comment string) *Block {
+       block := b.newBlock(comment)
+       block.unreachable = true
+       return block
+}
+
+func (b *builder) add(n ast.Node) {
+       b.current.Nodes = append(b.current.Nodes, n)
+}
+
+// jump adds an edge from the current block to the target block,
+// and sets b.current to nil.
+func (b *builder) jump(target *Block) {
+       b.current.Succs = append(b.current.Succs, target)
+       b.current = nil
+}
+
+// ifelse emits edges from the current block to the t and f blocks,
+// and sets b.current to nil.
+func (b *builder) ifelse(t, f *Block) {
+       b.current.Succs = append(b.current.Succs, t, f)
+       b.current = nil
+}
diff --git a/src/cmd/vet/internal/cfg/cfg.go b/src/cmd/vet/internal/cfg/cfg.go
new file mode 100644 (file)
index 0000000..e4d5bfe
--- /dev/null
@@ -0,0 +1,142 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package constructs a simple control-flow graph (CFG) of the
+// statements and expressions within a single function.
+//
+// Use cfg.New to construct the CFG for a function body.
+//
+// The blocks of the CFG contain all the function's non-control
+// statements.  The CFG does not contain control statements such as If,
+// Switch, Select, and Branch, but does contain their subexpressions.
+// For example, this source code:
+//
+//     if x := f(); x != nil {
+//             T()
+//     } else {
+//             F()
+//     }
+//
+// produces this CFG:
+//
+//    1:  x := f()
+//        x != nil
+//        succs: 2, 3
+//    2:  T()
+//        succs: 4
+//    3:  F()
+//        succs: 4
+//    4:
+//
+// The CFG does contain Return statements; even implicit returns are
+// materialized (at the position of the function's closing brace).
+//
+// The CFG does not record conditions associated with conditional branch
+// edges, nor the short-circuit semantics of the && and || operators,
+// nor abnormal control flow caused by panic.  If you need this
+// information, use golang.org/x/tools/go/ssa instead.
+//
+package cfg
+
+// Although the vet tool has type information, it is often extremely
+// fragmentary, so for simplicity this package does not depend on
+// go/types.  Consequently control-flow conditions are ignored even
+// when constant, and "mayReturn" information must be provided by the
+// client.
+import (
+       "bytes"
+       "fmt"
+       "go/ast"
+       "go/format"
+       "go/token"
+)
+
+// A CFG represents the control-flow graph of a single function.
+//
+// The entry point is Blocks[0]; there may be multiple return blocks.
+type CFG struct {
+       Blocks []*Block // block[0] is entry; order otherwise undefined
+}
+
+// A Block represents a basic block: a list of statements and
+// expressions that are always evaluated sequentially.
+//
+// A block may have 0-2 successors: zero for a return block or a block
+// that calls a function such as panic that never returns; one for a
+// normal (jump) block; and two for a conditional (if) block.
+type Block struct {
+       Nodes []ast.Node // statements, expressions, and ValueSpecs
+       Succs []*Block   // successor nodes in the graph
+
+       comment     string    // for debugging
+       index       int32     // index within CFG.Blocks
+       unreachable bool      // is block of stmts following return/panic/for{}
+       succs2      [2]*Block // underlying array for Succs
+}
+
+// New returns a new control-flow graph for the specified function body,
+// which must be non-nil.
+//
+// The CFG builder calls mayReturn to determine whether a given function
+// call may return.  For example, calls to panic, os.Exit, and log.Fatal
+// do not return, so the builder can remove infeasible graph edges
+// following such calls.  The builder calls mayReturn only for a
+// CallExpr beneath an ExprStmt.
+func New(body *ast.BlockStmt, mayReturn func(*ast.CallExpr) bool) *CFG {
+       b := builder{
+               mayReturn: mayReturn,
+               cfg:       new(CFG),
+       }
+       b.current = b.newBlock("entry")
+       b.stmt(body)
+
+       // Does control fall off the end of the function's body?
+       // Make implicit return explicit.
+       if b.current != nil && !b.current.unreachable {
+               b.add(&ast.ReturnStmt{
+                       Return: body.End() - 1,
+               })
+       }
+
+       return b.cfg
+}
+
+func (b *Block) String() string {
+       return fmt.Sprintf("block %d (%s)", b.index, b.comment)
+}
+
+// Return returns the return statement at the end of this block if present, nil otherwise.
+func (b *Block) Return() (ret *ast.ReturnStmt) {
+       if len(b.Nodes) > 0 {
+               ret, _ = b.Nodes[len(b.Nodes)-1].(*ast.ReturnStmt)
+       }
+       return
+}
+
+// Format formats the control-flow graph for ease of debugging.
+func (g *CFG) Format(fset *token.FileSet) string {
+       var buf bytes.Buffer
+       for _, b := range g.Blocks {
+               fmt.Fprintf(&buf, ".%d: # %s\n", b.index, b.comment)
+               for _, n := range b.Nodes {
+                       fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n))
+               }
+               if len(b.Succs) > 0 {
+                       fmt.Fprintf(&buf, "\tsuccs:")
+                       for _, succ := range b.Succs {
+                               fmt.Fprintf(&buf, " %d", succ.index)
+                       }
+                       buf.WriteByte('\n')
+               }
+               buf.WriteByte('\n')
+       }
+       return buf.String()
+}
+
+func formatNode(fset *token.FileSet, n ast.Node) string {
+       var buf bytes.Buffer
+       format.Node(&buf, fset, n)
+       // Indent secondary lines by a tab.
+       return string(bytes.Replace(buf.Bytes(), []byte("\n"), []byte("\n\t"), -1))
+}
diff --git a/src/cmd/vet/internal/cfg/cfg_test.go b/src/cmd/vet/internal/cfg/cfg_test.go
new file mode 100644 (file)
index 0000000..5d98f13
--- /dev/null
@@ -0,0 +1,184 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cfg
+
+import (
+       "bytes"
+       "fmt"
+       "go/ast"
+       "go/parser"
+       "go/token"
+       "testing"
+)
+
+const src = `package main
+
+import "log"
+
+func f1() {
+       live()
+       return
+       dead()
+}
+
+func f2() {
+       for {
+               live()
+       }
+       dead()
+}
+
+func f3() {
+       if true { // even known values are ignored
+               return
+       }
+       for true { // even known values are ignored
+               live()
+       }
+       for {
+               live()
+       }
+       dead()
+}
+
+func f4(x int) {
+       switch x {
+       case 1:
+               live()
+               fallthrough
+       case 2:
+               live()
+               log.Fatal()
+       default:
+               panic("oops")
+       }
+       dead()
+}
+
+func f4(ch chan int) {
+       select {
+       case <-ch:
+               live()
+               return
+       default:
+               live()
+               panic("oops")
+       }
+       dead()
+}
+
+func f5(unknown bool) {
+       for {
+               if unknown {
+                       break
+               }
+               continue
+               dead()
+       }
+       live()
+}
+
+func f6(unknown bool) {
+outer:
+       for {
+               for {
+                       break outer
+                       dead()
+               }
+               dead()
+       }
+       live()
+}
+
+func f7() {
+       for {
+               break nosuchlabel
+               dead()
+       }
+       dead()
+}
+
+func f8() {
+       select{}
+       dead()
+}
+
+func f9(ch chan int) {
+       select {
+       case <-ch:
+               return
+       }
+       dead()
+}
+
+func f10(ch chan int) {
+       select {
+       case <-ch:
+               return
+               dead()
+       default:
+       }
+       live()
+}
+`
+
+func TestDeadCode(t *testing.T) {
+       // We'll use dead code detection to verify the CFG.
+
+       fset := token.NewFileSet()
+       f, err := parser.ParseFile(fset, "dummy.go", src, parser.Mode(0))
+       if err != nil {
+               t.Fatal(err)
+       }
+       for _, decl := range f.Decls {
+               if decl, ok := decl.(*ast.FuncDecl); ok {
+                       g := New(decl.Body, mayReturn)
+
+                       // Mark blocks reachable from entry.
+                       live := make(map[*Block]bool)
+                       var visit func(*Block)
+                       visit = func(b *Block) {
+                               if !live[b] {
+                                       live[b] = true
+                                       for _, succ := range b.Succs {
+                                               visit(succ)
+                                       }
+                               }
+                       }
+                       visit(g.Blocks[0])
+
+                       // Print statements in unreachable blocks
+                       // (in order determined by builder).
+                       var buf bytes.Buffer
+                       for _, b := range g.Blocks {
+                               if !live[b] {
+                                       for _, n := range b.Nodes {
+                                               fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n))
+                                       }
+                               }
+                       }
+
+                       // Check that the result contains "dead" at least once but not "live".
+                       if !bytes.Contains(buf.Bytes(), []byte("dead")) ||
+                               bytes.Contains(buf.Bytes(), []byte("live")) {
+                               t.Errorf("unexpected dead statements in function %s:\n%s",
+                                       decl.Name.Name,
+                                       &buf)
+                               t.Logf("control flow graph:\n%s", g.Format(fset))
+                       }
+               }
+       }
+}
+
+// A trivial mayReturn predicate that looks only at syntax, not types.
+func mayReturn(call *ast.CallExpr) bool {
+       switch fun := call.Fun.(type) {
+       case *ast.Ident:
+               return fun.Name != "panic"
+       case *ast.SelectorExpr:
+               return fun.Sel.Name != "Fatal"
+       }
+       return true
+}
diff --git a/src/cmd/vet/lostcancel.go b/src/cmd/vet/lostcancel.go
new file mode 100644 (file)
index 0000000..708b6f3
--- /dev/null
@@ -0,0 +1,296 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "cmd/vet/internal/cfg"
+       "fmt"
+       "go/ast"
+       "go/types"
+       "strconv"
+)
+
+func init() {
+       register("lostcancel",
+               "check for failure to call cancellation function returned by context.WithCancel",
+               checkLostCancel,
+               funcDecl, funcLit)
+}
+
+const debugLostCancel = false
+
+var contextPackage = "context"
+
+// checkLostCancel reports a failure to the call the cancel function
+// returned by context.WithCancel, either because the variable was
+// assigned to the blank identifier, or because there exists a
+// control-flow path from the call to a return statement and that path
+// does not "use" the cancel function.  Any reference to the variable
+// counts as a use, even within a nested function literal.
+//
+// checkLostCancel analyzes a single named or literal function.
+func checkLostCancel(f *File, node ast.Node) {
+       // Fast path: bypass check if file doesn't use context.WithCancel.
+       if !hasImport(f.file, contextPackage) {
+               return
+       }
+
+       // Maps each cancel variable to its defining ValueSpec/AssignStmt.
+       cancelvars := make(map[*types.Var]ast.Node)
+
+       // Find the set of cancel vars to analyze.
+       stack := make([]ast.Node, 0, 32)
+       ast.Inspect(node, func(n ast.Node) bool {
+               switch n.(type) {
+               case *ast.FuncLit:
+                       if len(stack) > 0 {
+                               return false // don't stray into nested functions
+                       }
+               case nil:
+                       stack = stack[:len(stack)-1] // pop
+                       return true
+               }
+               stack = append(stack, n) // push
+
+               // Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
+               //
+               //   ctx, cancel    := context.WithCancel(...)
+               //   ctx, cancel     = context.WithCancel(...)
+               //   var ctx, cancel = context.WithCancel(...)
+               //
+               if isContextWithCancel(f, n) && isCall(stack[len(stack)-2]) {
+                       var id *ast.Ident // id of cancel var
+                       stmt := stack[len(stack)-3]
+                       switch stmt := stmt.(type) {
+                       case *ast.ValueSpec:
+                               if len(stmt.Names) > 1 {
+                                       id = stmt.Names[1]
+                               }
+                       case *ast.AssignStmt:
+                               if len(stmt.Lhs) > 1 {
+                                       id, _ = stmt.Lhs[1].(*ast.Ident)
+                               }
+                       }
+                       if id != nil {
+                               if id.Name == "_" {
+                                       f.Badf(id.Pos(), "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
+                                               n.(*ast.SelectorExpr).Sel.Name)
+                               } else if v, ok := f.pkg.uses[id].(*types.Var); ok {
+                                       cancelvars[v] = stmt
+                               } else if v, ok := f.pkg.defs[id].(*types.Var); ok {
+                                       cancelvars[v] = stmt
+                               }
+                       }
+               }
+
+               return true
+       })
+
+       if len(cancelvars) == 0 {
+               return // no need to build CFG
+       }
+
+       // Tell the CFG builder which functions never return.
+       info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors}
+       mayReturn := func(call *ast.CallExpr) bool {
+               name := callName(info, call)
+               return !noReturnFuncs[name]
+       }
+
+       // Build the CFG.
+       var g *cfg.CFG
+       switch node := node.(type) {
+       case *ast.FuncDecl:
+               g = cfg.New(node.Body, mayReturn)
+       case *ast.FuncLit:
+               g = cfg.New(node.Body, mayReturn)
+       }
+
+       // Print CFG.
+       if debugLostCancel {
+               fmt.Println(g.Format(f.fset))
+       }
+
+       // Examine the CFG for each variable in turn.
+       // (It would be more efficient to analyze all cancelvars in a
+       // single pass over the AST, but seldom is there more than one.)
+       for v, stmt := range cancelvars {
+               if ret := lostCancelPath(f, g, v, stmt); ret != nil {
+                       lineno := f.fset.Position(stmt.Pos()).Line
+                       f.Badf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name())
+                       f.Badf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
+               }
+       }
+}
+
+func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
+
+func hasImport(f *ast.File, path string) bool {
+       for _, imp := range f.Imports {
+               v, _ := strconv.Unquote(imp.Path.Value)
+               if v == path {
+                       return true
+               }
+       }
+       return false
+}
+
+// isContextWithCancel reports whether n is one of the qualified identifiers
+// context.With{Cancel,Timeout,Deadline}.
+func isContextWithCancel(f *File, n ast.Node) bool {
+       if sel, ok := n.(*ast.SelectorExpr); ok {
+               switch sel.Sel.Name {
+               case "WithCancel", "WithTimeout", "WithDeadline":
+                       if x, ok := sel.X.(*ast.Ident); ok {
+                               if pkgname, ok := f.pkg.uses[x].(*types.PkgName); ok {
+                                       return pkgname.Imported().Path() == contextPackage
+                               }
+                               // Import failed, so we can't check package path.
+                               // Just check the local package name (heuristic).
+                               return x.Name == "context"
+                       }
+               }
+       }
+       return false
+}
+
+// lostCancelPath finds a path through the CFG, from stmt (which defines
+// the 'cancel' variable v) to a return statement, that doesn't "use" v.
+// If it finds one, it returns the return statement (which may be synthetic).
+func lostCancelPath(f *File, g *cfg.CFG, v *types.Var, stmt ast.Node) *ast.ReturnStmt {
+       // uses reports whether stmts contain a "use" of variable v.
+       uses := func(f *File, v *types.Var, stmts []ast.Node) bool {
+               found := false
+               for _, stmt := range stmts {
+                       ast.Inspect(stmt, func(n ast.Node) bool {
+                               if id, ok := n.(*ast.Ident); ok {
+                                       if f.pkg.uses[id] == v {
+                                               found = true
+                                       }
+                               }
+                               return !found
+                       })
+               }
+               return found
+       }
+
+       // blockUses computes "uses" for each block, caching the result.
+       memo := make(map[*cfg.Block]bool)
+       blockUses := func(f *File, v *types.Var, b *cfg.Block) bool {
+               res, ok := memo[b]
+               if !ok {
+                       res = uses(f, v, b.Nodes)
+                       memo[b] = res
+               }
+               return res
+       }
+
+       // Find the var's defining block in the CFG,
+       // plus the rest of the statements of that block.
+       var defblock *cfg.Block
+       var rest []ast.Node
+outer:
+       for _, b := range g.Blocks {
+               for i, n := range b.Nodes {
+                       if n == stmt {
+                               defblock = b
+                               rest = b.Nodes[i+1:]
+                               break outer
+                       }
+               }
+       }
+       if defblock == nil {
+               panic("internal error: can't find defining block for cancel var")
+       }
+
+       // Is v "used" in the remainder of its defining block?
+       if uses(f, v, rest) {
+               return nil
+       }
+
+       // Does the defining block return without using v?
+       if ret := defblock.Return(); ret != nil {
+               return ret
+       }
+
+       // Search the CFG depth-first for a path, from defblock to a
+       // return block, in which v is never "used".
+       seen := make(map[*cfg.Block]bool)
+       var search func(blocks []*cfg.Block) *ast.ReturnStmt
+       search = func(blocks []*cfg.Block) *ast.ReturnStmt {
+               for _, b := range blocks {
+                       if !seen[b] {
+                               seen[b] = true
+
+                               // Prune the search if the block uses v.
+                               if blockUses(f, v, b) {
+                                       continue
+                               }
+
+                               // Found path to return statement?
+                               if ret := b.Return(); ret != nil {
+                                       if debugLostCancel {
+                                               fmt.Printf("found path to return in block %s\n", b)
+                                       }
+                                       return ret // found
+                               }
+
+                               // Recur
+                               if ret := search(b.Succs); ret != nil {
+                                       if debugLostCancel {
+                                               fmt.Printf(" from block %s\n", b)
+                                       }
+                                       return ret
+                               }
+                       }
+               }
+               return nil
+       }
+       return search(defblock.Succs)
+}
+
+var noReturnFuncs = map[string]bool{
+       "(*testing.common).FailNow": true,
+       "(*testing.common).Fatal":   true,
+       "(*testing.common).Fatalf":  true,
+       "(*testing.common).Skip":    true,
+       "(*testing.common).SkipNow": true,
+       "(*testing.common).Skipf":   true,
+       "log.Fatal":                 true,
+       "log.Fatalf":                true,
+       "log.Fatalln":               true,
+       "os.Exit":                   true,
+       "panic":                     true,
+       "runtime.Goexit":            true,
+}
+
+// callName returns the canonical name of the builtin, method, or
+// function called by call, if known.
+func callName(info *types.Info, call *ast.CallExpr) string {
+       switch fun := call.Fun.(type) {
+       case *ast.Ident:
+               // builtin, e.g. "panic"
+               if obj, ok := info.Uses[fun].(*types.Builtin); ok {
+                       return obj.Name()
+               }
+       case *ast.SelectorExpr:
+               if sel, ok := info.Selections[fun]; ok {
+                       // method call, e.g. "(*testing.common).Fatal"
+                       meth := sel.Obj()
+                       return fmt.Sprintf("(%s).%s",
+                               meth.Type().(*types.Signature).Recv().Type(),
+                               meth.Name())
+               }
+               if obj, ok := info.Uses[fun.Sel]; ok {
+                       // qualified identifier, e.g. "os.Exit"
+                       return fmt.Sprintf("%s.%s",
+                               obj.Pkg().Path(),
+                               obj.Name())
+               }
+       }
+
+       // function with no name, or defined in missing imported package
+       return ""
+}
diff --git a/src/cmd/vet/testdata/lostcancel.go b/src/cmd/vet/testdata/lostcancel.go
new file mode 100644 (file)
index 0000000..143456e
--- /dev/null
@@ -0,0 +1,137 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testdata
+
+import (
+       "context"
+       "log"
+       "os"
+       "testing"
+)
+
+// Check the three functions and assignment forms (var, :=, =) we look for.
+// (Do these early: line numbers are fragile.)
+func _() {
+       var ctx, cancel = context.WithCancel() // ERROR "the cancel function is not used on all paths \(possible context leak\)"
+} // ERROR "this return statement may be reached without using the cancel var defined on line 17"
+
+func _() {
+       ctx, cancel2 := context.WithDeadline() // ERROR "the cancel2 function is not used..."
+} // ERROR "may be reached without using the cancel2 var defined on line 21"
+
+func _() {
+       var ctx context.Context
+       var cancel3 func()
+       ctx, cancel3 = context.WithTimeout() // ERROR "function is not used..."
+} // ERROR "this return statement may be reached without using the cancel3 var defined on line 27"
+
+func _() {
+       ctx, _ := context.WithCancel()  // ERROR "the cancel function returned by context.WithCancel should be called, not discarded, to avoid a context leak"
+       ctx, _ = context.WithTimeout()  // ERROR "the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak"
+       ctx, _ = context.WithDeadline() // ERROR "the cancel function returned by context.WithDeadline should be called, not discarded, to avoid a context leak"
+}
+
+func _() {
+       ctx, cancel := context.WithCancel()
+       defer cancel() // ok
+}
+
+func _() {
+       ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+       if condition {
+               cancel()
+       }
+       return // ERROR "this return statement may be reached without using the cancel var"
+}
+
+func _() {
+       ctx, cancel := context.WithCancel()
+       if condition {
+               cancel()
+       } else {
+               // ok: infinite loop
+               for {
+                       print(0)
+               }
+       }
+}
+
+func _() {
+       ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+       if condition {
+               cancel()
+       } else {
+               for i := 0; i < 10; i++ {
+                       print(0)
+               }
+       }
+} // ERROR "this return statement may be reached without using the cancel var"
+
+func _() {
+       ctx, cancel := context.WithCancel()
+       // ok: used on all paths
+       switch someInt {
+       case 0:
+               new(testing.T).FailNow()
+       case 1:
+               log.Fatal()
+       case 2:
+               cancel()
+       case 3:
+               print("hi")
+               fallthrough
+       default:
+               os.Exit(1)
+       }
+}
+
+func _() {
+       ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+       switch someInt {
+       case 0:
+               new(testing.T).FailNow()
+       case 1:
+               log.Fatal()
+       case 2:
+               cancel()
+       case 3:
+               print("hi") // falls through to implicit return
+       default:
+               os.Exit(1)
+       }
+} // ERROR "this return statement may be reached without using the cancel var"
+
+func _(ch chan int) int {
+       ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+       select {
+       case <-ch:
+               new(testing.T).FailNow()
+       case y <- ch:
+               print("hi") // falls through to implicit return
+       case ch <- 1:
+               cancel()
+       default:
+               os.Exit(1)
+       }
+} // ERROR "this return statement may be reached without using the cancel var"
+
+func _(ch chan int) int {
+       ctx, cancel := context.WithCancel()
+       // A blocking select must execute one of its cases.
+       select {
+       case <-ch:
+               panic()
+       }
+}
+
+func _() {
+       go func() {
+               ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+               print(ctx)
+       }() // ERROR "may be reached without using the cancel var"
+}
+
+var condition bool
+var someInt int