]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/vet: detect defer resp.Body.Close() before error check
authorFrancesc Campoy <campoy@golang.org>
Tue, 8 Nov 2016 07:37:21 +0000 (23:37 -0800)
committerFrancesc Campoy Flores <campoy@golang.org>
Thu, 10 Nov 2016 20:38:11 +0000 (20:38 +0000)
This check detects the code

resp, err := http.Get("http://foo.com")
defer resp.Body.Close()
if err != nil {
...
}

For every call to a function on the net/http package or any method
on http.Client that returns (*http.Response, error), it checks
whether the next line is a defer statement that calls on the response.

Fixes #17780.

Change-Id: I9d70edcbfa2bad205bf7f45281597d074c795977
Reviewed-on: https://go-review.googlesource.com/32911
Reviewed-by: Rob Pike <r@golang.org>
src/cmd/vet/doc.go
src/cmd/vet/httpresponse.go [new file with mode: 0644]
src/cmd/vet/testdata/httpresponse.go [new file with mode: 0644]

index 2baa53099d6c74f63b8b024fa35cc8dce6a599aa..5cbe116abe134f48415bea22b60dba43fd6ce18d 100644 (file)
@@ -84,12 +84,12 @@ Flag: -copylocks
 
 Locks that are erroneously passed by value.
 
-Tests and documentation examples
+HTTP responses used incorrectly
 
-Flag: -tests
+Flag: -httpresponse
 
-Mistakes involving tests including functions with incorrect names or signatures
-and example tests that document identifiers not in the package.
+Mistakes deferring a function call on an HTTP response before
+checking whether the error returned with the response was nil.
 
 Failure to call the cancelation function returned by WithCancel
 
@@ -162,6 +162,13 @@ Flag: -structtags
 Struct tags that do not follow the format understood by reflect.StructTag.Get.
 Well-known encoding struct tags (json, xml) used with unexported fields.
 
+Tests and documentation examples
+
+Flag: -tests
+
+Mistakes involving tests including functions with incorrect names or signatures
+and example tests that document identifiers not in the package.
+
 Unreachable code
 
 Flag: -unreachable
