]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: fix mishandling of unsafe-uintptr arguments in go/defer
authorCuong Manh Le <cuong.manhle.vn@gmail.com>
Tue, 8 Sep 2020 08:28:43 +0000 (15:28 +0700)
committerCuong Manh Le <cuong.manhle.vn@gmail.com>
Wed, 9 Sep 2020 07:50:01 +0000 (07:50 +0000)
Currently, the statement:

go g(uintptr(f()))

gets rewritten into:

tmp := f()
newproc(8, g, uintptr(tmp))
runtime.KeepAlive(tmp)

which doesn't guarantee that tmp is still alive by time the g call is
scheduled to run.

This CL fixes the issue, by wrapping g call in a closure:

go func(p unsafe.Pointer) {
g(uintptr(p))
}(f())

then this will be rewritten into:

tmp := f()
go func(p unsafe.Pointer) {
g(uintptr(p))
runtime.KeepAlive(p)
}(tmp)
runtime.KeepAlive(tmp)  // superfluous, but harmless

So the unsafe.Pointer p will be kept alive at the time g call runs.

Updates #24491

Change-Id: Ic10821251cbb1b0073daec92b82a866c6ebaf567
Reviewed-on: https://go-review.googlesource.com/c/go/+/253457
Run-TryBot: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/cmd/compile/internal/gc/order.go
src/cmd/compile/internal/gc/syntax.go
src/cmd/compile/internal/gc/walk.go
test/fixedbugs/issue24491.go [new file with mode: 0644]

index aa91160e5c9a3b93e0b66a39b71f914f7f143080..412f073a8d4066d69e2741a70448fa75de62ae16 100644 (file)
@@ -502,6 +502,7 @@ func (o *Order) call(n *Node) {
                        x := o.copyExpr(arg.Left, arg.Left.Type, false)
                        x.Name.SetKeepalive(true)
                        arg.Left = x
+                       n.SetNeedsWrapper(true)
                }
        }
 
index 47e5e59156c92074b224944491dca00b4450858c..5580f789c5b236e885c9117a59b68bbed4a894cb 100644 (file)
@@ -141,19 +141,20 @@ const (
        nodeInitorder, _                   // tracks state during init1; two bits
        _, _                               // second nodeInitorder bit
        _, nodeHasBreak
-       _, nodeNoInline  // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
-       _, nodeImplicit  // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
-       _, nodeIsDDD     // is the argument variadic
-       _, nodeDiag      // already printed error about this
-       _, nodeColas     // OAS resulting from :=
-       _, nodeNonNil    // guaranteed to be non-nil
-       _, nodeTransient // storage can be reused immediately after this statement
-       _, nodeBounded   // bounds check unnecessary
-       _, nodeHasCall   // expression contains a function call
-       _, nodeLikely    // if statement condition likely
-       _, nodeHasVal    // node.E contains a Val
-       _, nodeHasOpt    // node.E contains an Opt
-       _, nodeEmbedded  // ODCLFIELD embedded type
+       _, nodeNoInline     // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
+       _, nodeImplicit     // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
+       _, nodeIsDDD        // is the argument variadic
+       _, nodeDiag         // already printed error about this
+       _, nodeColas        // OAS resulting from :=
+       _, nodeNonNil       // guaranteed to be non-nil
+       _, nodeTransient    // storage can be reused immediately after this statement
+       _, nodeBounded      // bounds check unnecessary
+       _, nodeHasCall      // expression contains a function call
+       _, nodeLikely       // if statement condition likely
+       _, nodeHasVal       // node.E contains a Val
+       _, nodeHasOpt       // node.E contains an Opt
+       _, nodeEmbedded     // ODCLFIELD embedded type
+       _, nodeNeedsWrapper // OCALLxxx node that needs to be wrapped
 )
 
 func (n *Node) Class() Class     { return Class(n.flags.get3(nodeClass)) }
@@ -286,6 +287,20 @@ func (n *Node) SetIota(x int64) {
        n.Xoffset = x
 }
 
