]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/cgo: don't update each call in place
authorIan Lance Taylor <iant@golang.org>
Fri, 2 Nov 2018 05:06:51 +0000 (22:06 -0700)
committerIan Lance Taylor <iant@golang.org>
Fri, 2 Nov 2018 05:35:56 +0000 (05:35 +0000)
Updating each call in place broke when there were multiple cgo calls
used as arguments to another cgo call where some required rewriting.
Instead, rewrite calls to strings via the existing mangling mechanism,
and only substitute the top level call in place.

Fixes #28540

Change-Id: Ifd66f04c205adc4ad6dd5ee8e79e57dce17e86bb
Reviewed-on: https://go-review.googlesource.com/c/146860
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
misc/cgo/test/twoargs.go [new file with mode: 0644]
src/cmd/cgo/gcc.go
src/cmd/cgo/main.go

diff --git a/misc/cgo/test/twoargs.go b/misc/cgo/test/twoargs.go
new file mode 100644 (file)
index 0000000..ca0534c
--- /dev/null
@@ -0,0 +1,22 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Crash from call with two arguments that need pointer checking.
+// No runtime test; just make sure it compiles.
+
+package cgotest
+
+/*
+static void twoargs1(void *p, int n) {}
+static void *twoargs2() { return 0; }
+static int twoargs3(void * p) { return 0; }
+*/
+import "C"
+
+import "unsafe"
+
+func twoargsF() {
+       v := []string{}
+       C.twoargs1(C.twoargs2(), C.twoargs3(unsafe.Pointer(&v)))
+}
index 1e746ce577c73cd98512fd0e34eb676713190e0f..e8be785bf62c22c61c3745407f86d422c1a14755 100644 (file)
@@ -722,20 +722,18 @@ func (p *Package) mangleName(n *Name) {
 func (p *Package) rewriteCalls(f *File) bool {
        needsUnsafe := false
        // Walk backward so that in C.f1(C.f2()) we rewrite C.f2 first.
-       for i := len(f.Calls) - 1; i >= 0; i-- {
-               call := f.Calls[i]
-               // This is a call to C.xxx; set goname to "xxx".
-               goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
-               if goname == "malloc" {
+       for _, call := range f.Calls {
+               if call.Done {
                        continue
                }
-               name := f.Name[goname]
-               if name.Kind != "func" {
-                       // Probably a type conversion.
-                       continue
-               }
-               if p.rewriteCall(f, call, name) {
-                       needsUnsafe = true
+               start := f.offset(call.Call.Pos())
+               end := f.offset(call.Call.End())
+               str, nu := p.rewriteCall(f, call)
+               if str != "" {
+                       f.Edit.Replace(start, end, str)
+                       if nu {
+                               needsUnsafe = true
+                       }
                }
        }
        return needsUnsafe
@@ -745,8 +743,29 @@ func (p *Package) rewriteCalls(f *File) bool {
 // If any pointer checks are required, we rewrite the call into a
 // function literal that calls _cgoCheckPointer for each pointer
 // argument and then calls the original function.
-// This returns whether the package needs to import unsafe as _cgo_unsafe.
-func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
+// This returns the rewritten call and whether the package needs to
+// import unsafe as _cgo_unsafe.
+// If it returns the empty string, the call did not need to be rewritten.
+func (p *Package) rewriteCall(f *File, call *Call) (string, bool) {
+       // This is a call to C.xxx; set goname to "xxx".
+       // It may have already been mangled by rewriteName.
+       var goname string
+       switch fun := call.Call.Fun.(type) {
+       case *ast.SelectorExpr:
+               goname = fun.Sel.Name
+       case *ast.Ident:
+               goname = strings.TrimPrefix(fun.Name, "_C2func_")
+               goname = strings.TrimPrefix(goname, "_Cfunc_")
+       }
+       if goname == "" || goname == "malloc" {
+               return "", false
+       }
+       name := f.Name[goname]
+       if name == nil || name.Kind != "func" {
+               // Probably a type conversion.
+               return "", false
+       }
+
        params := name.FuncType.Params
        args := call.Call.Args
 
@@ -754,7 +773,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
        // less than the number of parameters.
        // This will be caught when the generated file is compiled.
        if len(args) < len(params) {
-               return false
+               return "", false
        }
 
        any := false
@@ -765,7 +784,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
                }
        }
        if !any {
-               return false
+               return "", false
        }
 
        // We need to rewrite this call.
@@ -848,7 +867,10 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
        // Write _cgoCheckPointer calls to sbCheck.
        var sbCheck bytes.Buffer
        for i, param := range params {
-               arg := p.mangle(f, &args[i])
+               arg, nu := p.mangle(f, &args[i])
+               if nu {
+                       needsUnsafe = true
+               }
 
                // Explicitly convert untyped constants to the
                // parameter type, to avoid a type mismatch.
@@ -893,12 +915,12 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
                sb.WriteString("return ")
        }
 
-       // Now we are ready to call the C function.
-       // To work smoothly with rewriteRef we leave the call in place
-       // and just replace the old arguments with our new ones.
-       f.Edit.Insert(f.offset(call.Call.Fun.Pos()), sb.String())
+       m, nu := p.mangle(f, &call.Call.Fun)
+       if nu {
+               needsUnsafe = true
+       }
+       sb.WriteString(gofmtLine(m))
 
-       sb.Reset()
        sb.WriteString("(")
        for i := range params {
                if i > 0 {
@@ -916,9 +938,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
        }
        sb.WriteString("()")
 
-       f.Edit.Replace(f.offset(call.Call.Lparen), f.offset(call.Call.Rparen)+1, sb.String())
-
-       return needsUnsafe
+       return sb.String(), needsUnsafe
 }
 
 // needsPointerCheck returns whether the type t needs a pointer check.
@@ -1025,32 +1045,54 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
        }
 }
 
-// mangle replaces references to C names in arg with the mangled names.
-// It removes the corresponding references in f.Ref, so that we don't
-// try to do the replacement again in rewriteRef.
-func (p *Package) mangle(f *File, arg *ast.Expr) ast.Expr {
+// mangle replaces references to C names in arg with the mangled names,
+// rewriting calls when it finds them.
+// It removes the corresponding references in f.Ref and f.Calls, so that we
+// don't try to do the replacement again in rewriteRef or rewriteCall.
+func (p *Package) mangle(f *File, arg *ast.Expr) (ast.Expr, bool) {
+       needsUnsafe := false
        f.walk(arg, ctxExpr, func(f *File, arg interface{}, context astContext) {
                px, ok := arg.(*ast.Expr)
                if !ok {
                        return
                }
                sel, ok := (*px).(*ast.SelectorExpr)
-               if !ok {
+               if ok {
+                       if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
+                               return
+                       }
+
+                       for _, r := range f.Ref {
+                               if r.Expr == px {
+                                       *px = p.rewriteName(f, r)
+                                       r.Done = true
+                                       break
+                               }
+                       }
+
                        return
                }
-               if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
+
+               call, ok := (*px).(*ast.CallExpr)
+               if !ok {
                        return
                }
 
-               for _, r := range f.Ref {
-                       if r.Expr == px {
-                               *px = p.rewriteName(f, r)
-                               r.Done = true
-                               break
+               for _, c := range f.Calls {
+                       if !c.Done && c.Call.Lparen == call.Lparen {
+                               cstr, nu := p.rewriteCall(f, c)
+                               if cstr != "" {
+                                       // Smuggle the rewritten call through an ident.
+                                       *px = ast.NewIdent(cstr)
+                                       if nu {
+                                               needsUnsafe = true
+                                       }
+                                       c.Done = true
+                               }
                        }
                }
        })
-       return *arg
+       return *arg, needsUnsafe
 }
 
 // checkIndex checks whether arg the form &a[i], possibly inside type
index 626ffe2390e2d4a24accec403ecde080e6d134aa..3098a4a63d3dcc86ebb0b86d410a02948735fcb3 100644 (file)
@@ -81,6 +81,7 @@ func nameKeys(m map[string]*Name) []string {
 type Call struct {
        Call     *ast.CallExpr
        Deferred bool
+       Done     bool
 }
 
 // A Ref refers to an expression of the form C.xxx in the AST.