]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: modify switches of strings to use jump table for lengths
authorKeith Randall <khr@golang.org>
Sun, 6 Mar 2022 20:07:54 +0000 (12:07 -0800)
committerKeith Randall <khr@google.com>
Thu, 14 Apr 2022 21:16:11 +0000 (21:16 +0000)
Reorganize the way we rewrite expression switches on strings, so that
jump tables are naturally used for the outer switch on the string length.

The changes to the prove pass in this CL are required so as to not repeat
the test for string length in each case.

name                         old time/op  new time/op  delta
SwitchStringPredictable    2.28ns ± 9%  2.08ns ± 5%   -9.04%  (p=0.000 n=10+10)
SwitchStringUnpredictable  10.5ns ± 1%   9.5ns ± 1%   -9.08%  (p=0.000 n=9+10)

Update #5496
Update #34381

Change-Id: Ie6846b1dd27f3e472f7c30dfcc598c68d440b997
Reviewed-on: https://go-review.googlesource.com/c/go/+/395714
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Keith Randall <khr@google.com>
src/cmd/compile/internal/ssa/prove.go
src/cmd/compile/internal/test/switch_test.go
src/cmd/compile/internal/walk/switch.go

index 8f86e161127810e147530d727a16fbb142679e5b..26176af07c836b8092adc0e7d4ef9c139275cf49 100644 (file)
@@ -16,6 +16,10 @@ const (
        unknown branch = iota
        positive
        negative
+       // The outedges from a jump table are jumpTable0,
+       // jumpTable0+1, jumpTable0+2, etc. There could be an
+       // arbitrary number so we can't list them all here.
+       jumpTable0
 )
 
 // relation represents the set of possible relations between
