// argument and then calls the original function.
// This returns whether the package needs to import unsafe as _cgo_unsafe.
func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
+ params := name.FuncType.Params
+ args := call.Call.Args
+
// Avoid a crash if the number of arguments is
// less than the number of parameters.
// This will be caught when the generated file is compiled.
- if len(call.Call.Args) < len(name.FuncType.Params) {
+ if len(args) < len(params) {
return false
}
any := false
- for i, param := range name.FuncType.Params {
- if p.needsPointerCheck(f, param.Go, call.Call.Args[i]) {
+ for i, param := range params {
+ if p.needsPointerCheck(f, param.Go, args[i]) {
any = true
break
}
// Using a function literal like this lets us do correct
// argument type checking, and works correctly if the call is
// deferred.
+ var sb bytes.Buffer
+ sb.WriteString("func(")
+
needsUnsafe := false
- params := make([]*ast.Field, len(name.FuncType.Params))
- nargs := make([]ast.Expr, len(name.FuncType.Params))
- var stmts []ast.Stmt
- for i, param := range name.FuncType.Params {
- // params is going to become the parameters of the
- // function literal.
- // nargs is going to become the list of arguments made
- // by the call within the function literal.
- // nparam is the parameter of the function literal that
- // corresponds to param.
-
- origArg := call.Call.Args[i]
- nparam := ast.NewIdent(fmt.Sprintf("_cgo%d", i))
- nargs[i] = nparam
-
- // The Go version of the C type might use unsafe.Pointer,
- // but the file might not import unsafe.
- // Rewrite the Go type if necessary to use _cgo_unsafe.
+
+ for i, param := range params {
+ if i > 0 {
+ sb.WriteString(", ")
+ }
+
+ fmt.Fprintf(&sb, "_cgo%d ", i)
+
ptype := p.rewriteUnsafe(param.Go)
if ptype != param.Go {
needsUnsafe = true
}
+ sb.WriteString(gofmtLine(ptype))
+ }
- params[i] = &ast.Field{
- Names: []*ast.Ident{nparam},
- Type: ptype,
- }
-
- if !p.needsPointerCheck(f, param.Go, origArg) {
- continue
- }
+ sb.WriteString(")")
- // Run the cgo pointer checks on nparam.
+ result := false
+ twoResults := false
- // Change the function literal to call the real function
- // with the parameter passed through _cgoCheckPointer.
- c := &ast.CallExpr{
- Fun: ast.NewIdent("_cgoCheckPointer"),
- Args: []ast.Expr{
- nparam,
- },
+ // Check whether this call expects two results.
+ for _, ref := range f.Ref {
+ if ref.Expr != &call.Call.Fun {
+ continue
}
-
- // Add optional additional arguments for an address
- // expression.
- c.Args = p.checkAddrArgs(f, c.Args, origArg)
-
- stmt := &ast.ExprStmt{
- X: c,
+ if ref.Context == ctxCall2 {
+ sb.WriteString(" (")
+ result = true
+ twoResults = true
}
- stmts = append(stmts, stmt)
+ break
}
- const cgoMarker = "__cgo__###__marker__"
- fcall := &ast.CallExpr{
- Fun: ast.NewIdent(cgoMarker),
- Args: nargs,
- }
- ftype := &ast.FuncType{
- Params: &ast.FieldList{
- List: params,
- },
- }
+ // Add the result type, if any.
if name.FuncType.Result != nil {
rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
if rtype != name.FuncType.Result.Go {
needsUnsafe = true
}
- ftype.Results = &ast.FieldList{
- List: []*ast.Field{
- &ast.Field{
- Type: rtype,
- },
- },
+ if !twoResults {
+ sb.WriteString(" ")
}
+ sb.WriteString(gofmtLine(rtype))
+ result = true
}
- // If this call expects two results, we have to
- // adjust the results of the function we generated.
- for _, ref := range f.Ref {
- if ref.Expr == &call.Call.Fun && ref.Context == ctxCall2 {
- if ftype.Results == nil {
- // An explicit void argument
- // looks odd but it seems to
- // be how cgo has worked historically.
- ftype.Results = &ast.FieldList{
- List: []*ast.Field{
- &ast.Field{
- Type: ast.NewIdent("_Ctype_void"),
- },
- },
- }
- }
- ftype.Results.List = append(ftype.Results.List,
- &ast.Field{
- Type: ast.NewIdent("error"),
- })
+ // Add the second result type, if any.
+ if twoResults {
+ if name.FuncType.Result == nil {
+ // An explicit void result looks odd but it
+ // seems to be how cgo has worked historically.
+ sb.WriteString("_Ctype_void")
}
+ sb.WriteString(", error)")
}
- var fbody ast.Stmt
- if ftype.Results == nil {
- fbody = &ast.ExprStmt{
- X: fcall,
+ sb.WriteString(" { ")
+
+ for i, param := range params {
+ arg := args[i]
+ if !p.needsPointerCheck(f, param.Go, arg) {
+ continue
}
- } else {
- fbody = &ast.ReturnStmt{
- Results: []ast.Expr{fcall},
+
+ // Check for &a[i].
+ if p.checkIndex(&sb, f, arg, i) {
+ continue
+ }
+
+ // Check for &x.
+ if p.checkAddr(&sb, arg, i) {
+ continue
}
+
+ fmt.Fprintf(&sb, "_cgoCheckPointer(_cgo%d); ", i)
}
- lit := &ast.FuncLit{
- Type: ftype,
- Body: &ast.BlockStmt{
- List: append(stmts, fbody),
- },
+
+ if result {
+ sb.WriteString("return ")
}
- text := strings.Replace(gofmt(lit), "\n", ";", -1)
- repl := strings.Split(text, cgoMarker)
- f.Edit.Insert(f.offset(call.Call.Fun.Pos()), repl[0])
- f.Edit.Insert(f.offset(call.Call.Fun.End()), repl[1])
+
+ // Now we are ready to call the C function.
+ // To work smoothly with rewriteRef we leave the call in place
+ // and just insert our new arguments between the function
+ // and the old arguments.
+ f.Edit.Insert(f.offset(call.Call.Fun.Pos()), sb.String())
+
+ sb.Reset()
+ sb.WriteString("(")
+ for i := range params {
+ if i > 0 {
+ sb.WriteString(", ")
+ }
+ fmt.Fprintf(&sb, "_cgo%d", i)
+ }
+ sb.WriteString("); }")
+
+ f.Edit.Insert(f.offset(call.Call.Lparen), sb.String())
return needsUnsafe
}
}
}
-// checkAddrArgs tries to add arguments to the call of
-// _cgoCheckPointer when the argument is an address expression. We
-// pass true to mean that the argument is an address operation of
-// something other than a slice index, which means that it's only
-// necessary to check the specific element pointed to, not the entire
-// object. This is for &s.f, where f is a field in a struct. We can
-// pass a slice or array, meaning that we should check the entire
-// slice or array but need not check any other part of the object.
-// This is for &s.a[i], where we need to check all of a. However, we
-// only pass the slice or array if we can refer to it without side
-// effects.
-func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr {
+// checkIndex checks whether arg the form &a[i], possibly inside type
+// conversions. If so, and if a has no side effects, it writes
+// _cgoCheckPointer(_cgoNN, a) to sb and returns true. This tells
+// _cgoCheckPointer to check the complete contents of the slice.
+func (p *Package) checkIndex(sb *bytes.Buffer, f *File, arg ast.Expr, i int) bool {
// Strip type conversions.
+ x := arg
for {
c, ok := x.(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
}
u, ok := x.(*ast.UnaryExpr)
if !ok || u.Op != token.AND {
- return args
+ return false
}
index, ok := u.X.(*ast.IndexExpr)
if !ok {
- // This is the address of something that is not an
- // index expression. We only need to examine the
- // single value to which it points.
- // TODO: what if true is shadowed?
- return append(args, ast.NewIdent("true"))
- }
- if !p.hasSideEffects(f, index.X) {
- // Examine the entire slice.
- return append(args, index.X)
- }
- // Treat the pointer as unknown.
- return args
+ return false
+ }
+ if p.hasSideEffects(f, index.X) {
+ return false
+ }
+
+ fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, %s); ", i, gofmtLine(index.X))
+
+ return true
+}
+
+// checkAddr checks whether arg has the form &x, possibly inside type
+// conversions. If so it writes _cgoCheckPointer(_cgoNN, true) to sb
+// and returns true. This tells _cgoCheckPointer to check just the
+// contents of the pointer being passed, not any other part of the
+// memory allocation. This is run after checkIndex, which looks for
+// the special case of &a[i], which requires different checks.
+func (p *Package) checkAddr(sb *bytes.Buffer, arg ast.Expr, i int) bool {
+ // Strip type conversions.
+ px := &arg
+ for {
+ c, ok := (*px).(*ast.CallExpr)
+ if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
+ break
+ }
+ px = &c.Args[0]
+ }
+ if u, ok := (*px).(*ast.UnaryExpr); !ok || u.Op != token.AND {
+ return false
+ }
+
+ // Use "0 == 0" to do the right thing in the unlikely event
+ // that "true" is shadowed.
+ fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, 0 == 0); ", i)
+
+ return true
}
// hasSideEffects returns whether the expression x has any side