]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/cgo: always use a function literal for pointer checking
authorIan Lance Taylor <iant@golang.org>
Fri, 14 Oct 2016 21:53:59 +0000 (14:53 -0700)
committerIan Lance Taylor <iant@golang.org>
Wed, 19 Oct 2016 21:20:50 +0000 (21:20 +0000)
The pointer checking code needs to know the exact type of the parameter
expected by the C function, so that it can use a type assertion to
convert the empty interface returned by cgoCheckPointer to the correct
type. Previously this was done by using a type conversion, but that
meant that the code accepted arguments that were convertible to the
parameter type, rather than arguments that were assignable as in a
normal function call. In other words, some code that should not have
passed type checking was accepted.

This CL changes cgo to always use a function literal for pointer
checking. Now the argument is passed to the function literal, which has
the correct argument type, so type checking is performed just as for a
function call as it should be.

Since we now always use a function literal, simplify the checking code
to run as a statement by itself. It now no longer needs to return a
value, and we no longer need a type assertion.

This does have the cost of introducing another function call into any
call to a C function that requires pointer checking, but the cost of the
additional call should be minimal compared to the cost of pointer
checking.

Fixes #16591.

Change-Id: I220165564cf69db9fd5f746532d7f977a5b2c989
Reviewed-on: https://go-review.googlesource.com/31233
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
misc/cgo/errors/issue16591.go [new file with mode: 0644]
misc/cgo/errors/test.bash
misc/cgo/test/callback.go
src/cmd/cgo/gcc.go
src/cmd/cgo/out.go
src/runtime/cgocall.go

diff --git a/misc/cgo/errors/issue16591.go b/misc/cgo/errors/issue16591.go
new file mode 100644 (file)
index 0000000..10eb840
--- /dev/null
@@ -0,0 +1,17 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Issue 16591: Test that we detect an invalid call that was being
+// hidden by a type conversion inserted by cgo checking.
+
+package p
+
+// void f(int** p) { }
+import "C"
+
+type x *C.int
+
+func F(p *x) {
+       C.f(p) // ERROR HERE
+}
index 84d44d8a338443d65683412ec4a454d9bc58c948..cb442507a6470b57be58aa9eaf58ccec3e66dde0 100755 (executable)
@@ -46,6 +46,7 @@ check issue13423.go
 expect issue13635.go C.uchar C.schar C.ushort C.uint C.ulong C.longlong C.ulonglong C.complexfloat C.complexdouble
 check issue13830.go
 check issue16116.go
+check issue16591.go
 
 if ! go build issue14669.go; then
        exit 1
