]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: fix miscompilation of "defer delete(m, k)"
authorMatthew Dempsky <mdempsky@google.com>
Tue, 6 Mar 2018 22:36:49 +0000 (14:36 -0800)
committerMatthew Dempsky <mdempsky@google.com>
Tue, 6 Mar 2018 23:33:28 +0000 (23:33 +0000)
Previously, for slow map key types (i.e., any type other than a 32-bit
or 64-bit plain memory type), we would rewrite

    defer delete(m, k)

into

    ktmp := k
    defer delete(m, &ktmp)

However, if the defer statement was inside a loop, we would end up
reusing the same ktmp value for all of the deferred deletes.

We already rewrite

    defer print(x, y, z)

into

    defer func(a1, a2, a3) {
        print(a1, a2, a3)
    }(x, y, z)

This CL generalizes this rewrite to also apply for slow map deletes.

This could be extended to apply even more generally to other builtins,
but as discussed on #24259, there are cases where we must *not* do
this (e.g., "defer recover()"). However, if we elect to do this more
generally, this CL should still make that easier.

Lastly, while here, fix a few isues in wrapCall (nee walkprintfunc):

1) lookupN appends the generation number to the symbol anyway, so "%d"
was being literally included in the generated function names.

2) walkstmt will be called when the function is compiled later anyway,
so no need to do it now.

Fixes #24259.

Change-Id: I70286867c64c69c18e9552f69e3f4154a0fc8b04
Reviewed-on: https://go-review.googlesource.com/99017
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/cmd/compile/internal/gc/order.go
src/cmd/compile/internal/gc/walk.go
src/runtime/map_test.go

index 7d04e10e1de2598ecf6c452c9685f21ad0967389..8ae1dbcbef3a6ce2618e2cdde99be27a4e86a9db 100644 (file)
@@ -618,25 +618,7 @@ func (o *Order) stmt(n *Node) {
        // Special: order arguments to inner call but not call itself.
        case ODEFER, OPROC:
                t := o.markTemp()
-
-               switch n.Left.Op {
-               // Delete will take the address of the key.
-               // Copy key into new temp and do not clean it
-               // (it persists beyond the statement).
-               case ODELETE:
-                       o.exprList(n.Left.List)
-
-                       if mapfast(n.Left.List.First().Type) == mapslow {
-                               t1 := o.markTemp()
-                               np := n.Left.List.Addr(1) // map key
-                               *np = o.copyExpr(*np, (*np).Type, false)
-                               o.popTemp(t1)
-                       }
-
-               default:
-                       o.call(n.Left)
-               }
-
+               o.call(n.Left)
                o.out = append(o.out, n)
                o.cleanTemp(t)
 
index f1ef2341eece9dc6ef9f45379a345f653c51f5a1..8770684d876b1b2b491bc615c9d24fa2b596cca7 100644 (file)
@@ -242,9 +242,18 @@ func walkstmt(n *Node) *Node {
 
        case ODEFER:
                Curfn.Func.SetHasDefer(true)
+               fallthrough
+       case OPROC:
                switch n.Left.Op {
                case OPRINT, OPRINTN:
-                       n.Left = walkprintfunc(n.Left, &n.Ninit)
+                       n.Left = wrapCall(n.Left, &n.Ninit)
+
+               case ODELETE:
+                       if mapfast(n.Left.List.First().Type) == mapslow {
+                               n.Left = wrapCall(n.Left, &n.Ninit)
+                       } else {
+                               n.Left = walkexpr(n.Left, &n.Ninit)
+                       }
 
                case OCOPY:
                        n.Left = copyany(n.Left, &n.Ninit, true)
@@ -273,21 +282,6 @@ func walkstmt(n *Node) *Node {
                walkstmtlist(n.Nbody.Slice())
                walkstmtlist(n.Rlist.Slice())
 
-       case OPROC:
-               switch n.Left.Op {
-               case OPRINT, OPRINTN:
-                       n.Left = walkprintfunc(n.Left, &n.Ninit)
-
-               case OCOPY:
-                       n.Left = copyany(n.Left, &n.Ninit, true)
-
-               default:
-                       n.Left = walkexpr(n.Left, &n.Ninit)
-               }
-
-               // make room for size & fn arguments.
-               adjustargs(n, 2*Widthptr)
-
        case ORETURN:
                walkexprlist(n.List.Slice(), &n.Ninit)
                if n.List.Len() == 0 {
@@ -3847,45 +3841,43 @@ func candiscard(n *Node) bool {
        return true
 }
 
-// rewrite
-//     print(x, y, z)
+// Rewrite
+//     go builtin(x, y, z)
 // into
-//     func(a1, a2, a3) {
-//             print(a1, a2, a3)
+//     go func(a1, a2, a3) {
+//             builtin(a1, a2, a3)
 //     }(x, y, z)
-// and same for println.
+// for print, println, and delete.
 
-var walkprintfunc_prgen int
+var wrapCall_prgen int
 
-// The result of walkprintfunc MUST be assigned back to n, e.g.
-//     n.Left = walkprintfunc(n.Left, init)
-func walkprintfunc(n *Node, init *Nodes) *Node {
+// The result of wrapCall MUST be assigned back to n, e.g.
+//     n.Left = wrapCall(n.Left, init)
+func wrapCall(n *Node, init *Nodes) *Node {
        if n.Ninit.Len() != 0 {
                walkstmtlist(n.Ninit.Slice())
                init.AppendNodes(&n.Ninit)
        }
 
        t := nod(OTFUNC, nil, nil)
-       var printargs []*Node
-       for i, n1 := range n.List.Slice() {
+       var args []*Node
+       for i, arg := range n.List.Slice() {
                buf := fmt.Sprintf("a%d", i)
-               a := namedfield(buf, n1.Type)
+               a := namedfield(buf, arg.Type)
                t.List.Append(a)
-               printargs = append(printargs, a.Left)
+               args = append(args, a.Left)
        }
 
        oldfn := Curfn
        Curfn = nil
 
-       walkprintfunc_prgen++
-       sym := lookupN("print·%d", walkprintfunc_prgen)
+       wrapCall_prgen++
+       sym := lookupN("wrap·", wrapCall_prgen)
        fn := dclfunc(sym, t)
 
        a := nod(n.Op, nil, nil)
-       a.List.Set(printargs)
+       a.List.Set(args)
        a = typecheck(a, Etop)
-       a = walkstmt(a)
-
        fn.Nbody.Set1(a)
 
        funcbody()
index b12b09eeb61c514e3fe65ac5e3d1242b1768f63f..d1b268bda4fb1131e04010cd7bb091c35150ceb9 100644 (file)
@@ -875,3 +875,24 @@ func BenchmarkMapDelete(b *testing.B) {
        b.Run("Int64", runWith(benchmarkMapDeleteInt64, 100, 1000, 10000))
        b.Run("Str", runWith(benchmarkMapDeleteStr, 100, 1000, 10000))
 }
+
+func TestDeferDeleteSlow(t *testing.T) {
+       ks := []complex128{0, 1, 2, 3}
+
+       m := make(map[interface{}]int)
+       for i, k := range ks {
+               m[k] = i
+       }
+       if len(m) != len(ks) {
+               t.Errorf("want %d elements, got %d", len(ks), len(m))
+       }
+
+       func() {
+               for _, k := range ks {
+                       defer delete(m, k)
+               }
+       }()
+       if len(m) != 0 {
+               t.Errorf("want 0 elements, got %d", len(m))
+       }
+}