]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/cgo, runtime: add checks for passing pointers from Go to C
authorIan Lance Taylor <iant@golang.org>
Fri, 16 Oct 2015 22:26:00 +0000 (15:26 -0700)
committerIan Lance Taylor <iant@golang.org>
Tue, 10 Nov 2015 22:22:10 +0000 (22:22 +0000)
This implements part of the proposal in issue 12416 by adding dynamic
checks for passing pointers from Go to C.  This code is intended to be
on at all times.  It does not try to catch every case.  It does not
implement checks on calling Go functions from C.

The new cgo checks may be disabled using GODEBUG=cgocheck=0.

Update #12416.

Change-Id: I48de130e7e2e83fb99a1e176b2c856be38a4d3c8
Reviewed-on: https://go-review.googlesource.com/16003
Reviewed-by: Russ Cox <rsc@golang.org>
misc/cgo/errors/ptr.go [new file with mode: 0644]
misc/cgo/errors/test.bash
misc/cgo/test/callback.go
src/cmd/cgo/ast.go
src/cmd/cgo/doc.go
src/cmd/cgo/gcc.go
src/cmd/cgo/main.go
src/cmd/cgo/out.go
src/runtime/cgocall.go
src/runtime/runtime1.go
src/runtime/type.go

diff --git a/misc/cgo/errors/ptr.go b/misc/cgo/errors/ptr.go
new file mode 100644 (file)
index 0000000..b417d48
--- /dev/null
@@ -0,0 +1,267 @@
+// Copyright 2015 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests that cgo detects invalid pointer passing at runtime.
+
+package main
+
+import (
+       "bufio"
+       "bytes"
+       "fmt"
+       "io"
+       "io/ioutil"
+       "os"
+       "os/exec"
+       "path/filepath"
+       "runtime"
+       "strings"
+       "sync"
+)
+
+// ptrTest is the tests without the boilerplate.
+type ptrTest struct {
+       c       string   // the cgo comment
+       imports []string // a list of imports
+       support string   // supporting functions
+       body    string   // the body of the main function
+       fail    bool     // whether the test should fail
+}
+
+var ptrTests = []ptrTest{
+       {
+               // Passing a pointer to a struct that contains a Go pointer.
+               c:    `typedef struct s { int *p; } s; void f(s *ps) {}`,
+               body: `C.f(&C.s{new(C.int)})`,
+               fail: true,
+       },
+       {
+               // Passing a pointer to a struct that contains a Go pointer.
+               c:    `typedef struct s { int *p; } s; void f(s *ps) {}`,
+               body: `p := &C.s{new(C.int)}; C.f(p)`,
+               fail: true,
+       },
+       {
+               // Passing a pointer to an int field of a Go struct
+               // that (irrelevantly) contains a Go pointer.
+               c:    `struct s { int i; int *p; }; void f(int *p) {}`,
+               body: `p := &C.struct_s{i: 0, p: new(C.int)}; C.f(&p.i)`,
+               fail: false,
+       },
+       {
+               // Passing a pointer to a pointer field of a Go struct.
+               c:    `struct s { int i; int *p; }; void f(int **p) {}`,
+               body: `p := &C.struct_s{i: 0, p: new(C.int)}; C.f(&p.p)`,
+               fail: true,
+       },
+       {
+               // Passing a pointer to a pointer field of a Go
+               // struct, where the field does not contain a Go
+               // pointer, but another field (irrelevantly) does.
+               c:    `struct s { int *p1; int *p2; }; void f(int **p) {}`,
+               body: `p := &C.struct_s{p1: nil, p2: new(C.int)}; C.f(&p.p1)`,
+               fail: false,
+       },
+       {
+               // Passing the address of a slice with no Go pointers.
+               c:       `void f(void **p) {}`,
+               imports: []string{"unsafe"},
+               body:    `s := []unsafe.Pointer{nil}; C.f(&s[0])`,
+               fail:    false,
+       },
+       {
+               // Passing the address of a slice with a Go pointer.
+               c:       `void f(void **p) {}`,
+               imports: []string{"unsafe"},
+               body:    `i := 0; s := []unsafe.Pointer{unsafe.Pointer(&i)}; C.f(&s[0])`,
+               fail:    true,
+       },
+       {
+               // Passing the address of a slice with a Go pointer,
+               // where we are passing the address of an element that
+               // is not a Go pointer.
+               c:       `void f(void **p) {}`,
+               imports: []string{"unsafe"},
+               body:    `i := 0; s := []unsafe.Pointer{nil, unsafe.Pointer(&i)}; C.f(&s[0])`,
+               fail:    true,
+       },
+       {
+               // Passing the address of a slice that is an element
+               // in a struct only looks at the slice.
+               c:       `void f(void **p) {}`,
+               imports: []string{"unsafe"},
+               support: `type S struct { p *int; s []unsafe.Pointer }`,
+               body:    `i := 0; p := &S{p:&i, s:[]unsafe.Pointer{nil}}; C.f(&p.s[0])`,
+               fail:    false,
+       },
+       {
+               // Passing the address of a static variable with no
+               // pointers doesn't matter.
+               c:       `void f(char** parg) {}`,
+               support: `var hello = [...]C.char{'h', 'e', 'l', 'l', 'o'}`,
+               body:    `parg := [1]*C.char{&hello[0]}; C.f(&parg[0])`,
+               fail:    false,
+       },
+       {
+               // Passing the address of a static variable with
+               // pointers does matter.
+               c:       `void f(char*** parg) {}`,
+               support: `var hello = [...]*C.char{new(C.char)}`,
+               body:    `parg := [1]**C.char{&hello[0]}; C.f(&parg[0])`,
+               fail:    true,
+       },
+}
+
+func main() {
+       os.Exit(doTests())
+}
+
+func doTests() int {
+       dir, err := ioutil.TempDir("", "cgoerrors")
+       if err != nil {
+               fmt.Fprintln(os.Stderr, err)
+               return 2
+       }
+       defer os.RemoveAll(dir)
+
+       workers := runtime.NumCPU() + 1
+
+       var wg sync.WaitGroup
+       c := make(chan int)
+       errs := make(chan int)
+       for i := 0; i < workers; i++ {
+               wg.Add(1)
+               go func() {
+                       worker(dir, c, errs)
+                       wg.Done()
+               }()
+       }
+
+       for i := range ptrTests {
+               c <- i
+       }
+       close(c)
+
+       go func() {
+               wg.Wait()
+               close(errs)
+       }()
+
+       tot := 0
+       for e := range errs {
+               tot += e
+       }
+       return tot
+}
+
+func worker(dir string, c, errs chan int) {
+       e := 0
+       for i := range c {
+               if !doOne(dir, i) {
+                       e++
+               }
+       }
+       if e > 0 {
+               errs <- e
+       }
+}
+
+func doOne(dir string, i int) bool {
+       t := &ptrTests[i]
+
+       name := filepath.Join(dir, fmt.Sprintf("t%d.go", i))
+       f, err := os.Create(name)
+       if err != nil {
+               fmt.Fprintln(os.Stderr, err)
+               return false
+       }
+
+       b := bufio.NewWriter(f)
+       fmt.Fprintln(b, `package main`)
+       fmt.Fprintln(b)
+       fmt.Fprintln(b, `/*`)
+       fmt.Fprintln(b, t.c)
+       fmt.Fprintln(b, `*/`)
+       fmt.Fprintln(b, `import "C"`)
+       fmt.Fprintln(b)
+       for _, imp := range t.imports {
+               fmt.Fprintln(b, `import "`+imp+`"`)
+       }
+       if len(t.imports) > 0 {
+               fmt.Fprintln(b)
+       }
+       if len(t.support) > 0 {
+               fmt.Fprintln(b, t.support)
+               fmt.Fprintln(b)
+       }
+       fmt.Fprintln(b, `func main() {`)
+       fmt.Fprintln(b, t.body)
+       fmt.Fprintln(b, `}`)
+
+       if err := b.Flush(); err != nil {
+               fmt.Fprintf(os.Stderr, "flushing %s: %v\n", name, err)
+               return false
+       }
+       if err := f.Close(); err != nil {
+               fmt.Fprintln(os.Stderr, "closing %s: %v\n", name, err)
+               return false
+       }
+
+       cmd := exec.Command("go", "run", name)
+       cmd.Dir = dir
+       buf, err := cmd.CombinedOutput()
+
+       ok := true
+       if t.fail {
+               if err == nil {
+                       var errbuf bytes.Buffer
+                       fmt.Fprintf(&errbuf, "test %d did not fail as expected\n", i)
+                       reportTestOutput(&errbuf, i, buf)
+                       os.Stderr.Write(errbuf.Bytes())
+                       ok = false
+               } else if !bytes.Contains(buf, []byte("Go pointer")) {
+                       var errbuf bytes.Buffer
+                       fmt.Fprintf(&errbuf, "test %d output does not contain expected error\n", i)
+                       reportTestOutput(&errbuf, i, buf)
+                       os.Stderr.Write(errbuf.Bytes())
+                       ok = false
+               }
+       } else {
+               if err != nil {
+                       var errbuf bytes.Buffer
+                       fmt.Fprintf(&errbuf, "test %d failed unexpectedly: %v\n", i, err)
+                       reportTestOutput(&errbuf, i, buf)
+                       os.Stderr.Write(errbuf.Bytes())
+                       ok = false
+               }
+       }
+
+       if t.fail && ok {
+               cmd = exec.Command("go", "run", name)
+               cmd.Dir = dir
+               env := []string{"GODEBUG=cgocheck=0"}
+               for _, e := range os.Environ() {
+                       if !strings.HasPrefix(e, "GODEBUG=") {
+                               env = append(env, e)
+                       }
+               }
+               cmd.Env = env
+               buf, err := cmd.CombinedOutput()
+               if err != nil {
+                       var errbuf bytes.Buffer
+                       fmt.Fprintf(&errbuf, "test %d failed unexpectedly with GODEBUG=cgocheck=0: %v\n", i, err)
+                       reportTestOutput(&errbuf, i, buf)
+                       os.Stderr.Write(errbuf.Bytes())
+                       ok = false
+               }
+       }
+
+       return ok
+}
+
+func reportTestOutput(w io.Writer, i int, buf []byte) {
+       fmt.Fprintf(w, "=== test %d output ===\n", i)
+       fmt.Fprintf(w, "%s", buf)
+       fmt.Fprintf(w, "=== end of test %d output ===\n", i)
+}
index 25ab2499408e91f5d327969e1a7d78fd20279665..a061419992fa32efd44b55fc15563ab40f9a15c0 100755 (executable)
@@ -34,5 +34,9 @@ check issue8442.go
 check issue11097a.go
 check issue11097b.go
 
