]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: fix ordering problems in struct equality
authorKeith Randall <khr@golang.org>
Mon, 15 Jun 2020 16:17:18 +0000 (09:17 -0700)
committerKeith Randall <khr@golang.org>
Mon, 15 Jun 2020 18:03:08 +0000 (18:03 +0000)
Make sure that if a field comparison might panic, we evaluate
(and short circuit if not equal) all previous fields, and don't
evaluate any subsequent fields.

Add a bunch more tests to the equality+panic checker.

Update #8606

Change-Id: I6a159bbc8da5b2b7ee835c0cd1fc565575b58c46
Reviewed-on: https://go-review.googlesource.com/c/go/+/237919
Run-TryBot: Keith Randall <khr@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/cmd/compile/internal/gc/alg.go
test/fixedbugs/issue8606.go

index ecbed1a3c9c0f1f3f880ab57546a6e8ba880732b..e2e2374717e7db08efd7e677c43dc145a2590042 100644 (file)
@@ -63,6 +63,26 @@ func IncomparableField(t *types.Type) *types.Field {
        return nil
 }
 
+// EqCanPanic reports whether == on type t could panic (has an interface somewhere).
+// t must be comparable.
+func EqCanPanic(t *types.Type) bool {
+       switch t.Etype {
+       default:
+               return false
+       case TINTER:
+               return true
+       case TARRAY:
+               return EqCanPanic(t.Elem())
+       case TSTRUCT:
+               for _, f := range t.FieldSlice() {
+                       if !f.Sym.IsBlank() && EqCanPanic(f.Type) {
+                               return true
+                       }
+               }
+               return false
+       }
+}
+
 // algtype is like algtype1, except it returns the fixed-width AMEMxx variants
 // instead of the general AMEM kind when possible.
 func algtype(t *types.Type) AlgKind {
@@ -624,14 +644,19 @@ func geneq(t *types.Type) *obj.LSym {
 
        case TSTRUCT:
                // Build a list of conditions to satisfy.
-               // Track their order so that we can preserve aspects of that order.
+               // The conditions are a list-of-lists. Conditions are reorderable
+               // within each inner list. The outer lists must be evaluated in order.
+               // Even within each inner list, track their order so that we can preserve
+               // aspects of that order. (TODO: latter part needed?)
                type nodeIdx struct {
                        n   *Node
                        idx int
                }
-               var conds []nodeIdx
+               var conds [][]nodeIdx
+               conds = append(conds, []nodeIdx{})
                and := func(n *Node) {
-                       conds = append(conds, nodeIdx{n: n, idx: len(conds)})
+                       i := len(conds) - 1
+                       conds[i] = append(conds[i], nodeIdx{n: n, idx: len(conds[i])})
                }
 
                // Walk the struct using memequal for runs of AMEM
@@ -647,6 +672,10 @@ func geneq(t *types.Type) *obj.LSym {
 
                        // Compare non-memory fields with field equality.
                        if !IsRegularMemory(f.Type) {
+                               if EqCanPanic(f.Type) {
+                                       // Enforce ordering by starting a new set of reorderable conditions.
+                                       conds = append(conds, []nodeIdx{})
+                               }
                                p := nodSym(OXDOT, np, f.Sym)
                                q := nodSym(OXDOT, nq, f.Sym)
                                switch {
@@ -657,6 +686,10 @@ func geneq(t *types.Type) *obj.LSym {
                                default:
                                        and(nod(OEQ, p, q))
                                }
+                               if EqCanPanic(f.Type) {
+                                       // Also enforce ordering after something that can panic.
+                                       conds = append(conds, []nodeIdx{})
+                               }
                                i++
                                continue
                        }
@@ -680,20 +713,24 @@ func geneq(t *types.Type) *obj.LSym {
 
                // Sort conditions to put runtime calls last.
                // Preserve the rest of the ordering.
-               sort.SliceStable(conds, func(i, j int) bool {
-                       x, y := conds[i], conds[j]
-                       if (x.n.Op != OCALL) == (y.n.Op != OCALL) {
-                               return x.idx < y.idx
-                       }
-                       return x.n.Op != OCALL
-               })
+               var flatConds []nodeIdx
+               for _, c := range conds {
+                       sort.SliceStable(c, func(i, j int) bool {
+                               x, y := c[i], c[j]
+                               if (x.n.Op != OCALL) == (y.n.Op != OCALL) {
+                                       return x.idx < y.idx
+                               }
+                               return x.n.Op != OCALL
+                       })
+                       flatConds = append(flatConds, c...)
+               }
 
                var cond *Node
-               if len(conds) == 0 {
+               if len(flatConds) == 0 {
                        cond = nodbool(true)
                } else {
-                       cond = conds[0].n
-                       for _, c := range conds[1:] {
+                       cond = flatConds[0].n
+                       for _, c := range flatConds[1:] {
                                cond = nod(OANDAND, cond, c.n)
                        }
                }
index 8122b1d2b657268b1230d0ee66dcce3014895655..8c850696954746f911d1a8b2e3dc4e8a9df43978 100644 (file)
@@ -12,22 +12,65 @@ import "fmt"
 
 func main() {
        type A [2]interface{}
+       type A2 [6]interface{}
        type S struct{ x, y interface{} }
+       type S2 struct{ x, y, z, a, b, c interface{} }
+       type T1 struct {
+               i interface{}
+               a int64
+               j interface{}
+       }
+       type T2 struct {
+               i       interface{}
+               a, b, c int64
+               j       interface{}
+       }
+       type T3 struct {
+               i interface{}
+               s string
+               j interface{}
+       }
+       b := []byte{1}
 
        for _, test := range []struct {
                panic bool
                a, b  interface{}
        }{
-               {false, A{1, []byte{1}}, A{2, []byte{1}}},
-               {true, A{[]byte{1}, 1}, A{[]byte{1}, 2}},
-               {false, S{1, []byte{1}}, S{2, []byte{1}}},
-               {true, S{[]byte{1}, 1}, S{[]byte{1}, 2}},
-               {false, A{1, []byte{1}}, A{"2", []byte{1}}},
-               {true, A{[]byte{1}, 1}, A{[]byte{1}, "2"}},
-               {false, S{1, []byte{1}}, S{"2", []byte{1}}},
-               {true, S{[]byte{1}, 1}, S{[]byte{1}, "2"}},
+               {false, A{1, b}, A{2, b}},
+               {true, A{b, 1}, A{b, 2}},
+               {false, A{1, b}, A{"2", b}},
+               {true, A{b, 1}, A{b, "2"}},
+
+               {false, A2{1, b}, A2{2, b}},
+               {true, A2{b, 1}, A2{b, 2}},
+               {false, A2{1, b}, A2{"2", b}},
+               {true, A2{b, 1}, A2{b, "2"}},
+
+               {false, S{1, b}, S{2, b}},
+               {true, S{b, 1}, S{b, 2}},
+               {false, S{1, b}, S{"2", b}},
+               {true, S{b, 1}, S{b, "2"}},
+
+               {false, S2{x: 1, y: b}, S2{x: 2, y: b}},
+               {true, S2{x: b, y: 1}, S2{x: b, y: 2}},
+               {false, S2{x: 1, y: b}, S2{x: "2", y: b}},
+               {true, S2{x: b, y: 1}, S2{x: b, y: "2"}},
+
+               {true, T1{i: b, a: 1}, T1{i: b, a: 2}},
+               {false, T1{a: 1, j: b}, T1{a: 2, j: b}},
+               {true, T2{i: b, a: 1}, T2{i: b, a: 2}},
+               {false, T2{a: 1, j: b}, T2{a: 2, j: b}},
+               {true, T3{i: b, s: "foo"}, T3{i: b, s: "bar"}},
+               {false, T3{s: "foo", j: b}, T3{s: "bar", j: b}},
+               {true, T3{i: b, s: "fooz"}, T3{i: b, s: "bar"}},
+               {false, T3{s: "fooz", j: b}, T3{s: "bar", j: b}},
        } {
                f := func() {
+                       defer func() {
+                               if recover() != nil {
+                                       panic(fmt.Sprintf("comparing %#v and %#v panicked", test.a, test.b))
+                               }
+                       }()
                        if test.a == test.b {
                                panic(fmt.Sprintf("values %#v and %#v should not be equal", test.a, test.b))
                        }