if len(needType) > 0 {
p.loadDWARF(f, needType)
}
- p.rewriteCalls(f)
+ if p.rewriteCalls(f) {
+ // Add `import _cgo_unsafe "unsafe"` as the first decl
+ // after the package statement.
+ imp := &ast.GenDecl{
+ Tok: token.IMPORT,
+ Specs: []ast.Spec{
+ &ast.ImportSpec{
+ Name: ast.NewIdent("_cgo_unsafe"),
+ Path: &ast.BasicLit{
+ Kind: token.STRING,
+ Value: `"unsafe"`,
+ },
+ },
+ },
+ }
+ f.AST.Decls = append([]ast.Decl{imp}, f.AST.Decls...)
+ }
p.rewriteRef(f)
}
// rewriteCalls rewrites all calls that pass pointers to check that
// they follow the rules for passing pointers between Go and C.
-func (p *Package) rewriteCalls(f *File) {
+// This returns whether the package needs to import unsafe as _cgo_unsafe.
+func (p *Package) rewriteCalls(f *File) bool {
+ needsUnsafe := false
for _, call := range f.Calls {
// This is a call to C.xxx; set goname to "xxx".
goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
// Probably a type conversion.
continue
}
- p.rewriteCall(f, call, name)
+ if p.rewriteCall(f, call, name) {
+ needsUnsafe = true
+ }
}
+ return needsUnsafe
}
// rewriteCall rewrites one call to add pointer checks. We replace
// each pointer argument x with _cgoCheckPointer(x).(T).
-func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
+// This returns whether the package needs to import unsafe as _cgo_unsafe.
+func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// Avoid a crash if the number of arguments is
// less than the number of parameters.
// This will be caught when the generated file is compiled.
if len(call.Call.Args) < len(name.FuncType.Params) {
- return
+ return false
}
any := false
}
}
if !any {
- return
+ return false
}
// We need to rewrite this call.
// point of the defer statement, not when the function is called, so
// rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p)
+ needsUnsafe := false
var dargs []ast.Expr
if call.Deferred {
dargs = make([]ast.Expr, len(name.FuncType.Params))
// expression.
c.Args = p.checkAddrArgs(f, c.Args, origArg)
- // _cgoCheckPointer returns interface{}.
- // We need to type assert that to the type we want.
- // If the Go version of this C type uses
- // unsafe.Pointer, we can't use a type assertion,
- // because the Go file might not import unsafe.
- // Instead we use a local variant of _cgoCheckPointer.
-
- var arg ast.Expr
- if n := p.unsafeCheckPointerName(param.Go, call.Deferred); n != "" {
- c.Fun = ast.NewIdent(n)
- arg = c
- } else {
- // In order for the type assertion to succeed,
- // we need it to match the actual type of the
- // argument. The only type we have is the
- // type of the function parameter. We know
- // that the argument type must be assignable
- // to the function parameter type, or the code
- // would not compile, but there is nothing
- // requiring that the types be exactly the
- // same. Add a type conversion to the
- // argument so that the type assertion will
- // succeed.
- c.Args[0] = &ast.CallExpr{
- Fun: param.Go,
- Args: []ast.Expr{
- c.Args[0],
- },
- }
-
- arg = &ast.TypeAssertExpr{
- X: c,
- Type: param.Go,
- }
+ // The Go version of the C type might use unsafe.Pointer,
+ // but the file might not import unsafe.
+ // Rewrite the Go type if necessary to use _cgo_unsafe.
+ ptype := p.rewriteUnsafe(param.Go)
+ if ptype != param.Go {
+ needsUnsafe = true
+ }
+
+ // In order for the type assertion to succeed, we need
+ // it to match the actual type of the argument. The
+ // only type we have is the type of the function
+ // parameter. We know that the argument type must be
+ // assignable to the function parameter type, or the
+ // code would not compile, but there is nothing
+ // requiring that the types be exactly the same. Add a
+ // type conversion to the argument so that the type
+ // assertion will succeed.
+ c.Args[0] = &ast.CallExpr{
+ Fun: ptype,
+ Args: []ast.Expr{
+ c.Args[0],
+ },
}
- call.Call.Args[i] = arg
+ call.Call.Args[i] = &ast.TypeAssertExpr{
+ X: c,
+ Type: ptype,
+ }
}
if call.Deferred {
params := make([]*ast.Field, len(name.FuncType.Params))
for i, param := range name.FuncType.Params {
- ptype := param.Go
- if p.hasUnsafePointer(ptype) {
- // Avoid generating unsafe.Pointer by using
- // interface{}. This works because we are
- // going to call a _cgoCheckPointer function
- // anyhow.
- ptype = &ast.InterfaceType{
- Methods: &ast.FieldList{},
- }
+ ptype := p.rewriteUnsafe(param.Go)
+ if ptype != param.Go {
+ needsUnsafe = true
}
params[i] = &ast.Field{
Names: []*ast.Ident{
}
}
}
+
+ return needsUnsafe
}
// needsPointerCheck returns whether the type t needs a pointer check.
return false
}
-// unsafeCheckPointerName is given the Go version of a C type. If the
-// type uses unsafe.Pointer, we arrange to build a version of
-// _cgoCheckPointer that returns that type. This avoids using a type
-// assertion to unsafe.Pointer in our copy of user code. We return
-// the name of the _cgoCheckPointer function we are going to build, or
-// the empty string if the type does not use unsafe.Pointer.
-//
-// The deferred parameter is true if this check is for the argument of
-// a deferred function. In that case we need to use an empty interface
-// as the argument type, because the deferred function we introduce in
-// rewriteCall will use an empty interface type, and we can't add a
-// type assertion. This is handled by keeping a separate list, and
-// writing out the lists separately in writeDefs.
-func (p *Package) unsafeCheckPointerName(t ast.Expr, deferred bool) string {
- if !p.hasUnsafePointer(t) {
- return ""
- }
- var buf bytes.Buffer
- conf.Fprint(&buf, fset, t)
- s := buf.String()
- checks := &p.CgoChecks
- if deferred {
- checks = &p.DeferredCgoChecks
- }
- for i, t := range *checks {
- if s == t {
- return p.unsafeCheckPointerNameIndex(i, deferred)
- }
- }
- *checks = append(*checks, s)
- return p.unsafeCheckPointerNameIndex(len(*checks)-1, deferred)
-}
-
-// hasUnsafePointer returns whether the Go type t uses unsafe.Pointer.
-// t is the Go version of a C type, so we don't need to handle every case.
-// We only care about direct references, not references via typedefs.
-func (p *Package) hasUnsafePointer(t ast.Expr) bool {
+// rewriteUnsafe returns a version of t with references to unsafe.Pointer
+// rewritten to use _cgo_unsafe.Pointer instead.
+func (p *Package) rewriteUnsafe(t ast.Expr) ast.Expr {
switch t := t.(type) {
case *ast.Ident:
// We don't see a SelectorExpr for unsafe.Pointer;
// this is created by code in this file.
- return t.Name == "unsafe.Pointer"
+ if t.Name == "unsafe.Pointer" {
+ return ast.NewIdent("_cgo_unsafe.Pointer")
+ }
case *ast.ArrayType:
- return p.hasUnsafePointer(t.Elt)
+ t1 := p.rewriteUnsafe(t.Elt)
+ if t1 != t.Elt {
+ r := *t
+ r.Elt = t1
+ return &r
+ }
case *ast.StructType:
+ changed := false
+ fields := *t.Fields
+ fields.List = nil
for _, f := range t.Fields.List {
- if p.hasUnsafePointer(f.Type) {
- return true
+ ft := p.rewriteUnsafe(f.Type)
+ if ft == f.Type {
+ fields.List = append(fields.List, f)
+ } else {
+ fn := *f
+ fn.Type = ft
+ fields.List = append(fields.List, &fn)
+ changed = true
}
}
+ if changed {
+ r := *t
+ r.Fields = &fields
+ return &r
+ }
case *ast.StarExpr: // Pointer type.
- return p.hasUnsafePointer(t.X)
- }
- return false
-}
-
-// unsafeCheckPointerNameIndex returns the name to use for a
-// _cgoCheckPointer variant based on the index in the CgoChecks slice.
-func (p *Package) unsafeCheckPointerNameIndex(i int, deferred bool) string {
- if deferred {
- return fmt.Sprintf("_cgoCheckPointerInDefer%d", i)
+ x1 := p.rewriteUnsafe(t.X)
+ if x1 != t.X {
+ r := *t
+ r.X = x1
+ return &r
+ }
}
- return fmt.Sprintf("_cgoCheckPointer%d", i)
+ return t
}
// rewriteRef rewrites all the C.xxx references in f.AST to refer to the