The cfg subpackage builds a control flow graph of ast.Nodes.
The lostcancel module checks this graph to find paths, from a call to
WithCancel to a return statement, on which the cancel variable is
not used. (A trivial case is simply assigning the cancel result to
the blank identifier.)
In a sample of 50,000 source files, the new check found 2068 blank
assignments and 118 return-before-cancel errors. I manually inspected
20 of the latter and didn't find a single false positive among them.
Change-Id: I84cd49445f9f8d04908b04881eb1496a96611205
Reviewed-on: https://go-review.googlesource.com/24150
Reviewed-by: Robert Griesemer <gri@golang.org>
Reviewed-by: Rob Pike <r@golang.org>
}
// checkUnreachable checks a function body for dead code.
+//
+// TODO(adonovan): use the new cfg package, which is more precise.
func checkUnreachable(f *File, node ast.Node) {
var body *ast.BlockStmt
switch n := node.(type) {
Mistakes involving tests including functions with incorrect names or signatures
and example tests that document identifiers not in the package.
+Failure to call the cancellation function returned by context.WithCancel.
+
+Flag: -lostcancel
+
+The cancellation function returned by context.WithCancel, WithTimeout,
+and WithDeadline must be called or the new context will remain live
+until its parent context is cancelled.
+(The background context is never cancelled.)
+
Methods
Flag: -methods
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cfg
+
+// This file implements the CFG construction pass.
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+)
+
+type builder struct {
+ cfg *CFG
+ mayReturn func(*ast.CallExpr) bool
+ current *Block
+ lblocks map[*ast.Object]*lblock // labeled blocks
+ targets *targets // linked stack of branch targets
+}
+
+func (b *builder) stmt(_s ast.Stmt) {
+ // The label of the current statement. If non-nil, its _goto
+ // target is always set; its _break and _continue are set only
+ // within the body of switch/typeswitch/select/for/range.
+ // It is effectively an additional default-nil parameter of stmt().
+ var label *lblock
+start:
+ switch s := _s.(type) {
+ case *ast.BadStmt,
+ *ast.SendStmt,
+ *ast.IncDecStmt,
+ *ast.GoStmt,
+ *ast.DeferStmt,
+ *ast.EmptyStmt,
+ *ast.AssignStmt:
+ // No effect on control flow.
+ b.add(s)
+
+ case *ast.ExprStmt:
+ b.add(s)
+ if call, ok := s.X.(*ast.CallExpr); ok && !b.mayReturn(call) {
+ // Calls to panic, os.Exit, etc, never return.
+ b.current = b.newUnreachableBlock("unreachable.call")
+ }
+
+ case *ast.DeclStmt:
+ // Treat each var ValueSpec as a separate statement.
+ d := s.Decl.(*ast.GenDecl)
+ if d.Tok == token.VAR {
+ for _, spec := range d.Specs {
+ if spec, ok := spec.(*ast.ValueSpec); ok {
+ b.add(spec)
+ }
+ }
+ }
+
+ case *ast.LabeledStmt:
+ label = b.labeledBlock(s.Label)
+ b.jump(label._goto)
+ b.current = label._goto
+ _s = s.Stmt
+ goto start // effectively: tailcall stmt(g, s.Stmt, label)
+
+ case *ast.ReturnStmt:
+ b.add(s)
+ b.current = b.newUnreachableBlock("unreachable.return")
+
+ case *ast.BranchStmt:
+ var block *Block
+ switch s.Tok {
+ case token.BREAK:
+ if s.Label != nil {
+ if lb := b.labeledBlock(s.Label); lb != nil {
+ block = lb._break
+ }
+ } else {
+ for t := b.targets; t != nil && block == nil; t = t.tail {
+ block = t._break
+ }
+ }
+
+ case token.CONTINUE:
+ if s.Label != nil {
+ if lb := b.labeledBlock(s.Label); lb != nil {
+ block = lb._continue
+ }
+ } else {
+ for t := b.targets; t != nil && block == nil; t = t.tail {
+ block = t._continue
+ }
+ }
+
+ case token.FALLTHROUGH:
+ for t := b.targets; t != nil; t = t.tail {
+ block = t._fallthrough
+ }
+
+ case token.GOTO:
+ block = b.labeledBlock(s.Label)._goto
+ }
+ if block == nil {
+ block = b.newBlock("undefined.branch")
+ }
+ b.jump(block)
+ b.current = b.newUnreachableBlock("unreachable.branch")
+
+ case *ast.BlockStmt:
+ b.stmtList(s.List)
+
+ case *ast.IfStmt:
+ if s.Init != nil {
+ b.stmt(s.Init)
+ }
+ then := b.newBlock("if.then")
+ done := b.newBlock("if.done")
+ _else := done
+ if s.Else != nil {
+ _else = b.newBlock("if.else")
+ }
+ b.add(s.Cond)
+ b.ifelse(then, _else)
+ b.current = then
+ b.stmt(s.Body)
+ b.jump(done)
+
+ if s.Else != nil {
+ b.current = _else
+ b.stmt(s.Else)
+ b.jump(done)
+ }
+
+ b.current = done
+
+ case *ast.SwitchStmt:
+ b.switchStmt(s, label)
+
+ case *ast.TypeSwitchStmt:
+ b.typeSwitchStmt(s, label)
+
+ case *ast.SelectStmt:
+ b.selectStmt(s, label)
+
+ case *ast.ForStmt:
+ b.forStmt(s, label)
+
+ case *ast.RangeStmt:
+ b.rangeStmt(s, label)
+
+ default:
+ panic(fmt.Sprintf("unexpected statement kind: %T", s))
+ }
+}
+
+func (b *builder) stmtList(list []ast.Stmt) {
+ for _, s := range list {
+ b.stmt(s)
+ }
+}
+
+func (b *builder) switchStmt(s *ast.SwitchStmt, label *lblock) {
+ if s.Init != nil {
+ b.stmt(s.Init)
+ }
+ if s.Tag != nil {
+ b.add(s.Tag)
+ }
+ done := b.newBlock("switch.done")
+ if label != nil {
+ label._break = done
+ }
+ // We pull the default case (if present) down to the end.
+ // But each fallthrough label must point to the next
+ // body block in source order, so we preallocate a
+ // body block (fallthru) for the next case.
+ // Unfortunately this makes for a confusing block order.
+ var defaultBody *[]ast.Stmt
+ var defaultFallthrough *Block
+ var fallthru, defaultBlock *Block
+ ncases := len(s.Body.List)
+ for i, clause := range s.Body.List {
+ body := fallthru
+ if body == nil {
+ body = b.newBlock("switch.body") // first case only
+ }
+
+ // Preallocate body block for the next case.
+ fallthru = done
+ if i+1 < ncases {
+ fallthru = b.newBlock("switch.body")
+ }
+
+ cc := clause.(*ast.CaseClause)
+ if cc.List == nil {
+ // Default case.
+ defaultBody = &cc.Body
+ defaultFallthrough = fallthru
+ defaultBlock = body
+ continue
+ }
+
+ var nextCond *Block
+ for _, cond := range cc.List {
+ nextCond = b.newBlock("switch.next")
+ b.add(cond) // one half of the tag==cond condition
+ b.ifelse(body, nextCond)
+ b.current = nextCond
+ }
+ b.current = body
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ _fallthrough: fallthru,
+ }
+ b.stmtList(cc.Body)
+ b.targets = b.targets.tail
+ b.jump(done)
+ b.current = nextCond
+ }
+ if defaultBlock != nil {
+ b.jump(defaultBlock)
+ b.current = defaultBlock
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ _fallthrough: defaultFallthrough,
+ }
+ b.stmtList(*defaultBody)
+ b.targets = b.targets.tail
+ }
+ b.jump(done)
+ b.current = done
+}
+
+func (b *builder) typeSwitchStmt(s *ast.TypeSwitchStmt, label *lblock) {
+ if s.Init != nil {
+ b.stmt(s.Init)
+ }
+ if s.Assign != nil {
+ b.add(s.Assign)
+ }
+
+ done := b.newBlock("typeswitch.done")
+ if label != nil {
+ label._break = done
+ }
+ var default_ *ast.CaseClause
+ for _, clause := range s.Body.List {
+ cc := clause.(*ast.CaseClause)
+ if cc.List == nil {
+ default_ = cc
+ continue
+ }
+ body := b.newBlock("typeswitch.body")
+ var next *Block
+ for _, casetype := range cc.List {
+ next = b.newBlock("typeswitch.next")
+ // casetype is a type, so don't call b.add(casetype).
+ // This block logically contains a type assertion,
+ // x.(casetype), but it's unclear how to represent x.
+ _ = casetype
+ b.ifelse(body, next)
+ b.current = next
+ }
+ b.current = body
+ b.typeCaseBody(cc, done)
+ b.current = next
+ }
+ if default_ != nil {
+ b.typeCaseBody(default_, done)
+ } else {
+ b.jump(done)
+ }
+ b.current = done
+}
+
+func (b *builder) typeCaseBody(cc *ast.CaseClause, done *Block) {
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ }
+ b.stmtList(cc.Body)
+ b.targets = b.targets.tail
+ b.jump(done)
+}
+
+func (b *builder) selectStmt(s *ast.SelectStmt, label *lblock) {
+ // First evaluate channel expressions.
+ // TODO(adonovan): fix: evaluate only channel exprs here.
+ for _, clause := range s.Body.List {
+ if comm := clause.(*ast.CommClause).Comm; comm != nil {
+ b.stmt(comm)
+ }
+ }
+
+ done := b.newBlock("select.done")
+ if label != nil {
+ label._break = done
+ }
+
+ var defaultBody *[]ast.Stmt
+ for _, cc := range s.Body.List {
+ clause := cc.(*ast.CommClause)
+ if clause.Comm == nil {
+ defaultBody = &clause.Body
+ continue
+ }
+ body := b.newBlock("select.body")
+ next := b.newBlock("select.next")
+ b.ifelse(body, next)
+ b.current = body
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ }
+ switch comm := clause.Comm.(type) {
+ case *ast.ExprStmt: // <-ch
+ // nop
+ case *ast.AssignStmt: // x := <-states[state].Chan
+ b.add(comm.Lhs[0])
+ }
+ b.stmtList(clause.Body)
+ b.targets = b.targets.tail
+ b.jump(done)
+ b.current = next
+ }
+ if defaultBody != nil {
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ }
+ b.stmtList(*defaultBody)
+ b.targets = b.targets.tail
+ b.jump(done)
+ }
+ b.current = done
+}
+
+func (b *builder) forStmt(s *ast.ForStmt, label *lblock) {
+ // ...init...
+ // jump loop
+ // loop:
+ // if cond goto body else done
+ // body:
+ // ...body...
+ // jump post
+ // post: (target of continue)
+ // ...post...
+ // jump loop
+ // done: (target of break)
+ if s.Init != nil {
+ b.stmt(s.Init)
+ }
+ body := b.newBlock("for.body")
+ done := b.newBlock("for.done") // target of 'break'
+ loop := body // target of back-edge
+ if s.Cond != nil {
+ loop = b.newBlock("for.loop")
+ }
+ cont := loop // target of 'continue'
+ if s.Post != nil {
+ cont = b.newBlock("for.post")
+ }
+ if label != nil {
+ label._break = done
+ label._continue = cont
+ }
+ b.jump(loop)
+ b.current = loop
+ if loop != body {
+ b.add(s.Cond)
+ b.ifelse(body, done)
+ b.current = body
+ }
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ _continue: cont,
+ }
+ b.stmt(s.Body)
+ b.targets = b.targets.tail
+ b.jump(cont)
+
+ if s.Post != nil {
+ b.current = cont
+ b.stmt(s.Post)
+ b.jump(loop) // back-edge
+ }
+ b.current = done
+}
+
+func (b *builder) rangeStmt(s *ast.RangeStmt, label *lblock) {
+ b.add(s.X)
+
+ if s.Key != nil {
+ b.add(s.Key)
+ }
+ if s.Value != nil {
+ b.add(s.Value)
+ }
+
+ // ...
+ // loop: (target of continue)
+ // if ... goto body else done
+ // body:
+ // ...
+ // jump loop
+ // done: (target of break)
+
+ loop := b.newBlock("range.loop")
+ b.jump(loop)
+ b.current = loop
+
+ body := b.newBlock("range.body")
+ done := b.newBlock("range.done")
+ b.ifelse(body, done)
+ b.current = body
+
+ if label != nil {
+ label._break = done
+ label._continue = loop
+ }
+ b.targets = &targets{
+ tail: b.targets,
+ _break: done,
+ _continue: loop,
+ }
+ b.stmt(s.Body)
+ b.targets = b.targets.tail
+ b.jump(loop) // back-edge
+ b.current = done
+}
+
+// -------- helpers --------
+
+// Destinations associated with unlabeled for/switch/select stmts.
+// We push/pop one of these as we enter/leave each construct and for
+// each BranchStmt we scan for the innermost target of the right type.
+//
+type targets struct {
+ tail *targets // rest of stack
+ _break *Block
+ _continue *Block
+ _fallthrough *Block
+}
+
+// Destinations associated with a labeled block.
+// We populate these as labels are encountered in forward gotos or
+// labeled statements.
+//
+type lblock struct {
+ _goto *Block
+ _break *Block
+ _continue *Block
+}
+
+// labeledBlock returns the branch target associated with the
+// specified label, creating it if needed.
+//
+func (b *builder) labeledBlock(label *ast.Ident) *lblock {
+ lb := b.lblocks[label.Obj]
+ if lb == nil {
+ lb = &lblock{_goto: b.newBlock(label.Name)}
+ if b.lblocks == nil {
+ b.lblocks = make(map[*ast.Object]*lblock)
+ }
+ b.lblocks[label.Obj] = lb
+ }
+ return lb
+}
+
+// newBlock appends a new unconnected basic block to b.cfg's block
+// slice and returns it.
+// It does not automatically become the current block.
+// comment is an optional string for more readable debugging output.
+func (b *builder) newBlock(comment string) *Block {
+ g := b.cfg
+ block := &Block{
+ index: int32(len(g.Blocks)),
+ comment: comment,
+ }
+ block.Succs = block.succs2[:0]
+ g.Blocks = append(g.Blocks, block)
+ return block
+}
+
+func (b *builder) newUnreachableBlock(comment string) *Block {
+ block := b.newBlock(comment)
+ block.unreachable = true
+ return block
+}
+
+func (b *builder) add(n ast.Node) {
+ b.current.Nodes = append(b.current.Nodes, n)
+}
+
+// jump adds an edge from the current block to the target block,
+// and sets b.current to nil.
+func (b *builder) jump(target *Block) {
+ b.current.Succs = append(b.current.Succs, target)
+ b.current = nil
+}
+
+// ifelse emits edges from the current block to the t and f blocks,
+// and sets b.current to nil.
+func (b *builder) ifelse(t, f *Block) {
+ b.current.Succs = append(b.current.Succs, t, f)
+ b.current = nil
+}
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package constructs a simple control-flow graph (CFG) of the
+// statements and expressions within a single function.
+//
+// Use cfg.New to construct the CFG for a function body.
+//
+// The blocks of the CFG contain all the function's non-control
+// statements. The CFG does not contain control statements such as If,
+// Switch, Select, and Branch, but does contain their subexpressions.
+// For example, this source code:
+//
+// if x := f(); x != nil {
+// T()
+// } else {
+// F()
+// }
+//
+// produces this CFG:
+//
+// 1: x := f()
+// x != nil
+// succs: 2, 3
+// 2: T()
+// succs: 4
+// 3: F()
+// succs: 4
+// 4:
+//
+// The CFG does contain Return statements; even implicit returns are
+// materialized (at the position of the function's closing brace).
+//
+// The CFG does not record conditions associated with conditional branch
+// edges, nor the short-circuit semantics of the && and || operators,
+// nor abnormal control flow caused by panic. If you need this
+// information, use golang.org/x/tools/go/ssa instead.
+//
+package cfg
+
+// Although the vet tool has type information, it is often extremely
+// fragmentary, so for simplicity this package does not depend on
+// go/types. Consequently control-flow conditions are ignored even
+// when constant, and "mayReturn" information must be provided by the
+// client.
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/token"
+)
+
+// A CFG represents the control-flow graph of a single function.
+//
+// The entry point is Blocks[0]; there may be multiple return blocks.
+type CFG struct {
+ Blocks []*Block // block[0] is entry; order otherwise undefined
+}
+
+// A Block represents a basic block: a list of statements and
+// expressions that are always evaluated sequentially.
+//
+// A block may have 0-2 successors: zero for a return block or a block
+// that calls a function such as panic that never returns; one for a
+// normal (jump) block; and two for a conditional (if) block.
+type Block struct {
+ Nodes []ast.Node // statements, expressions, and ValueSpecs
+ Succs []*Block // successor nodes in the graph
+
+ comment string // for debugging
+ index int32 // index within CFG.Blocks
+ unreachable bool // is block of stmts following return/panic/for{}
+ succs2 [2]*Block // underlying array for Succs
+}
+
+// New returns a new control-flow graph for the specified function body,
+// which must be non-nil.
+//
+// The CFG builder calls mayReturn to determine whether a given function
+// call may return. For example, calls to panic, os.Exit, and log.Fatal
+// do not return, so the builder can remove infeasible graph edges
+// following such calls. The builder calls mayReturn only for a
+// CallExpr beneath an ExprStmt.
+func New(body *ast.BlockStmt, mayReturn func(*ast.CallExpr) bool) *CFG {
+ b := builder{
+ mayReturn: mayReturn,
+ cfg: new(CFG),
+ }
+ b.current = b.newBlock("entry")
+ b.stmt(body)
+
+ // Does control fall off the end of the function's body?
+ // Make implicit return explicit.
+ if b.current != nil && !b.current.unreachable {
+ b.add(&ast.ReturnStmt{
+ Return: body.End() - 1,
+ })
+ }
+
+ return b.cfg
+}
+
+func (b *Block) String() string {
+ return fmt.Sprintf("block %d (%s)", b.index, b.comment)
+}
+
+// Return returns the return statement at the end of this block if present, nil otherwise.
+func (b *Block) Return() (ret *ast.ReturnStmt) {
+ if len(b.Nodes) > 0 {
+ ret, _ = b.Nodes[len(b.Nodes)-1].(*ast.ReturnStmt)
+ }
+ return
+}
+
+// Format formats the control-flow graph for ease of debugging.
+func (g *CFG) Format(fset *token.FileSet) string {
+ var buf bytes.Buffer
+ for _, b := range g.Blocks {
+ fmt.Fprintf(&buf, ".%d: # %s\n", b.index, b.comment)
+ for _, n := range b.Nodes {
+ fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n))
+ }
+ if len(b.Succs) > 0 {
+ fmt.Fprintf(&buf, "\tsuccs:")
+ for _, succ := range b.Succs {
+ fmt.Fprintf(&buf, " %d", succ.index)
+ }
+ buf.WriteByte('\n')
+ }
+ buf.WriteByte('\n')
+ }
+ return buf.String()
+}
+
+func formatNode(fset *token.FileSet, n ast.Node) string {
+ var buf bytes.Buffer
+ format.Node(&buf, fset, n)
+ // Indent secondary lines by a tab.
+ return string(bytes.Replace(buf.Bytes(), []byte("\n"), []byte("\n\t"), -1))
+}
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cfg
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "testing"
+)
+
+const src = `package main
+
+import "log"
+
+func f1() {
+ live()
+ return
+ dead()
+}
+
+func f2() {
+ for {
+ live()
+ }
+ dead()
+}
+
+func f3() {
+ if true { // even known values are ignored
+ return
+ }
+ for true { // even known values are ignored
+ live()
+ }
+ for {
+ live()
+ }
+ dead()
+}
+
+func f4(x int) {
+ switch x {
+ case 1:
+ live()
+ fallthrough
+ case 2:
+ live()
+ log.Fatal()
+ default:
+ panic("oops")
+ }
+ dead()
+}
+
+func f4(ch chan int) {
+ select {
+ case <-ch:
+ live()
+ return
+ default:
+ live()
+ panic("oops")
+ }
+ dead()
+}
+
+func f5(unknown bool) {
+ for {
+ if unknown {
+ break
+ }
+ continue
+ dead()
+ }
+ live()
+}
+
+func f6(unknown bool) {
+outer:
+ for {
+ for {
+ break outer
+ dead()
+ }
+ dead()
+ }
+ live()
+}
+
+func f7() {
+ for {
+ break nosuchlabel
+ dead()
+ }
+ dead()
+}
+
+func f8() {
+ select{}
+ dead()
+}
+
+func f9(ch chan int) {
+ select {
+ case <-ch:
+ return
+ }
+ dead()
+}
+
+func f10(ch chan int) {
+ select {
+ case <-ch:
+ return
+ dead()
+ default:
+ }
+ live()
+}
+`
+
+func TestDeadCode(t *testing.T) {
+ // We'll use dead code detection to verify the CFG.
+
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, "dummy.go", src, parser.Mode(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, decl := range f.Decls {
+ if decl, ok := decl.(*ast.FuncDecl); ok {
+ g := New(decl.Body, mayReturn)
+
+ // Mark blocks reachable from entry.
+ live := make(map[*Block]bool)
+ var visit func(*Block)
+ visit = func(b *Block) {
+ if !live[b] {
+ live[b] = true
+ for _, succ := range b.Succs {
+ visit(succ)
+ }
+ }
+ }
+ visit(g.Blocks[0])
+
+ // Print statements in unreachable blocks
+ // (in order determined by builder).
+ var buf bytes.Buffer
+ for _, b := range g.Blocks {
+ if !live[b] {
+ for _, n := range b.Nodes {
+ fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n))
+ }
+ }
+ }
+
+ // Check that the result contains "dead" at least once but not "live".
+ if !bytes.Contains(buf.Bytes(), []byte("dead")) ||
+ bytes.Contains(buf.Bytes(), []byte("live")) {
+ t.Errorf("unexpected dead statements in function %s:\n%s",
+ decl.Name.Name,
+ &buf)
+ t.Logf("control flow graph:\n%s", g.Format(fset))
+ }
+ }
+ }
+}
+
+// A trivial mayReturn predicate that looks only at syntax, not types.
+func mayReturn(call *ast.CallExpr) bool {
+ switch fun := call.Fun.(type) {
+ case *ast.Ident:
+ return fun.Name != "panic"
+ case *ast.SelectorExpr:
+ return fun.Sel.Name != "Fatal"
+ }
+ return true
+}
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "cmd/vet/internal/cfg"
+ "fmt"
+ "go/ast"
+ "go/types"
+ "strconv"
+)
+
+func init() {
+ register("lostcancel",
+ "check for failure to call cancellation function returned by context.WithCancel",
+ checkLostCancel,
+ funcDecl, funcLit)
+}
+
+const debugLostCancel = false
+
+var contextPackage = "context"
+
+// checkLostCancel reports a failure to the call the cancel function
+// returned by context.WithCancel, either because the variable was
+// assigned to the blank identifier, or because there exists a
+// control-flow path from the call to a return statement and that path
+// does not "use" the cancel function. Any reference to the variable
+// counts as a use, even within a nested function literal.
+//
+// checkLostCancel analyzes a single named or literal function.
+func checkLostCancel(f *File, node ast.Node) {
+ // Fast path: bypass check if file doesn't use context.WithCancel.
+ if !hasImport(f.file, contextPackage) {
+ return
+ }
+
+ // Maps each cancel variable to its defining ValueSpec/AssignStmt.
+ cancelvars := make(map[*types.Var]ast.Node)
+
+ // Find the set of cancel vars to analyze.
+ stack := make([]ast.Node, 0, 32)
+ ast.Inspect(node, func(n ast.Node) bool {
+ switch n.(type) {
+ case *ast.FuncLit:
+ if len(stack) > 0 {
+ return false // don't stray into nested functions
+ }
+ case nil:
+ stack = stack[:len(stack)-1] // pop
+ return true
+ }
+ stack = append(stack, n) // push
+
+ // Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
+ //
+ // ctx, cancel := context.WithCancel(...)
+ // ctx, cancel = context.WithCancel(...)
+ // var ctx, cancel = context.WithCancel(...)
+ //
+ if isContextWithCancel(f, n) && isCall(stack[len(stack)-2]) {
+ var id *ast.Ident // id of cancel var
+ stmt := stack[len(stack)-3]
+ switch stmt := stmt.(type) {
+ case *ast.ValueSpec:
+ if len(stmt.Names) > 1 {
+ id = stmt.Names[1]
+ }
+ case *ast.AssignStmt:
+ if len(stmt.Lhs) > 1 {
+ id, _ = stmt.Lhs[1].(*ast.Ident)
+ }
+ }
+ if id != nil {
+ if id.Name == "_" {
+ f.Badf(id.Pos(), "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
+ n.(*ast.SelectorExpr).Sel.Name)
+ } else if v, ok := f.pkg.uses[id].(*types.Var); ok {
+ cancelvars[v] = stmt
+ } else if v, ok := f.pkg.defs[id].(*types.Var); ok {
+ cancelvars[v] = stmt
+ }
+ }
+ }
+
+ return true
+ })
+
+ if len(cancelvars) == 0 {
+ return // no need to build CFG
+ }
+
+ // Tell the CFG builder which functions never return.
+ info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors}
+ mayReturn := func(call *ast.CallExpr) bool {
+ name := callName(info, call)
+ return !noReturnFuncs[name]
+ }
+
+ // Build the CFG.
+ var g *cfg.CFG
+ switch node := node.(type) {
+ case *ast.FuncDecl:
+ g = cfg.New(node.Body, mayReturn)
+ case *ast.FuncLit:
+ g = cfg.New(node.Body, mayReturn)
+ }
+
+ // Print CFG.
+ if debugLostCancel {
+ fmt.Println(g.Format(f.fset))
+ }
+
+ // Examine the CFG for each variable in turn.
+ // (It would be more efficient to analyze all cancelvars in a
+ // single pass over the AST, but seldom is there more than one.)
+ for v, stmt := range cancelvars {
+ if ret := lostCancelPath(f, g, v, stmt); ret != nil {
+ lineno := f.fset.Position(stmt.Pos()).Line
+ f.Badf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name())
+ f.Badf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
+ }
+ }
+}
+
+func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
+
+func hasImport(f *ast.File, path string) bool {
+ for _, imp := range f.Imports {
+ v, _ := strconv.Unquote(imp.Path.Value)
+ if v == path {
+ return true
+ }
+ }
+ return false
+}
+
+// isContextWithCancel reports whether n is one of the qualified identifiers
+// context.With{Cancel,Timeout,Deadline}.
+func isContextWithCancel(f *File, n ast.Node) bool {
+ if sel, ok := n.(*ast.SelectorExpr); ok {
+ switch sel.Sel.Name {
+ case "WithCancel", "WithTimeout", "WithDeadline":
+ if x, ok := sel.X.(*ast.Ident); ok {
+ if pkgname, ok := f.pkg.uses[x].(*types.PkgName); ok {
+ return pkgname.Imported().Path() == contextPackage
+ }
+ // Import failed, so we can't check package path.
+ // Just check the local package name (heuristic).
+ return x.Name == "context"
+ }
+ }
+ }
+ return false
+}
+
+// lostCancelPath finds a path through the CFG, from stmt (which defines
+// the 'cancel' variable v) to a return statement, that doesn't "use" v.
+// If it finds one, it returns the return statement (which may be synthetic).
+func lostCancelPath(f *File, g *cfg.CFG, v *types.Var, stmt ast.Node) *ast.ReturnStmt {
+ // uses reports whether stmts contain a "use" of variable v.
+ uses := func(f *File, v *types.Var, stmts []ast.Node) bool {
+ found := false
+ for _, stmt := range stmts {
+ ast.Inspect(stmt, func(n ast.Node) bool {
+ if id, ok := n.(*ast.Ident); ok {
+ if f.pkg.uses[id] == v {
+ found = true
+ }
+ }
+ return !found
+ })
+ }
+ return found
+ }
+
+ // blockUses computes "uses" for each block, caching the result.
+ memo := make(map[*cfg.Block]bool)
+ blockUses := func(f *File, v *types.Var, b *cfg.Block) bool {
+ res, ok := memo[b]
+ if !ok {
+ res = uses(f, v, b.Nodes)
+ memo[b] = res
+ }
+ return res
+ }
+
+ // Find the var's defining block in the CFG,
+ // plus the rest of the statements of that block.
+ var defblock *cfg.Block
+ var rest []ast.Node
+outer:
+ for _, b := range g.Blocks {
+ for i, n := range b.Nodes {
+ if n == stmt {
+ defblock = b
+ rest = b.Nodes[i+1:]
+ break outer
+ }
+ }
+ }
+ if defblock == nil {
+ panic("internal error: can't find defining block for cancel var")
+ }
+
+ // Is v "used" in the remainder of its defining block?
+ if uses(f, v, rest) {
+ return nil
+ }
+
+ // Does the defining block return without using v?
+ if ret := defblock.Return(); ret != nil {
+ return ret
+ }
+
+ // Search the CFG depth-first for a path, from defblock to a
+ // return block, in which v is never "used".
+ seen := make(map[*cfg.Block]bool)
+ var search func(blocks []*cfg.Block) *ast.ReturnStmt
+ search = func(blocks []*cfg.Block) *ast.ReturnStmt {
+ for _, b := range blocks {
+ if !seen[b] {
+ seen[b] = true
+
+ // Prune the search if the block uses v.
+ if blockUses(f, v, b) {
+ continue
+ }
+
+ // Found path to return statement?
+ if ret := b.Return(); ret != nil {
+ if debugLostCancel {
+ fmt.Printf("found path to return in block %s\n", b)
+ }
+ return ret // found
+ }
+
+ // Recur
+ if ret := search(b.Succs); ret != nil {
+ if debugLostCancel {
+ fmt.Printf(" from block %s\n", b)
+ }
+ return ret
+ }
+ }
+ }
+ return nil
+ }
+ return search(defblock.Succs)
+}
+
+var noReturnFuncs = map[string]bool{
+ "(*testing.common).FailNow": true,
+ "(*testing.common).Fatal": true,
+ "(*testing.common).Fatalf": true,
+ "(*testing.common).Skip": true,
+ "(*testing.common).SkipNow": true,
+ "(*testing.common).Skipf": true,
+ "log.Fatal": true,
+ "log.Fatalf": true,
+ "log.Fatalln": true,
+ "os.Exit": true,
+ "panic": true,
+ "runtime.Goexit": true,
+}
+
+// callName returns the canonical name of the builtin, method, or
+// function called by call, if known.
+func callName(info *types.Info, call *ast.CallExpr) string {
+ switch fun := call.Fun.(type) {
+ case *ast.Ident:
+ // builtin, e.g. "panic"
+ if obj, ok := info.Uses[fun].(*types.Builtin); ok {
+ return obj.Name()
+ }
+ case *ast.SelectorExpr:
+ if sel, ok := info.Selections[fun]; ok {
+ // method call, e.g. "(*testing.common).Fatal"
+ meth := sel.Obj()
+ return fmt.Sprintf("(%s).%s",
+ meth.Type().(*types.Signature).Recv().Type(),
+ meth.Name())
+ }
+ if obj, ok := info.Uses[fun.Sel]; ok {
+ // qualified identifier, e.g. "os.Exit"
+ return fmt.Sprintf("%s.%s",
+ obj.Pkg().Path(),
+ obj.Name())
+ }
+ }
+
+ // function with no name, or defined in missing imported package
+ return ""
+}
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testdata
+
+import (
+ "context"
+ "log"
+ "os"
+ "testing"
+)
+
+// Check the three functions and assignment forms (var, :=, =) we look for.
+// (Do these early: line numbers are fragile.)
+func _() {
+ var ctx, cancel = context.WithCancel() // ERROR "the cancel function is not used on all paths \(possible context leak\)"
+} // ERROR "this return statement may be reached without using the cancel var defined on line 17"
+
+func _() {
+ ctx, cancel2 := context.WithDeadline() // ERROR "the cancel2 function is not used..."
+} // ERROR "may be reached without using the cancel2 var defined on line 21"
+
+func _() {
+ var ctx context.Context
+ var cancel3 func()
+ ctx, cancel3 = context.WithTimeout() // ERROR "function is not used..."
+} // ERROR "this return statement may be reached without using the cancel3 var defined on line 27"
+
+func _() {
+ ctx, _ := context.WithCancel() // ERROR "the cancel function returned by context.WithCancel should be called, not discarded, to avoid a context leak"
+ ctx, _ = context.WithTimeout() // ERROR "the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak"
+ ctx, _ = context.WithDeadline() // ERROR "the cancel function returned by context.WithDeadline should be called, not discarded, to avoid a context leak"
+}
+
+func _() {
+ ctx, cancel := context.WithCancel()
+ defer cancel() // ok
+}
+
+func _() {
+ ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+ if condition {
+ cancel()
+ }
+ return // ERROR "this return statement may be reached without using the cancel var"
+}
+
+func _() {
+ ctx, cancel := context.WithCancel()
+ if condition {
+ cancel()
+ } else {
+ // ok: infinite loop
+ for {
+ print(0)
+ }
+ }
+}
+
+func _() {
+ ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+ if condition {
+ cancel()
+ } else {
+ for i := 0; i < 10; i++ {
+ print(0)
+ }
+ }
+} // ERROR "this return statement may be reached without using the cancel var"
+
+func _() {
+ ctx, cancel := context.WithCancel()
+ // ok: used on all paths
+ switch someInt {
+ case 0:
+ new(testing.T).FailNow()
+ case 1:
+ log.Fatal()
+ case 2:
+ cancel()
+ case 3:
+ print("hi")
+ fallthrough
+ default:
+ os.Exit(1)
+ }
+}
+
+func _() {
+ ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+ switch someInt {
+ case 0:
+ new(testing.T).FailNow()
+ case 1:
+ log.Fatal()
+ case 2:
+ cancel()
+ case 3:
+ print("hi") // falls through to implicit return
+ default:
+ os.Exit(1)
+ }
+} // ERROR "this return statement may be reached without using the cancel var"
+
+func _(ch chan int) int {
+ ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+ select {
+ case <-ch:
+ new(testing.T).FailNow()
+ case y <- ch:
+ print("hi") // falls through to implicit return
+ case ch <- 1:
+ cancel()
+ default:
+ os.Exit(1)
+ }
+} // ERROR "this return statement may be reached without using the cancel var"
+
+func _(ch chan int) int {
+ ctx, cancel := context.WithCancel()
+ // A blocking select must execute one of its cases.
+ select {
+ case <-ch:
+ panic()
+ }
+}
+
+func _() {
+ go func() {
+ ctx, cancel := context.WithCancel() // ERROR "not used on all paths"
+ print(ctx)
+ }() // ERROR "may be reached without using the cancel var"
+}
+
+var condition bool
+var someInt int