+if ! go run ptr.go; then
+       exit 1
+fi
+
 rm -rf errs _obj
 exit 0
index 3967e711d1efd4c2bdd9742a7f655b231731b9ac..7ead6b38c1c6318487ebc684c5aec7524f13eaf7 100644 (file)
@@ -19,20 +19,47 @@ import (
        "path"
        "runtime"
        "strings"
+       "sync"
        "testing"
        "unsafe"
 )
 
+// Pass a func value from nestedCall to goCallback using an integer token.
+var callbackMutex sync.Mutex
+var callbackToken int
+var callbackFuncs = make(map[int]func())
+
 // nestedCall calls into C, back into Go, and finally to f.
 func nestedCall(f func()) {
-       // NOTE: Depends on representation of f.
        // callback(x) calls goCallback(x)
-       C.callback(*(*unsafe.Pointer)(unsafe.Pointer(&f)))
+       callbackMutex.Lock()
+       callbackToken++
+       i := callbackToken
+       callbackFuncs[i] = f
+       callbackMutex.Unlock()
+
+       // Pass the address of i because the C function was written to
+       // take a pointer.  We could pass an int if we felt like
+       // rewriting the C code.
+       C.callback(unsafe.Pointer(&i))
+
+       callbackMutex.Lock()
+       delete(callbackFuncs, i)
+       callbackMutex.Unlock()
 }
 
 //export goCallback
 func goCallback(p unsafe.Pointer) {
-       (*(*func())(unsafe.Pointer(&p)))()
+       i := *(*int)(p)
+
+       callbackMutex.Lock()
+       f := callbackFuncs[i]
+       callbackMutex.Unlock()
+
+       if f == nil {
+               panic("missing callback function")
+       }
+       f()
 }
 
 func testCallback(t *testing.T) {
index 8bbd1cc52e66e06edf8ec04665b9da2757d9ac03..c3a24c2b76315ca6ad1439694d415e67888be997 100644 (file)
@@ -124,7 +124,7 @@ func (f *File) ReadGo(name string) {
        if f.Ref == nil {
                f.Ref = make([]*Ref, 0, 8)
        }
-       f.walk(ast2, "prog", (*File).saveRef)
+       f.walk(ast2, "prog", (*File).saveExprs)
 
        // Accumulate exported functions.
        // The comments are only on ast1 but we need to
@@ -163,52 +163,72 @@ func commentText(g *ast.CommentGroup) string {
        return strings.Join(pieces, "")
 }
 
+// Save various references we are going to need later.
+func (f *File) saveExprs(x interface{}, context string) {
+       switch x := x.(type) {
+       case *ast.Expr:
+               switch (*x).(type) {
+               case *ast.SelectorExpr:
+                       f.saveRef(x, context)
+               }
+       case *ast.CallExpr:
+               f.saveCall(x)
+       }
+}
+
 // Save references to C.xxx for later processing.
-func (f *File) saveRef(x interface{}, context string) {
-       n, ok := x.(*ast.Expr)
-       if !ok {
+func (f *File) saveRef(n *ast.Expr, context string) {
+       sel := (*n).(*ast.SelectorExpr)
+       // For now, assume that the only instance of capital C is when
+       // used as the imported package identifier.
+       // The parser should take care of scoping in the future, so
+       // that we will be able to distinguish a "top-level C" from a
+       // local C.
+       if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
                return
        }
-       if sel, ok := (*n).(*ast.SelectorExpr); ok {
-               // For now, assume that the only instance of capital C is
-               // when used as the imported package identifier.
-               // The parser should take care of scoping in the future,
-               // so that we will be able to distinguish a "top-level C"
-               // from a local C.
-               if l, ok := sel.X.(*ast.Ident); ok && l.Name == "C" {
-                       if context == "as2" {
-                               context = "expr"
-                       }
-                       if context == "embed-type" {
-                               error_(sel.Pos(), "cannot embed C type")
-                       }
-                       goname := sel.Sel.Name
-                       if goname == "errno" {
-                               error_(sel.Pos(), "cannot refer to errno directly; see documentation")
-                               return
-                       }
-                       if goname == "_CMalloc" {
-                               error_(sel.Pos(), "cannot refer to C._CMalloc; use C.malloc")
-                               return
-                       }
-                       if goname == "malloc" {
-                               goname = "_CMalloc"
-                       }
-                       name := f.Name[goname]
-                       if name == nil {
-                               name = &Name{
-                                       Go: goname,
-                               }
-                               f.Name[goname] = name
-                       }
-                       f.Ref = append(f.Ref, &Ref{
-                               Name:    name,
-                               Expr:    n,
-                               Context: context,
-                       })
-                       return
+       if context == "as2" {
+               context = "expr"
+       }
+       if context == "embed-type" {
+               error_(sel.Pos(), "cannot embed C type")
+       }
+       goname := sel.Sel.Name
+       if goname == "errno" {
+               error_(sel.Pos(), "cannot refer to errno directly; see documentation")
+               return
+       }
+       if goname == "_CMalloc" {
+               error_(sel.Pos(), "cannot refer to C._CMalloc; use C.malloc")
+               return
+       }
+       if goname == "malloc" {
+               goname = "_CMalloc"
+       }
+       name := f.Name[goname]
+       if name == nil {
+               name = &Name{
+                       Go: goname,
                }
+               f.Name[goname] = name
+       }
+       f.Ref = append(f.Ref, &Ref{
+               Name:    name,
+               Expr:    n,
+               Context: context,
+       })
+}
+
+// Save calls to C.xxx for later processing.
+func (f *File) saveCall(call *ast.CallExpr) {
+       sel, ok := call.Fun.(*ast.SelectorExpr)
+       if !ok {
+               return
+       }
+       if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
+               return
        }
+       f.Calls = append(f.Calls, call)
 }
 
 // If a function should be exported add it to ExpFunc.
index a4675bd448342b0ea09ab27e1c4d4484d2e16418..84826784acbe3840bbd0e1614fb3373036400105 100644 (file)
@@ -222,6 +222,51 @@ definitions and declarations, then the two output files will produce
 duplicate symbols and the linker will fail. To avoid this, definitions
 must be placed in preambles in other files, or in C source files.
 
+Passing pointers
+
+Go is a garbage collected language, and the garbage collector needs to
+know the location of every pointer to Go memory.  Because of this,
+there are restrictions on passing pointers between Go and C.
+
+In this section the term Go pointer means a pointer to memory
+allocated by Go (such as by using the & operator or calling the
+predefined new function) and the term C pointer means a pointer to
+memory allocated by C (such as by a call to C.malloc).  Whether a
+pointer is a Go pointer or a C pointer is a dynamic property
+determined by how the memory was allocated; it has nothing to do with
+the type of the pointer.
+
+Go code may pass a Go pointer to C provided the Go memory to which it
+points does not contain any Go pointers.  The C code must preserve
+this property: it must not store any Go pointers into Go memory, even
+temporarily.  When passing a pointer to a field in a struct, the Go
+memory in question is the memory occupied by the field, not the entire
+struct.  When passing a pointer to an element in an array or slice,
+the Go memory in question is the entire array or the entire backing
+array of the slice.
+
+C code may not keep a copy of a Go pointer after the call returns.
+
+If Go code passes a Go pointer to a C function, the C function must
+return.  There is no specific time limit, but a C function that simply
+blocks holding a Go pointer while other goroutines are running may
+eventually cause the program to run out of memory and fail (because
+the garbage collector may not be able to make progress).
+
+A Go function called by C code may not return a Go pointer.  A Go
+function called by C code may take C pointers as arguments, and it may
+store non-pointer or C pointer data through those pointers, but it may
+not store a Go pointer into memory pointed to by a C pointer.  A Go
+function called by C code may take a Go pointer as an argument, but it
+must preserve the property that the Go memory to which it points does
+not contain any Go pointers.
+
+These rules are partially enforced by cgo by default.  It is possible
+to defeat this enforcement by using the unsafe package, and of course
+there is nothing stopping the C code from doing anything it likes.
+However, programs that break these rules are likely to fail in
+unexpected and unpredictable ways.
+
 Using cgo directly
 
 Usage:
index 198c05452f372253a26ee5c3c4c9dc622d4603d9..5173b2d0f69c65ebad144a94345f9f0160db0223 100644 (file)
@@ -167,6 +167,7 @@ func (p *Package) Translate(f *File) {
        if len(needType) > 0 {
                p.loadDWARF(f, needType)
        }
+       p.rewriteCalls(f)
        p.rewriteRef(f)
 }
 
@@ -570,6 +571,278 @@ func (p *Package) mangleName(n *Name) {
        n.Mangle = prefix + n.Kind + "_" + n.Go
 }
 
+// rewriteCalls rewrites all calls that pass pointers to check that
+// they follow the rules for passing pointers between Go and C.
+func (p *Package) rewriteCalls(f *File) {
+       for _, call := range f.Calls {
+               // This is a call to C.xxx; set goname to "xxx".
+               goname := call.Fun.(*ast.SelectorExpr).Sel.Name
+               if goname == "malloc" {
+                       continue
+               }
+               name := f.Name[goname]
+               if name.Kind != "func" {
+                       // Probably a type conversion.
+                       continue
+               }
+               p.rewriteCall(f, call, name)
+       }
+}
+
+// rewriteCall rewrites one call to add pointer checks.  We replace
+// each pointer argument x with _cgoCheckPointer(x).(T).
+func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) {
+       for i, param := range name.FuncType.Params {
+               // An untyped nil does not need a pointer check, and
+               // when _cgoCheckPointer returns the untyped nil the
+               // type assertion we are going to insert will fail.
+               // Easier to just skip nil arguments.
+               // TODO: Note that this fails if nil is shadowed.
+               if id, ok := call.Args[i].(*ast.Ident); ok && id.Name == "nil" {
+                       continue
+               }
+
+               if !p.needsPointerCheck(f, param.Go) {
+                       continue
+               }
+
+               if len(call.Args) <= i {
+                       // Avoid a crash; this will be caught when the
+                       // generated file is compiled.
+                       return
+               }
+
+               c := &ast.CallExpr{
+                       Fun: ast.NewIdent("_cgoCheckPointer"),
+                       Args: []ast.Expr{
+                               call.Args[i],
+                       },
+               }
+
+               // Add optional additional arguments for an address
+               // expression.
+               if u, ok := call.Args[i].(*ast.UnaryExpr); ok && u.Op == token.AND {
+                       c.Args = p.checkAddrArgs(f, c.Args, u.X)
+               }
+
+               // _cgoCheckPointer returns interface{}.
+               // We need to type assert that to the type we want.
+               // If the Go version of this C type uses
+               // unsafe.Pointer, we can't use a type assertion,
+               // because the Go file might not import unsafe.
+               // Instead we use a local variant of _cgoCheckPointer.
+
+               var arg ast.Expr
+               if n := p.unsafeCheckPointerName(param.Go); n != "" {
+                       c.Fun = ast.NewIdent(n)
+                       arg = c
+               } else {
+                       // In order for the type assertion to succeed,
+                       // we need it to match the actual type of the
+                       // argument.  The only type we have is the
+                       // type of the function parameter.  We know
+                       // that the argument type must be assignable
+                       // to the function parameter type, or the code
+                       // would not compile, but there is nothing
+                       // requiring that the types be exactly the
+                       // same.  Add a type conversion to the
+                       // argument so that the type assertion will
+                       // succeed.
+                       c.Args[0] = &ast.CallExpr{
+                               Fun: param.Go,
+                               Args: []ast.Expr{
+                                       c.Args[0],
+                               },
+                       }
+
+                       arg = &ast.TypeAssertExpr{
+                               X:    c,
+                               Type: param.Go,
+                       }
+               }
+
+               call.Args[i] = arg
+       }
+}
+
+// needsPointerCheck returns whether the type t needs a pointer check.
+// This is true if t is a pointer and if the value to which it points
+// might contain a pointer.
+func (p *Package) needsPointerCheck(f *File, t ast.Expr) bool {
+       return p.hasPointer(f, t, true)
+}
+
+// hasPointer is used by needsPointerCheck.  If top is true it returns
+// whether t is or contains a pointer that might point to a pointer.
+// If top is false it returns whether t is or contains a pointer.
+func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
+       switch t := t.(type) {
+       case *ast.ArrayType:
+               if t.Len == nil {
+                       if !top {
+                               return true
+                       }
+                       return p.hasPointer(f, t.Elt, false)
+               }
+               return p.hasPointer(f, t.Elt, top)
+       case *ast.StructType:
+               for _, field := range t.Fields.List {
+                       if p.hasPointer(f, field.Type, top) {
+                               return true
+                       }
+               }
+               return false
+       case *ast.StarExpr: // Pointer type.
+               if !top {
+                       return true
+               }
+               return p.hasPointer(f, t.X, false)
+       case *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
+               return true
+       case *ast.Ident:
+               // TODO: Handle types defined within function.
+               for _, d := range p.Decl {
+                       gd, ok := d.(*ast.GenDecl)
+                       if !ok || gd.Tok != token.TYPE {
+                               continue
+                       }
+                       for _, spec := range gd.Specs {
+                               ts, ok := spec.(*ast.TypeSpec)
+                               if !ok {
+                                       continue
+                               }
+                               if ts.Name.Name == t.Name {
+                                       return p.hasPointer(f, ts.Type, top)
+                               }
+                       }
+               }
+               if def := typedef[t.Name]; def != nil {
+                       return p.hasPointer(f, def.Go, top)
+               }
+               if t.Name == "string" {
+                       return !top
+               }
+               if t.Name == "error" {
+                       return true
+               }
+               if goTypes[t.Name] != nil {
+                       return false
+               }
+               // We can't figure out the type.  Conservative
+               // approach is to assume it has a pointer.
+               return true
+       case *ast.SelectorExpr:
+               if l, ok := t.X.(*ast.Ident); !ok || l.Name != "C" {
+                       // Type defined in a different package.
+                       // Conservative approach is to assume it has a
+                       // pointer.
+                       return true
+               }
+               name := f.Name[t.Sel.Name]
+               if name != nil && name.Kind == "type" && name.Type != nil && name.Type.Go != nil {
+                       return p.hasPointer(f, name.Type.Go, top)
+               }
+               // We can't figure out the type.  Conservative
+               // approach is to assume it has a pointer.
+               return true
+       default:
+               error_(t.Pos(), "could not understand type %s", gofmt(t))
+               return true
+       }
+}
+
+// checkAddrArgs tries to add arguments to the call of
+// _cgoCheckPointer when the argument is an address expression.  We
+// pass true to mean that the argument is an address operation of
+// something other than a slice index, which means that it's only
+// necessary to check the specific element pointed to, not the entire
+// object.  This is for &s.f, where f is a field in a struct.  We can
+// pass a slice or array, meaning that we should check the entire
+// slice or array but need not check any other part of the object.
+// This is for &s.a[i], where we need to check all of a.  However, we
+// only pass the slice or array if we can refer to it without side
+// effects.
+func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr {
+       index, ok := x.(*ast.IndexExpr)
+       if !ok {
+               // This is the address of something that is not an
+               // index expression.  We only need to examine the
+               // single value to which it points.
+               // TODO: what is true is shadowed?
+               return append(args, ast.NewIdent("true"))
+       }
+       if !p.hasSideEffects(f, index.X) {
+               // Examine the entire slice.
+               return append(args, index.X)
+       }
+       // Treat the pointer as unknown.
+       return args
+}
+
+// hasSideEffects returns whether the expression x has any side
+// effects.  x is an expression, not a statement, so the only side
+// effect is a function call.
+func (p *Package) hasSideEffects(f *File, x ast.Expr) bool {
+       found := false
+       f.walk(x, "expr",
+               func(f *File, x interface{}, context string) {
+                       switch x.(type) {
+                       case *ast.CallExpr:
+                               found = true
+                       }
+               })
+       return found
+}
+
+// unsafeCheckPointerName is given the Go version of a C type.  If the
+// type uses unsafe.Pointer, we arrange to build a version of
+// _cgoCheckPointer that returns that type.  This avoids using a type
+// assertion to unsafe.Pointer in our copy of user code.  We return
+// the name of the _cgoCheckPointer function we are going to build, or
+// the empty string if the type does not use unsafe.Pointer.
+func (p *Package) unsafeCheckPointerName(t ast.Expr) string {
+       if !p.hasUnsafePointer(t) {
+               return ""
+       }
+       var buf bytes.Buffer
+       conf.Fprint(&buf, fset, t)
+       s := buf.String()
+       for i, t := range p.CgoChecks {
+               if s == t {
+                       return p.unsafeCheckPointerNameIndex(i)
+               }
+       }
+       p.CgoChecks = append(p.CgoChecks, s)
+       return p.unsafeCheckPointerNameIndex(len(p.CgoChecks) - 1)
+}
+
+// hasUnsafePointer returns whether the Go type t uses unsafe.Pointer.
+// t is the Go version of a C type, so we don't need to handle every case.
+// We only care about direct references, not references via typedefs.
+func (p *Package) hasUnsafePointer(t ast.Expr) bool {
+       switch t := t.(type) {
+       case *ast.Ident:
+               return t.Name == "unsafe.Pointer"
+       case *ast.ArrayType:
+               return p.hasUnsafePointer(t.Elt)
+       case *ast.StructType:
+               for _, f := range t.Fields.List {
+                       if p.hasUnsafePointer(f.Type) {
+                               return true
+                       }
+               }
+       case *ast.StarExpr: // Pointer type.
+               return p.hasUnsafePointer(t.X)
+       }
+       return false
+}
+
+// unsafeCheckPointerNameIndex returns the name to use for a
+// _cgoCheckPointer variant based on the index in the CgoChecks slice.
+func (p *Package) unsafeCheckPointerNameIndex(i int) string {
+       return fmt.Sprintf("_cgoCheckPointer%d", i)
+}
+
 // rewriteRef rewrites all the C.xxx references in f.AST to refer to the
 // Go equivalents, now that we have figured out the meaning of all
 // the xxx.  In *godefs mode, rewriteRef replaces the names
index 5e7520db04e55e83d9f392a933e9fd18cdf6ec4d..3f8b7f816a81da2dedc2970218915a0be6f96e7f 100644 (file)
@@ -42,6 +42,7 @@ type Package struct {
        GoFiles     []string // list of Go files
        GccFiles    []string // list of gcc output files
        Preamble    string   // collected preamble for _cgo_export.h
+       CgoChecks   []string // see unsafeCheckPointerName
 }
 
 // A File collects information about a single Go input file.
@@ -51,6 +52,7 @@ type File struct {
        Package  string              // Package name
        Preamble string              // C preamble (doc comment on import "C")
        Ref      []*Ref              // all references to C.xxx in AST
+       Calls    []*ast.CallExpr     // all calls to C.xxx in AST
        ExpFunc  []*ExpFunc          // exported functions for this file
        Name     map[string]*Name    // map from Go name to Name
 }
index 86184e5df42fa5e3086f520bc9c8a8b142973335..a6184f3b62e0f61ae42deb04bb61c6a37b1ff316 100644 (file)
@@ -108,6 +108,13 @@ func (p *Package) writeDefs() {
                fmt.Fprint(fgo2, goProlog)
        }
 
+       for i, t := range p.CgoChecks {
+               n := p.unsafeCheckPointerNameIndex(i)
+               fmt.Fprintf(fgo2, "\nfunc %s(p interface{}, args ...interface{}) %s {\n", n, t)
+               fmt.Fprintf(fgo2, "\treturn _cgoCheckPointer(p, args...).(%s)\n", t)
+               fmt.Fprintf(fgo2, "}\n")
+       }
+
        gccgoSymbolPrefix := p.gccgoSymbolPrefix()
 
        cVars := make(map[string]bool)
@@ -1241,6 +1248,9 @@ func _cgo_runtime_cmalloc(uintptr) unsafe.Pointer
 
 //go:linkname _cgo_runtime_cgocallback runtime.cgocallback
 func _cgo_runtime_cgocallback(unsafe.Pointer, unsafe.Pointer, uintptr)
+
+//go:linkname _cgoCheckPointer runtime.cgoCheckPointer
+func _cgoCheckPointer(interface{}, ...interface{}) interface{}
 `
 
 const goStringDef = `
index d39e6602465aead2657374d7000921d58b5e649e..4ce778fc05daf7deb8de6c22d7ad14126fdcaaea 100644 (file)
@@ -302,3 +302,252 @@ func cgounimpl() {
 }
 
 var racecgosync uint64 // represents possible synchronization in C code
+
+// Pointer checking for cgo code.
+
+// We want to detect all cases where a program that does not use
+// unsafe makes a cgo call passing a Go pointer to memory that
+// contains a Go pointer.  Here a Go pointer is defined as a pointer
+// to memory allocated by the Go runtime.  Programs that use unsafe
+// can evade this restriction easily, so we don't try to catch them.
+// The cgo program will rewrite all possibly bad pointer arguments to
+// call cgoCheckPointer, where we can catch cases of a Go pointer
+// pointing to a Go pointer.
+
+// Complicating matters, taking the address of a slice or array
+// element permits the C program to access all elements of the slice
+// or array.  In that case we will see a pointer to a single element,
+// but we need to check the entire data structure.
+
+// The cgoCheckPointer call takes additional arguments indicating that
+// it was called on an address expression.  An additional argument of
+// true means that it only needs to check a single element.  An
+// additional argument of a slice or array means that it needs to
+// check the entire slice/array, but nothing else.  Otherwise, the
+// pointer could be anything, and we check the entire heap object,
+// which is conservative but safe.
+
+// When and if we implement a moving garbage collector,
+// cgoCheckPointer will pin the pointer for the duration of the cgo
+// call.  (This is necessary but not sufficient; the cgo program will
+// also have to change to pin Go pointers that can not point to Go
+// pointers.)
+
+// cgoCheckPointer checks if the argument contains a Go pointer that
+// points to a Go pointer, and panics if it does.  It returns the pointer.
+func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
+       if debug.cgocheck == 0 {
+               return ptr
+       }
+
+       ep := (*eface)(unsafe.Pointer(&ptr))
+       t := ep._type
+
+       top := true
+       if len(args) > 0 && t.kind&kindMask == kindPtr {
+               p := ep.data
+               if t.kind&kindDirectIface == 0 {
+                       p = *(*unsafe.Pointer)(p)
+               }
+               if !cgoIsGoPointer(p) {
+                       return ptr
+               }
+               aep := (*eface)(unsafe.Pointer(&args[0]))
+               switch aep._type.kind & kindMask {
+               case kindBool:
+                       pt := (*ptrtype)(unsafe.Pointer(t))
+                       cgoCheckArg(pt.elem, p, true, false)
+                       return ptr
+               case kindSlice:
+                       // Check the slice rather than the pointer.
+                       ep = aep
+                       t = ep._type
+               case kindArray:
+                       // Check the array rather than the pointer.
+                       // Pass top as false since we have a pointer
+                       // to the array.
+                       ep = aep
+                       t = ep._type
+                       top = false
+               default:
+                       throw("can't happen")
+               }
+       }
+
+       cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top)
+       return ptr
+}
+
+const cgoCheckPointerFail = "cgo argument has Go pointer to Go pointer"
+
+// cgoCheckArg is the real work of cgoCheckPointer.  The argument p,
+// is either a pointer to the value (of type t), or the value itself,
+// depending on indir.  The top parameter is whether we are at the top
+// level, where Go pointers are allowed.
+func cgoCheckArg(t *_type, p unsafe.Pointer, indir, top bool) {
+       if t.kind&kindNoPointers != 0 {
+               // If the type has no pointers there is nothing to do.
+               return
+       }
+
+       switch t.kind & kindMask {
+       default:
+               throw("can't happen")
+       case kindArray:
+               at := (*arraytype)(unsafe.Pointer(t))
+               if !indir {
+                       if at.len != 1 {
+                               throw("can't happen")
+                       }
+                       cgoCheckArg(at.elem, p, at.elem.kind&kindDirectIface == 0, top)
+                       return
+               }
+               for i := uintptr(0); i < at.len; i++ {
+                       cgoCheckArg(at.elem, p, true, top)
+                       p = unsafe.Pointer(uintptr(p) + at.elem.size)
+               }
+       case kindChan, kindMap:
+               // These types contain internal pointers that will
+               // always be allocated in the Go heap.  It's never OK
+               // to pass them to C.
+               panic(errorString(cgoCheckPointerFail))
+       case kindFunc:
+               if indir {
+                       p = *(*unsafe.Pointer)(p)
+               }
+               if !cgoIsGoPointer(p) {
+                       return
+               }
+               panic(errorString(cgoCheckPointerFail))
+       case kindInterface:
+               it := *(**_type)(p)
+               if it == nil {
+                       return
+               }
+               // A type known at compile time is OK since it's
+               // constant.  A type not known at compile time will be
+               // in the heap and will not be OK.
+               if inheap(uintptr(unsafe.Pointer(it))) {
+                       panic(errorString(cgoCheckPointerFail))
+               }
+               p = *(*unsafe.Pointer)(unsafe.Pointer(uintptr(p) + ptrSize))
+               if !cgoIsGoPointer(p) {
+                       return
+               }
+               if !top {
+                       panic(errorString(cgoCheckPointerFail))
+               }
+               cgoCheckArg(it, p, it.kind&kindDirectIface == 0, false)
+       case kindSlice:
+               st := (*slicetype)(unsafe.Pointer(t))
+               s := (*slice)(p)
+               p = s.array
+               if !cgoIsGoPointer(p) {
+                       return
+               }
+               if !top {
+                       panic(errorString(cgoCheckPointerFail))
+               }
+               for i := 0; i < s.cap; i++ {
+                       cgoCheckArg(st.elem, p, true, false)
+                       p = unsafe.Pointer(uintptr(p) + st.elem.size)
+               }
+       case kindStruct:
+               st := (*structtype)(unsafe.Pointer(t))
+               if !indir {
+                       if len(st.fields) != 1 {
+                               throw("can't happen")
+                       }
+                       cgoCheckArg(st.fields[0].typ, p, st.fields[0].typ.kind&kindDirectIface == 0, top)
+                       return
+               }
+               for _, f := range st.fields {
+                       cgoCheckArg(f.typ, unsafe.Pointer(uintptr(p)+f.offset), true, top)
+               }
+       case kindPtr, kindUnsafePointer:
+               if indir {
+                       p = *(*unsafe.Pointer)(p)
+               }
+
+               if !cgoIsGoPointer(p) {
+                       return
+               }
+               if !top {
+                       panic(errorString(cgoCheckPointerFail))
+               }
+
+               cgoCheckUnknownPointer(p)
+       }
+}
+
+// cgoCheckUnknownPointer is called for an arbitrary pointer into Go
+// memory.  It checks whether that Go memory contains any other
+// pointer into Go memory.  If it does, we panic.
+func cgoCheckUnknownPointer(p unsafe.Pointer) {
+       if cgoInRange(p, mheap_.arena_start, mheap_.arena_used) {
+               if !inheap(uintptr(p)) {
+                       // This pointer is either to a stack or to an
+                       // unused span.  Escape analysis should
+                       // prevent the former and the latter should
+                       // not happen.
+                       panic(errorString("cgo argument has invalid Go pointer"))
+               }
+
+               base, hbits, span := heapBitsForObject(uintptr(p), 0, 0)
+               if base == 0 {
+                       return
+               }
+               n := span.elemsize
+               for i := uintptr(0); i < n; i += ptrSize {
+                       bits := hbits.bits()
+                       if i >= 2*ptrSize && bits&bitMarked == 0 {
+                               // No more possible pointers.
+                               break
+                       }
+                       if bits&bitPointer != 0 {
+                               if cgoIsGoPointer(*(*unsafe.Pointer)(unsafe.Pointer(base + i))) {
+                                       panic(errorString(cgoCheckPointerFail))
+                               }
+                       }
+                       hbits = hbits.next()
+               }
+
+               return
+       }
+
+       for datap := &firstmoduledata; datap != nil; datap = datap.next {
+               if cgoInRange(p, datap.data, datap.edata) || cgoInRange(p, datap.bss, datap.ebss) {
+                       // We have no way to know the size of the object.
+                       // We have to assume that it might contain a pointer.
+                       panic(errorString(cgoCheckPointerFail))
+               }
+               // In the text or noptr sections, we know that the
+               // pointer does not point to a Go pointer.
+       }
+}
+
+// cgoIsGoPointer returns whether the pointer is a Go pointer--a
+// pointer to Go memory.  We only care about Go memory that might
+// contain pointers.
+func cgoIsGoPointer(p unsafe.Pointer) bool {
+       if p == nil {
+               return false
+       }
+
+       if cgoInRange(p, mheap_.arena_start, mheap_.arena_used) {
+               return true
+       }
+
+       for datap := &firstmoduledata; datap != nil; datap = datap.next {
+               if cgoInRange(p, datap.data, datap.edata) || cgoInRange(p, datap.bss, datap.ebss) {
+                       return true
+               }
+       }
+
+       return false
+}
+
+// cgoInRange returns whether p is between start and end.
+func cgoInRange(p unsafe.Pointer, start, end uintptr) bool {
+       return start <= uintptr(p) && uintptr(p) < end
+}
index f9b11b4de19af127276edae5e8c74a58d4d252a4..9a468443fd5f851609776f4bc76cefbc86cb0695 100644 (file)
@@ -308,6 +308,7 @@ type dbgVar struct {
 // already have an initial value.
 var debug struct {
        allocfreetrace    int32
+       cgocheck          int32
        efence            int32
        gccheckmark       int32
        gcpacertrace      int32
@@ -326,6 +327,7 @@ var debug struct {
 
 var dbgvars = []dbgVar{
        {"allocfreetrace", &debug.allocfreetrace},
+       {"cgocheck", &debug.cgocheck},
        {"efence", &debug.efence},
        {"gccheckmark", &debug.gccheckmark},
        {"gcpacertrace", &debug.gcpacertrace},
@@ -344,6 +346,7 @@ var dbgvars = []dbgVar{
 
 func parsedebugvars() {
        // defaults
+       debug.cgocheck = 1
        debug.invalidptr = 1
 
        for p := gogetenv("GODEBUG"); p != ""; {
index c8d7554fca8a03810d27384fdbeb2e0d40270d45..d5f3bb1ef02e06b50390153df3a950e839471753 100644 (file)
@@ -70,6 +70,13 @@ type maptype struct {
        needkeyupdate bool   // true if we need to update key on an overwrite
 }
 
+type arraytype struct {
+       typ   _type
+       elem  *_type
+       slice *_type
+       len   uintptr
+}
+
 type chantype struct {
        typ  _type
        elem *_type
@@ -92,3 +99,16 @@ type ptrtype struct {
        typ  _type
        elem *_type
 }
+
+type structfield struct {
+       name    *string
+       pkgpath *string
+       typ     *_type
+       tag     *string
+       offset  uintptr
+}
+
+type structtype struct {
+       typ    _type
+       fields []structfield
+}