f.saveRef(x, context)
}
case *ast.CallExpr:
- f.saveCall(x)
+ f.saveCall(x, context)
}
}
}
// Save calls to C.xxx for later processing.
-func (f *File) saveCall(call *ast.CallExpr) {
+func (f *File) saveCall(call *ast.CallExpr, context string) {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return
if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
return
}
- f.Calls = append(f.Calls, call)
+ c := &Call{Call: call, Deferred: context == "defer"}
+ f.Calls = append(f.Calls, c)
}
// If a function should be exported add it to ExpFunc.
case *ast.GoStmt:
f.walk(n.Call, "expr", visit)
case *ast.DeferStmt:
- f.walk(n.Call, "expr", visit)
+ f.walk(n.Call, "defer", visit)
case *ast.ReturnStmt:
f.walk(n.Results, "expr", visit)
case *ast.BranchStmt:
func (p *Package) rewriteCalls(f *File) {
for _, call := range f.Calls {
// This is a call to C.xxx; set goname to "xxx".
- goname := call.Fun.(*ast.SelectorExpr).Sel.Name
+ goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
if goname == "malloc" {
continue
}
// rewriteCall rewrites one call to add pointer checks. We replace
// each pointer argument x with _cgoCheckPointer(x).(T).
-func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) {
+func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
+ any := false
for i, param := range name.FuncType.Params {
- if len(call.Args) <= i {
+ if len(call.Call.Args) <= i {
// Avoid a crash; this will be caught when the
// generated file is compiled.
return
}
+ if p.needsPointerCheck(f, param.Go, call.Call.Args[i]) {
+ any = true
+ break
+ }
+ }
+ if !any {
+ return
+ }
- // An untyped nil does not need a pointer check, and
- // when _cgoCheckPointer returns the untyped nil the
- // type assertion we are going to insert will fail.
- // Easier to just skip nil arguments.
- // TODO: Note that this fails if nil is shadowed.
- if id, ok := call.Args[i].(*ast.Ident); ok && id.Name == "nil" {
- continue
+ // We need to rewrite this call.
+ //
+ // We are going to rewrite C.f(p) to C.f(_cgoCheckPointer(p)).
+ // If the call to C.f is deferred, that will check p at the
+ // point of the defer statement, not when the function is called, so
+ // rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p)
+
+ var dargs []ast.Expr
+ if call.Deferred {
+ dargs = make([]ast.Expr, len(name.FuncType.Params))
+ }
+ for i, param := range name.FuncType.Params {
+ origArg := call.Call.Args[i]
+ darg := origArg
+
+ if call.Deferred {
+ dargs[i] = darg
+ darg = ast.NewIdent(fmt.Sprintf("_cgo%d", i))
+ call.Call.Args[i] = darg
}
- if !p.needsPointerCheck(f, param.Go) {
+ if !p.needsPointerCheck(f, param.Go, origArg) {
continue
}
c := &ast.CallExpr{
Fun: ast.NewIdent("_cgoCheckPointer"),
Args: []ast.Expr{
- call.Args[i],
+ darg,
},
}
// Add optional additional arguments for an address
// expression.
- c.Args = p.checkAddrArgs(f, c.Args, call.Args[i])
+ c.Args = p.checkAddrArgs(f, c.Args, origArg)
// _cgoCheckPointer returns interface{}.
// We need to type assert that to the type we want.
}
}
- call.Args[i] = arg
+ call.Call.Args[i] = arg
+ }
+
+ if call.Deferred {
+ params := make([]*ast.Field, len(name.FuncType.Params))
+ for i, param := range name.FuncType.Params {
+ ptype := param.Go
+ if p.hasUnsafePointer(ptype) {
+ // Avoid generating unsafe.Pointer by using
+ // interface{}. This works because we are
+ // going to call a _cgoCheckPointer function
+ // anyhow.
+ ptype = &ast.InterfaceType{
+ Methods: &ast.FieldList{},
+ }
+ }
+ params[i] = &ast.Field{
+ Names: []*ast.Ident{
+ ast.NewIdent(fmt.Sprintf("_cgo%d", i)),
+ },
+ Type: ptype,
+ }
+ }
+
+ dbody := &ast.CallExpr{
+ Fun: call.Call.Fun,
+ Args: call.Call.Args,
+ }
+ call.Call.Fun = &ast.FuncLit{
+ Type: &ast.FuncType{
+ Params: &ast.FieldList{
+ List: params,
+ },
+ },
+ Body: &ast.BlockStmt{
+ List: []ast.Stmt{
+ &ast.ExprStmt{
+ X: dbody,
+ },
+ },
+ },
+ }
+ call.Call.Args = dargs
+ call.Call.Lparen = token.NoPos
+ call.Call.Rparen = token.NoPos
+
+ // There is a Ref pointing to the old call.Call.Fun.
+ for _, ref := range f.Ref {
+ if ref.Expr == &call.Call.Fun {
+ ref.Expr = &dbody.Fun
+ }
+ }
}
}
// needsPointerCheck returns whether the type t needs a pointer check.
// This is true if t is a pointer and if the value to which it points
// might contain a pointer.
-func (p *Package) needsPointerCheck(f *File, t ast.Expr) bool {
+func (p *Package) needsPointerCheck(f *File, t ast.Expr, arg ast.Expr) bool {
+ // An untyped nil does not need a pointer check, and when
+ // _cgoCheckPointer returns the untyped nil the type assertion we
+ // are going to insert will fail. Easier to just skip nil arguments.
+ // TODO: Note that this fails if nil is shadowed.
+ if id, ok := arg.(*ast.Ident); ok && id.Name == "nil" {
+ return false
+ }
+
return p.hasPointer(f, t, true)
}