From: Alan Donovan Date: Wed, 15 Jun 2016 03:50:39 +0000 (-0400) Subject: cmd/vet: -lostcancel: check for discarded result of context.WithCancel X-Git-Tag: go1.7rc1~74 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=b65cb7f198836faf6605051b95bd60a169fa5e8b;p=gostls13.git cmd/vet: -lostcancel: check for discarded result of context.WithCancel The cfg subpackage builds a control flow graph of ast.Nodes. The lostcancel module checks this graph to find paths, from a call to WithCancel to a return statement, on which the cancel variable is not used. (A trivial case is simply assigning the cancel result to the blank identifier.) In a sample of 50,000 source files, the new check found 2068 blank assignments and 118 return-before-cancel errors. I manually inspected 20 of the latter and didn't find a single false positive among them. Change-Id: I84cd49445f9f8d04908b04881eb1496a96611205 Reviewed-on: https://go-review.googlesource.com/24150 Reviewed-by: Robert Griesemer Reviewed-by: Rob Pike --- diff --git a/src/cmd/vet/deadcode.go b/src/cmd/vet/deadcode.go index abede47a45..b1077aef38 100644 --- a/src/cmd/vet/deadcode.go +++ b/src/cmd/vet/deadcode.go @@ -29,6 +29,8 @@ type deadState struct { } // checkUnreachable checks a function body for dead code. +// +// TODO(adonovan): use the new cfg package, which is more precise. func checkUnreachable(f *File, node ast.Node) { var body *ast.BlockStmt switch n := node.(type) { diff --git a/src/cmd/vet/doc.go b/src/cmd/vet/doc.go index c697f3bc36..bb8dcf171f 100644 --- a/src/cmd/vet/doc.go +++ b/src/cmd/vet/doc.go @@ -91,6 +91,15 @@ Flag: -tests Mistakes involving tests including functions with incorrect names or signatures and example tests that document identifiers not in the package. +Failure to call the cancellation function returned by context.WithCancel. + +Flag: -lostcancel + +The cancellation function returned by context.WithCancel, WithTimeout, +and WithDeadline must be called or the new context will remain live +until its parent context is cancelled. +(The background context is never cancelled.) + Methods Flag: -methods diff --git a/src/cmd/vet/internal/cfg/builder.go b/src/cmd/vet/internal/cfg/builder.go new file mode 100644 index 0000000000..79c906bca0 --- /dev/null +++ b/src/cmd/vet/internal/cfg/builder.go @@ -0,0 +1,510 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cfg + +// This file implements the CFG construction pass. + +import ( + "fmt" + "go/ast" + "go/token" +) + +type builder struct { + cfg *CFG + mayReturn func(*ast.CallExpr) bool + current *Block + lblocks map[*ast.Object]*lblock // labeled blocks + targets *targets // linked stack of branch targets +} + +func (b *builder) stmt(_s ast.Stmt) { + // The label of the current statement. If non-nil, its _goto + // target is always set; its _break and _continue are set only + // within the body of switch/typeswitch/select/for/range. + // It is effectively an additional default-nil parameter of stmt(). + var label *lblock +start: + switch s := _s.(type) { + case *ast.BadStmt, + *ast.SendStmt, + *ast.IncDecStmt, + *ast.GoStmt, + *ast.DeferStmt, + *ast.EmptyStmt, + *ast.AssignStmt: + // No effect on control flow. + b.add(s) + + case *ast.ExprStmt: + b.add(s) + if call, ok := s.X.(*ast.CallExpr); ok && !b.mayReturn(call) { + // Calls to panic, os.Exit, etc, never return. + b.current = b.newUnreachableBlock("unreachable.call") + } + + case *ast.DeclStmt: + // Treat each var ValueSpec as a separate statement. + d := s.Decl.(*ast.GenDecl) + if d.Tok == token.VAR { + for _, spec := range d.Specs { + if spec, ok := spec.(*ast.ValueSpec); ok { + b.add(spec) + } + } + } + + case *ast.LabeledStmt: + label = b.labeledBlock(s.Label) + b.jump(label._goto) + b.current = label._goto + _s = s.Stmt + goto start // effectively: tailcall stmt(g, s.Stmt, label) + + case *ast.ReturnStmt: + b.add(s) + b.current = b.newUnreachableBlock("unreachable.return") + + case *ast.BranchStmt: + var block *Block + switch s.Tok { + case token.BREAK: + if s.Label != nil { + if lb := b.labeledBlock(s.Label); lb != nil { + block = lb._break + } + } else { + for t := b.targets; t != nil && block == nil; t = t.tail { + block = t._break + } + } + + case token.CONTINUE: + if s.Label != nil { + if lb := b.labeledBlock(s.Label); lb != nil { + block = lb._continue + } + } else { + for t := b.targets; t != nil && block == nil; t = t.tail { + block = t._continue + } + } + + case token.FALLTHROUGH: + for t := b.targets; t != nil; t = t.tail { + block = t._fallthrough + } + + case token.GOTO: + block = b.labeledBlock(s.Label)._goto + } + if block == nil { + block = b.newBlock("undefined.branch") + } + b.jump(block) + b.current = b.newUnreachableBlock("unreachable.branch") + + case *ast.BlockStmt: + b.stmtList(s.List) + + case *ast.IfStmt: + if s.Init != nil { + b.stmt(s.Init) + } + then := b.newBlock("if.then") + done := b.newBlock("if.done") + _else := done + if s.Else != nil { + _else = b.newBlock("if.else") + } + b.add(s.Cond) + b.ifelse(then, _else) + b.current = then + b.stmt(s.Body) + b.jump(done) + + if s.Else != nil { + b.current = _else + b.stmt(s.Else) + b.jump(done) + } + + b.current = done + + case *ast.SwitchStmt: + b.switchStmt(s, label) + + case *ast.TypeSwitchStmt: + b.typeSwitchStmt(s, label) + + case *ast.SelectStmt: + b.selectStmt(s, label) + + case *ast.ForStmt: + b.forStmt(s, label) + + case *ast.RangeStmt: + b.rangeStmt(s, label) + + default: + panic(fmt.Sprintf("unexpected statement kind: %T", s)) + } +} + +func (b *builder) stmtList(list []ast.Stmt) { + for _, s := range list { + b.stmt(s) + } +} + +func (b *builder) switchStmt(s *ast.SwitchStmt, label *lblock) { + if s.Init != nil { + b.stmt(s.Init) + } + if s.Tag != nil { + b.add(s.Tag) + } + done := b.newBlock("switch.done") + if label != nil { + label._break = done + } + // We pull the default case (if present) down to the end. + // But each fallthrough label must point to the next + // body block in source order, so we preallocate a + // body block (fallthru) for the next case. + // Unfortunately this makes for a confusing block order. + var defaultBody *[]ast.Stmt + var defaultFallthrough *Block + var fallthru, defaultBlock *Block + ncases := len(s.Body.List) + for i, clause := range s.Body.List { + body := fallthru + if body == nil { + body = b.newBlock("switch.body") // first case only + } + + // Preallocate body block for the next case. + fallthru = done + if i+1 < ncases { + fallthru = b.newBlock("switch.body") + } + + cc := clause.(*ast.CaseClause) + if cc.List == nil { + // Default case. + defaultBody = &cc.Body + defaultFallthrough = fallthru + defaultBlock = body + continue + } + + var nextCond *Block + for _, cond := range cc.List { + nextCond = b.newBlock("switch.next") + b.add(cond) // one half of the tag==cond condition + b.ifelse(body, nextCond) + b.current = nextCond + } + b.current = body + b.targets = &targets{ + tail: b.targets, + _break: done, + _fallthrough: fallthru, + } + b.stmtList(cc.Body) + b.targets = b.targets.tail + b.jump(done) + b.current = nextCond + } + if defaultBlock != nil { + b.jump(defaultBlock) + b.current = defaultBlock + b.targets = &targets{ + tail: b.targets, + _break: done, + _fallthrough: defaultFallthrough, + } + b.stmtList(*defaultBody) + b.targets = b.targets.tail + } + b.jump(done) + b.current = done +} + +func (b *builder) typeSwitchStmt(s *ast.TypeSwitchStmt, label *lblock) { + if s.Init != nil { + b.stmt(s.Init) + } + if s.Assign != nil { + b.add(s.Assign) + } + + done := b.newBlock("typeswitch.done") + if label != nil { + label._break = done + } + var default_ *ast.CaseClause + for _, clause := range s.Body.List { + cc := clause.(*ast.CaseClause) + if cc.List == nil { + default_ = cc + continue + } + body := b.newBlock("typeswitch.body") + var next *Block + for _, casetype := range cc.List { + next = b.newBlock("typeswitch.next") + // casetype is a type, so don't call b.add(casetype). + // This block logically contains a type assertion, + // x.(casetype), but it's unclear how to represent x. + _ = casetype + b.ifelse(body, next) + b.current = next + } + b.current = body + b.typeCaseBody(cc, done) + b.current = next + } + if default_ != nil { + b.typeCaseBody(default_, done) + } else { + b.jump(done) + } + b.current = done +} + +func (b *builder) typeCaseBody(cc *ast.CaseClause, done *Block) { + b.targets = &targets{ + tail: b.targets, + _break: done, + } + b.stmtList(cc.Body) + b.targets = b.targets.tail + b.jump(done) +} + +func (b *builder) selectStmt(s *ast.SelectStmt, label *lblock) { + // First evaluate channel expressions. + // TODO(adonovan): fix: evaluate only channel exprs here. + for _, clause := range s.Body.List { + if comm := clause.(*ast.CommClause).Comm; comm != nil { + b.stmt(comm) + } + } + + done := b.newBlock("select.done") + if label != nil { + label._break = done + } + + var defaultBody *[]ast.Stmt + for _, cc := range s.Body.List { + clause := cc.(*ast.CommClause) + if clause.Comm == nil { + defaultBody = &clause.Body + continue + } + body := b.newBlock("select.body") + next := b.newBlock("select.next") + b.ifelse(body, next) + b.current = body + b.targets = &targets{ + tail: b.targets, + _break: done, + } + switch comm := clause.Comm.(type) { + case *ast.ExprStmt: // <-ch + // nop + case *ast.AssignStmt: // x := <-states[state].Chan + b.add(comm.Lhs[0]) + } + b.stmtList(clause.Body) + b.targets = b.targets.tail + b.jump(done) + b.current = next + } + if defaultBody != nil { + b.targets = &targets{ + tail: b.targets, + _break: done, + } + b.stmtList(*defaultBody) + b.targets = b.targets.tail + b.jump(done) + } + b.current = done +} + +func (b *builder) forStmt(s *ast.ForStmt, label *lblock) { + // ...init... + // jump loop + // loop: + // if cond goto body else done + // body: + // ...body... + // jump post + // post: (target of continue) + // ...post... + // jump loop + // done: (target of break) + if s.Init != nil { + b.stmt(s.Init) + } + body := b.newBlock("for.body") + done := b.newBlock("for.done") // target of 'break' + loop := body // target of back-edge + if s.Cond != nil { + loop = b.newBlock("for.loop") + } + cont := loop // target of 'continue' + if s.Post != nil { + cont = b.newBlock("for.post") + } + if label != nil { + label._break = done + label._continue = cont + } + b.jump(loop) + b.current = loop + if loop != body { + b.add(s.Cond) + b.ifelse(body, done) + b.current = body + } + b.targets = &targets{ + tail: b.targets, + _break: done, + _continue: cont, + } + b.stmt(s.Body) + b.targets = b.targets.tail + b.jump(cont) + + if s.Post != nil { + b.current = cont + b.stmt(s.Post) + b.jump(loop) // back-edge + } + b.current = done +} + +func (b *builder) rangeStmt(s *ast.RangeStmt, label *lblock) { + b.add(s.X) + + if s.Key != nil { + b.add(s.Key) + } + if s.Value != nil { + b.add(s.Value) + } + + // ... + // loop: (target of continue) + // if ... goto body else done + // body: + // ... + // jump loop + // done: (target of break) + + loop := b.newBlock("range.loop") + b.jump(loop) + b.current = loop + + body := b.newBlock("range.body") + done := b.newBlock("range.done") + b.ifelse(body, done) + b.current = body + + if label != nil { + label._break = done + label._continue = loop + } + b.targets = &targets{ + tail: b.targets, + _break: done, + _continue: loop, + } + b.stmt(s.Body) + b.targets = b.targets.tail + b.jump(loop) // back-edge + b.current = done +} + +// -------- helpers -------- + +// Destinations associated with unlabeled for/switch/select stmts. +// We push/pop one of these as we enter/leave each construct and for +// each BranchStmt we scan for the innermost target of the right type. +// +type targets struct { + tail *targets // rest of stack + _break *Block + _continue *Block + _fallthrough *Block +} + +// Destinations associated with a labeled block. +// We populate these as labels are encountered in forward gotos or +// labeled statements. +// +type lblock struct { + _goto *Block + _break *Block + _continue *Block +} + +// labeledBlock returns the branch target associated with the +// specified label, creating it if needed. +// +func (b *builder) labeledBlock(label *ast.Ident) *lblock { + lb := b.lblocks[label.Obj] + if lb == nil { + lb = &lblock{_goto: b.newBlock(label.Name)} + if b.lblocks == nil { + b.lblocks = make(map[*ast.Object]*lblock) + } + b.lblocks[label.Obj] = lb + } + return lb +} + +// newBlock appends a new unconnected basic block to b.cfg's block +// slice and returns it. +// It does not automatically become the current block. +// comment is an optional string for more readable debugging output. +func (b *builder) newBlock(comment string) *Block { + g := b.cfg + block := &Block{ + index: int32(len(g.Blocks)), + comment: comment, + } + block.Succs = block.succs2[:0] + g.Blocks = append(g.Blocks, block) + return block +} + +func (b *builder) newUnreachableBlock(comment string) *Block { + block := b.newBlock(comment) + block.unreachable = true + return block +} + +func (b *builder) add(n ast.Node) { + b.current.Nodes = append(b.current.Nodes, n) +} + +// jump adds an edge from the current block to the target block, +// and sets b.current to nil. +func (b *builder) jump(target *Block) { + b.current.Succs = append(b.current.Succs, target) + b.current = nil +} + +// ifelse emits edges from the current block to the t and f blocks, +// and sets b.current to nil. +func (b *builder) ifelse(t, f *Block) { + b.current.Succs = append(b.current.Succs, t, f) + b.current = nil +} diff --git a/src/cmd/vet/internal/cfg/cfg.go b/src/cmd/vet/internal/cfg/cfg.go new file mode 100644 index 0000000000..e4d5bfe5d2 --- /dev/null +++ b/src/cmd/vet/internal/cfg/cfg.go @@ -0,0 +1,142 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package constructs a simple control-flow graph (CFG) of the +// statements and expressions within a single function. +// +// Use cfg.New to construct the CFG for a function body. +// +// The blocks of the CFG contain all the function's non-control +// statements. The CFG does not contain control statements such as If, +// Switch, Select, and Branch, but does contain their subexpressions. +// For example, this source code: +// +// if x := f(); x != nil { +// T() +// } else { +// F() +// } +// +// produces this CFG: +// +// 1: x := f() +// x != nil +// succs: 2, 3 +// 2: T() +// succs: 4 +// 3: F() +// succs: 4 +// 4: +// +// The CFG does contain Return statements; even implicit returns are +// materialized (at the position of the function's closing brace). +// +// The CFG does not record conditions associated with conditional branch +// edges, nor the short-circuit semantics of the && and || operators, +// nor abnormal control flow caused by panic. If you need this +// information, use golang.org/x/tools/go/ssa instead. +// +package cfg + +// Although the vet tool has type information, it is often extremely +// fragmentary, so for simplicity this package does not depend on +// go/types. Consequently control-flow conditions are ignored even +// when constant, and "mayReturn" information must be provided by the +// client. +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/token" +) + +// A CFG represents the control-flow graph of a single function. +// +// The entry point is Blocks[0]; there may be multiple return blocks. +type CFG struct { + Blocks []*Block // block[0] is entry; order otherwise undefined +} + +// A Block represents a basic block: a list of statements and +// expressions that are always evaluated sequentially. +// +// A block may have 0-2 successors: zero for a return block or a block +// that calls a function such as panic that never returns; one for a +// normal (jump) block; and two for a conditional (if) block. +type Block struct { + Nodes []ast.Node // statements, expressions, and ValueSpecs + Succs []*Block // successor nodes in the graph + + comment string // for debugging + index int32 // index within CFG.Blocks + unreachable bool // is block of stmts following return/panic/for{} + succs2 [2]*Block // underlying array for Succs +} + +// New returns a new control-flow graph for the specified function body, +// which must be non-nil. +// +// The CFG builder calls mayReturn to determine whether a given function +// call may return. For example, calls to panic, os.Exit, and log.Fatal +// do not return, so the builder can remove infeasible graph edges +// following such calls. The builder calls mayReturn only for a +// CallExpr beneath an ExprStmt. +func New(body *ast.BlockStmt, mayReturn func(*ast.CallExpr) bool) *CFG { + b := builder{ + mayReturn: mayReturn, + cfg: new(CFG), + } + b.current = b.newBlock("entry") + b.stmt(body) + + // Does control fall off the end of the function's body? + // Make implicit return explicit. + if b.current != nil && !b.current.unreachable { + b.add(&ast.ReturnStmt{ + Return: body.End() - 1, + }) + } + + return b.cfg +} + +func (b *Block) String() string { + return fmt.Sprintf("block %d (%s)", b.index, b.comment) +} + +// Return returns the return statement at the end of this block if present, nil otherwise. +func (b *Block) Return() (ret *ast.ReturnStmt) { + if len(b.Nodes) > 0 { + ret, _ = b.Nodes[len(b.Nodes)-1].(*ast.ReturnStmt) + } + return +} + +// Format formats the control-flow graph for ease of debugging. +func (g *CFG) Format(fset *token.FileSet) string { + var buf bytes.Buffer + for _, b := range g.Blocks { + fmt.Fprintf(&buf, ".%d: # %s\n", b.index, b.comment) + for _, n := range b.Nodes { + fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n)) + } + if len(b.Succs) > 0 { + fmt.Fprintf(&buf, "\tsuccs:") + for _, succ := range b.Succs { + fmt.Fprintf(&buf, " %d", succ.index) + } + buf.WriteByte('\n') + } + buf.WriteByte('\n') + } + return buf.String() +} + +func formatNode(fset *token.FileSet, n ast.Node) string { + var buf bytes.Buffer + format.Node(&buf, fset, n) + // Indent secondary lines by a tab. + return string(bytes.Replace(buf.Bytes(), []byte("\n"), []byte("\n\t"), -1)) +} diff --git a/src/cmd/vet/internal/cfg/cfg_test.go b/src/cmd/vet/internal/cfg/cfg_test.go new file mode 100644 index 0000000000..5d98f136bd --- /dev/null +++ b/src/cmd/vet/internal/cfg/cfg_test.go @@ -0,0 +1,184 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cfg + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "testing" +) + +const src = `package main + +import "log" + +func f1() { + live() + return + dead() +} + +func f2() { + for { + live() + } + dead() +} + +func f3() { + if true { // even known values are ignored + return + } + for true { // even known values are ignored + live() + } + for { + live() + } + dead() +} + +func f4(x int) { + switch x { + case 1: + live() + fallthrough + case 2: + live() + log.Fatal() + default: + panic("oops") + } + dead() +} + +func f4(ch chan int) { + select { + case <-ch: + live() + return + default: + live() + panic("oops") + } + dead() +} + +func f5(unknown bool) { + for { + if unknown { + break + } + continue + dead() + } + live() +} + +func f6(unknown bool) { +outer: + for { + for { + break outer + dead() + } + dead() + } + live() +} + +func f7() { + for { + break nosuchlabel + dead() + } + dead() +} + +func f8() { + select{} + dead() +} + +func f9(ch chan int) { + select { + case <-ch: + return + } + dead() +} + +func f10(ch chan int) { + select { + case <-ch: + return + dead() + default: + } + live() +} +` + +func TestDeadCode(t *testing.T) { + // We'll use dead code detection to verify the CFG. + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "dummy.go", src, parser.Mode(0)) + if err != nil { + t.Fatal(err) + } + for _, decl := range f.Decls { + if decl, ok := decl.(*ast.FuncDecl); ok { + g := New(decl.Body, mayReturn) + + // Mark blocks reachable from entry. + live := make(map[*Block]bool) + var visit func(*Block) + visit = func(b *Block) { + if !live[b] { + live[b] = true + for _, succ := range b.Succs { + visit(succ) + } + } + } + visit(g.Blocks[0]) + + // Print statements in unreachable blocks + // (in order determined by builder). + var buf bytes.Buffer + for _, b := range g.Blocks { + if !live[b] { + for _, n := range b.Nodes { + fmt.Fprintf(&buf, "\t%s\n", formatNode(fset, n)) + } + } + } + + // Check that the result contains "dead" at least once but not "live". + if !bytes.Contains(buf.Bytes(), []byte("dead")) || + bytes.Contains(buf.Bytes(), []byte("live")) { + t.Errorf("unexpected dead statements in function %s:\n%s", + decl.Name.Name, + &buf) + t.Logf("control flow graph:\n%s", g.Format(fset)) + } + } + } +} + +// A trivial mayReturn predicate that looks only at syntax, not types. +func mayReturn(call *ast.CallExpr) bool { + switch fun := call.Fun.(type) { + case *ast.Ident: + return fun.Name != "panic" + case *ast.SelectorExpr: + return fun.Sel.Name != "Fatal" + } + return true +} diff --git a/src/cmd/vet/lostcancel.go b/src/cmd/vet/lostcancel.go new file mode 100644 index 0000000000..708b6f3029 --- /dev/null +++ b/src/cmd/vet/lostcancel.go @@ -0,0 +1,296 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "cmd/vet/internal/cfg" + "fmt" + "go/ast" + "go/types" + "strconv" +) + +func init() { + register("lostcancel", + "check for failure to call cancellation function returned by context.WithCancel", + checkLostCancel, + funcDecl, funcLit) +} + +const debugLostCancel = false + +var contextPackage = "context" + +// checkLostCancel reports a failure to the call the cancel function +// returned by context.WithCancel, either because the variable was +// assigned to the blank identifier, or because there exists a +// control-flow path from the call to a return statement and that path +// does not "use" the cancel function. Any reference to the variable +// counts as a use, even within a nested function literal. +// +// checkLostCancel analyzes a single named or literal function. +func checkLostCancel(f *File, node ast.Node) { + // Fast path: bypass check if file doesn't use context.WithCancel. + if !hasImport(f.file, contextPackage) { + return + } + + // Maps each cancel variable to its defining ValueSpec/AssignStmt. + cancelvars := make(map[*types.Var]ast.Node) + + // Find the set of cancel vars to analyze. + stack := make([]ast.Node, 0, 32) + ast.Inspect(node, func(n ast.Node) bool { + switch n.(type) { + case *ast.FuncLit: + if len(stack) > 0 { + return false // don't stray into nested functions + } + case nil: + stack = stack[:len(stack)-1] // pop + return true + } + stack = append(stack, n) // push + + // Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]: + // + // ctx, cancel := context.WithCancel(...) + // ctx, cancel = context.WithCancel(...) + // var ctx, cancel = context.WithCancel(...) + // + if isContextWithCancel(f, n) && isCall(stack[len(stack)-2]) { + var id *ast.Ident // id of cancel var + stmt := stack[len(stack)-3] + switch stmt := stmt.(type) { + case *ast.ValueSpec: + if len(stmt.Names) > 1 { + id = stmt.Names[1] + } + case *ast.AssignStmt: + if len(stmt.Lhs) > 1 { + id, _ = stmt.Lhs[1].(*ast.Ident) + } + } + if id != nil { + if id.Name == "_" { + f.Badf(id.Pos(), "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak", + n.(*ast.SelectorExpr).Sel.Name) + } else if v, ok := f.pkg.uses[id].(*types.Var); ok { + cancelvars[v] = stmt + } else if v, ok := f.pkg.defs[id].(*types.Var); ok { + cancelvars[v] = stmt + } + } + } + + return true + }) + + if len(cancelvars) == 0 { + return // no need to build CFG + } + + // Tell the CFG builder which functions never return. + info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors} + mayReturn := func(call *ast.CallExpr) bool { + name := callName(info, call) + return !noReturnFuncs[name] + } + + // Build the CFG. + var g *cfg.CFG + switch node := node.(type) { + case *ast.FuncDecl: + g = cfg.New(node.Body, mayReturn) + case *ast.FuncLit: + g = cfg.New(node.Body, mayReturn) + } + + // Print CFG. + if debugLostCancel { + fmt.Println(g.Format(f.fset)) + } + + // Examine the CFG for each variable in turn. + // (It would be more efficient to analyze all cancelvars in a + // single pass over the AST, but seldom is there more than one.) + for v, stmt := range cancelvars { + if ret := lostCancelPath(f, g, v, stmt); ret != nil { + lineno := f.fset.Position(stmt.Pos()).Line + f.Badf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name()) + f.Badf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno) + } + } +} + +func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok } + +func hasImport(f *ast.File, path string) bool { + for _, imp := range f.Imports { + v, _ := strconv.Unquote(imp.Path.Value) + if v == path { + return true + } + } + return false +} + +// isContextWithCancel reports whether n is one of the qualified identifiers +// context.With{Cancel,Timeout,Deadline}. +func isContextWithCancel(f *File, n ast.Node) bool { + if sel, ok := n.(*ast.SelectorExpr); ok { + switch sel.Sel.Name { + case "WithCancel", "WithTimeout", "WithDeadline": + if x, ok := sel.X.(*ast.Ident); ok { + if pkgname, ok := f.pkg.uses[x].(*types.PkgName); ok { + return pkgname.Imported().Path() == contextPackage + } + // Import failed, so we can't check package path. + // Just check the local package name (heuristic). + return x.Name == "context" + } + } + } + return false +} + +// lostCancelPath finds a path through the CFG, from stmt (which defines +// the 'cancel' variable v) to a return statement, that doesn't "use" v. +// If it finds one, it returns the return statement (which may be synthetic). +func lostCancelPath(f *File, g *cfg.CFG, v *types.Var, stmt ast.Node) *ast.ReturnStmt { + // uses reports whether stmts contain a "use" of variable v. + uses := func(f *File, v *types.Var, stmts []ast.Node) bool { + found := false + for _, stmt := range stmts { + ast.Inspect(stmt, func(n ast.Node) bool { + if id, ok := n.(*ast.Ident); ok { + if f.pkg.uses[id] == v { + found = true + } + } + return !found + }) + } + return found + } + + // blockUses computes "uses" for each block, caching the result. + memo := make(map[*cfg.Block]bool) + blockUses := func(f *File, v *types.Var, b *cfg.Block) bool { + res, ok := memo[b] + if !ok { + res = uses(f, v, b.Nodes) + memo[b] = res + } + return res + } + + // Find the var's defining block in the CFG, + // plus the rest of the statements of that block. + var defblock *cfg.Block + var rest []ast.Node +outer: + for _, b := range g.Blocks { + for i, n := range b.Nodes { + if n == stmt { + defblock = b + rest = b.Nodes[i+1:] + break outer + } + } + } + if defblock == nil { + panic("internal error: can't find defining block for cancel var") + } + + // Is v "used" in the remainder of its defining block? + if uses(f, v, rest) { + return nil + } + + // Does the defining block return without using v? + if ret := defblock.Return(); ret != nil { + return ret + } + + // Search the CFG depth-first for a path, from defblock to a + // return block, in which v is never "used". + seen := make(map[*cfg.Block]bool) + var search func(blocks []*cfg.Block) *ast.ReturnStmt + search = func(blocks []*cfg.Block) *ast.ReturnStmt { + for _, b := range blocks { + if !seen[b] { + seen[b] = true + + // Prune the search if the block uses v. + if blockUses(f, v, b) { + continue + } + + // Found path to return statement? + if ret := b.Return(); ret != nil { + if debugLostCancel { + fmt.Printf("found path to return in block %s\n", b) + } + return ret // found + } + + // Recur + if ret := search(b.Succs); ret != nil { + if debugLostCancel { + fmt.Printf(" from block %s\n", b) + } + return ret + } + } + } + return nil + } + return search(defblock.Succs) +} + +var noReturnFuncs = map[string]bool{ + "(*testing.common).FailNow": true, + "(*testing.common).Fatal": true, + "(*testing.common).Fatalf": true, + "(*testing.common).Skip": true, + "(*testing.common).SkipNow": true, + "(*testing.common).Skipf": true, + "log.Fatal": true, + "log.Fatalf": true, + "log.Fatalln": true, + "os.Exit": true, + "panic": true, + "runtime.Goexit": true, +} + +// callName returns the canonical name of the builtin, method, or +// function called by call, if known. +func callName(info *types.Info, call *ast.CallExpr) string { + switch fun := call.Fun.(type) { + case *ast.Ident: + // builtin, e.g. "panic" + if obj, ok := info.Uses[fun].(*types.Builtin); ok { + return obj.Name() + } + case *ast.SelectorExpr: + if sel, ok := info.Selections[fun]; ok { + // method call, e.g. "(*testing.common).Fatal" + meth := sel.Obj() + return fmt.Sprintf("(%s).%s", + meth.Type().(*types.Signature).Recv().Type(), + meth.Name()) + } + if obj, ok := info.Uses[fun.Sel]; ok { + // qualified identifier, e.g. "os.Exit" + return fmt.Sprintf("%s.%s", + obj.Pkg().Path(), + obj.Name()) + } + } + + // function with no name, or defined in missing imported package + return "" +} diff --git a/src/cmd/vet/testdata/lostcancel.go b/src/cmd/vet/testdata/lostcancel.go new file mode 100644 index 0000000000..143456e52f --- /dev/null +++ b/src/cmd/vet/testdata/lostcancel.go @@ -0,0 +1,137 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testdata + +import ( + "context" + "log" + "os" + "testing" +) + +// Check the three functions and assignment forms (var, :=, =) we look for. +// (Do these early: line numbers are fragile.) +func _() { + var ctx, cancel = context.WithCancel() // ERROR "the cancel function is not used on all paths \(possible context leak\)" +} // ERROR "this return statement may be reached without using the cancel var defined on line 17" + +func _() { + ctx, cancel2 := context.WithDeadline() // ERROR "the cancel2 function is not used..." +} // ERROR "may be reached without using the cancel2 var defined on line 21" + +func _() { + var ctx context.Context + var cancel3 func() + ctx, cancel3 = context.WithTimeout() // ERROR "function is not used..." +} // ERROR "this return statement may be reached without using the cancel3 var defined on line 27" + +func _() { + ctx, _ := context.WithCancel() // ERROR "the cancel function returned by context.WithCancel should be called, not discarded, to avoid a context leak" + ctx, _ = context.WithTimeout() // ERROR "the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak" + ctx, _ = context.WithDeadline() // ERROR "the cancel function returned by context.WithDeadline should be called, not discarded, to avoid a context leak" +} + +func _() { + ctx, cancel := context.WithCancel() + defer cancel() // ok +} + +func _() { + ctx, cancel := context.WithCancel() // ERROR "not used on all paths" + if condition { + cancel() + } + return // ERROR "this return statement may be reached without using the cancel var" +} + +func _() { + ctx, cancel := context.WithCancel() + if condition { + cancel() + } else { + // ok: infinite loop + for { + print(0) + } + } +} + +func _() { + ctx, cancel := context.WithCancel() // ERROR "not used on all paths" + if condition { + cancel() + } else { + for i := 0; i < 10; i++ { + print(0) + } + } +} // ERROR "this return statement may be reached without using the cancel var" + +func _() { + ctx, cancel := context.WithCancel() + // ok: used on all paths + switch someInt { + case 0: + new(testing.T).FailNow() + case 1: + log.Fatal() + case 2: + cancel() + case 3: + print("hi") + fallthrough + default: + os.Exit(1) + } +} + +func _() { + ctx, cancel := context.WithCancel() // ERROR "not used on all paths" + switch someInt { + case 0: + new(testing.T).FailNow() + case 1: + log.Fatal() + case 2: + cancel() + case 3: + print("hi") // falls through to implicit return + default: + os.Exit(1) + } +} // ERROR "this return statement may be reached without using the cancel var" + +func _(ch chan int) int { + ctx, cancel := context.WithCancel() // ERROR "not used on all paths" + select { + case <-ch: + new(testing.T).FailNow() + case y <- ch: + print("hi") // falls through to implicit return + case ch <- 1: + cancel() + default: + os.Exit(1) + } +} // ERROR "this return statement may be reached without using the cancel var" + +func _(ch chan int) int { + ctx, cancel := context.WithCancel() + // A blocking select must execute one of its cases. + select { + case <-ch: + panic() + } +} + +func _() { + go func() { + ctx, cancel := context.WithCancel() // ERROR "not used on all paths" + print(ctx) + }() // ERROR "may be reached without using the cancel var" +} + +var condition bool +var someInt int