+func (n *Node) NeedsWrapper() bool {
+       return n.flags&nodeNeedsWrapper != 0
+}
+
+// SetNeedsWrapper indicates that OCALLxxx node needs to be wrapped by a closure.
+func (n *Node) SetNeedsWrapper(b bool) {
+       switch n.Op {
+       case OCALLFUNC, OCALLMETH, OCALLINTER:
+       default:
+               Fatalf("Node.SetNeedsWrapper %v", n.Op)
+       }
+       n.flags.set(nodeNeedsWrapper, b)
+}
+
 // mayBeShared reports whether n may occur in multiple places in the AST.
 // Extra care must be taken when mutating such a node.
 func (n *Node) mayBeShared() bool {
index 0158af8700f1d8ba49e539b90fc93bb1ef3949bd..ab7f857031fc7b6e0304802d2d5dfd846bd680ba 100644 (file)
@@ -232,7 +232,11 @@ func walkstmt(n *Node) *Node {
                        n.Left = copyany(n.Left, &n.Ninit, true)
 
                default:
-                       n.Left = walkexpr(n.Left, &n.Ninit)
+                       if n.Left.NeedsWrapper() {
+                               n.Left = wrapCall(n.Left, &n.Ninit)
+                       } else {
+                               n.Left = walkexpr(n.Left, &n.Ninit)
+                       }
                }
 
        case OFOR, OFORUNTIL:
@@ -3857,6 +3861,14 @@ func candiscard(n *Node) bool {
 //             builtin(a1, a2, a3)
 //     }(x, y, z)
 // for print, println, and delete.
+//
+// Rewrite
+//     go f(x, y, uintptr(unsafe.Pointer(z)))
+// into
+//     go func(a1, a2, a3) {
+//             builtin(a1, a2, uintptr(a3))
+//     }(x, y, unsafe.Pointer(z))
+// for function contains unsafe-uintptr arguments.
 
 var wrapCall_prgen int
 
@@ -3868,9 +3880,17 @@ func wrapCall(n *Node, init *Nodes) *Node {
                init.AppendNodes(&n.Ninit)
        }
 
+       isBuiltinCall := n.Op != OCALLFUNC && n.Op != OCALLMETH && n.Op != OCALLINTER
+       // origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion.
+       origArgs := make([]*Node, n.List.Len())
        t := nod(OTFUNC, nil, nil)
        for i, arg := range n.List.Slice() {
                s := lookupN("a", i)
+               if !isBuiltinCall && arg.Op == OCONVNOP && arg.Type.Etype == TUINTPTR && arg.Left.Type.Etype == TUNSAFEPTR {
+                       origArgs[i] = arg
+                       arg = arg.Left
+                       n.List.SetIndex(i, arg)
+               }
                t.List.Append(symfield(s, arg.Type))
        }
 
@@ -3878,10 +3898,22 @@ func wrapCall(n *Node, init *Nodes) *Node {
        sym := lookupN("wrap·", wrapCall_prgen)
        fn := dclfunc(sym, t)
 
-       a := nod(n.Op, nil, nil)
-       a.List.Set(paramNnames(t.Type))
-       a = typecheck(a, ctxStmt)
-       fn.Nbody.Set1(a)
+       args := paramNnames(t.Type)
+       for i, origArg := range origArgs {
+               if origArg == nil {
+                       continue
+               }
+               arg := nod(origArg.Op, args[i], nil)
+               arg.Type = origArg.Type
+               args[i] = arg
+       }
+       call := nod(n.Op, nil, nil)
+       if !isBuiltinCall {
+               call.Op = OCALL
+               call.Left = n.Left
+       }
+       call.List.Set(args)
+       fn.Nbody.Set1(call)
 
        funcbody()
 
@@ -3889,12 +3921,12 @@ func wrapCall(n *Node, init *Nodes) *Node {
        typecheckslice(fn.Nbody.Slice(), ctxStmt)
        xtop = append(xtop, fn)
 
-       a = nod(OCALL, nil, nil)
-       a.Left = fn.Func.Nname
-       a.List.Set(n.List.Slice())
-       a = typecheck(a, ctxStmt)
-       a = walkexpr(a, init)
-       return a
+       call = nod(OCALL, nil, nil)
+       call.Left = fn.Func.Nname
+       call.List.Set(n.List.Slice())
+       call = typecheck(call, ctxStmt)
+       call = walkexpr(call, init)
+       return call
 }
 
 // substArgTypes substitutes the given list of types for
diff --git a/test/fixedbugs/issue24491.go b/test/fixedbugs/issue24491.go
new file mode 100644 (file)
index 0000000..4703368
--- /dev/null
@@ -0,0 +1,45 @@
+// run
+
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This test makes sure unsafe-uintptr arguments are handled correctly.
+
+package main
+
+import (
+       "runtime"
+       "unsafe"
+)
+
+var done = make(chan bool, 1)
+
+func setup() unsafe.Pointer {
+       s := "ok"
+       runtime.SetFinalizer(&s, func(p *string) { *p = "FAIL" })
+       return unsafe.Pointer(&s)
+}
+
+//go:noinline
+//go:uintptrescapes
+func test(s string, p uintptr) {
+       runtime.GC()
+       if *(*string)(unsafe.Pointer(p)) != "ok" {
+               panic(s + " return unexpected result")
+       }
+       done <- true
+}
+
+func main() {
+       test("normal", uintptr(setup()))
+       <-done
+
+       go test("go", uintptr(setup()))
+       <-done
+
+       func() {
+               defer test("defer", uintptr(setup()))
+       }()
+       <-done
+}