Update #12416.
Change-Id: Iccbcb12709d1ca9bea87274f44f93cfcebadb070
Reviewed-on: https://go-review.googlesource.com/17048
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: David Crawshaw <crawshaw@golang.org>
                fail:      true,
                expensive: true,
        },
+       {
+               // Exported functions may not return Go pointers.
+               name: "export1",
+               c:    `extern unsigned char *GoFn();`,
+               support: `//export GoFn
+                          func GoFn() *byte { return new(byte) }`,
+               body: `C.GoFn()`,
+               fail: true,
+       },
+       {
+               // Returning a C pointer is fine.
+               name: "exportok",
+               c: `#include <stdlib.h>
+                    extern unsigned char *GoFn();`,
+               support: `//export GoFn
+                          func GoFn() *byte { return (*byte)(C.malloc(1)) }`,
+               body: `C.GoFn()`,
+       },
 }
 
 func main() {
 
        pc := make([]uintptr, 100)
        n := 0
        name := []string{
-               "test.goCallback",
                "runtime.call16",
                "runtime.cgocallbackg1",
                "runtime.cgocallbackg",
                "runtime.goexit",
        }
        if unsafe.Sizeof((*byte)(nil)) == 8 {
-               name[1] = "runtime.call32"
+               name[0] = "runtime.call32"
        }
        nestedCall(func() {
-               n = runtime.Callers(2, pc)
+               n = runtime.Callers(4, pc)
        })
        if n != len(name) {
                t.Errorf("expected %d frames, got %d", len(name), n)
 
 package main
 
 /*
-extern int *GoFn(void);
+extern int *GoFn(int *);
 
 // Yes, you can have definitions if you use //export, as long as they are weak.
 int f(void) __attribute__ ((weak));
 
 int f() {
-  int *p = GoFn();
+  int i;
+  int *p = GoFn(&i);
   if (*p != 12345)
     return 0;
   return 1;
 import "C"
 
 //export GoFn
-func GoFn() *C.int {
-       i := C.int(12345)
-       return &i
+func GoFn(p *C.int) *C.int {
+       *p = C.int(12345)
+       return p
 }
 
 func main() {
 
 // hasPointer is used by needsPointerCheck.  If top is true it returns
 // whether t is or contains a pointer that might point to a pointer.
 // If top is false it returns whether t is or contains a pointer.
+// f may be nil.
 func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
        switch t := t.(type) {
        case *ast.ArrayType:
                        // pointer.
                        return true
                }
+               if f == nil {
+                       // Conservative approach: assume pointer.
+                       return true
+               }
                name := f.Name[t.Sel.Name]
                if name != nil && name.Kind == "type" && name.Type != nil && name.Type.Go != nil {
                        return p.hasPointer(f, name.Type.Go, top)
                // This is the address of something that is not an
                // index expression.  We only need to examine the
                // single value to which it points.
-               // TODO: what is true is shadowed?
+               // TODO: what if true is shadowed?
                return append(args, ast.NewIdent("true"))
        }
        if !p.hasSideEffects(f, index.X) {
 
                }
                fmt.Fprintf(fgcc, "}\n")
 
-               // Build the wrapper function compiled by gc.
-               goname := exp.Func.Name.Name
+               // Build the wrapper function compiled by cmd/compile.
+               goname := "_cgoexpwrap" + cPrefix + "_"
                if fn.Recv != nil {
-                       goname = "_cgoexpwrap" + cPrefix + "_" + fn.Recv.List[0].Names[0].Name + "_" + goname
+                       goname += fn.Recv.List[0].Names[0].Name + "_"
                }
-               fmt.Fprintf(fgo2, "//go:cgo_export_dynamic %s\n", goname)
+               goname += exp.Func.Name.Name
+               fmt.Fprintf(fgo2, "//go:cgo_export_dynamic %s\n", exp.Func.Name.Name)
                fmt.Fprintf(fgo2, "//go:linkname _cgoexp%s_%s _cgoexp%s_%s\n", cPrefix, exp.ExpName, cPrefix, exp.ExpName)
                fmt.Fprintf(fgo2, "//go:cgo_export_static _cgoexp%s_%s\n", cPrefix, exp.ExpName)
                fmt.Fprintf(fgo2, "//go:nosplit\n") // no split stack, so no use of m or g
 
                fmt.Fprintf(fm, "int _cgoexp%s_%s;\n", cPrefix, exp.ExpName)
 
-               // Calling a function with a receiver from C requires
-               // a Go wrapper function.
+               // This code uses printer.Fprint, not conf.Fprint,
+               // because we don't want //line comments in the middle
+               // of the function types.
+               fmt.Fprintf(fgo2, "\n")
+               fmt.Fprintf(fgo2, "func %s(", goname)
+               comma := false
                if fn.Recv != nil {
-                       fmt.Fprintf(fgo2, "func %s(recv ", goname)
-                       conf.Fprint(fgo2, fset, fn.Recv.List[0].Type)
-                       forFieldList(fntype.Params,
+                       fmt.Fprintf(fgo2, "recv ")
+                       printer.Fprint(fgo2, fset, fn.Recv.List[0].Type)
+                       comma = true
+               }
+               forFieldList(fntype.Params,
+                       func(i int, aname string, atype ast.Expr) {
+                               if comma {
+                                       fmt.Fprintf(fgo2, ", ")
+                               }
+                               fmt.Fprintf(fgo2, "p%d ", i)
+                               printer.Fprint(fgo2, fset, atype)
+                               comma = true
+                       })
+               fmt.Fprintf(fgo2, ")")
+               if gccResult != "void" {
+                       fmt.Fprint(fgo2, " (")
+                       forFieldList(fntype.Results,
                                func(i int, aname string, atype ast.Expr) {
-                                       fmt.Fprintf(fgo2, ", p%d ", i)
-                                       conf.Fprint(fgo2, fset, atype)
+                                       if i > 0 {
+                                               fmt.Fprint(fgo2, ", ")
+                                       }
+                                       fmt.Fprintf(fgo2, "r%d ", i)
+                                       printer.Fprint(fgo2, fset, atype)
                                })
-                       fmt.Fprintf(fgo2, ")")
-                       if gccResult != "void" {
-                               fmt.Fprint(fgo2, " (")
-                               forFieldList(fntype.Results,
-                                       func(i int, aname string, atype ast.Expr) {
-                                               if i > 0 {
-                                                       fmt.Fprint(fgo2, ", ")
-                                               }
-                                               conf.Fprint(fgo2, fset, atype)
-                                       })
-                               fmt.Fprint(fgo2, ")")
-                       }
-                       fmt.Fprint(fgo2, " {\n")
+                       fmt.Fprint(fgo2, ")")
+               }
+               fmt.Fprint(fgo2, " {\n")
+               if gccResult == "void" {
                        fmt.Fprint(fgo2, "\t")
-                       if gccResult != "void" {
-                               fmt.Fprint(fgo2, "return ")
-                       }
-                       fmt.Fprintf(fgo2, "recv.%s(", exp.Func.Name)
-                       forFieldList(fntype.Params,
+               } else {
+                       // Verify that any results don't contain any
+                       // Go pointers.
+                       addedDefer := false
+                       forFieldList(fntype.Results,
                                func(i int, aname string, atype ast.Expr) {
-                                       if i > 0 {
-                                               fmt.Fprint(fgo2, ", ")
+                                       if !p.hasPointer(nil, atype, false) {
+                                               return
+                                       }
+                                       if !addedDefer {
+                                               fmt.Fprint(fgo2, "\tdefer func() {\n")
+                                               addedDefer = true
                                        }
-                                       fmt.Fprintf(fgo2, "p%d", i)
+                                       fmt.Fprintf(fgo2, "\t\t_cgoCheckResult(r%d)\n", i)
                                })
-                       fmt.Fprint(fgo2, ")\n")
-                       fmt.Fprint(fgo2, "}\n")
+                       if addedDefer {
+                               fmt.Fprint(fgo2, "\t}()\n")
+                       }
+                       fmt.Fprint(fgo2, "\treturn ")
+               }
+               if fn.Recv != nil {
+                       fmt.Fprintf(fgo2, "recv.")
                }
+               fmt.Fprintf(fgo2, "%s(", exp.Func.Name)
+               forFieldList(fntype.Params,
+                       func(i int, aname string, atype ast.Expr) {
+                               if i > 0 {
+                                       fmt.Fprint(fgo2, ", ")
+                               }
+                               fmt.Fprintf(fgo2, "p%d", i)
+                       })
+               fmt.Fprint(fgo2, ")\n")
+               fmt.Fprint(fgo2, "}\n")
        }
 
        fmt.Fprintf(fgcch, "%s", gccExportHeaderEpilog)
 
 //go:linkname _cgoCheckPointer runtime.cgoCheckPointer
 func _cgoCheckPointer(interface{}, ...interface{}) interface{}
+
+//go:linkname _cgoCheckResult runtime.cgoCheckResult
+func _cgoCheckResult(interface{})
 `
 
 const goStringDef = `
 
                switch aep._type.kind & kindMask {
                case kindBool:
                        pt := (*ptrtype)(unsafe.Pointer(t))
-                       cgoCheckArg(pt.elem, p, true, false)
+                       cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail)
                        return ptr
                case kindSlice:
                        // Check the slice rather than the pointer.
                }
        }
 
-       cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top)
+       cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top, cgoCheckPointerFail)
        return ptr
 }
 
 const cgoCheckPointerFail = "cgo argument has Go pointer to Go pointer"
+const cgoResultFail = "cgo result has Go pointer"
 
 // cgoCheckArg is the real work of cgoCheckPointer.  The argument p
 // is either a pointer to the value (of type t), or the value itself,
 // depending on indir.  The top parameter is whether we are at the top
 // level, where Go pointers are allowed.
-func cgoCheckArg(t *_type, p unsafe.Pointer, indir, top bool) {
+func cgoCheckArg(t *_type, p unsafe.Pointer, indir, top bool, msg string) {
        if t.kind&kindNoPointers != 0 {
                // If the type has no pointers there is nothing to do.
                return
                        if at.len != 1 {
                                throw("can't happen")
                        }
-                       cgoCheckArg(at.elem, p, at.elem.kind&kindDirectIface == 0, top)
+                       cgoCheckArg(at.elem, p, at.elem.kind&kindDirectIface == 0, top, msg)
                        return
                }
                for i := uintptr(0); i < at.len; i++ {
-                       cgoCheckArg(at.elem, p, true, top)
+                       cgoCheckArg(at.elem, p, true, top, msg)
                        p = add(p, at.elem.size)
                }
        case kindChan, kindMap:
                // These types contain internal pointers that will
                // always be allocated in the Go heap.  It's never OK
                // to pass them to C.
-               panic(errorString(cgoCheckPointerFail))
+               panic(errorString(msg))
        case kindFunc:
                if indir {
                        p = *(*unsafe.Pointer)(p)
                if !cgoIsGoPointer(p) {
                        return
                }
-               panic(errorString(cgoCheckPointerFail))
+               panic(errorString(msg))
        case kindInterface:
                it := *(**_type)(p)
                if it == nil {
                // constant.  A type not known at compile time will be
                // in the heap and will not be OK.
                if inheap(uintptr(unsafe.Pointer(it))) {
-                       panic(errorString(cgoCheckPointerFail))
+                       panic(errorString(msg))
                }
                p = *(*unsafe.Pointer)(add(p, sys.PtrSize))
                if !cgoIsGoPointer(p) {
                        return
                }
                if !top {
-                       panic(errorString(cgoCheckPointerFail))
+                       panic(errorString(msg))
                }
-               cgoCheckArg(it, p, it.kind&kindDirectIface == 0, false)
+               cgoCheckArg(it, p, it.kind&kindDirectIface == 0, false, msg)
        case kindSlice:
                st := (*slicetype)(unsafe.Pointer(t))
                s := (*slice)(p)
                        return
                }
                if !top {
-                       panic(errorString(cgoCheckPointerFail))
+                       panic(errorString(msg))
                }
                for i := 0; i < s.cap; i++ {
-                       cgoCheckArg(st.elem, p, true, false)
+                       cgoCheckArg(st.elem, p, true, false, msg)
                        p = add(p, st.elem.size)
                }
        case kindStruct:
                        if len(st.fields) != 1 {
                                throw("can't happen")
                        }
-                       cgoCheckArg(st.fields[0].typ, p, st.fields[0].typ.kind&kindDirectIface == 0, top)
+                       cgoCheckArg(st.fields[0].typ, p, st.fields[0].typ.kind&kindDirectIface == 0, top, msg)
                        return
                }
                for _, f := range st.fields {
-                       cgoCheckArg(f.typ, add(p, f.offset), true, top)
+                       cgoCheckArg(f.typ, add(p, f.offset), true, top, msg)
                }
        case kindPtr, kindUnsafePointer:
                if indir {
                        return
                }
                if !top {
-                       panic(errorString(cgoCheckPointerFail))
+                       panic(errorString(msg))
                }
 
-               cgoCheckUnknownPointer(p)
+               cgoCheckUnknownPointer(p, msg)
        }
 }
 
 // cgoCheckUnknownPointer is called for an arbitrary pointer into Go
 // memory.  It checks whether that Go memory contains any other
 // pointer into Go memory.  If it does, we panic.
-func cgoCheckUnknownPointer(p unsafe.Pointer) {
+func cgoCheckUnknownPointer(p unsafe.Pointer, msg string) {
        if cgoInRange(p, mheap_.arena_start, mheap_.arena_used) {
                if !inheap(uintptr(p)) {
                        // This pointer is either to a stack or to an
                        }
                        if bits&bitPointer != 0 {
                                if cgoIsGoPointer(*(*unsafe.Pointer)(unsafe.Pointer(base + i))) {
-                                       panic(errorString(cgoCheckPointerFail))
+                                       panic(errorString(msg))
                                }
                        }
                        hbits = hbits.next()
                if cgoInRange(p, datap.data, datap.edata) || cgoInRange(p, datap.bss, datap.ebss) {
                        // We have no way to know the size of the object.
                        // We have to assume that it might contain a pointer.
-                       panic(errorString(cgoCheckPointerFail))
+                       panic(errorString(msg))
                }
                // In the text or noptr sections, we know that the
                // pointer does not point to a Go pointer.
 func cgoInRange(p unsafe.Pointer, start, end uintptr) bool {
        return start <= uintptr(p) && uintptr(p) < end
 }
+
+// cgoCheckResult is called to check the result parameter of an
+// exported Go function.  It panics if the result is or contains a Go
+// pointer.
+func cgoCheckResult(val interface{}) {
+       if debug.cgocheck == 0 {
+               return
+       }
+
+       ep := (*eface)(unsafe.Pointer(&val))
+       t := ep._type
+       cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, false, cgoResultFail)
+}