]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: explicitly manage default and nil switch cases
authorJosh Bleecher Snyder <josharian@gmail.com>
Tue, 31 May 2016 20:11:15 +0000 (13:11 -0700)
committerJosh Bleecher Snyder <josharian@gmail.com>
Mon, 22 Aug 2016 19:56:21 +0000 (19:56 +0000)
Rather than juggle default and nil cases as part
of a slice, handle them explicitly.

Change-Id: I97b200c9d3f23fe1a438acdbf3d13b0cf7e0851e
Reviewed-on: https://go-review.googlesource.com/26761
TryBot-Result: Gobot Gobot <gobot@golang.org>
Run-TryBot: Josh Bleecher Snyder <josharian@gmail.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/cmd/compile/internal/gc/swt.go
src/cmd/compile/internal/gc/swt_test.go

index 07d324c593878d972703032b6b92878770aca0fc..848eca8915ddcb35c9d59bf2af9483fca8a82c01 100644 (file)
@@ -17,14 +17,11 @@ const (
 )
 
 const (
-       caseKindDefault = iota // default:
-
        // expression switch
-       caseKindExprConst // case 5:
-       caseKindExprVar   // case x:
+       caseKindExprConst = iota // case 5:
+       caseKindExprVar          // case x:
 
        // type switch
-       caseKindTypeNil   // case nil:
        caseKindTypeConst // case time.Time: (concrete type, has type hash)
        caseKindTypeVar   // case io.Reader: (interface type)
 )
@@ -52,6 +49,13 @@ type caseClause struct {
        typ     uint8  // type of case
 }
 
