]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: make CSE faster
authorKeith Randall <khr@golang.org>
Tue, 4 Oct 2016 21:35:45 +0000 (14:35 -0700)
committerKeith Randall <khr@golang.org>
Wed, 5 Oct 2016 17:00:08 +0000 (17:00 +0000)
To refine a set of possibly equivalent values, the old CSE algorithm
picked one value, compared it against all the others, and made two sets
out of the results (the values that match the picked value and the
values that didn't).  Unfortunately, this leads to O(n^2) behavior. The
picked value ends up being equal to no other values, we make size 1 and
size n-1 sets, and then recurse on the size n-1 set.

Instead, sort the set by the equivalence classes of its arguments.  Then
we just look for spots in the sorted list where the equivalence classes
of the arguments change.  This lets us do a multi-way split for O(n lg
n) time.

This change makes cmpDepth unnecessary.

The refinement portion used to call the type comparator.  That is
unnecessary as the type was already part of the initial partition.

Lowers time of 16361 from 8 sec to 3 sec.
Lowers time of 15112 from 282 sec to 20 sec. That's kind of unfair, as
CL 30257 changed it from 21 sec to 282 sec. But that CL fixed other bad
compile times (issue #17127) by large factors, so net still a big win.

Fixes #15112
Fixes #16361

Change-Id: I351ce111bae446608968c6d48710eeb6a3d8e527
Reviewed-on: https://go-review.googlesource.com/30354
Reviewed-by: Todd Neal <todd@tneal.org>
src/cmd/compile/internal/ssa/cse.go

index 532232de57c9783164cd934a6448028be519ba43..24f071bcfd0ff3822672af05d7b729bb2944a8b8 100644 (file)
@@ -9,10 +9,6 @@ import (
        "sort"
 )
 
-const (
-       cmpDepth = 1
-)
-
 // cse does common-subexpression elimination on the Function.
 // Values are just relinked, nothing is deleted. A subsequent deadcode
 // pass is required to actually remove duplicate expressions.
@@ -60,7 +56,8 @@ func cse(f *Func) {
                        valueEqClass[v.ID] = -v.ID
                }
        }
-       for i, e := range partition {
+       var pNum ID = 1
+       for _, e := range partition {
                if f.pass.debug > 1 && len(e) > 500 {
                        fmt.Printf("CSE.large partition (%d): ", len(e))
                        for j := 0; j < 3; j++ {
@@ -70,20 +67,22 @@ func cse(f *Func) {
                }
 
                for _, v := range e {
-                       valueEqClass[v.ID] = ID(i)
+                       valueEqClass[v.ID] = pNum
                }
                if f.pass.debug > 2 && len(e) > 1 {
-                       fmt.Printf("CSE.partition #%d:", i)
+                       fmt.Printf("CSE.partition #%d:", pNum)
                        for _, v := range e {
                                fmt.Printf(" %s", v.String())
                        }
                        fmt.Printf("\n")
                }
+               pNum++
        }
 
-       // Find an equivalence class where some members of the class have
-       // non-equivalent arguments. Split the equivalence class appropriately.
-       // Repeat until we can't find any more splits.
+       // Split equivalence classes at points where they have
+       // non-equivalent arguments.  Repeat until we can't find any
+       // more splits.
+       var splitPoints []int
        for {
                changed := false
 
@@ -91,39 +90,51 @@ func cse(f *Func) {
                // we process new additions as they arrive, avoiding O(n^2) behavior.
                for i := 0; i < len(partition); i++ {
                        e := partition[i]
-                       v := e[0]
-                       // all values in this equiv class that are not equivalent to v get moved
-                       // into another equiv class.
-                       // To avoid allocating while building that equivalence class,
-                       // move the values equivalent to v to the beginning of e
-                       // and other values to the end of e.
-                       allvals := e
-               eqloop:
-                       for j := 1; j < len(e); {
-                               w := e[j]
-                               equivalent := true
-                               for i := 0; i < len(v.Args); i++ {
-                                       if valueEqClass[v.Args[i].ID] != valueEqClass[w.Args[i].ID] {
-                                               equivalent = false
+
+                       // Sort by eq class of arguments.
+                       sort.Sort(partitionByArgClass{e, valueEqClass})
+
+                       // Find split points.
+                       splitPoints = append(splitPoints[:0], 0)
+                       for j := 1; j < len(e); j++ {
+                               v, w := e[j-1], e[j]
+                               eqArgs := true
+                               for k, a := range v.Args {
+                                       b := w.Args[k]
+                                       if valueEqClass[a.ID] != valueEqClass[b.ID] {
+                                               eqArgs = false
                                                break
                                        }
                                }
-                               if !equivalent || v.Type.Compare(w.Type) != CMPeq {
-                                       // w is not equivalent to v.
-                                       // move it to the end and shrink e.
-                                       e[j], e[len(e)-1] = e[len(e)-1], e[j]
-                                       e = e[:len(e)-1]
-                                       valueEqClass[w.ID] = ID(len(partition))
-                                       changed = true
-                                       continue eqloop
+                               if !eqArgs {
+                                       splitPoints = append(splitPoints, j)
                                }
-                               // v and w are equivalent. Keep w in e.
-                               j++
                        }
-                       partition[i] = e
-                       if len(e) < len(allvals) {
-                               partition = append(partition, allvals[len(e):])
+                       if len(splitPoints) == 1 {
+                               continue // no splits, leave equivalence class alone.
                        }
+
+                       // Move another equivalence class down in place of e.
+                       partition[i] = partition[len(partition)-1]
+                       partition = partition[:len(partition)-1]
+                       i--
+
+                       // Add new equivalence classes for the parts of e we found.
+                       splitPoints = append(splitPoints, len(e))
+                       for j := 0; j < len(splitPoints)-1; j++ {
+                               f := e[splitPoints[j]:splitPoints[j+1]]
+                               if len(f) == 1 {
+                                       // Don't add singletons.
+                                       valueEqClass[f[0].ID] = -f[0].ID
+                                       continue
+                               }
+                               for _, v := range f {
+                                       valueEqClass[v.ID] = pNum
+                               }
+                               pNum++
+                               partition = append(partition, f)
+                       }
+                       changed = true
                }
 
                if !changed {
@@ -253,7 +264,7 @@ func partitionValues(a []*Value, auxIDs auxmap) []eqclass {
                j := 1
                for ; j < len(a); j++ {
                        w := a[j]
-                       if cmpVal(v, w, auxIDs, cmpDepth) != CMPeq {
+                       if cmpVal(v, w, auxIDs) != CMPeq {
                                break
                        }
                }
@@ -274,7 +285,7 @@ func lt2Cmp(isLt bool) Cmp {
 
 type auxmap map[interface{}]int32
 
-func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp {
+func cmpVal(v, w *Value, auxIDs auxmap) Cmp {
        // Try to order these comparison by cost (cheaper first)
        if v.Op != w.Op {
                return lt2Cmp(v.Op < w.Op)
@@ -308,18 +319,6 @@ func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp {
                return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux])
        }
 
-       if depth > 0 {
-               for i := range v.Args {
-                       if v.Args[i] == w.Args[i] {
-                               // skip comparing equal args
-                               continue
-                       }
-                       if ac := cmpVal(v.Args[i], w.Args[i], auxIDs, depth-1); ac != CMPeq {
-                               return ac
-                       }
-               }
-       }
-
        return CMPeq
 }
 
@@ -334,7 +333,7 @@ func (sv sortvalues) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
 func (sv sortvalues) Less(i, j int) bool {
        v := sv.a[i]
        w := sv.a[j]
-       if cmp := cmpVal(v, w, sv.auxIDs, cmpDepth); cmp != CMPeq {
+       if cmp := cmpVal(v, w, sv.auxIDs); cmp != CMPeq {
                return cmp == CMPlt
        }
 
@@ -354,3 +353,25 @@ func (sv partitionByDom) Less(i, j int) bool {
        w := sv.a[j]
        return sv.sdom.domorder(v.Block) < sv.sdom.domorder(w.Block)
 }
+
+type partitionByArgClass struct {
+       a       []*Value // array of values
+       eqClass []ID     // equivalence class IDs of values
+}
+
+func (sv partitionByArgClass) Len() int      { return len(sv.a) }
+func (sv partitionByArgClass) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
+func (sv partitionByArgClass) Less(i, j int) bool {
+       v := sv.a[i]
+       w := sv.a[j]
+       for i, a := range v.Args {
+               b := w.Args[i]
+               if sv.eqClass[a.ID] < sv.eqClass[b.ID] {
+                       return true
+               }
+               if sv.eqClass[a.ID] > sv.eqClass[b.ID] {
+                       return false
+               }
+       }
+       return false
+}