From 69aed4712d73c9c1b70be3e2e222eb55391e2fb0 Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Sat, 25 Jun 2022 11:11:07 -0700 Subject: [PATCH] cmd/compile: use better splitting condition for string binary search Currently we use a full cmpstring to do the comparison for each split in the binary search for a string switch. Instead, split by comparing a single byte of the input string with a constant. That will give us a much faster split (although it might be not quite as good a split). Fixes #53333 R=go1.20 Change-Id: I28c7209342314f367071e4aa1f2beb6ec9ff7123 Reviewed-on: https://go-review.googlesource.com/c/go/+/414894 TryBot-Result: Gopher Robot Run-TryBot: Keith Randall Reviewed-by: David Chase Reviewed-by: Heschi Kreinick --- src/cmd/compile/internal/walk/switch.go | 92 +++++++++++++++++++++++-- test/codegen/switch.go | 27 ++++++++ 2 files changed, 114 insertions(+), 5 deletions(-) diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index 82da1562c0..d38ba500f2 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -190,10 +190,6 @@ func (s *exprSwitch) flush() { } runs = append(runs, cc[start:]) - if len(runs) == 1 { - s.search(runs[0], &s.done) - return - } // We have strings of more than one length. Generate an // outer switch which switches on the length of the string // and an inner switch in each case which resolves all the @@ -232,7 +228,7 @@ func (s *exprSwitch) flush() { // Search within this run of same-length strings. pos := run[0].pos s.done.Append(ir.NewLabelStmt(pos, label)) - s.search(run, &s.done) + stringSearch(s.exprname, run, &s.done) s.done.Append(ir.NewBranchStmt(pos, ir.OGOTO, endLabel)) // Add length case to outer switch. @@ -678,3 +674,89 @@ func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i in do(0, n, out) } + +func stringSearch(expr ir.Node, cc []exprClause, out *ir.Nodes) { + if len(cc) < 4 { + // Short list, just do brute force equality checks. + for _, c := range cc { + nif := ir.NewIfStmt(base.Pos.WithNotStmt(), typecheck.DefaultLit(typecheck.Expr(c.test(expr)), nil), []ir.Node{c.jmp}, nil) + out.Append(nif) + out = &nif.Else + } + return + } + + // The strategy here is to find a simple test to divide the set of possible strings + // that might match expr approximately in half. + // The test we're going to use is to do an ordered comparison of a single byte + // of expr to a constant. We will pick the index of that byte and the value we're + // comparing against to make the split as even as possible. + // if expr[3] <= 'd' { ... search strings with expr[3] at 'd' or lower ... } + // else { ... search strings with expr[3] at 'e' or higher ... } + // + // To add complication, we will do the ordered comparison in the signed domain. + // The reason for this is to prevent CSE from merging the load used for the + // ordered comparison with the load used for the later equality check. + // if expr[3] <= 'd' { ... if expr[0] == 'f' && expr[1] == 'o' && expr[2] == 'o' && expr[3] == 'd' { ... } } + // If we did both expr[3] loads in the unsigned domain, they would be CSEd, and that + // would in turn defeat the combining of expr[0]...expr[3] into a single 4-byte load. + // See issue 48222. + // By using signed loads for the ordered comparison and unsigned loads for the + // equality comparison, they don't get CSEd and the equality comparisons will be + // done using wider loads. + + n := len(ir.StringVal(cc[0].lo)) // Length of the constant strings. + bestScore := int64(0) // measure of how good the split is. + bestIdx := 0 // split using expr[bestIdx] + bestByte := int8(0) // compare expr[bestIdx] against bestByte + for idx := 0; idx < n; idx++ { + for b := int8(-128); b < 127; b++ { + le := 0 + for _, c := range cc { + s := ir.StringVal(c.lo) + if int8(s[idx]) <= b { + le++ + } + } + score := int64(le) * int64(len(cc)-le) + if score > bestScore { + bestScore = score + bestIdx = idx + bestByte = b + } + } + } + + // The split must be at least 1:n-1 because we have at least 2 distinct strings; they + // have to be different somewhere. + // TODO: what if the best split is still pretty bad? + if bestScore == 0 { + base.Fatalf("unable to split string set") + } + + // Convert expr to a []int8 + slice := ir.NewConvExpr(base.Pos, ir.OSTR2BYTESTMP, types.NewSlice(types.Types[types.TINT8]), expr) + slice.SetTypecheck(1) // legacy typechecker doesn't handle this op + // Load the byte we're splitting on. + load := ir.NewIndexExpr(base.Pos, slice, ir.NewInt(int64(bestIdx))) + // Compare with the value we're splitting on. + cmp := ir.Node(ir.NewBinaryExpr(base.Pos, ir.OLE, load, ir.NewInt(int64(bestByte)))) + cmp = typecheck.DefaultLit(typecheck.Expr(cmp), nil) + nif := ir.NewIfStmt(base.Pos, cmp, nil, nil) + + var le []exprClause + var gt []exprClause + for _, c := range cc { + s := ir.StringVal(c.lo) + if int8(s[bestIdx]) <= bestByte { + le = append(le, c) + } else { + gt = append(gt, c) + } + } + stringSearch(expr, le, &nif.Body) + stringSearch(expr, gt, &nif.Else) + out.Append(nif) + + // TODO: if expr[bestIdx] has enough different possible values, use a jump table. +} diff --git a/test/codegen/switch.go b/test/codegen/switch.go index af3762869a..c3c24e2e11 100644 --- a/test/codegen/switch.go +++ b/test/codegen/switch.go @@ -72,3 +72,30 @@ func length(x string) int { return len(x) } } + +// Use single-byte ordered comparisons for binary searching strings. +// See issue 53333. +func mimetype(ext string) string { + // amd64: `CMPB\s1\(.*\), \$104$`,-`cmpstring` + // arm64: `MOVB\s1\(R.*\), R.*$`, `CMPW\s\$104, R.*$`, -`cmpstring` + switch ext { + // amd64: `CMPL\s\(.*\), \$1836345390$` + // arm64: `CMPW\s\$1836345390, R.*$` + case ".htm": + return "A" + // amd64: `CMPL\s\(.*\), \$1953457454$` + // arm64: `CMPW\s\$1953457454, R.*$` + case ".eot": + return "B" + // amd64: `CMPL\s\(.*\), \$1735815982$` + // arm64: `CMPW\s\$1735815982, R.*$` + case ".svg": + return "C" + // amd64: `CMPL\s\(.*\), \$1718907950$` + // arm64: `CMPW\s\$1718907950, R.*$` + case ".ttf": + return "D" + default: + return "" + } +} -- 2.50.0