From: Josh Bleecher Snyder Date: Fri, 17 Jun 2016 23:27:23 +0000 (-0700) Subject: cmd/compile: recognize integer ranges in switch statements X-Git-Tag: go1.8beta1~1563 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=8c85e23087d90e831a70ccd199cac49a38d91027;p=gostls13.git cmd/compile: recognize integer ranges in switch statements Consider a switch statement like: switch x { case 1: // ... case 2, 3, 4, 5, 6: // ... case 5: // ... } Prior to this CL, the generated code treated 2, 3, 4, 5, and 6 independently in a binary search. With this CL, the generated code checks whether 2 <= x && x <= 6. walkinrange then optimizes that range check into a single unsigned comparison. Experiments suggest that the best min range size is 2, using binary size as a proxy for optimization. Binary sizes before/after this CL: cmd/compile: 14209728 / 14165360 cmd/go: 9543100 / 9539004 Change-Id: If2f7fb97ca80468fa70351ef540866200c4c996c Reviewed-on: https://go-review.googlesource.com/26770 Run-TryBot: Josh Bleecher Snyder TryBot-Result: Gobot Gobot Reviewed-by: Matthew Dempsky --- diff --git a/src/cmd/compile/internal/gc/fmt.go b/src/cmd/compile/internal/gc/fmt.go index 96cc393465..dd2d12634b 100644 --- a/src/cmd/compile/internal/gc/fmt.go +++ b/src/cmd/compile/internal/gc/fmt.go @@ -880,17 +880,27 @@ func (p *printer) stmtfmt(n *Node) *printer { case OXCASE: if n.List.Len() != 0 { - p.f("case %v: %v", hconv(n.List, FmtComma), n.Nbody) + p.f("case %v", hconv(n.List, FmtComma)) } else { - p.f("default: %v", n.Nbody) + p.s("default") } + p.f(": %v", n.Nbody) case OCASE: - if n.Left != nil { - p.f("case %v: %v", n.Left, n.Nbody) - } else { - p.f("default: %v", n.Nbody) + switch { + case n.Left != nil: + // single element + p.f("case %v", n.Left) + case n.List.Len() > 0: + // range + if n.List.Len() != 2 { + Fatalf("bad OCASE list length", n.List) + } + p.f("case %v..%v", n.List.First(), n.List.Second()) + default: + p.s("default") } + p.f(": %v", n.Nbody) case OBREAK, OCONTINUE, diff --git a/src/cmd/compile/internal/gc/swt.go b/src/cmd/compile/internal/gc/swt.go index 42b46c3917..c838c2fdcd 100644 --- a/src/cmd/compile/internal/gc/swt.go +++ b/src/cmd/compile/internal/gc/swt.go @@ -16,7 +16,10 @@ const ( switchKindType // switch a.(type) {...} ) -const binarySearchMin = 4 // minimum number of cases for binary search +const ( + binarySearchMin = 4 // minimum number of cases for binary search + integerRangeMin = 2 // minimum size of integer ranges +) // An exprSwitch walks an expression switch. type exprSwitch struct { @@ -291,7 +294,18 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node { lno := setlineno(n) a := Nod(OIF, nil, nil) - if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE { + if rng := n.List.Slice(); rng != nil { + // Integer range. + // exprname is a temp or a constant, + // so it is safe to evaluate twice. + // In most cases, this conjunction will be + // rewritten by walkinrange into a single comparison. + low := Nod(OGE, s.exprname, rng[0]) + high := Nod(OLE, s.exprname, rng[1]) + a.Left = Nod(OANDAND, low, high) + a.Left = typecheck(a.Left, Erv) + a.Left = walkexpr(a.Left, nil) // give walk the opportunity to optimize the range check + } else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE { a.Left = Nod(OEQ, s.exprname, n.Left) // if name == val a.Left = typecheck(a.Left, Erv) } else if s.kind == switchKindTrue { @@ -312,7 +326,13 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node { // find the middle and recur half := len(cc) / 2 a := Nod(OIF, nil, nil) - mid := cc[half-1].node.Left + n := cc[half-1].node + var mid *Node + if rng := n.List.Slice(); rng != nil { + mid = rng[1] // high end of range + } else { + mid = n.Left + } le := Nod(OLE, s.exprname, mid) if Isconst(mid, CTSTR) { // Search by length and then by value; see caseClauseByConstVal. @@ -368,9 +388,51 @@ func casebody(sw *Node, typeswvar *Node) { n.List.Set(nil) cas = append(cas, n) default: - // expand multi-valued cases - for _, n1 := range n.List.Slice() { - cas = append(cas, Nod(OCASE, n1, jmp)) + // Expand multi-valued cases and detect ranges of integer cases. + if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin { + // Can't use integer ranges. Expand each case into a separate node. + for _, n1 := range n.List.Slice() { + cas = append(cas, Nod(OCASE, n1, jmp)) + } + break + } + // Find integer ranges within runs of constants. + s := n.List.Slice() + j := 0 + for j < len(s) { + // Find a run of constants. + var run int + for run = j; run < len(s) && Isconst(s[run], CTINT); run++ { + } + if run-j >= integerRangeMin { + // Search for integer ranges in s[j:run]. + // Typechecking is done, so all values are already in an appropriate range. + search := s[j:run] + sort.Sort(constIntNodesByVal(search)) + for beg, end := 0, 1; end <= len(search); end++ { + if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 { + continue + } + if end-beg >= integerRangeMin { + // Record range in List. + c := Nod(OCASE, nil, jmp) + c.List.Set2(search[beg], search[end-1]) + cas = append(cas, c) + } else { + // Not large enough for range; record separately. + for _, n := range search[beg:end] { + cas = append(cas, Nod(OCASE, n, jmp)) + } + } + beg = end + } + j = run + } + // Advance to next constant, adding individual non-constant + // or as-yet-unhandled constant cases as we go. + for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ { + cas = append(cas, Nod(OCASE, s[j], jmp)) + } } } @@ -418,7 +480,7 @@ func casebody(sw *Node, typeswvar *Node) { func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { var cc caseClauses for _, n := range clauses { - if n.Left == nil { + if n.Left == nil && n.List.Len() == 0 { // default case if cc.defjmp != nil { Fatalf("duplicate default case not detected during typechecking") @@ -427,6 +489,9 @@ func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { continue } c := caseClause{node: n, ordinal: len(cc.list)} + if n.List.Len() > 0 { + c.isconst = true + } switch consttype(n.Left) { case CTFLT, CTINT, CTRUNE, CTSTR: c.isconst = true @@ -533,14 +598,34 @@ func (s *exprSwitch) checkDupCases(cc []caseClause) { if ct := consttype(c.node.Left); ct < 0 || ct == CTBOOL { continue } - val := c.node.Left.Val().Interface() - prev, dup := seen[val] - if !dup { - seen[val] = c.node + if c.node.Left != nil { + // Single constant. + val := c.node.Left.Val().Interface() + prev, dup := seen[val] + if !dup { + seen[val] = c.node + continue + } + setlineno(c.node) + Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line()) + continue + } + if c.node.List.Len() == 2 { + // Range of integers. + low := c.node.List.Index(0).Int64() + high := c.node.List.Index(1).Int64() + for i := low; i <= high; i++ { + prev, dup := seen[i] + if !dup { + seen[i] = c.node + continue + } + setlineno(c.node) + Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line()) + } continue } - setlineno(c.node) - Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line()) + Fatalf("bad caseClause node in checkDupCases: %v", c.node) } return } @@ -784,8 +869,24 @@ type caseClauseByConstVal []caseClause func (x caseClauseByConstVal) Len() int { return len(x) } func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] } func (x caseClauseByConstVal) Less(i, j int) bool { - v1 := x[i].node.Left.Val().U - v2 := x[j].node.Left.Val().U + // n1 and n2 might be individual constants or integer ranges. + // We have checked for duplicates already, + // so ranges can be safely represented by any value in the range. + n1 := x[i].node + var v1 interface{} + if s := n1.List.Slice(); s != nil { + v1 = s[0].Val().U + } else { + v1 = n1.Left.Val().U + } + + n2 := x[j].node + var v2 interface{} + if s := n2.List.Slice(); s != nil { + v2 = s[0].Val().U + } else { + v2 = n2.Left.Val().U + } switch v1 := v1.(type) { case *Mpflt: @@ -821,3 +922,11 @@ func (x caseClauseByType) Less(i, j int) bool { } return c1.ordinal < c2.ordinal } + +type constIntNodesByVal []*Node + +func (x constIntNodesByVal) Len() int { return len(x) } +func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x constIntNodesByVal) Less(i, j int) bool { + return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0 +} diff --git a/src/cmd/compile/internal/gc/syntax.go b/src/cmd/compile/internal/gc/syntax.go index 001e15b327..bced2429b1 100644 --- a/src/cmd/compile/internal/gc/syntax.go +++ b/src/cmd/compile/internal/gc/syntax.go @@ -436,7 +436,7 @@ const ( // statements OBLOCK // { List } (block of code) OBREAK // break - OCASE // case Left: Nbody (select case after processing; Left==nil means default) + OCASE // case Left or List[0]..List[1]: Nbody (select case after processing; Left==nil and List==nil means default) OXCASE // case List: Nbody (select case before processing; List==nil means default) OCONTINUE // continue ODEFER // defer Left (Left must be call)