index 21d1df59eddcbf84edd62f670060a9e4b5388168..b88bf134bc1e40a73109259155da0904f47621a8 100644 (file)
@@ -186,6 +186,7 @@ func testCallbackCallers(t *testing.T) {
                "runtime.asmcgocall",
                "runtime.cgocall",
                "test._Cfunc_callback",
+               "test.nestedCall.func1",
                "test.nestedCall",
                "test.testCallbackCallers",
                "test.TestCallbackCallers",
index 5df94ac54e7e2aa0ce98f06044f5bb4662a0be88..714d6360ccb5b49fb22eb1722d7af10c10e6db70 100644 (file)
@@ -639,34 +639,57 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
 
        // We need to rewrite this call.
        //
-       // We are going to rewrite C.f(p) to C.f(_cgoCheckPointer(p)).
-       // If the call to C.f is deferred, that will check p at the
-       // point of the defer statement, not when the function is called, so
-       // rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p)
-
+       // We are going to rewrite C.f(p) to
+       //    func (_cgo0 ptype) {
+       //            _cgoCheckPointer(_cgo0)
+       //            C.f(_cgo0)
+       //    }(p)
+       // Using a function literal like this lets us do correct
+       // argument type checking, and works correctly if the call is
+       // deferred.
        needsUnsafe := false
-       var dargs []ast.Expr
-       if call.Deferred {
-               dargs = make([]ast.Expr, len(name.FuncType.Params))
-       }
+       params := make([]*ast.Field, len(name.FuncType.Params))
+       args := make([]ast.Expr, len(name.FuncType.Params))
+       var stmts []ast.Stmt
        for i, param := range name.FuncType.Params {
+               // params is going to become the parameters of the
+               // function literal.
+               // args is going to become the list of arguments to the
+               // function literal.
+               // nparam is the parameter of the function literal that
+               // corresponds to param.
+
                origArg := call.Call.Args[i]
-               darg := origArg
+               args[i] = origArg
+               nparam := ast.NewIdent(fmt.Sprintf("_cgo%d", i))
+
+               // The Go version of the C type might use unsafe.Pointer,
+               // but the file might not import unsafe.
+               // Rewrite the Go type if necessary to use _cgo_unsafe.
+               ptype := p.rewriteUnsafe(param.Go)
+               if ptype != param.Go {
+                       needsUnsafe = true
+               }
 
-               if call.Deferred {
-                       dargs[i] = darg
-                       darg = ast.NewIdent(fmt.Sprintf("_cgo%d", i))
-                       call.Call.Args[i] = darg
+               params[i] = &ast.Field{
+                       Names: []*ast.Ident{nparam},
+                       Type:  ptype,
                }
 
+               call.Call.Args[i] = nparam
+
                if !p.needsPointerCheck(f, param.Go, origArg) {
                        continue
                }
 
+               // Run the cgo pointer checks on nparam.
+
+               // Change the function literal to call the real function
+               // with the parameter passed through _cgoCheckPointer.
                c := &ast.CallExpr{
                        Fun: ast.NewIdent("_cgoCheckPointer"),
                        Args: []ast.Expr{
-                               darg,
+                               nparam,
                        },
                }
 
@@ -674,77 +697,64 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
                // expression.
                c.Args = p.checkAddrArgs(f, c.Args, origArg)
 
-               // The Go version of the C type might use unsafe.Pointer,
-               // but the file might not import unsafe.
-               // Rewrite the Go type if necessary to use _cgo_unsafe.
-               ptype := p.rewriteUnsafe(param.Go)
-               if ptype != param.Go {
-                       needsUnsafe = true
-               }
-
-               // In order for the type assertion to succeed, we need
-               // it to match the actual type of the argument. The
-               // only type we have is the type of the function
-               // parameter. We know that the argument type must be
-               // assignable to the function parameter type, or the
-               // code would not compile, but there is nothing
-               // requiring that the types be exactly the same. Add a
-               // type conversion to the argument so that the type
-               // assertion will succeed.
-               c.Args[0] = &ast.CallExpr{
-                       Fun: ptype,
-                       Args: []ast.Expr{
-                               c.Args[0],
-                       },
-               }
-
-               call.Call.Args[i] = &ast.TypeAssertExpr{
-                       X:    c,
-                       Type: ptype,
+               stmt := &ast.ExprStmt{
+                       X: c,
                }
+               stmts = append(stmts, stmt)
        }
 
-       if call.Deferred {
-               params := make([]*ast.Field, len(name.FuncType.Params))
-               for i, param := range name.FuncType.Params {
-                       ptype := p.rewriteUnsafe(param.Go)
-                       if ptype != param.Go {
-                               needsUnsafe = true
-                       }
-                       params[i] = &ast.Field{
-                               Names: []*ast.Ident{
-                                       ast.NewIdent(fmt.Sprintf("_cgo%d", i)),
-                               },
-                               Type: ptype,
-                       }
+       fcall := &ast.CallExpr{
+               Fun:  call.Call.Fun,
+               Args: call.Call.Args,
+       }
+       ftype := &ast.FuncType{
+               Params: &ast.FieldList{
+                       List: params,
+               },
+       }
+       var fbody ast.Stmt
+       if name.FuncType.Result == nil {
+               fbody = &ast.ExprStmt{
+                       X: fcall,
                }
-
-               dbody := &ast.CallExpr{
-                       Fun:  call.Call.Fun,
-                       Args: call.Call.Args,
+       } else {
+               fbody = &ast.ReturnStmt{
+                       Results: []ast.Expr{fcall},
                }
-               call.Call.Fun = &ast.FuncLit{
-                       Type: &ast.FuncType{
-                               Params: &ast.FieldList{
-                                       List: params,
-                               },
-                       },
-                       Body: &ast.BlockStmt{
-                               List: []ast.Stmt{
-                                       &ast.ExprStmt{
-                                               X: dbody,
-                                       },
+               rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
+               if rtype != name.FuncType.Result.Go {
+                       needsUnsafe = true
+               }
+               ftype.Results = &ast.FieldList{
+                       List: []*ast.Field{
+                               &ast.Field{
+                                       Type: rtype,
                                },
                        },
                }
-               call.Call.Args = dargs
-               call.Call.Lparen = token.NoPos
-               call.Call.Rparen = token.NoPos
+       }
+       call.Call.Fun = &ast.FuncLit{
+               Type: ftype,
+               Body: &ast.BlockStmt{
+                       List: append(stmts, fbody),
+               },
+       }
+       call.Call.Args = args
+       call.Call.Lparen = token.NoPos
+       call.Call.Rparen = token.NoPos
 
-               // There is a Ref pointing to the old call.Call.Fun.
-               for _, ref := range f.Ref {
-                       if ref.Expr == &call.Call.Fun {
-                               ref.Expr = &dbody.Fun
+       // There is a Ref pointing to the old call.Call.Fun.
+       for _, ref := range f.Ref {
+               if ref.Expr == &call.Call.Fun {
+                       ref.Expr = &fcall.Fun
+
+                       // If this call expects two results, we have to
+                       // adjust the results of the  function we generated.
+                       if ref.Context == "call2" {
+                               ftype.Results.List = append(ftype.Results.List,
+                                       &ast.Field{
+                                               Type: ast.NewIdent("error"),
+                                       })
                        }
                }
        }
index e0b9cc46a87c532c4ae8a70b0ab6c305d4357ce7..25031c8d48c4ec1186b919eb931282ca37191ab3 100644 (file)
@@ -1379,14 +1379,14 @@ func _cgo_runtime_cgocall(unsafe.Pointer, uintptr) int32
 func _cgo_runtime_cgocallback(unsafe.Pointer, unsafe.Pointer, uintptr, uintptr)
 
 //go:linkname _cgoCheckPointer runtime.cgoCheckPointer
-func _cgoCheckPointer(interface{}, ...interface{}) interface{}
+func _cgoCheckPointer(interface{}, ...interface{})
 
 //go:linkname _cgoCheckResult runtime.cgoCheckResult
 func _cgoCheckResult(interface{})
 `
 
 const gccgoGoProlog = `
-func _cgoCheckPointer(interface{}, ...interface{}) interface{}
+func _cgoCheckPointer(interface{}, ...interface{})
 
 func _cgoCheckResult(interface{})
 `
@@ -1566,18 +1566,17 @@ typedef struct __go_empty_interface {
        void *__object;
 } Eface;
 
-extern Eface runtimeCgoCheckPointer(Eface, Slice)
+extern void runtimeCgoCheckPointer(Eface, Slice)
        __asm__("runtime.cgoCheckPointer")
        __attribute__((weak));
 
-extern Eface localCgoCheckPointer(Eface, Slice)
+extern void localCgoCheckPointer(Eface, Slice)
        __asm__("GCCGOSYMBOLPREF._cgoCheckPointer");
 
-Eface localCgoCheckPointer(Eface ptr, Slice args) {
+void localCgoCheckPointer(Eface ptr, Slice args) {
        if(runtimeCgoCheckPointer) {
-               return runtimeCgoCheckPointer(ptr, args);
+               runtimeCgoCheckPointer(ptr, args);
        }
-       return ptr;
 }
 
 extern void runtimeCgoCheckResult(Eface)
index 7d358b3346abe213cfd33b20e8db77afbe7d7431..4542cb7b098ece1d0946d3cd7b8d3159a62a7704 100644 (file)
@@ -370,10 +370,10 @@ var racecgosync uint64 // represents possible synchronization in C code
 // pointers.)
 
 // cgoCheckPointer checks if the argument contains a Go pointer that
-// points to a Go pointer, and panics if it does. It returns the pointer.
-func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
+// points to a Go pointer, and panics if it does.
+func cgoCheckPointer(ptr interface{}, args ...interface{}) {
        if debug.cgocheck == 0 {
-               return ptr
+               return
        }
 
        ep := (*eface)(unsafe.Pointer(&ptr))
@@ -386,7 +386,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
                        p = *(*unsafe.Pointer)(p)
                }
                if !cgoIsGoPointer(p) {
-                       return ptr
+                       return
                }
                aep := (*eface)(unsafe.Pointer(&args[0]))
                switch aep._type.kind & kindMask {
@@ -397,7 +397,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
                        }
                        pt := (*ptrtype)(unsafe.Pointer(t))
                        cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail)
-                       return ptr
+                       return
                case kindSlice:
                        // Check the slice rather than the pointer.
                        ep = aep
@@ -415,7 +415,6 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
        }
 
        cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top, cgoCheckPointerFail)
-       return ptr
 }
 
 const cgoCheckPointerFail = "cgo argument has Go pointer to Go pointer"