]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: optimize generated struct/array equality code
authorKeith Randall <khr@golang.org>
Wed, 16 Sep 2020 17:13:37 +0000 (10:13 -0700)
committerKeith Randall <khr@golang.org>
Wed, 28 Oct 2020 04:33:41 +0000 (04:33 +0000)
Use a standard "not-equal" label that we can jump to when we
detect that the arguments are not equal. This prevents the
recombination that was noticed in #39428.

Fixes #39428

Change-Id: Ib7c6b3539f4f6046043fd7148f940fcdcab70427
Reviewed-on: https://go-review.googlesource.com/c/go/+/255317
Trust: Keith Randall <khr@golang.org>
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Josh Bleecher Snyder <josharian@gmail.com>
src/cmd/compile/internal/gc/alg.go

index 2ab69c2c564cd0f1c6e145360bcb65c6b1f260ce..2f7fa27bb9100ab6571feac54e0da5b70fe18947 100644 (file)
@@ -529,6 +529,10 @@ func geneq(t *types.Type) *obj.LSym {
        fn := dclfunc(sym, tfn)
        np := asNode(tfn.Type.Params().Field(0).Nname)
        nq := asNode(tfn.Type.Params().Field(1).Nname)
+       nr := asNode(tfn.Type.Results().Field(0).Nname)
+
+       // Label to jump to if an equality test fails.
+       neq := autolabel(".neq")
 
        // We reach here only for types that have equality but
        // cannot be handled by the standard algorithms,
@@ -555,13 +559,13 @@ func geneq(t *types.Type) *obj.LSym {
                // for i := 0; i < nelem; i++ {
                //   if eq(p[i], q[i]) {
                //   } else {
-               //     return
+               //     goto neq
                //   }
                // }
                //
                // TODO(josharian): consider doing some loop unrolling
                // for larger nelem as well, processing a few elements at a time in a loop.
-               checkAll := func(unroll int64, eq func(pi, qi *Node) *Node) {
+               checkAll := func(unroll int64, last bool, eq func(pi, qi *Node) *Node) {
                        // checkIdx generates a node to check for equality at index i.
                        checkIdx := func(i *Node) *Node {
                                // pi := p[i]
@@ -576,37 +580,38 @@ func geneq(t *types.Type) *obj.LSym {
                        }
 
                        if nelem <= unroll {
+                               if last {
+                                       // Do last comparison in a different manner.
+                                       nelem--
+                               }
                                // Generate a series of checks.
-                               var cond *Node
                                for i := int64(0); i < nelem; i++ {
-                                       c := nodintconst(i)
-                                       check := checkIdx(c)
-                                       if cond == nil {
-                                               cond = check
-                                               continue
-                                       }
-                                       cond = nod(OANDAND, cond, check)
+                                       // if check {} else { goto neq }
+                                       nif := nod(OIF, checkIdx(nodintconst(i)), nil)
+                                       nif.Rlist.Append(nodSym(OGOTO, nil, neq))
+                                       fn.Nbody.Append(nif)
+                               }
+                               if last {
+                                       fn.Nbody.Append(nod(OAS, nr, checkIdx(nodintconst(nelem))))
+                               }
+                       } else {
+                               // Generate a for loop.
+                               // for i := 0; i < nelem; i++
+                               i := temp(types.Types[TINT])
+                               init := nod(OAS, i, nodintconst(0))
+                               cond := nod(OLT, i, nodintconst(nelem))
+                               post := nod(OAS, i, nod(OADD, i, nodintconst(1)))
+                               loop := nod(OFOR, cond, post)
+                               loop.Ninit.Append(init)
+                               // if eq(pi, qi) {} else { goto neq }
+                               nif := nod(OIF, checkIdx(i), nil)
+                               nif.Rlist.Append(nodSym(OGOTO, nil, neq))
+                               loop.Nbody.Append(nif)
+                               fn.Nbody.Append(loop)
+                               if last {
+                                       fn.Nbody.Append(nod(OAS, nr, nodbool(true)))
                                }
-                               nif := nod(OIF, cond, nil)
-                               nif.Rlist.Append(nod(ORETURN, nil, nil))
-                               fn.Nbody.Append(nif)
-                               return
                        }
-
-                       // Generate a for loop.
-                       // for i := 0; i < nelem; i++
-                       i := temp(types.Types[TINT])
-                       init := nod(OAS, i, nodintconst(0))
-                       cond := nod(OLT, i, nodintconst(nelem))
-                       post := nod(OAS, i, nod(OADD, i, nodintconst(1)))
-                       loop := nod(OFOR, cond, post)
-                       loop.Ninit.Append(init)
-                       // if eq(pi, qi) {} else { return }
-                       check := checkIdx(i)
-                       nif := nod(OIF, check, nil)
-                       nif.Rlist.Append(nod(ORETURN, nil, nil))
-                       loop.Nbody.Append(nif)
-                       fn.Nbody.Append(loop)
                }
 
                switch t.Elem().Etype {
@@ -614,32 +619,28 @@ func geneq(t *types.Type) *obj.LSym {
                        // Do two loops. First, check that all the lengths match (cheap).
                        // Second, check that all the contents match (expensive).
                        // TODO: when the array size is small, unroll the length match checks.
-                       checkAll(3, func(pi, qi *Node) *Node {
+                       checkAll(3, false, func(pi, qi *Node) *Node {
                                // Compare lengths.
                                eqlen, _ := eqstring(pi, qi)
                                return eqlen
                        })
-                       checkAll(1, func(pi, qi *Node) *Node {
+                       checkAll(1, true, func(pi, qi *Node) *Node {
                                // Compare contents.
                                _, eqmem := eqstring(pi, qi)
                                return eqmem
                        })
                case TFLOAT32, TFLOAT64:
-                       checkAll(2, func(pi, qi *Node) *Node {
+                       checkAll(2, true, func(pi, qi *Node) *Node {
                                // p[i] == q[i]
                                return nod(OEQ, pi, qi)
                        })
                // TODO: pick apart structs, do them piecemeal too
                default:
-                       checkAll(1, func(pi, qi *Node) *Node {
+                       checkAll(1, true, func(pi, qi *Node) *Node {
                                // p[i] == q[i]
                                return nod(OEQ, pi, qi)
                        })
                }
-               // return true
-               ret := nod(ORETURN, nil, nil)
-               ret.List.Append(nodbool(true))
-               fn.Nbody.Append(ret)
 
        case TSTRUCT:
                // Build a list of conditions to satisfy.
@@ -717,20 +718,40 @@ func geneq(t *types.Type) *obj.LSym {
                        flatConds = append(flatConds, c...)
                }
 
-               var cond *Node
                if len(flatConds) == 0 {
-                       cond = nodbool(true)
+                       fn.Nbody.Append(nod(OAS, nr, nodbool(true)))
                } else {
-                       cond = flatConds[0]
-                       for _, c := range flatConds[1:] {
-                               cond = nod(OANDAND, cond, c)
+                       for _, c := range flatConds[:len(flatConds)-1] {
+                               // if cond {} else { goto neq }
+                               n := nod(OIF, c, nil)
+                               n.Rlist.Append(nodSym(OGOTO, nil, neq))
+                               fn.Nbody.Append(n)
                        }
+                       fn.Nbody.Append(nod(OAS, nr, flatConds[len(flatConds)-1]))
                }
+       }
 
-               ret := nod(ORETURN, nil, nil)
-               ret.List.Append(cond)
-               fn.Nbody.Append(ret)
+       // ret:
+       //   return
+       ret := autolabel(".ret")
+       fn.Nbody.Append(nodSym(OLABEL, nil, ret))
+       fn.Nbody.Append(nod(ORETURN, nil, nil))
+
+       // neq:
+       //   r = false
+       //   return (or goto ret)
+       fn.Nbody.Append(nodSym(OLABEL, nil, neq))
+       fn.Nbody.Append(nod(OAS, nr, nodbool(false)))
+       if EqCanPanic(t) || hasCall(fn) {
+               // Epilogue is large, so share it with the equal case.
+               fn.Nbody.Append(nodSym(OGOTO, nil, ret))
+       } else {
+               // Epilogue is small, so don't bother sharing.
+               fn.Nbody.Append(nod(ORETURN, nil, nil))
        }
+       // TODO(khr): the epilogue size detection condition above isn't perfect.
+       // We should really do a generic CL that shares epilogues across
+       // the board. See #24936.
 
        if Debug.r != 0 {
                dumplist("geneq body", fn.Nbody)
@@ -762,6 +783,39 @@ func geneq(t *types.Type) *obj.LSym {
        return closure
 }
 
+func hasCall(n *Node) bool {
+       if n.Op == OCALL || n.Op == OCALLFUNC {
+               return true
+       }
+       if n.Left != nil && hasCall(n.Left) {
+               return true
+       }
+       if n.Right != nil && hasCall(n.Right) {
+               return true
+       }
+       for _, x := range n.Ninit.Slice() {
+               if hasCall(x) {
+                       return true
+               }
+       }
+       for _, x := range n.Nbody.Slice() {
+               if hasCall(x) {
+                       return true
+               }
+       }
+       for _, x := range n.List.Slice() {
+               if hasCall(x) {
+                       return true
+               }
+       }
+       for _, x := range n.Rlist.Slice() {
+               if hasCall(x) {
+                       return true
+               }
+       }
+       return false
+}
+
 // eqfield returns the node
 //     p.field == q.field
 func eqfield(p *Node, q *Node, field *types.Sym) *Node {