]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: SSA, don't let write barrier clobber return values
authorKeith Randall <khr@golang.org>
Fri, 27 May 2016 21:07:37 +0000 (14:07 -0700)
committerKeith Randall <khr@golang.org>
Fri, 27 May 2016 22:11:45 +0000 (22:11 +0000)
When we do *p = f(), we might need to copy the return value from
f to p with a write barrier.  The write barrier itself is a call,
so we need to copy the return value of f to a temporary location
before we call the write barrier function.  Otherwise, the call
itself (specifically, marshalling the args to typedmemmove) will
clobber the value we're trying to write.

Fixes #15854

Change-Id: I5703da87634d91a9884e3ec098d7b3af713462e7
Reviewed-on: https://go-review.googlesource.com/23522
Reviewed-by: David Chase <drchase@google.com>
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/cmd/compile/internal/gc/fixedbugs_test.go [new file with mode: 0644]
src/cmd/compile/internal/gc/ssa.go

diff --git a/src/cmd/compile/internal/gc/fixedbugs_test.go b/src/cmd/compile/internal/gc/fixedbugs_test.go
new file mode 100644 (file)
index 0000000..19b1d9a
--- /dev/null
@@ -0,0 +1,50 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gc
+
+import "testing"
+
+type T struct {
+       x [2]int64 // field that will be clobbered. Also makes type not SSAable.
+       p *byte    // has a pointer
+}
+
+//go:noinline
+func makeT() T {
+       return T{}
+}
+
+var g T
+
+var sink []byte
+
+func TestIssue15854(t *testing.T) {
+       for i := 0; i < 10000; i++ {
+               if g.x[0] != 0 {
+                       t.Fatalf("g.x[0] clobbered with %x\n", g.x[0])
+               }
+               // The bug was in the following assignment. The return
+               // value of makeT() is not copied out of the args area of
+               // stack frame in a timely fashion. So when write barriers
+               // are enabled, the marshaling of the args for the write
+               // barrier call clobbers the result of makeT() before it is
+               // read by the write barrier code.
+               g = makeT()
+               sink = make([]byte, 1000) // force write barriers to eventually happen
+       }
+}
+func TestIssue15854b(t *testing.T) {
+       const N = 10000
+       a := make([]T, N)
+       for i := 0; i < N; i++ {
+               a = append(a, makeT())
+               sink = make([]byte, 1000) // force write barriers to eventually happen
+       }
+       for i, v := range a {
+               if v.x[0] != 0 {
+                       t.Fatalf("a[%d].x[0] clobbered with %x\n", i, v.x[0])
+               }
+       }
+}
index a107f91ef3e70a33b68a0476a07a5bad047c7657..b604044cb73642cb2b22f7c3393f03d00e637479 100644 (file)
@@ -580,8 +580,8 @@ func (s *state) stmt(n *Node) {
 
        case OAS2DOTTYPE:
                res, resok := s.dottype(n.Rlist.First(), true)
-               s.assign(n.List.First(), res, needwritebarrier(n.List.First(), n.Rlist.First()), false, n.Lineno, 0)
-               s.assign(n.List.Second(), resok, false, false, n.Lineno, 0)
+               s.assign(n.List.First(), res, needwritebarrier(n.List.First(), n.Rlist.First()), false, n.Lineno, 0, false)
+               s.assign(n.List.Second(), resok, false, false, n.Lineno, 0, false)
                return
 
        case ODCL:
@@ -700,13 +700,14 @@ func (s *state) stmt(n *Node) {
                        }
                }
                var r *ssa.Value
