]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: fix mishandling of unsafe-uintptr arguments with call method in go/defer
authorCuong Manh Le <cuong.manhle.vn@gmail.com>
Fri, 26 Feb 2021 03:17:09 +0000 (10:17 +0700)
committerCuong Manh Le <cuong.manhle.vn@gmail.com>
Fri, 26 Feb 2021 10:18:26 +0000 (10:18 +0000)
In CL 253457, we did the same fix for direct function calls. But for
method calls, the receiver argument also need to be passed through the
wrapper function, which we are not doing so the compiler crashes with
the code in #44415.

It will be nicer if we can rewrite OCALLMETHOD to normal OCALLFUNC, but
that will be for future CL. The passing receiver argument to wrapper
function is easier for backporting to go1.16 branch.

Fixes #44415

Change-Id: I03607a64429042c6066ce673931db9769deb3124
Reviewed-on: https://go-review.googlesource.com/c/go/+/296490
Trust: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Run-TryBot: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Go Bot <gobot@golang.org>

src/cmd/compile/internal/walk/stmt.go
test/fixedbugs/issue24491a.go

index 46a621c2ba74d9aa6b236e8d1316af82ab87dce0..0c851506cbe8e3aa6e60c9c099d25fa6b42e46e7 100644 (file)
@@ -253,15 +253,22 @@ func wrapCall(n *ir.CallExpr, init *ir.Nodes) ir.Node {
                }
        }
 
+       wrapArgs := n.Args
+       // If there's a receiver argument, it needs to be passed through the wrapper too.
+       if n.Op() == ir.OCALLMETH || n.Op() == ir.OCALLINTER {
+               recv := n.X.(*ir.SelectorExpr).X
+               wrapArgs = append([]ir.Node{recv}, wrapArgs...)
+       }
+
        // origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion.
-       origArgs := make([]ir.Node, len(n.Args))
+       origArgs := make([]ir.Node, len(wrapArgs))
        var funcArgs []*ir.Field
-       for i, arg := range n.Args {
+       for i, arg := range wrapArgs {
                s := typecheck.LookupNum("a", i)
                if !isBuiltinCall && arg.Op() == ir.OCONVNOP && arg.Type().IsUintptr() && arg.(*ir.ConvExpr).X.Type().IsUnsafePtr() {
                        origArgs[i] = arg
                        arg = arg.(*ir.ConvExpr).X
-                       n.Args[i] = arg
+                       wrapArgs[i] = arg
                }
                funcArgs = append(funcArgs, ir.NewField(base.Pos, s, nil, arg.Type()))
        }
@@ -278,6 +285,12 @@ func wrapCall(n *ir.CallExpr, init *ir.Nodes) ir.Node {
                }
                args[i] = ir.NewConvExpr(base.Pos, origArg.Op(), origArg.Type(), args[i])
        }
+       if n.Op() == ir.OCALLMETH || n.Op() == ir.OCALLINTER {
+               // Move wrapped receiver argument back to its appropriate place.
+               recv := typecheck.Expr(args[0])
+               n.X.(*ir.SelectorExpr).X = recv
+               args = args[1:]
+       }
        call := ir.NewCallExpr(base.Pos, n.Op(), n.X, args)
        if !isBuiltinCall {
                call.SetOp(ir.OCALL)
@@ -291,6 +304,6 @@ func wrapCall(n *ir.CallExpr, init *ir.Nodes) ir.Node {
        typecheck.Stmts(fn.Body)
        typecheck.Target.Decls = append(typecheck.Target.Decls, fn)
 
-       call = ir.NewCallExpr(base.Pos, ir.OCALL, fn.Nname, n.Args)
+       call = ir.NewCallExpr(base.Pos, ir.OCALL, fn.Nname, wrapArgs)
        return walkExpr(typecheck.Stmt(call), init)
 }
index 8accf8c0a376f0dca01e860560a1b2244c2f62e1..d30b65b233c98c90851fc52067bebec3d03424df 100644 (file)
@@ -48,6 +48,14 @@ func f() int {
        return test("return", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup()))
 }
 
+type S struct{}
+
+//go:noinline
+//go:uintptrescapes
+func (S) test(s string, p, q uintptr, rest ...uintptr) int {
+       return test(s, p, q, rest...)
+}
+
 func main() {
        test("normal", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup()))
        <-done
@@ -60,6 +68,29 @@ func main() {
        }()
        <-done
 
+       func() {
+               for {
+                       defer test("defer in for loop", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup()))
+                       break
+               }
+       }()
+
+       <-done
+       func() {
+               s := &S{}
+               defer s.test("method call", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup()))
+       }()
+       <-done
+
+       func() {
+               s := &S{}
+               for {
+                       defer s.test("defer method loop", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup()))
+                       break
+               }
+       }()
+       <-done
+
        f()
        <-done
 }