+// caseClauses are all the case clauses in a switch statement.
+type caseClauses struct {
+       list   []caseClause // general cases
+       defjmp *Node        // OGOTO for default case or OBREAK if no default case present
+       niljmp *Node        // OGOTO for nil type case in a type switch
+}
+
 // typecheckswitch typechecks a switch statement.
 func typecheckswitch(n *Node) {
        lno := lineno
@@ -248,16 +252,10 @@ func (s *exprSwitch) walk(sw *Node) {
                typecheckslice(cas, Etop)
        }
 
-       // enumerate the cases, and lop off the default case
-       cc := caseClauses(sw, s.kind)
+       // Enumerate the cases and prepare the default case.
+       clauses := genCaseClauses(sw, s.kind)
        sw.List.Set(nil)
-       var def *Node
-       if len(cc) > 0 && cc[0].typ == caseKindDefault {
-               def = cc[0].node.Right
-               cc = cc[1:]
-       } else {
-               def = Nod(OBREAK, nil, nil)
-       }
+       cc := clauses.list
 
        // handle the cases in order
        for len(cc) > 0 {
@@ -283,14 +281,14 @@ func (s *exprSwitch) walk(sw *Node) {
 
        // handle default case
        if nerrors == 0 {
-               cas = append(cas, def)
+               cas = append(cas, clauses.defjmp)
                sw.Nbody.Set(append(cas, sw.Nbody.Slice()...))
                walkstmtlist(sw.Nbody.Slice())
        }
 }
 
 // walkCases generates an AST implementing the cases in cc.
-func (s *exprSwitch) walkCases(cc []*caseClause) *Node {
+func (s *exprSwitch) walkCases(cc []caseClause) *Node {
        if len(cc) < binarySearchMin {
                // linear search
                var cas []*Node
@@ -422,27 +420,35 @@ func casebody(sw *Node, typeswvar *Node) {
        lineno = lno
 }
 
-// caseClauses generates a slice of caseClauses
+// genCaseClauses generates the caseClauses value
 // corresponding to the clauses in the switch statement sw.
 // Kind is the kind of switch statement.
-func caseClauses(sw *Node, kind int) []*caseClause {
-       var cc []*caseClause
+func genCaseClauses(sw *Node, kind int) caseClauses {
+       var cc caseClauses
        for _, n := range sw.List.Slice() {
-               c := new(caseClause)
-               cc = append(cc, c)
-               c.ordinal = len(cc)
-               c.node = n
-
                if n.Left == nil {
-                       c.typ = caseKindDefault
+                       // default case
+                       if cc.defjmp != nil {
+                               Fatalf("duplicate default case not detected during typechecking")
+                       }
+                       cc.defjmp = n.Right
                        continue
                }
 
+               if kind == switchKindType && n.Left.Op == OLITERAL {
+                       // nil case in type switch
+                       if cc.niljmp != nil {
+                               Fatalf("duplicate nil case not detected during typechecking")
+                       }
+                       cc.niljmp = n.Right
+                       continue
+               }
+
+               // general case
+               c := caseClause{node: n, ordinal: len(cc.list)}
                if kind == switchKindType {
                        // type switch
                        switch {
-                       case n.Left.Op == OLITERAL:
-                               c.typ = caseKindTypeNil
                        case n.Left.Type.IsInterface():
                                c.typ = caseKindTypeVar
                        default:
@@ -458,22 +464,24 @@ func caseClauses(sw *Node, kind int) []*caseClause {
                                c.typ = caseKindExprVar
                        }
                }
+               cc.list = append(cc.list, c)
        }
 
-       if cc == nil {
-               return nil
+       if cc.defjmp == nil {
+               cc.defjmp = Nod(OBREAK, nil, nil)
+       }
+
+       if cc.list == nil {
+               return cc
        }
 
        // sort by value and diagnose duplicate cases
        if kind == switchKindType {
                // type switch
-               sort.Sort(caseClauseByType(cc))
-               for i, c1 := range cc {
-                       if c1.typ == caseKindTypeNil || c1.typ == caseKindDefault {
-                               break
-                       }
-                       for _, c2 := range cc[i+1:] {
-                               if c2.typ == caseKindTypeNil || c2.typ == caseKindDefault || c1.hash != c2.hash {
+               sort.Sort(caseClauseByType(cc.list))
+               for i, c1 := range cc.list {
+                       for _, c2 := range cc.list[i+1:] {
+                               if c1.hash != c2.hash {
                                        break
                                }
                                if Eqtype(c1.node.Left.Type, c2.node.Left.Type) {
@@ -483,12 +491,12 @@ func caseClauses(sw *Node, kind int) []*caseClause {
                }
        } else {
                // expression switch
-               sort.Sort(caseClauseByExpr(cc))
-               for i, c1 := range cc {
-                       if i+1 == len(cc) {
+               sort.Sort(caseClauseByExpr(cc.list))
+               for i, c1 := range cc.list {
+                       if i+1 == len(cc.list) {
                                break
                        }
-                       c2 := cc[i+1]
+                       c2 := cc.list[i+1]
                        if exprcmp(c1, c2) != 0 {
                                continue
                        }
@@ -498,7 +506,7 @@ func caseClauses(sw *Node, kind int) []*caseClause {
        }
 
        // put list back in processing order
-       sort.Sort(caseClauseByOrd(cc))
+       sort.Sort(caseClauseByOrd(cc.list))
        return cc
 }
 
@@ -545,20 +553,9 @@ func (s *typeSwitch) walk(sw *Node) {
        // set up labels and jumps
        casebody(sw, s.facename)
 
-       cc := caseClauses(sw, switchKindType)
+       clauses := genCaseClauses(sw, switchKindType)
        sw.List.Set(nil)
-       var def *Node
-       if len(cc) > 0 && cc[0].typ == caseKindDefault {
-               def = cc[0].node.Right
-               cc = cc[1:]
-       } else {
-               def = Nod(OBREAK, nil, nil)
-       }
-       var typenil *Node
-       if len(cc) > 0 && cc[0].typ == caseKindTypeNil {
-               typenil = cc[0].node.Right
-               cc = cc[1:]
-       }
+       def := clauses.defjmp
 
        // For empty interfaces, do:
        //     if e._type == nil {
@@ -573,9 +570,9 @@ func (s *typeSwitch) walk(sw *Node) {
        // Check for nil first.
        i := Nod(OIF, nil, nil)
        i.Left = Nod(OEQ, typ, nodnil())
-       if typenil != nil {
+       if clauses.niljmp != nil {
                // Do explicit nil case right here.
-               i.Nbody.Set1(typenil)
+               i.Nbody.Set1(clauses.niljmp)
        } else {
                // Jump to default case.
                lbl := autolabel(".s")
@@ -602,6 +599,8 @@ func (s *typeSwitch) walk(sw *Node) {
        a = typecheck(a, Etop)
        cas = append(cas, a)
 
+       cc := clauses.list
+
        // insert type equality check into each case block
        for _, c := range cc {
                n := c.node
@@ -696,7 +695,7 @@ func (s *typeSwitch) typeone(t *Node) *Node {
 }
 
 // walkCases generates an AST implementing the cases in cc.
-func (s *typeSwitch) walkCases(cc []*caseClause) *Node {
+func (s *typeSwitch) walkCases(cc []caseClause) *Node {
        if len(cc) < binarySearchMin {
                var cas []*Node
                for _, c := range cc {
@@ -723,31 +722,13 @@ func (s *typeSwitch) walkCases(cc []*caseClause) *Node {
        return a
 }
 
-type caseClauseByOrd []*caseClause
-
-func (x caseClauseByOrd) Len() int      { return len(x) }
-func (x caseClauseByOrd) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x caseClauseByOrd) Less(i, j int) bool {
-       c1, c2 := x[i], x[j]
-       switch {
-       // sort default first
-       case c1.typ == caseKindDefault:
-               return true
-       case c2.typ == caseKindDefault:
-               return false
-
-       // sort nil second
-       case c1.typ == caseKindTypeNil:
-               return true
-       case c2.typ == caseKindTypeNil:
-               return false
-       }
+type caseClauseByOrd []caseClause
 
-       // sort by ordinal
-       return c1.ordinal < c2.ordinal
-}
+func (x caseClauseByOrd) Len() int           { return len(x) }
+func (x caseClauseByOrd) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }
+func (x caseClauseByOrd) Less(i, j int) bool { return x[i].ordinal < x[j].ordinal }
 
-type caseClauseByExpr []*caseClause
+type caseClauseByExpr []caseClause
 
 func (x caseClauseByExpr) Len() int      { return len(x) }
 func (x caseClauseByExpr) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
@@ -755,7 +736,7 @@ func (x caseClauseByExpr) Less(i, j int) bool {
        return exprcmp(x[i], x[j]) < 0
 }
 
-func exprcmp(c1, c2 *caseClause) int {
+func exprcmp(c1, c2 caseClause) int {
        // sort non-constants last
        if c1.typ != caseKindExprConst {
                return +1
@@ -814,7 +795,7 @@ func exprcmp(c1, c2 *caseClause) int {
        return 0
 }
 
-type caseClauseByType []*caseClause
+type caseClauseByType []caseClause
 
 func (x caseClauseByType) Len() int      { return len(x) }
 func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
index c1ee8955cf6c93d74f7555c00b223dc22c052c5f..4e5dbcae505181f9525294ba9ea1878ad96c369d 100644 (file)
@@ -134,7 +134,7 @@ func TestExprcmp(t *testing.T) {
                },
        }
        for i, d := range testdata {
-               got := exprcmp(&d.a, &d.b)
+               got := exprcmp(d.a, d.b)
                if d.want != got {
                        t.Errorf("%d: exprcmp(a, b) = %d; want %d", i, got, d.want)
                        t.Logf("\ta = caseClause{node: %#v, typ: %#v}", d.a.node, d.a.typ)