--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains tests for the unmarshal checker.
+
+package testdata
+
+import (
+ "bytes"
+ "encoding/gob"
+ "encoding/json"
+ "encoding/xml"
+ "errors"
+ "fmt"
+)
+
+func _() {
+ type t struct {
+ a int
+ }
+ var v t
+ var r io.Reader
+
+ json.Unmarshal([]byte{}, v) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ json.Unmarshal([]byte{}, &v)
+ json.NewDecoder(r).Decode(v) // ERROR "call of Decode passes non-pointer"
+ json.NewDecoder(r).Decode(&v)
+ gob.NewDecoder(r).Decode(v) // ERROR "call of Decode passes non-pointer"
+ gob.NewDecoder(r).Decode(&v)
+ xml.Unmarshal([]byte{}, v) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ xml.Unmarshal([]byte{}, &v)
+ xml.NewDecoder(r).Decode(v) // ERROR "call of Decode passes non-pointer"
+ xml.NewDecoder(r).Decode(&v)
+
+ var p *t
+ json.Unmarshal([]byte{}, p)
+ json.Unmarshal([]byte{}, *p) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ json.NewDecoder(r).Decode(p)
+ json.NewDecoder(r).Decode(*p) // ERROR "call of Decode passes non-pointer"
+ gob.NewDecoder(r).Decode(p)
+ gob.NewDecoder(r).Decode(*p) // ERROR "call of Decode passes non-pointer"
+ xml.Unmarshal([]byte{}, p)
+ xml.Unmarshal([]byte{}, *p) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ xml.NewDecoder(r).Decode(p)
+ xml.NewDecoder(r).Decode(*p) // ERROR "call of Decode passes non-pointer"
+
+ var i interface{}
+ json.Unmarshal([]byte{}, i)
+ json.NewDecoder(r).Decode(i)
+
+ json.Unmarshal([]byte{}, nil) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ json.Unmarshal([]byte{}, []t{}) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ json.Unmarshal([]byte{}, map[string]int{}) // ERROR "call of Unmarshal passes non-pointer as second argument"
+ json.NewDecoder(r).Decode(nil) // ERROR "call of Decode passes non-pointer"
+ json.NewDecoder(r).Decode([]t{}) // ERROR "call of Decode passes non-pointer"
+ json.NewDecoder(r).Decode(map[string]int{}) // ERROR "call of Decode passes non-pointer"
+
+ json.Unmarshal(func() ([]byte, interface{}) { return []byte{}, v }())
+}
--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file defines the check for passing non-pointer or non-interface
+// types to unmarshal and decode functions.
+
+package main
+
+import (
+ "go/ast"
+ "go/types"
+ "strings"
+)
+
+func init() {
+ register("unmarshal",
+ "check for passing non-pointer or non-interface types to unmarshal and decode functions",
+ checkUnmarshalArg,
+ callExpr)
+}
+
+var pointerArgFuncs = map[string]int{
+ "encoding/json.Unmarshal": 1,
+ "(*encoding/json.Decoder).Decode": 0,
+ "(*encoding/gob.Decoder).Decode": 0,
+ "encoding/xml.Unmarshal": 1,
+ "(*encoding/xml.Decoder).Decode": 0,
+}
+
+func checkUnmarshalArg(f *File, n ast.Node) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok {
+ return // not a call statement
+ }
+ fun := unparen(call.Fun)
+
+ if f.pkg.types[fun].IsType() {
+ return // a conversion, not a call
+ }
+
+ info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors}
+ name := callName(info, call)
+
+ arg, ok := pointerArgFuncs[name]
+ if !ok {
+ return // not a function we are interested in
+ }
+
+ if len(call.Args) < arg+1 {
+ return // not enough arguments, e.g. called with return values of another function
+ }
+
+ typ := f.pkg.types[call.Args[arg]]
+
+ if typ.Type == nil {
+ return // type error prevents further analysis
+ }
+
+ switch typ.Type.Underlying().(type) {
+ case *types.Pointer, *types.Interface:
+ return
+ }
+
+ shortname := name[strings.LastIndexByte(name, '.')+1:]
+ switch arg {
+ case 0:
+ f.Badf(call.Lparen, "call of %s passes non-pointer", shortname)
+ case 1:
+ f.Badf(call.Lparen, "call of %s passes non-pointer as second argument", shortname)
+ }
+}