diff --git a/src/cmd/vet/httpresponse.go b/src/cmd/vet/httpresponse.go
new file mode 100644 (file)
index 0000000..f667edb
--- /dev/null
@@ -0,0 +1,153 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains the check for http.Response values being used before
+// checking for errors.
+
+package main
+
+import (
+       "go/ast"
+       "go/types"
+)
+
+var (
+       httpResponseType types.Type
+       httpClientType   types.Type
+)
+
+func init() {
+       if typ := importType("net/http", "Response"); typ != nil {
+               httpResponseType = typ
+       }
+       if typ := importType("net/http", "Client"); typ != nil {
+               httpClientType = typ
+       }
+       // if http.Response or http.Client are not defined don't register this check.
+       if httpResponseType == nil || httpClientType == nil {
+               return
+       }
+
+       register("httpresponse",
+               "check errors are checked before using an http Response",
+               checkHTTPResponse, callExpr)
+}
+
+func checkHTTPResponse(f *File, node ast.Node) {
+       call := node.(*ast.CallExpr)
+       if !isHTTPFuncOrMethodOnClient(f, call) {
+               return // the function call is not related to this check.
+       }
+
+       finder := &blockStmtFinder{node: call}
+       ast.Walk(finder, f.file)
+       stmts := finder.stmts()
+       if len(stmts) < 2 {
+               return // the call to the http function is the last statement of the block.
+       }
+
+       asg, ok := stmts[0].(*ast.AssignStmt)
+       if !ok {
+               return // the first statement is not assignment.
+       }
+       resp := rootIdent(asg.Lhs[0])
+       if resp == nil {
+               return // could not find the http.Response in the assignment.
+       }
+
+       def, ok := stmts[1].(*ast.DeferStmt)
+       if !ok {
+               return // the following statement is not a defer.
+       }
+       root := rootIdent(def.Call.Fun)
+       if root == nil {
+               return // could not find the receiver of the defer call.
+       }
+
+       if resp.Obj == root.Obj {
+               f.Badf(root.Pos(), "using %s before checking for errors", resp.Name)
+       }
+}
+
+// isHTTPFuncOrMethodOnClient checks whether the given call expression is on
+// either a function of the net/http package or a method of http.Client that
+// returns (*http.Response, error).
+func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool {
+       fun, _ := expr.Fun.(*ast.SelectorExpr)
+       sig, _ := f.pkg.types[fun].Type.(*types.Signature)
+       if sig == nil {
+               return false // the call is not on of the form x.f()
+       }
+
+       res := sig.Results()
+       if res.Len() != 2 {
+               return false // the function called does not return two values.
+       }
+       if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) {
+               return false // the first return type is not *http.Response.
+       }
+       if !types.Identical(res.At(1).Type().Underlying(), errorType) {
+               return false // the second return type is not error
+       }
+
+       typ := f.pkg.types[fun.X].Type
+       if typ == nil {
+               id, ok := fun.X.(*ast.Ident)
+               return ok && id.Name == "http" // function in net/http package.
+       }
+
+       if types.Identical(typ, httpClientType) {
+               return true // method on http.Client.
+       }
+       ptr, ok := typ.(*types.Pointer)
+       return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client.
+}
+
+// blockStmtFinder is an ast.Visitor that given any ast node can find the
+// statement containing it and its succeeding statements in the same block.
+type blockStmtFinder struct {
+       node  ast.Node       // target of search
+       stmt  ast.Stmt       // innermost statement enclosing argument to Visit
+       block *ast.BlockStmt // innermost block enclosing argument to Visit.
+}
+
+// Visit finds f.node performing a search down the ast tree.
+// It keeps the last block statement and statement seen for later use.
+func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor {
+       if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() {
+               return nil // not here
+       }
+       switch n := node.(type) {
+       case *ast.BlockStmt:
+               f.block = n
+       case ast.Stmt:
+               f.stmt = n
+       }
+       if f.node.Pos() == node.Pos() && f.node.End() == node.End() {
+               return nil // found
+       }
+       return f // keep looking
+}
+
+// stmts returns the statements of f.block starting from the one including f.node.
+func (f *blockStmtFinder) stmts() []ast.Stmt {
+       for i, v := range f.block.List {
+               if f.stmt == v {
+                       return f.block.List[i:]
+               }
+       }
+       return nil
+}
+
+// rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found.
+func rootIdent(n ast.Node) *ast.Ident {
+       switch n := n.(type) {
+       case *ast.SelectorExpr:
+               return rootIdent(n.X)
+       case *ast.Ident:
+               return n
+       default:
+               return nil
+       }
+}
diff --git a/src/cmd/vet/testdata/httpresponse.go b/src/cmd/vet/testdata/httpresponse.go
new file mode 100644 (file)
index 0000000..7302a64
--- /dev/null
@@ -0,0 +1,85 @@
+package testdata
+
+import (
+       "log"
+       "net/http"
+)
+
+func goodHTTPGet() {
+       res, err := http.Get("http://foo.com")
+       if err != nil {
+               log.Fatal(err)
+       }
+       defer res.Body.Close()
+}
+
+func badHTTPGet() {
+       res, err := http.Get("http://foo.com")
+       defer res.Body.Close() // ERROR "using res before checking for errors"
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func badHTTPHead() {
+       res, err := http.Head("http://foo.com")
+       defer res.Body.Close() // ERROR "using res before checking for errors"
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func goodClientGet() {
+       client := http.DefaultClient
+       res, err := client.Get("http://foo.com")
+       if err != nil {
+               log.Fatal(err)
+       }
+       defer res.Body.Close()
+}
+
+func badClientPtrGet() {
+       client := http.DefaultClient
+       resp, err := client.Get("http://foo.com")
+       defer resp.Body.Close() // ERROR "using resp before checking for errors"
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func badClientGet() {
+       client := http.Client{}
+       resp, err := client.Get("http://foo.com")
+       defer resp.Body.Close() // ERROR "using resp before checking for errors"
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func badClientPtrDo() {
+       client := http.DefaultClient
+       req, err := http.NewRequest("GET", "http://foo.com", nil)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       resp, err := client.Do(req)
+       defer resp.Body.Close() // ERROR "using resp before checking for errors"
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func badClientDo() {
+       var client http.Client
+       req, err := http.NewRequest("GET", "http://foo.com", nil)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       resp, err := client.Do(req)
+       defer resp.Body.Close() // ERROR "using resp before checking for errors"
+       if err != nil {
+               log.Fatal(err)
+       }
+}