]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/cgo: check pointers for deferred C calls at the right time
authorIan Lance Taylor <iant@golang.org>
Wed, 1 Jun 2016 22:24:14 +0000 (15:24 -0700)
committerIan Lance Taylor <iant@golang.org>
Fri, 3 Jun 2016 20:51:39 +0000 (20:51 +0000)
We used to check time at the point of the defer statement. This change
fixes cgo to check them when the deferred function is executed.

Fixes #15921.

Change-Id: I72a10e26373cad6ad092773e9ebec4add29b9561
Reviewed-on: https://go-review.googlesource.com/23650
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Austin Clements <austin@google.com>
misc/cgo/errors/ptr.go
src/cmd/cgo/ast.go
src/cmd/cgo/gcc.go
src/cmd/cgo/main.go

index 27eb78e36cfbc0bc958ba2d08faa531365637a16..e39f0413e4799eaf8e1e91386cb5d3998deb54c6 100644 (file)
@@ -314,6 +314,14 @@ var ptrTests = []ptrTest{
                body:    `i := 0; p := S{u:uintptr(unsafe.Pointer(&i))}; q := (*S)(C.malloc(C.size_t(unsafe.Sizeof(p)))); *q = p; C.f(unsafe.Pointer(q))`,
                fail:    false,
        },
+       {
+               // Check deferred pointers when they are used, not
+               // when the defer statement is run.
+               name: "defer",
+               c:    `typedef struct s { int *p; } s; void f(s *ps) {}`,
+               body: `p := &C.s{}; defer C.f(p); p.p = new(C.int)`,
+               fail: true,
+       },
 }
 
 func main() {
index 823da43c1d00e213e2d77c09b05906cc527a2d0b..000ecd446855c47b2befdd94783c6e6e5b88099b 100644 (file)
@@ -172,7 +172,7 @@ func (f *File) saveExprs(x interface{}, context string) {
                        f.saveRef(x, context)
                }
        case *ast.CallExpr:
-               f.saveCall(x)
+               f.saveCall(x, context)
        }
 }
 
@@ -220,7 +220,7 @@ func (f *File) saveRef(n *ast.Expr, context string) {
 }
 
 // Save calls to C.xxx for later processing.
