]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile/internal/liveness: enhance mergelocals for addr-taken candidates
authorThan McIntosh <thanm@google.com>
Wed, 3 Apr 2024 16:06:23 +0000 (16:06 +0000)
committerThan McIntosh <thanm@google.com>
Tue, 9 Apr 2024 17:42:19 +0000 (17:42 +0000)
It is possible to have situations where a given ir.Name is
non-address-taken at the source level, but whose address is
materialized in order to accommodate the needs of arch-dependent
memory ops. The issue here is that the SymAddr op will show up as
touching a variable of interest, but the subsequent memory op will
not. This is generally not an issue for computing whether something is
live across a call, but it is problematic for collecting the more
fine-grained live interval info that drives stack slot merging.

As an example, consider this Go code:

    package p
    type T struct {
    x [10]int
    f float64
    }
    func ABC(i, j int) int {
    var t T
    t.x[i&3] = j
    return t.x[j&3]
    }

On amd64 the code sequences we'll see for accesses to "t" might look like

    v10 = VarDef <mem> {t} v1
    v5 = MOVOstoreconst <mem> {t} [val=0,off=0] v2 v10
    v23 = LEAQ <*T> {t} [8] v2 : DI
    v12 = DUFFZERO <mem> [80] v23 v5
    v14 = ANDQconst <int> [3] v7 : AX
    v19 = MOVQstoreidx8 <mem> {t} v2 v14 v8 v12
    v22 = ANDQconst <int> [3] v8 : BX
    v24 = MOVQloadidx8 <int> {t} v2 v22 v19 : AX
    v25 = MakeResult <int,mem> v24 v19 : <>

Note that the the loads and stores (ex: v19, v24) all refer directly
to "t", which means that regular live analysis will work fine for
identifying variable lifetimes. The DUFFZERO is (in effect) an
indirect write, but since there are accesses immediately after it we
wind up with the same live intervals.

Now the same code with GOARCH=ppc64:

    v10 = VarDef <mem> {t} v1
    v20 = MOVDaddr <*T> {t} v2 : R20
    v12 = LoweredZero <mem> [88] v20 v10
     v3 = CLRLSLDI <int> [212543] v7 : R5
    v15 = MOVDaddr <*T> {t} v2 : R6
    v19 = MOVDstoreidx <mem> v15 v3 v8 v12
    v29 = CLRLSLDI <int> [212543] v8 : R4
    v24 = MOVDloadidx <int> v15 v29 v19 : R3
    v25 = MakeResult <int,mem> v24 v19 : <>

Here instead of memory ops that refer directly to the symbol, we take
the address of "t" (ex: v15) and then pass the address to memory ops
(where the ops themselves no longer refer to the symbol).

This patch enhances the stack slot merging liveness analysis to handle
cases like the PPC64 one above. We add a new phase in candidate
selection that collects more precise use information for merge
candidates, and screens out candidates that are too difficult to
analyze. The phase make a forward pass over each basic block looking
for instructions of the form vK := SymAddr(N) where N is a raw
candidate. It then creates an entry in a map with key vK and value
holding name and the vK use count. As the walk continues, we check for
uses of of vK: when we see one, record it in a side table as an
upwards exposed use of N. At each vK use we also decrement the use
count in the map entry, and if we hit zero, remove the map entry. If
we hit the end of the basic block and we still have map entries, this
implies that the address in question "escapes" the block -- at that
point to be conservative we just evict the name in question from the
candidate set.

Although this CL fixes the issues that forced a revert of the original
merging CL, this CL doesn't enable stack slot merging by default; a
subsequent CL will do that.

Updates #62737.
Updates #65532.
Updates #65495.

Change-Id: Id41d359a677767a8e7ac1e962ae23f7becb4031f
Reviewed-on: https://go-review.googlesource.com/c/go/+/576735
Reviewed-by: Cherry Mui <cherryyz@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/cmd/compile/internal/base/debug.go
src/cmd/compile/internal/liveness/mergelocals.go
src/cmd/compile/internal/liveness/plive.go
src/cmd/compile/internal/ssagen/pgen.go
src/cmd/compile/internal/test/mergelocals_test.go
src/cmd/compile/internal/test/testdata/mergelocals/integration.go

