]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: recognize integer ranges in switch statements
authorJosh Bleecher Snyder <josharian@gmail.com>
Fri, 17 Jun 2016 23:27:23 +0000 (16:27 -0700)
committerJosh Bleecher Snyder <josharian@gmail.com>
Tue, 30 Aug 2016 21:20:25 +0000 (21:20 +0000)
Consider a switch statement like:

switch x {
case 1:
  // ...
case 2, 3, 4, 5, 6:
  // ...
case 5:
  // ...
}

Prior to this CL, the generated code treated
2, 3, 4, 5, and 6 independently in a binary search.
With this CL, the generated code checks whether
2 <= x && x <= 6.
walkinrange then optimizes that range check
into a single unsigned comparison.

Experiments suggest that the best min range size
is 2, using binary size as a proxy for optimization.

Binary sizes before/after this CL:

cmd/compile: 14209728 / 14165360
cmd/go:       9543100 /  9539004

Change-Id: If2f7fb97ca80468fa70351ef540866200c4c996c
Reviewed-on: https://go-review.googlesource.com/26770
Run-TryBot: Josh Bleecher Snyder <josharian@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/cmd/compile/internal/gc/fmt.go
src/cmd/compile/internal/gc/swt.go
src/cmd/compile/internal/gc/syntax.go