-func (f *File) saveCall(call *ast.CallExpr) {
+func (f *File) saveCall(call *ast.CallExpr, context string) {
        sel, ok := call.Fun.(*ast.SelectorExpr)
        if !ok {
                return
@@ -228,7 +228,8 @@ func (f *File) saveCall(call *ast.CallExpr) {
        if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
                return
        }
-       f.Calls = append(f.Calls, call)
+       c := &Call{Call: call, Deferred: context == "defer"}
+       f.Calls = append(f.Calls, c)
 }
 
 // If a function should be exported add it to ExpFunc.
@@ -401,7 +402,7 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{}
        case *ast.GoStmt:
                f.walk(n.Call, "expr", visit)
        case *ast.DeferStmt:
-               f.walk(n.Call, "expr", visit)
+               f.walk(n.Call, "defer", visit)
        case *ast.ReturnStmt:
                f.walk(n.Results, "expr", visit)
        case *ast.BranchStmt:
index 451798244f6db2a48fd0793a74622faa96b85d92..21854c5ea326f7a4c4686a6152eb2e65330734d8 100644 (file)
@@ -581,7 +581,7 @@ func (p *Package) mangleName(n *Name) {
 func (p *Package) rewriteCalls(f *File) {
        for _, call := range f.Calls {
                // This is a call to C.xxx; set goname to "xxx".
-               goname := call.Fun.(*ast.SelectorExpr).Sel.Name
+               goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
                if goname == "malloc" {
                        continue
                }
@@ -596,37 +596,58 @@ func (p *Package) rewriteCalls(f *File) {
 
 // rewriteCall rewrites one call to add pointer checks. We replace
 // each pointer argument x with _cgoCheckPointer(x).(T).
-func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) {
+func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
+       any := false
        for i, param := range name.FuncType.Params {
-               if len(call.Args) <= i {
+               if len(call.Call.Args) <= i {
                        // Avoid a crash; this will be caught when the
                        // generated file is compiled.
                        return
                }
+               if p.needsPointerCheck(f, param.Go, call.Call.Args[i]) {
+                       any = true
+                       break
+               }
+       }
+       if !any {
+               return
+       }
 
-               // An untyped nil does not need a pointer check, and
-               // when _cgoCheckPointer returns the untyped nil the
-               // type assertion we are going to insert will fail.
-               // Easier to just skip nil arguments.
-               // TODO: Note that this fails if nil is shadowed.
-               if id, ok := call.Args[i].(*ast.Ident); ok && id.Name == "nil" {
-                       continue
+       // We need to rewrite this call.
+       //
+       // We are going to rewrite C.f(p) to C.f(_cgoCheckPointer(p)).
+       // If the call to C.f is deferred, that will check p at the
+       // point of the defer statement, not when the function is called, so
+       // rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p)
+
+       var dargs []ast.Expr
+       if call.Deferred {
+               dargs = make([]ast.Expr, len(name.FuncType.Params))
+       }
+       for i, param := range name.FuncType.Params {
+               origArg := call.Call.Args[i]
+               darg := origArg
+
+               if call.Deferred {
+                       dargs[i] = darg
+                       darg = ast.NewIdent(fmt.Sprintf("_cgo%d", i))
+                       call.Call.Args[i] = darg
                }
 
-               if !p.needsPointerCheck(f, param.Go) {
+               if !p.needsPointerCheck(f, param.Go, origArg) {
                        continue
                }
 
                c := &ast.CallExpr{
                        Fun: ast.NewIdent("_cgoCheckPointer"),
                        Args: []ast.Expr{
-                               call.Args[i],
+                               darg,
                        },
                }
 
                // Add optional additional arguments for an address
                // expression.
-               c.Args = p.checkAddrArgs(f, c.Args, call.Args[i])
+               c.Args = p.checkAddrArgs(f, c.Args, origArg)
 
                // _cgoCheckPointer returns interface{}.
                // We need to type assert that to the type we want.
@@ -664,14 +685,73 @@ func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) {
                        }
                }
 
-               call.Args[i] = arg
+               call.Call.Args[i] = arg
+       }
+
+       if call.Deferred {
+               params := make([]*ast.Field, len(name.FuncType.Params))
+               for i, param := range name.FuncType.Params {
+                       ptype := param.Go
+                       if p.hasUnsafePointer(ptype) {
+                               // Avoid generating unsafe.Pointer by using
+                               // interface{}. This works because we are
+                               // going to call a _cgoCheckPointer function
+                               // anyhow.
+                               ptype = &ast.InterfaceType{
+                                       Methods: &ast.FieldList{},
+                               }
+                       }
+                       params[i] = &ast.Field{
+                               Names: []*ast.Ident{
+                                       ast.NewIdent(fmt.Sprintf("_cgo%d", i)),
+                               },
+                               Type: ptype,
+                       }
+               }
+
+               dbody := &ast.CallExpr{
+                       Fun:  call.Call.Fun,
+                       Args: call.Call.Args,
+               }
+               call.Call.Fun = &ast.FuncLit{
+                       Type: &ast.FuncType{
+                               Params: &ast.FieldList{
+                                       List: params,
+                               },
+                       },
+                       Body: &ast.BlockStmt{
+                               List: []ast.Stmt{
+                                       &ast.ExprStmt{
+                                               X: dbody,
+                                       },
+                               },
+                       },
+               }
+               call.Call.Args = dargs
+               call.Call.Lparen = token.NoPos
+               call.Call.Rparen = token.NoPos
+
+               // There is a Ref pointing to the old call.Call.Fun.
+               for _, ref := range f.Ref {
+                       if ref.Expr == &call.Call.Fun {
+                               ref.Expr = &dbody.Fun
+                       }
+               }
        }
 }
 
 // needsPointerCheck returns whether the type t needs a pointer check.
 // This is true if t is a pointer and if the value to which it points
 // might contain a pointer.
-func (p *Package) needsPointerCheck(f *File, t ast.Expr) bool {
+func (p *Package) needsPointerCheck(f *File, t ast.Expr, arg ast.Expr) bool {
+       // An untyped nil does not need a pointer check, and when
+       // _cgoCheckPointer returns the untyped nil the type assertion we
+       // are going to insert will fail.  Easier to just skip nil arguments.
+       // TODO: Note that this fails if nil is shadowed.
+       if id, ok := arg.(*ast.Ident); ok && id.Name == "nil" {
+               return false
+       }
+
        return p.hasPointer(f, t, true)
 }
 
index cbdeb0f9ca6169be71777161e26ad9a93dca5f95..e2a387a09d33f150a1613ceb1c2029ad89dfcd8a 100644 (file)
@@ -52,7 +52,7 @@ type File struct {
        Package  string              // Package name
        Preamble string              // C preamble (doc comment on import "C")
        Ref      []*Ref              // all references to C.xxx in AST
-       Calls    []*ast.CallExpr     // all calls to C.xxx in AST
+       Calls    []*Call             // all calls to C.xxx in AST
        ExpFunc  []*ExpFunc          // exported functions for this file
        Name     map[string]*Name    // map from Go name to Name
 }
@@ -66,6 +66,12 @@ func nameKeys(m map[string]*Name) []string {
        return ks
 }
 
+// A Call refers to a call of a C.xxx function in the AST.
+type Call struct {
+       Call     *ast.CallExpr
+       Deferred bool
+}
+
 // A Ref refers to an expression of the form C.xxx in the AST.
 type Ref struct {
        Name    *Name