index 08ccef30656747da15cd35a51225581cb2ff8fe5..672e3909e4989c4f35225e32221e603a8fab93a0 100644 (file)
@@ -45,6 +45,7 @@ type DebugFlags struct {
        MergeLocalsDumpFunc   string `help:"dump specified func in merge locals"`
        MergeLocalsHash       string `help:"hash value for debugging stack slot merging of local variables" concurrent:"ok"`
        MergeLocalsTrace      int    `help:"trace debug output for locals merging"`
+       MergeLocalsHTrace     int    `help:"hash-selected trace debug output for locals merging"`
        Nil                   int    `help:"print information about nil checks"`
        NoOpenDefer           int    `help:"disable open-coded defers" concurrent:"ok"`
        NoRefName             int    `help:"do not include referenced symbol names in object file" concurrent:"ok"`
index 82440beb6c51cee7e94fe955b2bffb8fe3d89337..aae57cb0666677bc765f22ec862729b103663e93 100644 (file)
@@ -53,12 +53,22 @@ type candRegion struct {
 // of its auto variables can safely share the same stack slot, returning
 // a state object that describes how the overlap should be done.
 func MergeLocals(fn *ir.Func, f *ssa.Func) *MergeLocalsState {
-       cands, idx, regions := collectMergeCandidates(fn)
-       if len(regions) == 0 {
+
+       // Create a container object for useful state info and then
+       // call collectMergeCandidates to see if there are vars suitable
+       // for stack slot merging.
+       cs := &cstate{
+               fn:    fn,
+               f:     f,
+               trace: base.Debug.MergeLocalsTrace,
+       }
+       cs.collectMergeCandidates()
+       if len(cs.regions) == 0 {
                return nil
        }
-       lv := newliveness(fn, f, cands, idx, 0)
 
+       // Kick off liveness analysis.
+       //
        // If we have a local variable such as "r2" below that's written
        // but then not read, something like:
        //
@@ -75,16 +85,17 @@ func MergeLocals(fn *ir.Func, f *ssa.Func) *MergeLocalsState {
        // can ignore "r2" completely during liveness analysis for stack
        // maps, however for stack slock merging we most definitely want
        // to treat the writes as "uses".
-       lv.conservativeWrites = true
+       cs.lv = newliveness(fn, f, cs.cands, cs.nameToSlot, 0)
+       cs.lv.conservativeWrites = true
+       cs.lv.prologue()
+       cs.lv.solve()
 
-       lv.prologue()
-       lv.solve()
-       cs := &cstate{
-               fn:        fn,
-               ibuilders: make([]IntervalsBuilder, len(cands)),
-       }
-       computeIntervals(lv, cs)
-       rv := performMerging(lv, cs, regions)
+       // Compute intervals for each candidate based on the liveness and
+       // on block effects.
+       cs.computeIntervals()
+
+       // Perform merging within each region of the candidates list.
+       rv := cs.performMerging()
        if err := rv.check(); err != nil {
                base.FatalfAt(fn.Pos(), "invalid mergelocals state: %v", err)
        }
@@ -245,20 +256,20 @@ func (mls *MergeLocalsState) String() string {
 }
 
 // collectMergeCandidates visits all of the AUTO vars declared in
-// function fn and returns a list of candidate variables for merging /
-// overlapping. Return values are: 1) a slice of ir.Name's
-// corresponding to the candidates, 2) a map that maps ir.Name to slot
-// in the slice, and 3) a slice containing regions (start/end pairs)
-// corresponding to variables that could be overlapped provided that
-// their lifetimes are disjoint.
-func collectMergeCandidates(fn *ir.Func) ([]*ir.Name, map[*ir.Name]int32, []candRegion) {
-       m := make(map[*ir.Name]int32)
+// function fn and identifies a list of candidate variables for
+// merging / overlapping. On return the "cands" field of cs will be
+// filled in with our set of potentially overlappable candidate
+// variables, the "regions" field will hold regions/sequence of
+// compatible vars within the candidates list, "nameToSlot" field will
+// be populated, and the "indirectUE" field will be filled in with
+// information about indirect upwards-exposed uses in the func.
+func (cs *cstate) collectMergeCandidates() {
        var cands []*ir.Name
-       var regions []candRegion
 
        // Collect up the available set of appropriate AUTOs in the
-       // function as a first step.
-       for _, n := range fn.Dcl {
+       // function as a first step, and bail if we have fewer than
+       // two candidates.
+       for _, n := range cs.fn.Dcl {
                if !n.Used() {
                        continue
                }
@@ -268,49 +279,60 @@ func collectMergeCandidates(fn *ir.Func) ([]*ir.Name, map[*ir.Name]int32, []cand
                cands = append(cands, n)
        }
        if len(cands) < 2 {
-               return nil, nil, nil
+               return
        }
 
        // Sort by pointerness, size, and then name.
        sort.SliceStable(cands, func(i, j int) bool {
-               ci, cj := cands[i], cands[j]
-               ihp, jhp := 0, 0
-               var ilsym, jlsym *obj.LSym
-               if ci.Type().HasPointers() {
-                       ihp = 1
-                       ilsym, _, _ = reflectdata.GCSym(ci.Type())
-               }
-               if cj.Type().HasPointers() {
-                       jhp = 1
-                       jlsym, _, _ = reflectdata.GCSym(cj.Type())
-               }
-               if ihp != jhp {
-                       return ihp < jhp
-               }
-               if ci.Type().Size() != cj.Type().Size() {
-                       return ci.Type().Size() < cj.Type().Size()
-               }
-               if ihp != 0 && jhp != 0 && ilsym != jlsym {
-                       // FIXME: find less clunky way to do this
-                       return fmt.Sprintf("%v", ilsym) < fmt.Sprintf("%v", jlsym)
-               }
-               if ci.Sym().Name != cj.Sym().Name {
-                       return ci.Sym().Name < cj.Sym().Name
-               }
-               return fmt.Sprintf("%v", ci.Pos()) < fmt.Sprintf("%v", ci.Pos())
+               return nameLess(cands[i], cands[j])
        })
 
-       if base.Debug.MergeLocalsTrace > 1 {
-               fmt.Fprintf(os.Stderr, "=-= raw cand list for func %v:\n", fn)
+       if cs.trace > 1 {
+               fmt.Fprintf(os.Stderr, "=-= raw cand list for func %v:\n", cs.fn)
                for i := range cands {
                        dumpCand(cands[i], i)
                }
        }
 
-       // Now generate a pruned candidate list-- we only want to return a
-       // non-empty list if there is some possibility of overlapping two
-       // vars.
+       // Now generate an initial pruned candidate list and regions list.
+       // This may be empty if we don't have enough compatible candidates.
+       initial, _ := genRegions(cands)
+       if len(initial) < 2 {
+               return
+       }
+
+       // When bisecting it can be handy to see debug trace output for
+       // only those functions that hashdebug selects; set this up here.
+       cs.setupHashTrace(initial)
+
+       // Create and populate an indirect use table that we'll use
+       // during interval construction. As part of this process we may
+       // wind up tossing out additional candidates, so check to make
+       // sure we still have something to work with.
+       cs.cands, cs.regions = cs.populateIndirectUseTable(initial)
+       if len(cs.cands) < 2 {
+               return
+       }
+
+       // At this point we have a final pruned set of candidates and a
+       // corresponding set of regions for the candidates. Build a
+       // name-to-slot map for the candidates.
+       cs.nameToSlot = make(map[*ir.Name]int32)
+       for i, n := range cs.cands {
+               cs.nameToSlot[n] = int32(i)
+       }
+
+       if cs.trace > 1 {
+               fmt.Fprintf(os.Stderr, "=-= pruned candidate list for fn %v:\n", cs.fn)
+               for i := range cs.cands {
+                       dumpCand(cs.cands[i], i)
+               }
+       }
+}
+
+func genRegions(cands []*ir.Name) ([]*ir.Name, []candRegion) {
        var pruned []*ir.Name
+       var regions []candRegion
        st := 0
        for {
                en := nextRegion(cands, st)
@@ -334,19 +356,264 @@ func collectMergeCandidates(fn *ir.Func) ([]*ir.Name, map[*ir.Name]int32, []cand
                st = en + 1
        }
        if len(pruned) < 2 {
-               return nil, nil, nil
+               return nil, nil
+       }
+       return pruned, regions
+}
+
+func (cs *cstate) dumpFunc() {
+       fmt.Fprintf(os.Stderr, "=-= mergelocalsdumpfunc %v:\n", cs.fn)
+       ii := 0
+       for k, b := range cs.f.Blocks {
+               fmt.Fprintf(os.Stderr, "b%d:\n", k)
+               for _, v := range b.Values {
+                       pos := base.Ctxt.PosTable.Pos(v.Pos)
+                       fmt.Fprintf(os.Stderr, "=-= %d L%d|C%d %s\n", ii, pos.RelLine(), pos.RelCol(), v.LongString())
+                       ii++
+               }
+       }
+}
+
+func (cs *cstate) dumpFuncIfSelected() {
+       if base.Debug.MergeLocalsDumpFunc == "" {
+               return
        }
-       for i, n := range pruned {
-               m[n] = int32(i)
+       if !strings.HasSuffix(fmt.Sprintf("%v", cs.fn),
+               base.Debug.MergeLocalsDumpFunc) {
+               return
        }
+       cs.dumpFunc()
+}
 
-       if base.Debug.MergeLocalsTrace > 1 {
-               fmt.Fprintf(os.Stderr, "=-= pruned candidate list for func %v:\n", fn)
-               for i := range pruned {
-                       dumpCand(pruned[i], i)
+func (cs *cstate) setupHashTrace(cands []*ir.Name) {
+       if base.Debug.MergeLocalsHTrace == 0 || base.Debug.MergeLocalsHash == "" {
+               return
+       }
+
+       // With this trace variant, check to see whether any of the
+       // candidates are selected-- if yes then enable tracing. Hack:
+       // create a new hashdebug with verbosity turned off and use that
+       // to test, so as not to confuse bisect.
+       modified := strings.ReplaceAll(base.Debug.MergeLocalsHash, "v", "q")
+       quiethd := base.NewHashDebug("qmergelocals", modified, nil)
+       found := false
+       for _, cand := range cands {
+               if !quiethd.MatchPosWithInfo(cand.Pos(), "quiet", nil) {
+                       found = true
+                       fmt.Fprintf(os.Stderr, "=-= MergeLocalsHTrace fn=%v n=%s match\n",
+                               cs.fn, cand.Sym().Name)
+                       break
                }
        }
-       return pruned, m, regions
+       if found {
+               cs.trace = base.Debug.MergeLocalsHTrace
+       }
+}
+
+// populateIndirectUseTable creates and populates the "indirectUE" table
+// within cs by doing some additional analysis of how the vars in
+// cands are accessed in the function.
+//
+// It is possible to have situations where a given ir.Name is
+// non-address-taken at the source level, but whose address is
+// materialized in order to accomodate the needs of
+// architecture-dependent operations or one sort or another (examples
+// include things like LoweredZero/DuffZero, etc). The issue here is
+// that the SymAddr op will show up as touching a variable of
+// interest, but the subsequent memory op will not. This is generally
+// not an issue for computing whether something is live across a call,
+// but it is problematic for collecting the more fine-grained live
+// interval info that drives stack slot merging.
+//
+// To handle this problem, make a forward pass over each basic block
+// looking for instructions of the form vK := SymAddr(N) where N is a
+// raw candidate. Create an entry in a map at that point from vK to
+// its use count. Continue the walk, looking for uses of vK: when we
+// see one, record it in a side table as an upwards exposed use of N.
+// Each time we see a use, decrement the use count in the map, and if
+// we hit zero, remove the map entry. If we hit the end of the basic
+// block and we still have map entries, then evict the name in
+// question from the candidate set.
+func (cs *cstate) populateIndirectUseTable(cands []*ir.Name) ([]*ir.Name, []candRegion) {
+
+       // main indirect UE table, this is what we're producing in this func
+       indirectUE := make(map[ssa.ID][]*ir.Name)
+
+       // this map holds the current set of candidates; the set may
+       // shrink if we have to evict any candidates.
+       rawcands := make(map[*ir.Name]struct{})
+
+       // maps ssa value V to the ir.Name it is taking the addr of,
+       // plus a count of the uses we've seen of V during a block walk.
+       pendingUses := make(map[ssa.ID]nameCount)
+
+       // A temporary indirect UE tab just for the current block
+       // being processed; used to help with evictions.
+       blockIndirectUE := make(map[ssa.ID][]*ir.Name)
+
+       // temporary map used to record evictions in a given block.
+       evicted := make(map[*ir.Name]bool)
+       for _, n := range cands {
+               rawcands[n] = struct{}{}
+       }
+       for k := 0; k < len(cs.f.Blocks); k++ {
+               genmapclear(pendingUses)
+               genmapclear(blockIndirectUE)
+               b := cs.f.Blocks[k]
+               for _, v := range b.Values {
+                       if n, e := affectedVar(v); n != nil {
+                               if _, ok := rawcands[n]; ok {
+                                       if e&ssa.SymAddr != 0 && v.Uses != 0 {
+                                               // we're taking the address of candidate var n
+                                               if _, ok := pendingUses[v.ID]; ok {
+                                                       // should never happen
+                                                       base.FatalfAt(v.Pos, "internal error: apparent multiple defs for SSA value %d", v.ID)
+                                               }
+                                               // Stash an entry in pendingUses recording
+                                               // that we took the address of "n" via this
+                                               // val.
+                                               pendingUses[v.ID] = nameCount{n: n, count: v.Uses}
+                                               if cs.trace > 2 {
+                                                       fmt.Fprintf(os.Stderr, "=-= SymAddr(%s) on %s\n",
+                                                               n.Sym().Name, v.LongString())
+                                               }
+                                       }
+                               }
+                       }
+                       for _, arg := range v.Args {
+                               if nc, ok := pendingUses[arg.ID]; ok {
+                                       // We found a use of some value that took the
+                                       // address of nc.n. Record this inst as a
+                                       // potential indirect use.
+                                       if cs.trace > 2 {
+                                               fmt.Fprintf(os.Stderr, "=-= add indirectUE(%s) count=%d on %s\n", nc.n.Sym().Name, nc.count, v.LongString())
+                                       }
+                                       blockIndirectUE[v.ID] = append(blockIndirectUE[v.ID], nc.n)
+                                       nc.count--
+                                       if nc.count == 0 {
+                                               // That was the last use of the value. Clean
+                                               // up the entry in pendingUses.
+                                               if cs.trace > 2 {
+                                                       fmt.Fprintf(os.Stderr, "=-= last use of v%d\n",
+                                                               arg.ID)
+                                               }
+                                               delete(pendingUses, arg.ID)
+                                       } else {
+                                               // Not the last use; record the decremented
+                                               // use count and move on.
+                                               pendingUses[arg.ID] = nc
+                                       }
+                               }
+                       }
+               }
+
+               // We've reached the end of this basic block: if we have any
+               // leftover entries in pendingUses, then evict the
+               // corresponding names from the candidate set. The idea here
+               // is that if we materialized the address of some local and
+               // that value is flowing out of the block off somewhere else,
+               // we're going to treat that local as truly address-taken and
+               // not have it be a merge candidate.
+               genmapclear(evicted)
+               if len(pendingUses) != 0 {
+                       for id, nc := range pendingUses {
+                               if cs.trace > 2 {
+                                       fmt.Fprintf(os.Stderr, "=-= evicting %q due to pendingUse %d count %d\n", nc.n.Sym().Name, id, nc.count)
+                               }
+                               delete(rawcands, nc.n)
+                               evicted[nc.n] = true
+                       }
+               }
+               // Copy entries from blockIndirectUE into final indirectUE. Skip
+               // anything that we evicted in the loop above.
+               for id, sl := range blockIndirectUE {
+                       for _, n := range sl {
+                               if evicted[n] {
+                                       continue
+                               }
+                               indirectUE[id] = append(indirectUE[id], n)
+                               if cs.trace > 2 {
+                                       fmt.Fprintf(os.Stderr, "=-= add final indUE v%d name %s\n", id, n.Sym().Name)
+                               }
+                       }
+               }
+       }
+       if len(rawcands) < 2 {
+               return nil, nil
+       }
+       cs.indirectUE = indirectUE
+       if cs.trace > 2 {
+               fmt.Fprintf(os.Stderr, "=-= iuetab:\n")
+               ids := make([]ssa.ID, 0, len(indirectUE))
+               for k := range indirectUE {
+                       ids = append(ids, k)
+               }
+               sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] })
+               for _, id := range ids {
+                       fmt.Fprintf(os.Stderr, "  v%d:", id)
+                       for _, n := range indirectUE[id] {
+                               fmt.Fprintf(os.Stderr, " %s", n.Sym().Name)
+                       }
+                       fmt.Fprintf(os.Stderr, "\n")
+               }
+       }
+
+       pruned := cands[:0]
+       for k := range rawcands {
+               pruned = append(pruned, k)
+       }
+       sort.Slice(pruned, func(i, j int) bool {
+               return nameLess(pruned[i], pruned[j])
+       })
+       var regions []candRegion
+       pruned, regions = genRegions(pruned)
+       if len(pruned) < 2 {
+               return nil, nil
+       }
+       return pruned, regions
+}
+
+// FIXME: bootstrap tool compiler is build with a "go 1.20" go.mod, so
+// we are not allowed to use map clear yet. Use this helper instead.
+func genmapclear[KT comparable, VT any](m map[KT]VT) {
+       for k := range m {
+               delete(m, k)
+       }
+}
+
+type nameCount struct {
+       n     *ir.Name
+       count int32
+}
+
+// nameLess compares ci with cj to see if ci should be less than cj
+// in a relative ordering of candidate variables. This is used to
+// sort vars by size, pointerness, and GC shape.
+func nameLess(ci, cj *ir.Name) bool {
+       ihp, jhp := 0, 0
+       var ilsym, jlsym *obj.LSym
+       if ci.Type().HasPointers() {
+               ihp = 1
+               ilsym, _, _ = reflectdata.GCSym(ci.Type())
+       }
+       if cj.Type().HasPointers() {
+               jhp = 1
+               jlsym, _, _ = reflectdata.GCSym(cj.Type())
+       }
+       if ihp != jhp {
+               return ihp < jhp
+       }
+       if ci.Type().Size() != cj.Type().Size() {
+               return ci.Type().Size() < cj.Type().Size()
+       }
+       if ihp != 0 && jhp != 0 && ilsym != jlsym {
+               // FIXME: find less clunky way to do this
+               return fmt.Sprintf("%v", ilsym) < fmt.Sprintf("%v", jlsym)
+       }
+       if ci.Sym().Name != cj.Sym().Name {
+               return ci.Sym().Name < cj.Sym().Name
+       }
+       return fmt.Sprintf("%v", ci.Pos()) < fmt.Sprintf("%v", cj.Pos())
 }
 
 // nextRegion starts at location idx and walks forward in the cands
@@ -384,9 +651,19 @@ func nextRegion(cands []*ir.Name, idx int) int {
        return n - 1
 }
 
+// cstate holds state information we'll need during the analysis
+// phase of stack slot merging but can be discarded when the analysis
+// is done.
 type cstate struct {
-       fn        *ir.Func
-       ibuilders []IntervalsBuilder
+       fn         *ir.Func
+       f          *ssa.Func
+       lv         *liveness
+       cands      []*ir.Name
+       nameToSlot map[*ir.Name]int32
+       regions    []candRegion
+       indirectUE map[ssa.ID][]*ir.Name
+       ivs        []Intervals
+       trace      int // debug trace level
 }
 
 // mergeVisitRegion tries to perform overlapping of variables with a
@@ -397,8 +674,8 @@ type cstate struct {
 // first element in the st->en range, then walk the rest of the
 // elements adding in vars whose lifetimes don't overlap with the
 // first element, then repeat the process until we run out of work to do.
-func (mls *MergeLocalsState) mergeVisitRegion(lv *liveness, ivs []Intervals, st, en int) {
-       if base.Debug.MergeLocalsTrace > 1 {
+func (cs *cstate) mergeVisitRegion(mls *MergeLocalsState, st, en int) {
+       if cs.trace > 1 {
                fmt.Fprintf(os.Stderr, "=-= mergeVisitRegion(st=%d, en=%d)\n", st, en)
        }
        n := en - st + 1
@@ -415,8 +692,9 @@ func (mls *MergeLocalsState) mergeVisitRegion(lv *liveness, ivs []Intervals, st,
        }
 
        navail := n
-       cands := lv.vars
-       if base.Debug.MergeLocalsTrace > 1 {
+       cands := cs.cands
+       ivs := cs.ivs
+       if cs.trace > 1 {
                fmt.Fprintf(os.Stderr, "  =-= navail = %d\n", navail)
        }
        for navail >= 2 {
@@ -424,7 +702,7 @@ func (mls *MergeLocalsState) mergeVisitRegion(lv *liveness, ivs []Intervals, st,
                used.Set(int32(leader - st))
                navail--
 
-               if base.Debug.MergeLocalsTrace > 1 {
+               if cs.trace > 1 {
                        fmt.Fprintf(os.Stderr, "  =-= begin leader %d used=%s\n", leader,
                                used.String())
                }
@@ -443,7 +721,7 @@ func (mls *MergeLocalsState) mergeVisitRegion(lv *liveness, ivs []Intervals, st,
                        if used.Get(int32(succ - st)) {
                                continue
                        }
-                       if base.Debug.MergeLocalsTrace > 1 {
+                       if cs.trace > 1 {
                                fmt.Fprintf(os.Stderr, "  =-= overlap of %d[%v] {%s} with %d[%v] {%s} is: %v\n", leader, cands[leader], lints.String(), succ, cands[succ], ivs[succ].String(), lints.Overlaps(ivs[succ]))
                        }
 
@@ -470,7 +748,7 @@ func (mls *MergeLocalsState) mergeVisitRegion(lv *liveness, ivs []Intervals, st,
                        for i := range elems {
                                used.Set(int32(elems[i] - st))
                        }
-                       if base.Debug.MergeLocalsTrace > 1 {
+                       if cs.trace > 1 {
                                fmt.Fprintf(os.Stderr, "=-= overlapping %+v:\n", sl)
                                for i := range sl {
                                        dumpCand(mls.vars[sl[i]], sl[i])
@@ -486,59 +764,32 @@ func (mls *MergeLocalsState) mergeVisitRegion(lv *liveness, ivs []Intervals, st,
 // performMerging carries out variable merging within each of the
 // candidate ranges in regions, returning a state object
 // that describes the variable overlaps.
-func performMerging(lv *liveness, cs *cstate, regions []candRegion) *MergeLocalsState {
-       cands := lv.vars
+func (cs *cstate) performMerging() *MergeLocalsState {
+       cands := cs.cands
+
        mls := &MergeLocalsState{
                partition: make(map[*ir.Name][]int),
        }
 
-       // Finish intervals construction.
-       ivs := make([]Intervals, len(cands))
-       for i := range cands {
-               var err error
-               ivs[i], err = cs.ibuilders[i].Finish()
-               if err != nil {
-                       ninstr := 0
-                       if base.Debug.MergeLocalsTrace != 0 {
-                               iidx := 0
-                               for k := 0; k < len(lv.f.Blocks); k++ {
-                                       b := lv.f.Blocks[k]
-                                       fmt.Fprintf(os.Stderr, "\n")
-                                       for _, v := range b.Values {
-                                               fmt.Fprintf(os.Stderr, " b%d %d: %s\n", k, iidx, v.LongString())
-                                               iidx++
-                                               ninstr++
-                                       }
-                               }
-                       }
-                       base.FatalfAt(cands[i].Pos(), "interval construct error for var %q in func %q (%d instrs): %v", cands[i].Sym().Name, ir.FuncName(cs.fn), ninstr, err)
-                       return nil
-               }
-       }
-
        // Dump state before attempting overlap.
-       if base.Debug.MergeLocalsTrace > 1 {
+       if cs.trace > 1 {
                fmt.Fprintf(os.Stderr, "=-= cands live before overlap:\n")
                for i := range cands {
                        c := cands[i]
                        fmt.Fprintf(os.Stderr, "%d: %v sz=%d ivs=%s\n",
-                               i, c.Sym().Name, c.Type().Size(), ivs[i].String())
+                               i, c.Sym().Name, c.Type().Size(), cs.ivs[i].String())
                }
-               fmt.Fprintf(os.Stderr, "=-= regions (%d): ", len(regions))
-               for _, cr := range regions {
+               fmt.Fprintf(os.Stderr, "=-= regions (%d): ", len(cs.regions))
+               for _, cr := range cs.regions {
                        fmt.Fprintf(os.Stderr, " [%d,%d]", cr.st, cr.en)
                }
                fmt.Fprintf(os.Stderr, "\n")
        }
 
-       if base.Debug.MergeLocalsTrace > 1 {
-               fmt.Fprintf(os.Stderr, "=-= len(regions) = %d\n", len(regions))
-       }
-
        // Apply a greedy merge/overlap strategy within each region
        // of compatible variables.
-       for _, cr := range regions {
-               mls.mergeVisitRegion(lv, ivs, cr.st, cr.en)
+       for _, cr := range cs.regions {
+               cs.mergeVisitRegion(mls, cr.st, cr.en)
        }
        if len(mls.vars) == 0 {
                return nil
@@ -550,23 +801,13 @@ func performMerging(lv *liveness, cs *cstate, regions []candRegion) *MergeLocals
 // of the function we're compiling, building up an Intervals object
 // for each candidate variable by looking for upwards exposed uses
 // and kills.
-func computeIntervals(lv *liveness, cs *cstate) {
+func (cs *cstate) computeIntervals() {
+       lv := cs.lv
+       ibuilders := make([]IntervalsBuilder, len(cs.cands))
        nvars := int32(len(lv.vars))
        liveout := bitvec.New(nvars)
 
-       if base.Debug.MergeLocalsDumpFunc != "" &&
-               strings.HasSuffix(fmt.Sprintf("%v", cs.fn), base.Debug.MergeLocalsDumpFunc) {
-               fmt.Fprintf(os.Stderr, "=-= mergelocalsdumpfunc %v:\n", cs.fn)
-               ii := 0
-               for k, b := range lv.f.Blocks {
-                       fmt.Fprintf(os.Stderr, "b%d:\n", k)
-                       for _, v := range b.Values {
-                               pos := base.Ctxt.PosTable.Pos(v.Pos)
-                               fmt.Fprintf(os.Stderr, "=-= %d L%d|C%d %s\n", ii, pos.RelLine(), pos.RelCol(), v.LongString())
-                               ii++
-                       }
-               }
-       }
+       cs.dumpFuncIfSelected()
 
        // Count instructions.
        ninstr := 0
@@ -581,7 +822,7 @@ func computeIntervals(lv *liveness, cs *cstate) {
                b := lv.f.Blocks[k]
                be := lv.blockEffects(b)
 
-               if base.Debug.MergeLocalsTrace > 2 {
+               if cs.trace > 2 {
                        fmt.Fprintf(os.Stderr, "=-= liveout from tail of b%d: ", k)
                        for j := range lv.vars {
                                if be.liveout.Get(int32(j)) {
@@ -602,17 +843,17 @@ func computeIntervals(lv *liveness, cs *cstate) {
                        blockLiveOut := be.liveout.Get(int32(j))
                        if isLive {
                                if !blockLiveOut {
-                                       if base.Debug.MergeLocalsTrace > 2 {
+                                       if cs.trace > 2 {
                                                fmt.Fprintf(os.Stderr, "=+= at instr %d block boundary kill of %v\n", iidx, lv.vars[j])
                                        }
-                                       cs.ibuilders[j].Kill(iidx)
+                                       ibuilders[j].Kill(iidx)
                                }
                        } else if blockLiveOut {
-                               if base.Debug.MergeLocalsTrace > 2 {
+                               if cs.trace > 2 {
                                        fmt.Fprintf(os.Stderr, "=+= at block-end instr %d %v becomes live\n",
                                                iidx, lv.vars[j])
                                }
-                               cs.ibuilders[j].Live(iidx)
+                               ibuilders[j].Live(iidx)
                        }
                }
 
@@ -624,7 +865,7 @@ func computeIntervals(lv *liveness, cs *cstate) {
                for i := len(b.Values) - 1; i >= 0; i-- {
                        v := b.Values[i]
 
-                       if base.Debug.MergeLocalsTrace > 2 {
+                       if cs.trace > 2 {
                                fmt.Fprintf(os.Stderr, "=-= b%d instr %d: %s\n", k, iidx, v.LongString())
                        }
 
@@ -639,24 +880,92 @@ func computeIntervals(lv *liveness, cs *cstate) {
                                panic("should never happen")
                        }
                        if iskilled && liveout.Get(pos) {
-                               cs.ibuilders[pos].Kill(iidx)
+                               ibuilders[pos].Kill(iidx)
                                liveout.Unset(pos)
-                               if base.Debug.MergeLocalsTrace > 2 {
+                               if cs.trace > 2 {
                                        fmt.Fprintf(os.Stderr, "=+= at instr %d kill of %v\n",
                                                iidx, lv.vars[pos])
                                }
                        } else if becomeslive && !liveout.Get(pos) {
-                               cs.ibuilders[pos].Live(iidx)
+                               ibuilders[pos].Live(iidx)
                                liveout.Set(pos)
-                               if base.Debug.MergeLocalsTrace > 2 {
+                               if cs.trace > 2 {
                                        fmt.Fprintf(os.Stderr, "=+= at instr %d upwards-exposed use of %v\n",
                                                iidx, lv.vars[pos])
                                }
                        }
+
+                       if cs.indirectUE != nil {
+                               // Now handle "indirect" upwards-exposed uses.
+                               ues := cs.indirectUE[v.ID]
+                               for _, n := range ues {
+                                       if pos, ok := lv.idx[n]; ok {
+                                               if !liveout.Get(pos) {
+                                                       ibuilders[pos].Live(iidx)
+                                                       liveout.Set(pos)
+                                                       if cs.trace > 2 {
+                                                               fmt.Fprintf(os.Stderr, "=+= at instr %d v%d indirect upwards-exposed use of %v\n", iidx, v.ID, lv.vars[pos])
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
                        iidx--
                }
 
-               if b == lv.f.Entry {
+               // This check disabled for now due to the way scheduling works
+               // for ops that materialize values of local variables. For
+               // many architecture we have rewrite rules of this form:
+               //
+               // (LocalAddr <t> {sym} base mem) && t.Elem().HasPointers() => (MOVDaddr {sym} (SPanchored base mem))
+               // (LocalAddr <t> {sym} base _)  && !t.Elem().HasPointers() => (MOVDaddr {sym} base)
+               //
+               // which are designed to ensure that if you have a pointerful
+               // variable "abc" sequence
+               //
+               //    v30 = VarDef <mem> {abc} v21
+               //    v31 = LocalAddr <*SB> {abc} v2 v30
+               //    v32 = Zero <mem> {SB} [2056] v31 v30
+               //
+               // this will be lowered into
+               //
+               //    v30 = VarDef <mem> {sb} v21
+               //   v106 = SPanchored <uintptr> v2 v30
+               //    v31 = MOVDaddr <*SB> {sb} v106
+               //     v3 = DUFFZERO <mem> [2056] v31 v30
+               //
+               // Note the SPanchored: this ensures that the scheduler won't
+               // move the MOVDaddr earlier than the vardef. With a variable
+               // "xyz" that has no pointers, howver, if we start with
+               //
+               //    v66 = VarDef <mem> {t2} v65
+               //    v67 = LocalAddr <*T> {t2} v2 v66
+               //    v68 = Zero <mem> {T} [2056] v67 v66
+               //
+               // we might lower to
+               //
+               //    v66 = VarDef <mem> {t2} v65
+               //    v29 = MOVDaddr <*T> {t2} [2032] v2
+               //    v43 = LoweredZero <mem> v67 v29 v66
+               //    v68 = Zero [2056] v2 v43
+               //
+               // where that MOVDaddr can float around arbitrarily, meaning
+               // that we may see an upwards-exposed use to it before the
+               // VarDef.
+               //
+               // One avenue to restoring the check below would be to change
+               // the rewrite rules to something like
+               //
+               // (LocalAddr <t> {sym} base mem) && (t.Elem().HasPointers() || isMergeCandidate(t) => (MOVDaddr {sym} (SPanchored base mem))
+               //
+               // however that change will have to be carefully evaluated,
+               // since it would constrain the scheduler for _all_ LocalAddr
+               // ops for potential merge candidates, even if we don't
+               // actually succeed in any overlaps. This will be revisitged in
+               // a later CL if possible.
+               //
+               const checkLiveOnEntry = false
+               if checkLiveOnEntry && b == lv.f.Entry {
                        for j, v := range lv.vars {
                                if liveout.Get(int32(j)) {
                                        lv.f.Fatalf("%v %L recorded as live on entry",
@@ -668,20 +977,33 @@ func computeIntervals(lv *liveness, cs *cstate) {
        if iidx != -1 {
                panic("iidx underflow")
        }
+
+       // Finish intervals construction.
+       ivs := make([]Intervals, len(cs.cands))
+       for i := range cs.cands {
+               var err error
+               ivs[i], err = ibuilders[i].Finish()
+               if err != nil {
+                       cs.dumpFunc()
+                       base.FatalfAt(cs.cands[i].Pos(), "interval construct error for var %q in func %q (%d instrs): %v", cs.cands[i].Sym().Name, ir.FuncName(cs.fn), ninstr, err)
+               }
+       }
+       cs.ivs = ivs
+}
+
+func fmtFullPos(p src.XPos) string {
+       var sb strings.Builder
+       sep := ""
+       base.Ctxt.AllPos(p, func(pos src.Pos) {
+               fmt.Fprintf(&sb, sep)
+               sep = "|"
+               file := filepath.Base(pos.Filename())
+               fmt.Fprintf(&sb, "%s:%d:%d", file, pos.Line(), pos.Col())
+       })
+       return sb.String()
 }
 
 func dumpCand(c *ir.Name, i int) {
-       fmtFullPos := func(p src.XPos) string {
-               var sb strings.Builder
-               sep := ""
-               base.Ctxt.AllPos(p, func(pos src.Pos) {
-                       fmt.Fprintf(&sb, sep)
-                       sep = "|"
-                       file := filepath.Base(pos.Filename())
-                       fmt.Fprintf(&sb, "%s:%d:%d", file, pos.Line(), pos.Col())
-               })
-               return sb.String()
-       }
        fmt.Fprintf(os.Stderr, " %d: %s %q sz=%d hp=%v t=%v\n",
                i, fmtFullPos(c.Pos()), c.Sym().Name, c.Type().Size(),
                c.Type().HasPointers(), c.Type())
index ab1a7df93030becb99cfbfd8d4ee5d6f81a926ca..dd48d10bc5755ed5d40cc8161f5e937d680c4348 100644 (file)
@@ -481,13 +481,13 @@ func (lv *liveness) pointerMap(liveout bitvec.BitVec, vars []*ir.Name, args, loc
                        }
                        fallthrough // PPARAMOUT in registers acts memory-allocates like an AUTO
                case ir.PAUTO:
-                       typebits.Set(node.Type(), node.FrameOffset()+lv.stkptrsize, locals)
                        if checkForDuplicateSlots {
                                if prev, ok := slotsSeen[node.FrameOffset()]; ok {
                                        base.FatalfAt(node.Pos(), "two vars live at pointerMap generation: %q and %q", prev.Sym().Name, node.Sym().Name)
                                }
                                slotsSeen[node.FrameOffset()] = node
                        }
+                       typebits.Set(node.Type(), node.FrameOffset()+lv.stkptrsize, locals)
                }
        }
 }
index bef9049126f9d2b765566da480edef27037fd24d..f8d1ce827392bebad281a7c72c933c8065f1e73e 100644 (file)
@@ -256,11 +256,15 @@ func (s *ssafn) AllocFrame(f *ssa.Func) {
        }
 
        if base.Debug.MergeLocalsTrace > 1 {
-               fmt.Fprintf(os.Stderr, "=-= stack layout for %v:\n", fn)
+               prolog := false
                for i, v := range fn.Dcl {
                        if v.Op() != ir.ONAME || (v.Class != ir.PAUTO && !(v.Class == ir.PPARAMOUT && v.IsOutputParamInRegisters())) {
                                continue
                        }
+                       if !prolog {
+                               fmt.Fprintf(os.Stderr, "=-= stack layout for %v:\n", fn)
+                               prolog = true
+                       }
                        fmt.Fprintf(os.Stderr, " %d: %q frameoff %d used=%v\n", i, v.Sym().Name, v.FrameOffset(), v.Used())
                }
        }
index f070197c800d5392ecda5620da4fee10491b260d..2c554cf05e4864ed5b7d5a4036c1ddd2b055b5e3 100644 (file)
@@ -12,19 +12,19 @@ import (
        "cmd/internal/src"
        "internal/testenv"
        "path/filepath"
-       "slices"
        "sort"
        "strings"
        "testing"
 )
 
+func mkiv(name string) *ir.Name {
+       i32 := types.Types[types.TINT32]
+       s := typecheck.Lookup(name)
+       v := ir.NewNameAt(src.NoXPos, s, i32)
+       return v
+}
+
 func TestMergeLocalState(t *testing.T) {
-       mkiv := func(name string) *ir.Name {
-               i32 := types.Types[types.TINT32]
-               s := typecheck.Lookup(name)
-               v := ir.NewNameAt(src.NoXPos, s, i32)
-               return v
-       }
        v1 := mkiv("v1")
        v2 := mkiv("v2")
        v3 := mkiv("v3")
@@ -126,22 +126,25 @@ func TestMergeLocalsIntegration(t *testing.T) {
        // get overlapped, then another clump of 2 that share the same
        // frame offset.
        //
-       // The expected output blob we're interested in looks like this:
+       // The expected output blob we're interested might look like
+       // this (for amd64):
        //
        // =-= stack layout for ABC:
-       //  2: "p1" frameoff -8200 used=true
-       //  3: "xp3" frameoff -8200 used=true
-       //  4: "xp4" frameoff -8200 used=true
-       //  5: "p2" frameoff -16400 used=true
-       //  6: "s" frameoff -24592 used=true
-       //  7: "v1" frameoff -32792 used=true
-       //  8: "v3" frameoff -32792 used=true
-       //  9: "v2" frameoff -40992 used=true
+       // 2: "p1" frameoff -8200 used=true
+       // 3: "xp3" frameoff -8200 used=true
+       // 4: "xp4" frameoff -8200 used=true
+       // 5: "p2" frameoff -16400 used=true
+       // 6: "r" frameoff -16408 used=true
+       // 7: "s" frameoff -24600 used=true
+       // 8: "v2" frameoff -32800 used=true
+       // 9: "v3" frameoff -32800 used=true
        //
        tmpdir := t.TempDir()
        src := filepath.Join("testdata", "mergelocals", "integration.go")
        obj := filepath.Join(tmpdir, "p.a")
-       out, err := testenv.Command(t, testenv.GoToolPath(t), "tool", "compile", "-p=p", "-c", "1", "-o", obj, "-d=mergelocalstrace=2,mergelocals=1", src).CombinedOutput()
+       out, err := testenv.Command(t, testenv.GoToolPath(t), "tool", "compile",
+               "-p=p", "-c", "1", "-o", obj, "-d=mergelocalstrace=2,mergelocals=1",
+               src).CombinedOutput()
        if err != nil {
                t.Fatalf("failed to compile: %v\n%s", err, out)
        }
@@ -157,8 +160,11 @@ func TestMergeLocalsIntegration(t *testing.T) {
                        continue
                }
                fields := strings.Fields(line)
-               if len(fields) != 5 {
-                       t.Fatalf("bad trace output line: %s", line)
+               wantFields := 5
+               if len(fields) != wantFields {
+                       t.Logf(string(out))
+                       t.Fatalf("bad trace output line, wanted %d fields got %d: %s",
+                               wantFields, len(fields), line)
                }
                vname := fields[1]
                frameoff := fields[3]
@@ -168,17 +174,22 @@ func TestMergeLocalsIntegration(t *testing.T) {
        wantvnum := 8
        gotvnum := len(vars)
        if wantvnum != gotvnum {
+               t.Logf(string(out))
                t.Fatalf("expected trace output on %d vars got %d\n", wantvnum, gotvnum)
        }
 
-       // We expect one clump of 3, another clump of 2, and the rest singletons.
-       expected := []int{1, 1, 1, 2, 3}
+       // Expect at least one clump of at least 3.
+       n3 := 0
        got := []int{}
        for _, v := range varsAtFrameOffset {
+               if v > 2 {
+                       n3++
+               }
                got = append(got, v)
        }
        sort.Ints(got)
-       if !slices.Equal(got, expected) {
-               t.Fatalf("expected variable clumps %+v not equal to what we got: %+v", expected, got)
+       if n3 == 0 {
+               t.Logf(string(out))
+               t.Fatalf("expected at least one clump of 3, got: %+v", got)
        }
 }
index d640c6fce83aa036e98d7729eb269e97646897f5..21779f0c9e5f59132ef6f2d8d2827fae401d5dfb 100644 (file)
@@ -32,37 +32,42 @@ type Single struct {
        x  [1023]int
 }
 
+var G int
+
+//go:noinline
+func clobber() {
+       G++
+}
+
 func ABC(i, j int) int {
        r := 0
 
-       // here v1 interferes with v2 but could be overlapped with v3.
-       // we can also overlap v1 with v3.
-       var v1 Vanilla
+       // here v2 and v3 can be overlapped.
+       clobber()
        if i < 101 {
                var v2 Vanilla
-               v1.x[i] = j
-               r += v1.x[j]
                v2.x[i] = j
                r += v2.x[j]
        }
-
-       {
+       if j != 303 {
                var v3 Vanilla2
                v3.x[i] = j
                r += v3.x[j]
        }
+       clobber()
 
+       // not an overlap candidate (only one var of this size).
        var s Single
        s.x[i] = j
        r += s.x[j]
 
-       // Here p1 and p2 interfere, but p1 could be overlapped with xp3.
+       // Here p1 and p2 interfere, but p1 could be overlapped with xp3 + xp4.
        var p1, p2 Pointery
        p1.x[i] = j
        r += p1.x[j]
        p2.x[i] = j
        r += p2.x[j]
-       {
+       if j != 505 {
                var xp3 Pointery2
                xp3.x[i] = j
                r += xp3.x[j]
@@ -79,5 +84,5 @@ func ABC(i, j int) int {
                r += xp4.x[j]
        }
 
-       return r
+       return r + G
 }