]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/cgo: rewrite pointer checking to use more function literals
authorIan Lance Taylor <iant@golang.org>
Wed, 17 Oct 2018 15:04:01 +0000 (08:04 -0700)
committerIan Lance Taylor <iant@golang.org>
Thu, 1 Nov 2018 21:54:54 +0000 (21:54 +0000)
Fixes #14210
Fixes #25941

Change-Id: Idde2d032290da3edb742b5b4f6ffeb625f05b494
Reviewed-on: https://go-review.googlesource.com/c/142884
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
misc/cgo/errors/ptr_test.go
src/cmd/cgo/gcc.go
src/cmd/cgo/main.go

index fe8dfff1d891b76eea60b67d06f4c0b9441e8759..165c2d407c6b20e4764949640bce28c25e75fda9 100644 (file)
@@ -357,6 +357,55 @@ var ptrTests = []ptrTest{
                body:    `r, _, _ := os.Pipe(); r.SetDeadline(time.Now().Add(C.US * time.Microsecond))`,
                fail:    false,
        },
+       {
+               // Test for double evaluation of channel receive.
+               name:    "chan-recv",
+               c:       `void f(char** p) {}`,
+               imports: []string{"time"},
+               body:    `c := make(chan []*C.char, 2); c <- make([]*C.char, 1); go func() { time.Sleep(10 * time.Second); panic("received twice from chan") }(); C.f(&(<-c)[0]);`,
+               fail:    false,
+       },
+       {
+               // Test that converting the address of a struct field
+               // to unsafe.Pointer still just checks that field.
+               // Issue #25941.
+               name:    "struct-field",
+               c:       `void f(void* p) {}`,
+               imports: []string{"unsafe"},
+               support: `type S struct { p *int; a [8]byte; u uintptr }`,
+               body:    `s := &S{p: new(int)}; C.f(unsafe.Pointer(&s.a))`,
+               fail:    false,
+       },
+       {
+               // Test that converting multiple struct field
+               // addresses to unsafe.Pointer still just checks those
+               // fields. Issue #25941.
+               name:    "struct-field-2",
+               c:       `void f(void* p, int r, void* s) {}`,
+               imports: []string{"unsafe"},
+               support: `type S struct { a [8]byte; p *int; b int64; }`,
+               body:    `s := &S{p: new(int)}; C.f(unsafe.Pointer(&s.a), 32, unsafe.Pointer(&s.b))`,
+               fail:    false,
+       },
+       {
+               // Test that second argument to cgoCheckPointer is
+               // evaluated when a deferred function is deferred, not
+               // when it is run.
+               name:    "defer2",
+               c:       `void f(char **pc) {}`,
+               support: `type S1 struct { s []*C.char }; type S2 struct { ps *S1 }`,
+               body:    `p := &S2{&S1{[]*C.char{nil}}}; defer C.f(&p.ps.s[0]); p.ps = nil`,
+               fail:    false,
+       },
+       {
+               // Test that indexing into a function call still
+               // examines only the slice being indexed.
+               name:    "buffer",
+               c:       `void f(void *p) {}`,
+               imports: []string{"bytes", "unsafe"},
+               body:    `var b bytes.Buffer; b.WriteString("a"); C.f(unsafe.Pointer(&b.Bytes()[0]))`,
+               fail:    false,
+       },
 }
 
 func TestPointerChecks(t *testing.T) {
index 45bf90ffc2a27b4c1e09ec4b97744e611a8e8155..1e746ce577c73cd98512fd0e34eb676713190e0f 100644 (file)
@@ -721,7 +721,9 @@ func (p *Package) mangleName(n *Name) {
 // This returns whether the package needs to import unsafe as _cgo_unsafe.
 func (p *Package) rewriteCalls(f *File) bool {
        needsUnsafe := false
-       for _, call := range f.Calls {
+       // Walk backward so that in C.f1(C.f2()) we rewrite C.f2 first.
+       for i := len(f.Calls) - 1; i >= 0; i-- {
+               call := f.Calls[i]
                // This is a call to C.xxx; set goname to "xxx".
                goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
                if goname == "malloc" {
@@ -768,103 +770,132 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
 
        // We need to rewrite this call.
        //
-       // We are going to rewrite C.f(p) to
-       //    func (_cgo0 ptype) {
+       // Rewrite C.f(p) to
+       //    func() {
+       //            _cgo0 := p
        //            _cgoCheckPointer(_cgo0)
        //            C.f(_cgo0)
-       //    }(p)
-       // Using a function literal like this lets us do correct
-       // argument type checking, and works correctly if the call is
-       // deferred.
-       var sb bytes.Buffer
-       sb.WriteString("func(")
-
-       needsUnsafe := false
-
-       for i, param := range params {
-               if i > 0 {
-                       sb.WriteString(", ")
-               }
-
-               fmt.Fprintf(&sb, "_cgo%d ", i)
+       //    }()
+       // Using a function literal like this lets us evaluate the
+       // function arguments only once while doing pointer checks.
+       // This is particularly useful when passing additional arguments
+       // to _cgoCheckPointer, as done in checkIndex and checkAddr.
+       //
+       // When the function argument is a conversion to unsafe.Pointer,
+       // we unwrap the conversion before checking the pointer,
+       // and then wrap again when calling C.f. This lets us check
+       // the real type of the pointer in some cases. See issue #25941.
+       //
+       // When the call to C.f is deferred, we use an additional function
+       // literal to evaluate the arguments at the right time.
+       //    defer func() func() {
+       //            _cgo0 := p
+       //            return func() {
+       //                    _cgoCheckPointer(_cgo0)
+       //                    C.f(_cgo0)
+       //            }
+       //    }()()
+       // This works because the defer statement evaluates the first
+       // function literal in order to get the function to call.
 
-               ptype := p.rewriteUnsafe(param.Go)
-               if ptype != param.Go {
-                       needsUnsafe = true
-               }
-               sb.WriteString(gofmtLine(ptype))
+       var sb bytes.Buffer
+       sb.WriteString("func() ")
+       if call.Deferred {
+               sb.WriteString("func() ")
        }
 
-       sb.WriteString(")")
-
+       needsUnsafe := false
        result := false
        twoResults := false
-
-       // Check whether this call expects two results.
-       for _, ref := range f.Ref {
-               if ref.Expr != &call.Call.Fun {
-                       continue
-               }
-               if ref.Context == ctxCall2 {
-                       sb.WriteString(" (")
-                       result = true
-                       twoResults = true
+       if !call.Deferred {
+               // Check whether this call expects two results.
+               for _, ref := range f.Ref {
+                       if ref.Expr != &call.Call.Fun {
+                               continue
+                       }
+                       if ref.Context == ctxCall2 {
+                               sb.WriteString("(")
+                               result = true
+                               twoResults = true
+                       }
+                       break
                }
-               break
-       }
 
-       // Add the result type, if any.
-       if name.FuncType.Result != nil {
-               rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
-               if rtype != name.FuncType.Result.Go {
-                       needsUnsafe = true
-               }
-               if !twoResults {
-                       sb.WriteString(" ")
+               // Add the result type, if any.
+               if name.FuncType.Result != nil {
+                       rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
+                       if rtype != name.FuncType.Result.Go {
+                               needsUnsafe = true
+                       }
+                       sb.WriteString(gofmtLine(rtype))
+                       result = true
                }
-               sb.WriteString(gofmtLine(rtype))
-               result = true
-       }
 
-       // Add the second result type, if any.
-       if twoResults {
-               if name.FuncType.Result == nil {
-                       // An explicit void result looks odd but it
-                       // seems to be how cgo has worked historically.
-                       sb.WriteString("_Ctype_void")
+               // Add the second result type, if any.
+               if twoResults {
+                       if name.FuncType.Result == nil {
+                               // An explicit void result looks odd but it
+                               // seems to be how cgo has worked historically.
+                               sb.WriteString("_Ctype_void")
+                       }
+                       sb.WriteString(", error)")
                }
-               sb.WriteString(", error)")
        }
 
-       sb.WriteString(" { ")
+       sb.WriteString("{ ")
 
+       // Define _cgoN for each argument value.
+       // Write _cgoCheckPointer calls to sbCheck.
+       var sbCheck bytes.Buffer
        for i, param := range params {
-               arg := args[i]
-               if !p.needsPointerCheck(f, param.Go, arg) {
+               arg := p.mangle(f, &args[i])
+
+               // Explicitly convert untyped constants to the
+               // parameter type, to avoid a type mismatch.
+               if p.isConst(f, arg) {
+                       ptype := p.rewriteUnsafe(param.Go)
+                       if ptype != param.Go {
+                               needsUnsafe = true
+                       }
+                       arg = &ast.CallExpr{
+                               Fun:  ptype,
+                               Args: []ast.Expr{arg},
+                       }
+               }
+
+               if !p.needsPointerCheck(f, param.Go, args[i]) {
+                       fmt.Fprintf(&sb, "_cgo%d := %s; ", i, gofmtLine(arg))
                        continue
                }
 
                // Check for &a[i].
-               if p.checkIndex(&sb, f, arg, i) {
+               if p.checkIndex(&sb, &sbCheck, arg, i) {
                        continue
                }
 
                // Check for &x.
-               if p.checkAddr(&sb, arg, i) {
+               if p.checkAddr(&sb, &sbCheck, arg, i) {
                        continue
                }
 
-               fmt.Fprintf(&sb, "_cgoCheckPointer(_cgo%d); ", i)
+               fmt.Fprintf(&sb, "_cgo%d := %s; ", i, gofmtLine(arg))
+               fmt.Fprintf(&sbCheck, "_cgoCheckPointer(_cgo%d); ", i)
+       }
+
+       if call.Deferred {
+               sb.WriteString("return func() { ")
        }
 
+       // Write out the calls to _cgoCheckPointer.
+       sb.WriteString(sbCheck.String())
+
        if result {
                sb.WriteString("return ")
        }
 
        // Now we are ready to call the C function.
        // To work smoothly with rewriteRef we leave the call in place
-       // and just insert our new arguments between the function
-       // and the old arguments.
+       // and just replace the old arguments with our new ones.
        f.Edit.Insert(f.offset(call.Call.Fun.Pos()), sb.String())
 
        sb.Reset()
@@ -875,9 +906,17 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
                }
                fmt.Fprintf(&sb, "_cgo%d", i)
        }
-       sb.WriteString("); }")
+       sb.WriteString("); ")
+       if call.Deferred {
+               sb.WriteString("}")
+       }
+       sb.WriteString("}")
+       if call.Deferred {
+               sb.WriteString("()")
+       }
+       sb.WriteString("()")
 
-       f.Edit.Insert(f.offset(call.Call.Lparen), sb.String())
+       f.Edit.Replace(f.offset(call.Call.Lparen), f.offset(call.Call.Rparen)+1, sb.String())
 
        return needsUnsafe
 }
@@ -986,11 +1025,44 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
        }
 }
 
+// mangle replaces references to C names in arg with the mangled names.
+// It removes the corresponding references in f.Ref, so that we don't
+// try to do the replacement again in rewriteRef.
+func (p *Package) mangle(f *File, arg *ast.Expr) ast.Expr {
+       f.walk(arg, ctxExpr, func(f *File, arg interface{}, context astContext) {
+               px, ok := arg.(*ast.Expr)
+               if !ok {
+                       return
+               }
+               sel, ok := (*px).(*ast.SelectorExpr)
+               if !ok {
+                       return
+               }
+               if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
+                       return
+               }
+
+               for _, r := range f.Ref {
+                       if r.Expr == px {
+                               *px = p.rewriteName(f, r)
+                               r.Done = true
+                               break
+                       }
+               }
+       })
+       return *arg
+}
+
 // checkIndex checks whether arg the form &a[i], possibly inside type
-// conversions. If so, and if a has no side effects, it writes
-// _cgoCheckPointer(_cgoNN, a) to sb and returns true. This tells
-// _cgoCheckPointer to check the complete contents of the slice.
-func (p *Package) checkIndex(sb *bytes.Buffer, f *File, arg ast.Expr, i int) bool {
+// conversions. If so, it writes
+//    _cgoIndexNN := a
+//    _cgoNN := &cgoIndexNN[i] // with type conversions, if any
+// to sb, and writes
+//    _cgoCheckPointer(_cgoNN, _cgoIndexNN)
+// to sbCheck, and returns true. This tells _cgoCheckPointer to check
+// the complete contents of the slice or array being indexed, but no
+// other part of the memory allocation.
+func (p *Package) checkIndex(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool {
        // Strip type conversions.
        x := arg
        for {
@@ -1008,22 +1080,29 @@ func (p *Package) checkIndex(sb *bytes.Buffer, f *File, arg ast.Expr, i int) boo
        if !ok {
                return false
        }
-       if p.hasSideEffects(f, index.X) {
-               return false
-       }
 
-       fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, %s); ", i, gofmtLine(index.X))
+       fmt.Fprintf(sb, "_cgoIndex%d := %s; ", i, gofmtLine(index.X))
+       origX := index.X
+       index.X = ast.NewIdent(fmt.Sprintf("_cgoIndex%d", i))
+       fmt.Fprintf(sb, "_cgo%d := %s; ", i, gofmtLine(arg))
+       index.X = origX
+
+       fmt.Fprintf(sbCheck, "_cgoCheckPointer(_cgo%d, _cgoIndex%d); ", i, i)
 
        return true
 }
 
 // checkAddr checks whether arg has the form &x, possibly inside type
-// conversions. If so it writes _cgoCheckPointer(_cgoNN, true) to sb
-// and returns true. This tells _cgoCheckPointer to check just the
-// contents of the pointer being passed, not any other part of the
-// memory allocation. This is run after checkIndex, which looks for
-// the special case of &a[i], which requires different checks.
-func (p *Package) checkAddr(sb *bytes.Buffer, arg ast.Expr, i int) bool {
+// conversions. If so it writes
+//    _cgoBaseNN := &x
+//    _cgoNN := _cgoBaseNN // with type conversions, if any
+// to sb, and writes
+//    _cgoCheckPointer(_cgoBaseNN, true)
+// to sbCheck, and returns true. This tells _cgoCheckPointer to check
+// just the contents of the pointer being passed, not any other part
+// of the memory allocation. This is run after checkIndex, which looks
+// for the special case of &a[i], which requires different checks.
+func (p *Package) checkAddr(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool {
        // Strip type conversions.
        px := &arg
        for {
@@ -1037,27 +1116,20 @@ func (p *Package) checkAddr(sb *bytes.Buffer, arg ast.Expr, i int) bool {
                return false
        }
 
+       fmt.Fprintf(sb, "_cgoBase%d := %s; ", i, gofmtLine(*px))
+
+       origX := *px
+       *px = ast.NewIdent(fmt.Sprintf("_cgoBase%d", i))
+       fmt.Fprintf(sb, "_cgo%d := %s; ", i, gofmtLine(arg))
+       *px = origX
+
        // Use "0 == 0" to do the right thing in the unlikely event
        // that "true" is shadowed.
-       fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, 0 == 0); ", i)
+       fmt.Fprintf(sbCheck, "_cgoCheckPointer(_cgoBase%d, 0 == 0); ", i)
 
        return true
 }
 
-// hasSideEffects returns whether the expression x has any side
-// effects.  x is an expression, not a statement, so the only side
-// effect is a function call.
-func (p *Package) hasSideEffects(f *File, x ast.Expr) bool {
-       found := false
-       f.walk(x, ctxExpr,
-               func(f *File, x interface{}, context astContext) {
-                       if _, ok := x.(*ast.CallExpr); ok {
-                               found = true
-                       }
-               })
-       return found
-}
-
 // isType returns whether the expression is definitely a type.
 // This is conservative--it returns false for an unknown identifier.
 func (p *Package) isType(t ast.Expr) bool {
@@ -1087,6 +1159,9 @@ func (p *Package) isType(t ast.Expr) bool {
 
                        return true
                }
+               if strings.HasPrefix(t.Name, "_Ctype_") {
+                       return true
+               }
        case *ast.StarExpr:
                return p.isType(t.X)
        case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType,
@@ -1097,6 +1172,29 @@ func (p *Package) isType(t ast.Expr) bool {
        return false
 }
 
+// isConst returns whether x is an untyped constant.
+func (p *Package) isConst(f *File, x ast.Expr) bool {
+       switch x := x.(type) {
+       case *ast.BasicLit:
+               return true
+       case *ast.SelectorExpr:
+               id, ok := x.X.(*ast.Ident)
+               if !ok || id.Name != "C" {
+                       return false
+               }
+               name := f.Name[x.Sel.Name]
+               if name != nil {
+                       return name.IsConst()
+               }
+       case *ast.Ident:
+               return x.Name == "nil" ||
+                       strings.HasPrefix(x.Name, "_Ciconst_") ||
+                       strings.HasPrefix(x.Name, "_Cfconst_") ||
+                       strings.HasPrefix(x.Name, "_Csconst_")
+       }
+       return false
+}
+
 // rewriteUnsafe returns a version of t with references to unsafe.Pointer
 // rewritten to use _cgo_unsafe.Pointer instead.
 func (p *Package) rewriteUnsafe(t ast.Expr) ast.Expr {
@@ -1205,11 +1303,13 @@ func (p *Package) rewriteRef(f *File) {
                *r.Expr = expr
 
                // Record source-level edit for cgo output.
-               repl := gofmt(expr)
-               if r.Name.Kind != "type" {
-                       repl = "(" + repl + ")"
+               if !r.Done {
+                       repl := gofmt(expr)
+                       if r.Name.Kind != "type" {
+                               repl = "(" + repl + ")"
+                       }
+                       f.Edit.Replace(f.offset(old.Pos()), f.offset(old.End()), repl)
                }
-               f.Edit.Replace(f.offset(old.Pos()), f.offset(old.End()), repl)
        }
 
        // Remove functions only used as expressions, so their respective
index 5bcb9754d71e5c0d224554572f1d3222eaaadc89..626ffe2390e2d4a24accec403ecde080e6d134aa 100644 (file)
@@ -88,6 +88,7 @@ type Ref struct {
        Name    *Name
        Expr    *ast.Expr
        Context astContext
+       Done    bool
 }
 
 func (r *Ref) Pos() token.Pos {