index 96cc39346584facf714380b370388feb854381e8..dd2d12634b5899be0ec0fa997442b397bac684d9 100644 (file)
@@ -880,17 +880,27 @@ func (p *printer) stmtfmt(n *Node) *printer {
 
        case OXCASE:
                if n.List.Len() != 0 {
-                       p.f("case %v: %v", hconv(n.List, FmtComma), n.Nbody)
+                       p.f("case %v", hconv(n.List, FmtComma))
                } else {
-                       p.f("default: %v", n.Nbody)
+                       p.s("default")
                }
+               p.f(": %v", n.Nbody)
 
        case OCASE:
-               if n.Left != nil {
-                       p.f("case %v: %v", n.Left, n.Nbody)
-               } else {
-                       p.f("default: %v", n.Nbody)
+               switch {
+               case n.Left != nil:
+                       // single element
+                       p.f("case %v", n.Left)
+               case n.List.Len() > 0:
+                       // range
+                       if n.List.Len() != 2 {
+                               Fatalf("bad OCASE list length", n.List)
+                       }
+                       p.f("case %v..%v", n.List.First(), n.List.Second())
+               default:
+                       p.s("default")
                }
+               p.f(": %v", n.Nbody)
 
        case OBREAK,
                OCONTINUE,
index 42b46c391779a5c09f33a1083e4b390262934938..c838c2fdcd3190ef7b7632aa02502a78bc10ad6a 100644 (file)
@@ -16,7 +16,10 @@ const (
        switchKindType // switch a.(type) {...}
 )
 
-const binarySearchMin = 4 // minimum number of cases for binary search
+const (
+       binarySearchMin = 4 // minimum number of cases for binary search
+       integerRangeMin = 2 // minimum size of integer ranges
+)
 
 // An exprSwitch walks an expression switch.
 type exprSwitch struct {
@@ -291,7 +294,18 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node {
                        lno := setlineno(n)
 
                        a := Nod(OIF, nil, nil)
-                       if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
+                       if rng := n.List.Slice(); rng != nil {
+                               // Integer range.
+                               // exprname is a temp or a constant,
+                               // so it is safe to evaluate twice.
+                               // In most cases, this conjunction will be
+                               // rewritten by walkinrange into a single comparison.
+                               low := Nod(OGE, s.exprname, rng[0])
+                               high := Nod(OLE, s.exprname, rng[1])
+                               a.Left = Nod(OANDAND, low, high)
+                               a.Left = typecheck(a.Left, Erv)
+                               a.Left = walkexpr(a.Left, nil) // give walk the opportunity to optimize the range check
+                       } else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
                                a.Left = Nod(OEQ, s.exprname, n.Left) // if name == val
                                a.Left = typecheck(a.Left, Erv)
                        } else if s.kind == switchKindTrue {
@@ -312,7 +326,13 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node {
        // find the middle and recur
        half := len(cc) / 2
        a := Nod(OIF, nil, nil)
-       mid := cc[half-1].node.Left
+       n := cc[half-1].node
+       var mid *Node
+       if rng := n.List.Slice(); rng != nil {
+               mid = rng[1] // high end of range
+       } else {
+               mid = n.Left
+       }
        le := Nod(OLE, s.exprname, mid)
        if Isconst(mid, CTSTR) {
                // Search by length and then by value; see caseClauseByConstVal.
@@ -368,9 +388,51 @@ func casebody(sw *Node, typeswvar *Node) {
                        n.List.Set(nil)
                        cas = append(cas, n)
                default:
-                       // expand multi-valued cases
-                       for _, n1 := range n.List.Slice() {
-                               cas = append(cas, Nod(OCASE, n1, jmp))
+                       // Expand multi-valued cases and detect ranges of integer cases.
+                       if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
+                               // Can't use integer ranges. Expand each case into a separate node.
+                               for _, n1 := range n.List.Slice() {
+                                       cas = append(cas, Nod(OCASE, n1, jmp))
+                               }
+                               break
+                       }
+                       // Find integer ranges within runs of constants.
+                       s := n.List.Slice()
+                       j := 0
+                       for j < len(s) {
+                               // Find a run of constants.
+                               var run int
+                               for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
+                               }
+                               if run-j >= integerRangeMin {
+                                       // Search for integer ranges in s[j:run].
+                                       // Typechecking is done, so all values are already in an appropriate range.
+                                       search := s[j:run]
+                                       sort.Sort(constIntNodesByVal(search))
+                                       for beg, end := 0, 1; end <= len(search); end++ {
+                                               if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
+                                                       continue
+                                               }
+                                               if end-beg >= integerRangeMin {
+                                                       // Record range in List.
+                                                       c := Nod(OCASE, nil, jmp)
+                                                       c.List.Set2(search[beg], search[end-1])
+                                                       cas = append(cas, c)
+                                               } else {
+                                                       // Not large enough for range; record separately.
+                                                       for _, n := range search[beg:end] {
+                                                               cas = append(cas, Nod(OCASE, n, jmp))
+                                                       }
+                                               }
+                                               beg = end
+                                       }
+                                       j = run
+                               }
+                               // Advance to next constant, adding individual non-constant
+                               // or as-yet-unhandled constant cases as we go.
+                               for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
+                                       cas = append(cas, Nod(OCASE, s[j], jmp))
+                               }
                        }
                }
 
@@ -418,7 +480,7 @@ func casebody(sw *Node, typeswvar *Node) {
 func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
        var cc caseClauses
        for _, n := range clauses {
-               if n.Left == nil {
+               if n.Left == nil && n.List.Len() == 0 {
                        // default case
                        if cc.defjmp != nil {
                                Fatalf("duplicate default case not detected during typechecking")
@@ -427,6 +489,9 @@ func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
                        continue
                }
                c := caseClause{node: n, ordinal: len(cc.list)}
+               if n.List.Len() > 0 {
+                       c.isconst = true
+               }
                switch consttype(n.Left) {
                case CTFLT, CTINT, CTRUNE, CTSTR:
                        c.isconst = true
@@ -533,14 +598,34 @@ func (s *exprSwitch) checkDupCases(cc []caseClause) {
                        if ct := consttype(c.node.Left); ct < 0 || ct == CTBOOL {
                                continue
                        }
-                       val := c.node.Left.Val().Interface()
-                       prev, dup := seen[val]
-                       if !dup {
-                               seen[val] = c.node
+                       if c.node.Left != nil {
+                               // Single constant.
+                               val := c.node.Left.Val().Interface()
+                               prev, dup := seen[val]
+                               if !dup {
+                                       seen[val] = c.node
+                                       continue
+                               }
+                               setlineno(c.node)
+                               Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line())
+                               continue
+                       }
+                       if c.node.List.Len() == 2 {
+                               // Range of integers.
+                               low := c.node.List.Index(0).Int64()
+                               high := c.node.List.Index(1).Int64()
+                               for i := low; i <= high; i++ {
+                                       prev, dup := seen[i]
+                                       if !dup {
+                                               seen[i] = c.node
+                                               continue
+                                       }
+                                       setlineno(c.node)
+                                       Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line())
+                               }
                                continue
                        }
-                       setlineno(c.node)
-                       Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line())
+                       Fatalf("bad caseClause node in checkDupCases: %v", c.node)
                }
                return
        }
@@ -784,8 +869,24 @@ type caseClauseByConstVal []caseClause
 func (x caseClauseByConstVal) Len() int      { return len(x) }
 func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
 func (x caseClauseByConstVal) Less(i, j int) bool {
-       v1 := x[i].node.Left.Val().U
-       v2 := x[j].node.Left.Val().U
+       // n1 and n2 might be individual constants or integer ranges.
+       // We have checked for duplicates already,
+       // so ranges can be safely represented by any value in the range.
+       n1 := x[i].node
+       var v1 interface{}
+       if s := n1.List.Slice(); s != nil {
+               v1 = s[0].Val().U
+       } else {
+               v1 = n1.Left.Val().U
+       }
+
+       n2 := x[j].node
+       var v2 interface{}
+       if s := n2.List.Slice(); s != nil {
+               v2 = s[0].Val().U
+       } else {
+               v2 = n2.Left.Val().U
+       }
 
        switch v1 := v1.(type) {
        case *Mpflt:
@@ -821,3 +922,11 @@ func (x caseClauseByType) Less(i, j int) bool {
        }
        return c1.ordinal < c2.ordinal
 }
+
+type constIntNodesByVal []*Node
+
+func (x constIntNodesByVal) Len() int      { return len(x) }
+func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+func (x constIntNodesByVal) Less(i, j int) bool {
+       return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
+}
index 001e15b3279a9d7220be49ddf5aba6c0db6642cc..bced2429b1eda68824b11d2da86aa8a120df3a8c 100644 (file)
@@ -436,7 +436,7 @@ const (
        // statements
        OBLOCK    // { List } (block of code)
        OBREAK    // break
-       OCASE     // case Left: Nbody (select case after processing; Left==nil means default)
+       OCASE     // case Left or List[0]..List[1]: Nbody (select case after processing; Left==nil and List==nil means default)
        OXCASE    // case List: Nbody (select case before processing; List==nil means default)
        OCONTINUE // continue
        ODEFER    // defer Left (Left must be call)