@@ -940,20 +944,31 @@ func prove(f *Func) {
 // getBranch returns the range restrictions added by p
 // when reaching b. p is the immediate dominator of b.
 func getBranch(sdom SparseTree, p *Block, b *Block) branch {
-       if p == nil || p.Kind != BlockIf {
+       if p == nil {
                return unknown
        }
-       // If p and p.Succs[0] are dominators it means that every path
-       // from entry to b passes through p and p.Succs[0]. We care that
-       // no path from entry to b passes through p.Succs[1]. If p.Succs[0]
-       // has one predecessor then (apart from the degenerate case),
-       // there is no path from entry that can reach b through p.Succs[1].
-       // TODO: how about p->yes->b->yes, i.e. a loop in yes.
-       if sdom.IsAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 {
-               return positive
-       }
-       if sdom.IsAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 {
-               return negative
+       switch p.Kind {
+       case BlockIf:
+               // If p and p.Succs[0] are dominators it means that every path
+               // from entry to b passes through p and p.Succs[0]. We care that
+               // no path from entry to b passes through p.Succs[1]. If p.Succs[0]
+               // has one predecessor then (apart from the degenerate case),
+               // there is no path from entry that can reach b through p.Succs[1].
+               // TODO: how about p->yes->b->yes, i.e. a loop in yes.
+               if sdom.IsAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 {
+                       return positive
+               }
+               if sdom.IsAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 {
+                       return negative
+               }
+       case BlockJumpTable:
+               // TODO: this loop can lead to quadratic behavior, as
+               // getBranch can be called len(p.Succs) times.
+               for i, e := range p.Succs {
+                       if sdom.IsAncestorEq(e.b, b) && len(e.b.Preds) == 1 {
+                               return jumpTable0 + branch(i)
+                       }
+               }
        }
        return unknown
 }
@@ -984,11 +999,36 @@ func addIndVarRestrictions(ft *factsTable, b *Block, iv indVar) {
 // branching from Block b in direction br.
 func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
        c := b.Controls[0]
-       switch br {
-       case negative:
+       switch {
+       case br == negative:
                addRestrictions(b, ft, boolean, nil, c, eq)
-       case positive:
+       case br == positive:
                addRestrictions(b, ft, boolean, nil, c, lt|gt)
+       case br >= jumpTable0:
+               idx := br - jumpTable0
+               val := int64(idx)
+               if v, off := isConstDelta(c); v != nil {
+                       // Establish the bound on the underlying value we're switching on,
+                       // not on the offset-ed value used as the jump table index.
+                       c = v
+                       val -= off
+               }
+               old, ok := ft.limits[c.ID]
+               if !ok {
+                       old = noLimit
+               }
+               ft.limitStack = append(ft.limitStack, limitFact{c.ID, old})
+               if val < old.min || val > old.max || uint64(val) < old.umin || uint64(val) > old.umax {
+                       ft.unsat = true
+                       if b.Func.pass.debug > 2 {
+                               b.Func.Warnl(b.Pos, "block=%s outedge=%d %s=%d unsat", b, idx, c, val)
+                       }
+               } else {
+                       ft.limits[c.ID] = limit{val, val, uint64(val), uint64(val)}
+                       if b.Func.pass.debug > 2 {
+                               b.Func.Warnl(b.Pos, "block=%s outedge=%d %s=%d", b, idx, c, val)
+                       }
+               }
        default:
                panic("unknown branch")
        }
@@ -1343,10 +1383,14 @@ func removeBranch(b *Block, branch branch) {
                // attempt to preserve statement marker.
                b.Pos = b.Pos.WithIsStmt()
        }
-       b.Kind = BlockFirst
-       b.ResetControls()
-       if branch == positive {
-               b.swapSuccessors()
+       if branch == positive || branch == negative {
+               b.Kind = BlockFirst
+               b.ResetControls()
+               if branch == positive {
+                       b.swapSuccessors()
+               }
+       } else {
+               // TODO: figure out how to remove an entry from a jump table
        }
 }
 
index 6f7bfcf3d8c5fa728dbfd6096bbff137cf3c109f..30dee6257e3804164ff9daa133a423bf2f8e885a 100644 (file)
@@ -77,6 +77,49 @@ func benchmarkSwitch32(b *testing.B, predictable bool) {
        sink = n
 }
 
+func BenchmarkSwitchStringPredictable(b *testing.B) {
+       benchmarkSwitchString(b, true)
+}
+func BenchmarkSwitchStringUnpredictable(b *testing.B) {
+       benchmarkSwitchString(b, false)
+}
+func benchmarkSwitchString(b *testing.B, predictable bool) {
+       a := []string{
+               "foo",
+               "foo1",
+               "foo22",
+               "foo333",
+               "foo4444",
+               "foo55555",
+               "foo666666",
+               "foo7777777",
+       }
+       n := 0
+       rng := newRNG()
+       for i := 0; i < b.N; i++ {
+               rng = rng.next(predictable)
+               switch a[rng.value()&7] {
+               case "foo":
+                       n += 1
+               case "foo1":
+                       n += 2
+               case "foo22":
+                       n += 3
+               case "foo333":
+                       n += 4
+               case "foo4444":
+                       n += 5
+               case "foo55555":
+                       n += 6
+               case "foo666666":
+                       n += 7
+               case "foo7777777":
+                       n += 8
+               }
+       }
+       sink = n
+}
+
 // A simple random number generator used to make switches conditionally predictable.
 type rng uint64
 
index a4003ecea42adcdf3ebf24592bad3ae825d6bc70..6a2dbe1753e30a8dd8bb6d879da3a013554670c8 100644 (file)
@@ -67,6 +67,7 @@ func walkSwitchExpr(sw *ir.SwitchStmt) {
        base.Pos = lno
 
        s := exprSwitch{
+               pos:      lno,
                exprname: cond,
        }
 
@@ -113,6 +114,7 @@ func walkSwitchExpr(sw *ir.SwitchStmt) {
 
 // An exprSwitch walks an expression switch.
 type exprSwitch struct {
+       pos      src.XPos
        exprname ir.Node // value being switched on
 
        done    ir.Nodes
@@ -183,17 +185,59 @@ func (s *exprSwitch) flush() {
                }
                runs = append(runs, cc[start:])
 
-               // Perform two-level binary search.
-               binarySearch(len(runs), &s.done,
-                       func(i int) ir.Node {
-                               return ir.NewBinaryExpr(base.Pos, ir.OLE, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(runs[i-1])))
-                       },
-                       func(i int, nif *ir.IfStmt) {
-                               run := runs[i]
-                               nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(run)))
-                               s.search(run, &nif.Body)
-                       },
-               )
+               if len(runs) == 1 {
+                       s.search(runs[0], &s.done)
+                       return
+               }
+               // We have strings of more than one length. Generate an
+               // outer switch which switches on the length of the string
+               // and an inner switch in each case which resolves all the
+               // strings of the same length. The code looks something like this:
+
+               // goto outerLabel
+               // len5:
+               //   ... search among length 5 strings ...
+               //   goto endLabel
+               // len8:
+               //   ... search among length 8 strings ...
+               //   goto endLabel
+               // ... other lengths ...
+               // outerLabel:
+               // switch len(s) {
+               //   case 5: goto len5
+               //   case 8: goto len8
+               //   ... other lengths ...
+               // }
+               // endLabel:
+
+               outerLabel := typecheck.AutoLabel(".s")
+               endLabel := typecheck.AutoLabel(".s")
+
+               // Jump around all the individual switches for each length.
+               s.done.Append(ir.NewBranchStmt(s.pos, ir.OGOTO, outerLabel))
+
+               var outer exprSwitch
+               outer.exprname = ir.NewUnaryExpr(s.pos, ir.OLEN, s.exprname)
+               outer.exprname.SetType(types.Types[types.TINT])
+
+               for _, run := range runs {
+                       // Target label to jump to when we match this length.
+                       label := typecheck.AutoLabel(".s")
+
+                       // Search within this run of same-length strings.
+                       pos := run[0].pos
+                       s.done.Append(ir.NewLabelStmt(pos, label))
+                       s.search(run, &s.done)
+                       s.done.Append(ir.NewBranchStmt(pos, ir.OGOTO, endLabel))
+
+                       // Add length case to outer switch.
+                       cas := ir.NewBasicLit(pos, constant.MakeInt64(runLen(run)))
+                       jmp := ir.NewBranchStmt(pos, ir.OGOTO, label)
+                       outer.Add(pos, cas, jmp)
+               }
+               s.done.Append(ir.NewLabelStmt(s.pos, outerLabel))
+               outer.Emit(&s.done)
+               s.done.Append(ir.NewLabelStmt(s.pos, endLabel))
                return
        }
 
@@ -278,7 +322,6 @@ func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool {
                }
        }
        out.Append(jt)
-       // TODO: handle the size portion of string switches using a jump table.
        return true
 }
 
@@ -587,6 +630,7 @@ func (s *typeSwitch) flush() {
        }
        cc = merged
 
+       // TODO: figure out if we could use a jump table using some low bits of the type hashes.
        binarySearch(len(cc), &s.done,
                func(i int) ir.Node {
                        return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(int64(cc[i-1].hash)))