+               var isVolatile bool
                needwb := n.Op == OASWB && rhs != nil
                deref := !canSSAType(t)
                if deref {
                        if rhs == nil {
                                r = nil // Signal assign to use OpZero.
                        } else {
-                               r = s.addr(rhs, false)
+                               r, isVolatile = s.addr(rhs, false)
                        }
                } else {
                        if rhs == nil {
@@ -755,7 +756,7 @@ func (s *state) stmt(n *Node) {
                        }
                }
 
-               s.assign(n.Left, r, needwb, deref, n.Lineno, skip)
+               s.assign(n.Left, r, needwb, deref, n.Lineno, skip, isVolatile)
 
        case OIF:
                bThen := s.f.NewBlock(ssa.BlockPlain)
@@ -1438,10 +1439,10 @@ func (s *state) expr(n *Node) *ssa.Value {
                if s.canSSA(n) {
                        return s.variable(n, n.Type)
                }
-               addr := s.addr(n, false)
+               addr, _ := s.addr(n, false)
                return s.newValue2(ssa.OpLoad, n.Type, addr, s.mem())
        case OCLOSUREVAR:
-               addr := s.addr(n, false)
+               addr, _ := s.addr(n, false)
                return s.newValue2(ssa.OpLoad, n.Type, addr, s.mem())
        case OLITERAL:
                switch u := n.Val().U.(type) {
@@ -1910,7 +1911,9 @@ func (s *state) expr(n *Node) *ssa.Value {
                return s.expr(n.Left)
 
        case OADDR:
-               return s.addr(n.Left, n.Bounded)
+               a, _ := s.addr(n.Left, n.Bounded)
+               // Note we know the volatile result is false because you can't write &f() in Go.
+               return a
 
        case OINDREG:
                if int(n.Reg) != Thearch.REGSP {
@@ -1930,7 +1933,7 @@ func (s *state) expr(n *Node) *ssa.Value {
                        v := s.expr(n.Left)
                        return s.newValue1I(ssa.OpStructSelect, n.Type, int64(fieldIdx(n)), v)
                }
-               p := s.addr(n, false)
+               p, _ := s.addr(n, false)
                return s.newValue2(ssa.OpLoad, n.Type, p, s.mem())
 
        case ODOTPTR:
@@ -1957,11 +1960,11 @@ func (s *state) expr(n *Node) *ssa.Value {
                        }
                        return s.newValue2(ssa.OpLoad, Types[TUINT8], ptr, s.mem())
                case n.Left.Type.IsSlice():
-                       p := s.addr(n, false)
+                       p, _ := s.addr(n, false)
                        return s.newValue2(ssa.OpLoad, n.Left.Type.Elem(), p, s.mem())
                case n.Left.Type.IsArray():
                        // TODO: fix when we can SSA arrays of length 1.
-                       p := s.addr(n, false)
+                       p, _ := s.addr(n, false)
                        return s.newValue2(ssa.OpLoad, n.Left.Type.Elem(), p, s.mem())
                default:
                        s.Fatalf("bad type for index %v", n.Left.Type)
@@ -2126,7 +2129,7 @@ func (s *state) append(n *Node, inplace bool) *ssa.Value {
 
        var slice, addr *ssa.Value
        if inplace {
-               addr = s.addr(sn, false)
+               addr, _ = s.addr(sn, false)
                slice = s.newValue2(ssa.OpLoad, n.Type, addr, s.mem())
        } else {
                slice = s.expr(sn)
@@ -2197,15 +2200,21 @@ func (s *state) append(n *Node, inplace bool) *ssa.Value {
        }
 
        // Evaluate args
-       args := make([]*ssa.Value, 0, nargs)
-       store := make([]bool, 0, nargs)
+       type argRec struct {
+               // if store is true, we're appending the value v.  If false, we're appending the
+               // value at *v.  If store==false, isVolatile reports whether the source
+               // is in the outargs section of the stack frame.
+               v          *ssa.Value
+               store      bool
+               isVolatile bool
+       }
+       args := make([]argRec, 0, nargs)
        for _, n := range n.List.Slice()[1:] {
                if canSSAType(n.Type) {
-                       args = append(args, s.expr(n))
-                       store = append(store, true)
+                       args = append(args, argRec{v: s.expr(n), store: true})
                } else {
-                       args = append(args, s.addr(n, false))
-                       store = append(store, false)
+                       v, isVolatile := s.addr(n, false)
+                       args = append(args, argRec{v: v, isVolatile: isVolatile})
                }
        }
 
@@ -2219,17 +2228,17 @@ func (s *state) append(n *Node, inplace bool) *ssa.Value {
        // TODO: maybe just one writeBarrier.enabled check?
        for i, arg := range args {
                addr := s.newValue2(ssa.OpPtrIndex, pt, p2, s.constInt(Types[TINT], int64(i)))
-               if store[i] {
+               if arg.store {
                        if haspointers(et) {
-                               s.insertWBstore(et, addr, arg, n.Lineno, 0)
+                               s.insertWBstore(et, addr, arg.v, n.Lineno, 0)
                        } else {
-                               s.vars[&memVar] = s.newValue3I(ssa.OpStore, ssa.TypeMem, et.Size(), addr, arg, s.mem())
+                               s.vars[&memVar] = s.newValue3I(ssa.OpStore, ssa.TypeMem, et.Size(), addr, arg.v, s.mem())
                        }
                } else {
                        if haspointers(et) {
-                               s.insertWBmove(et, addr, arg, n.Lineno)
+                               s.insertWBmove(et, addr, arg.v, n.Lineno, arg.isVolatile)
                        } else {
-                               s.vars[&memVar] = s.newValue3I(ssa.OpMove, ssa.TypeMem, et.Size(), addr, arg, s.mem())
+                               s.vars[&memVar] = s.newValue3I(ssa.OpMove, ssa.TypeMem, et.Size(), addr, arg.v, s.mem())
                        }
                }
        }
@@ -2301,9 +2310,10 @@ const (
 // Right has already been evaluated to ssa, left has not.
 // If deref is true, then we do left = *right instead (and right has already been nil-checked).
 // If deref is true and right == nil, just do left = 0.
+// If deref is true, rightIsVolatile reports whether right points to volatile (clobbered by a call) storage.
 // Include a write barrier if wb is true.
 // skip indicates assignments (at the top level) that can be avoided.
-func (s *state) assign(left *Node, right *ssa.Value, wb, deref bool, line int32, skip skipMask) {
+func (s *state) assign(left *Node, right *ssa.Value, wb, deref bool, line int32, skip skipMask, rightIsVolatile bool) {
        if left.Op == ONAME && isblank(left) {
                return
        }
@@ -2344,7 +2354,7 @@ func (s *state) assign(left *Node, right *ssa.Value, wb, deref bool, line int32,
                        }
 
                        // Recursively assign the new value we've made to the base of the dot op.
-                       s.assign(left.Left, new, false, false, line, 0)
+                       s.assign(left.Left, new, false, false, line, 0, rightIsVolatile)
                        // TODO: do we need to update named values here?
                        return
                }
@@ -2354,7 +2364,7 @@ func (s *state) assign(left *Node, right *ssa.Value, wb, deref bool, line int32,
                return
        }
        // Left is not ssa-able. Compute its address.
-       addr := s.addr(left, false)
+       addr, _ := s.addr(left, false)
        if left.Op == ONAME && skip == 0 {
                s.vars[&memVar] = s.newValue1A(ssa.OpVarDef, ssa.TypeMem, left, s.mem())
        }
@@ -2365,7 +2375,7 @@ func (s *state) assign(left *Node, right *ssa.Value, wb, deref bool, line int32,
                        return
                }
                if wb {
-                       s.insertWBmove(t, addr, right, line)
+                       s.insertWBmove(t, addr, right, line, rightIsVolatile)
                        return
                }
                s.vars[&memVar] = s.newValue3I(ssa.OpMove, ssa.TypeMem, t.Size(), addr, right, s.mem())
@@ -2684,10 +2694,12 @@ func (s *state) lookupSymbol(n *Node, sym interface{}) interface{} {
 }
 
 // addr converts the address of the expression n to SSA, adds it to s and returns the SSA result.
+// Also returns a bool reporting whether the returned value is "volatile", that is it
+// points to the outargs section and thus the referent will be clobbered by any call.
 // The value that the returned Value represents is guaranteed to be non-nil.
 // If bounded is true then this address does not require a nil check for its operand
 // even if that would otherwise be implied.
-func (s *state) addr(n *Node, bounded bool) *ssa.Value {
+func (s *state) addr(n *Node, bounded bool) (*ssa.Value, bool) {
        t := Ptrto(n.Type)
        switch n.Op {
        case ONAME:
@@ -2700,41 +2712,41 @@ func (s *state) addr(n *Node, bounded bool) *ssa.Value {
                        if n.Xoffset != 0 {
                                v = s.entryNewValue1I(ssa.OpOffPtr, v.Type, n.Xoffset, v)
                        }
-                       return v
+                       return v, false
                case PPARAM:
                        // parameter slot
                        v := s.decladdrs[n]
                        if v != nil {
-                               return v
+                               return v, false
                        }
                        if n.String() == ".fp" {
                                // Special arg that points to the frame pointer.
                                // (Used by the race detector, others?)
                                aux := s.lookupSymbol(n, &ssa.ArgSymbol{Typ: n.Type, Node: n})
-                               return s.entryNewValue1A(ssa.OpAddr, t, aux, s.sp)
+                               return s.entryNewValue1A(ssa.OpAddr, t, aux, s.sp), false
                        }
                        s.Fatalf("addr of undeclared ONAME %v. declared: %v", n, s.decladdrs)
-                       return nil
+                       return nil, false
                case PAUTO:
                        aux := s.lookupSymbol(n, &ssa.AutoSymbol{Typ: n.Type, Node: n})
-                       return s.newValue1A(ssa.OpAddr, t, aux, s.sp)
+                       return s.newValue1A(ssa.OpAddr, t, aux, s.sp), false
                case PPARAMOUT: // Same as PAUTO -- cannot generate LEA early.
                        // ensure that we reuse symbols for out parameters so
                        // that cse works on their addresses
                        aux := s.lookupSymbol(n, &ssa.ArgSymbol{Typ: n.Type, Node: n})
-                       return s.newValue1A(ssa.OpAddr, t, aux, s.sp)
+                       return s.newValue1A(ssa.OpAddr, t, aux, s.sp), false
                default:
                        s.Unimplementedf("variable address class %v not implemented", classnames[n.Class])
-                       return nil
+                       return nil, false
                }
        case OINDREG:
                // indirect off a register
                // used for storing/loading arguments/returns to/from callees
                if int(n.Reg) != Thearch.REGSP {
                        s.Unimplementedf("OINDREG of non-SP register %s in addr: %v", obj.Rconv(int(n.Reg)), n)
-                       return nil
+                       return nil, false
                }
-               return s.entryNewValue1I(ssa.OpOffPtr, t, n.Xoffset, s.sp)
+               return s.entryNewValue1I(ssa.OpOffPtr, t, n.Xoffset, s.sp), true
        case OINDEX:
                if n.Left.Type.IsSlice() {
                        a := s.expr(n.Left)
@@ -2745,37 +2757,37 @@ func (s *state) addr(n *Node, bounded bool) *ssa.Value {
                                s.boundsCheck(i, len)
                        }
                        p := s.newValue1(ssa.OpSlicePtr, t, a)
-                       return s.newValue2(ssa.OpPtrIndex, t, p, i)
+                       return s.newValue2(ssa.OpPtrIndex, t, p, i), false
                } else { // array
-                       a := s.addr(n.Left, bounded)
+                       a, isVolatile := s.addr(n.Left, bounded)
                        i := s.expr(n.Right)
                        i = s.extendIndex(i)
                        len := s.constInt(Types[TINT], n.Left.Type.NumElem())
                        if !n.Bounded {
                                s.boundsCheck(i, len)
                        }
-                       return s.newValue2(ssa.OpPtrIndex, Ptrto(n.Left.Type.Elem()), a, i)
+                       return s.newValue2(ssa.OpPtrIndex, Ptrto(n.Left.Type.Elem()), a, i), isVolatile
                }
        case OIND:
-               return s.exprPtr(n.Left, bounded, n.Lineno)
+               return s.exprPtr(n.Left, bounded, n.Lineno), false
        case ODOT:
-               p := s.addr(n.Left, bounded)
-               return s.newValue1I(ssa.OpOffPtr, t, n.Xoffset, p)
+               p, isVolatile := s.addr(n.Left, bounded)
+               return s.newValue1I(ssa.OpOffPtr, t, n.Xoffset, p), isVolatile
        case ODOTPTR:
                p := s.exprPtr(n.Left, bounded, n.Lineno)
-               return s.newValue1I(ssa.OpOffPtr, t, n.Xoffset, p)
+               return s.newValue1I(ssa.OpOffPtr, t, n.Xoffset, p), false
        case OCLOSUREVAR:
                return s.newValue1I(ssa.OpOffPtr, t, n.Xoffset,
-                       s.entryNewValue0(ssa.OpGetClosurePtr, Ptrto(Types[TUINT8])))
+                       s.entryNewValue0(ssa.OpGetClosurePtr, Ptrto(Types[TUINT8]))), false
        case OCONVNOP:
-               addr := s.addr(n.Left, bounded)
-               return s.newValue1(ssa.OpCopy, t, addr) // ensure that addr has the right type
+               addr, isVolatile := s.addr(n.Left, bounded)
+               return s.newValue1(ssa.OpCopy, t, addr), isVolatile // ensure that addr has the right type
        case OCALLFUNC, OCALLINTER, OCALLMETH:
-               return s.call(n, callNormal)
+               return s.call(n, callNormal), true
 
        default:
                s.Unimplementedf("unhandled addr %v", n.Op)
-               return nil
+               return nil, false
        }
 }
 
@@ -3007,7 +3019,7 @@ func (s *state) rtcall(fn *Node, returns bool, results []*Type, args ...*ssa.Val
 
 // insertWBmove inserts the assignment *left = *right including a write barrier.
 // t is the type being assigned.
-func (s *state) insertWBmove(t *Type, left, right *ssa.Value, line int32) {
+func (s *state) insertWBmove(t *Type, left, right *ssa.Value, line int32, rightIsVolatile bool) {
        // if writeBarrier.enabled {
        //   typedmemmove(&t, left, right)
        // } else {
@@ -3038,8 +3050,25 @@ func (s *state) insertWBmove(t *Type, left, right *ssa.Value, line int32) {
        b.AddEdgeTo(bElse)
 
        s.startBlock(bThen)
-       taddr := s.newValue1A(ssa.OpAddr, Types[TUINTPTR], &ssa.ExternSymbol{Typ: Types[TUINTPTR], Sym: typenamesym(t)}, s.sb)
-       s.rtcall(typedmemmove, true, nil, taddr, left, right)
+
+       if !rightIsVolatile {
+               // Issue typedmemmove call.
+               taddr := s.newValue1A(ssa.OpAddr, Types[TUINTPTR], &ssa.ExternSymbol{Typ: Types[TUINTPTR], Sym: typenamesym(t)}, s.sb)
+               s.rtcall(typedmemmove, true, nil, taddr, left, right)
+       } else {
+               // Copy to temp location if the source is volatile (will be clobbered by
+               // a function call).  Marshaling the args to typedmemmove might clobber the
+               // value we're trying to move.
+               tmp := temp(t)
+               s.vars[&memVar] = s.newValue1A(ssa.OpVarDef, ssa.TypeMem, tmp, s.mem())
+               tmpaddr, _ := s.addr(tmp, true)
+               s.vars[&memVar] = s.newValue3I(ssa.OpMove, ssa.TypeMem, t.Size(), tmpaddr, right, s.mem())
+               // Issue typedmemmove call.
+               taddr := s.newValue1A(ssa.OpAddr, Types[TUINTPTR], &ssa.ExternSymbol{Typ: Types[TUINTPTR], Sym: typenamesym(t)}, s.sb)
+               s.rtcall(typedmemmove, true, nil, taddr, left, tmpaddr)
+               // Mark temp as dead.
+               s.vars[&memVar] = s.newValue1A(ssa.OpVarKill, ssa.TypeMem, tmp, s.mem())
+       }
        s.endBlock().AddEdgeTo(bEnd)
 
        s.startBlock(bElse)