"sort"
)
-const (
- // expression switch
- switchKindExpr = iota // switch a {...} or switch 5 {...}
- switchKindTrue // switch true {...} or switch {...}
- switchKindFalse // switch false {...}
-)
-
-const (
- binarySearchMin = 4 // minimum number of cases for binary search
- integerRangeMin = 2 // minimum size of integer ranges
-)
-
-// An exprSwitch walks an expression switch.
-type exprSwitch struct {
- exprname *Node // node for the expression being switched on
- kind int // kind of switch statement (switchKind*)
-}
-
-// A typeSwitch walks a type switch.
-type typeSwitch struct {
- hashname *Node // node for the hash of the type of the variable being switched on
- facename *Node // node for the concrete type of the variable being switched on
- okname *Node // boolean node used for comma-ok type assertions
-}
-
-// A caseClause is a single case clause in a switch statement.
-type caseClause struct {
- node *Node // points at case statement
- ordinal int // position in switch
- hash uint32 // hash of a type switch
- // isconst indicates whether this case clause is a constant,
- // for the purposes of the switch code generation.
- // For expression switches, that's generally literals (case 5:, not case x:).
- // For type switches, that's concrete types (case time.Time:), not interfaces (case io.Reader:).
- isconst bool
-}
-
-// caseClauses are all the case clauses in a switch statement.
-type caseClauses struct {
- list []caseClause // general cases
- defjmp *Node // OGOTO for default case or OBREAK if no default case present
- niljmp *Node // OGOTO for nil type case in a type switch
-}
-
// typecheckswitch typechecks a switch statement.
func typecheckswitch(n *Node) {
typecheckslice(n.Ninit.Slice(), ctxStmt)
yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right)
t = nil
}
- n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
// We don't actually declare the type switch's guarded
// declaration itself. So if there are no cases, we won't
t = nil
}
}
- n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
var defCase *Node
var cs constSet
// walkswitch walks a switch statement.
func walkswitch(sw *Node) {
- // convert switch {...} to switch true {...}
- if sw.Left == nil {
- sw.Left = nodbool(true)
- sw.Left = typecheck(sw.Left, ctxExpr)
- sw.Left = defaultlit(sw.Left, nil)
+ // Guard against double walk, see #25776.
+ if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
+ return // Was fatal, but eliminating every possible source of double-walking is hard
}
- if sw.Left.Op == OTYPESW {
- var s typeSwitch
- s.walk(sw)
+ if sw.Left != nil && sw.Left.Op == OTYPESW {
+ walkTypeSwitch(sw)
} else {
- var s exprSwitch
- s.walk(sw)
+ walkExprSwitch(sw)
}
}
-// walk generates an AST implementing sw.
-// sw is an expression switch.
-// The AST is generally of the form of a linear
-// search using if..goto, although binary search
-// is used with long runs of constants.
-func (s *exprSwitch) walk(sw *Node) {
- // Guard against double walk, see #25776.
- if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
- return // Was fatal, but eliminating every possible source of double-walking is hard
- }
-
- casebody(sw, nil)
+// walkExprSwitch generates an AST implementing sw. sw is an
+// expression switch.
+func walkExprSwitch(sw *Node) {
+ lno := setlineno(sw)
cond := sw.Left
sw.Left = nil
- s.kind = switchKindExpr
- if Isconst(cond, CTBOOL) {
- s.kind = switchKindTrue
- if !cond.Val().U.(bool) {
- s.kind = switchKindFalse
- }
+ // convert switch {...} to switch true {...}
+ if cond == nil {
+ cond = nodbool(true)
+ cond = typecheck(cond, ctxExpr)
+ cond = defaultlit(cond, nil)
}
// Given "switch string(byteslice)",
- // with all cases being constants (or the default case),
+ // with all cases being side-effect free,
// use a zero-cost alias of the byte slice.
- // In theory, we could be more aggressive,
- // allowing any side-effect-free expressions in cases,
- // but it's a bit tricky because some of that information
- // is unavailable due to the introduction of temporaries during order.
- // Restricting to constants is simple and probably powerful enough.
// Do this before calling walkexpr on cond,
// because walkexpr will lower the string
// conversion into a runtime call.
// See issue 24937 for more discussion.
- if cond.Op == OBYTES2STR {
- ok := true
- for _, cas := range sw.List.Slice() {
- if cas.Op != OCASE {
- Fatalf("switch string(byteslice) bad op: %v", cas.Op)
- }
- if cas.Left != nil && !Isconst(cas.Left, CTSTR) {
- ok = false
- break
- }
- }
- if ok {
- cond.Op = OBYTES2STRTMP
- }
+ if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
+ cond.Op = OBYTES2STRTMP
}
cond = walkexpr(cond, &sw.Ninit)
- t := sw.Type
- if t == nil {
- return
+ if cond.Op != OLITERAL {
+ cond = copyexpr(cond, cond.Type, &sw.Nbody)
}
- // convert the switch into OIF statements
- var cas []*Node
- if s.kind == switchKindTrue || s.kind == switchKindFalse {
- s.exprname = nodbool(s.kind == switchKindTrue)
- } else if consttype(cond) > 0 {
- // leave constants to enable dead code elimination (issue 9608)
- s.exprname = cond
- } else {
- s.exprname = temp(cond.Type)
- cas = []*Node{nod(OAS, s.exprname, cond)} // This gets walk()ed again in walkstmtlist just before end of this function. See #29562.
- typecheckslice(cas, ctxStmt)
+ lineno = lno
+
+ s := exprSwitch{
+ exprname: cond,
}
- // Enumerate the cases and prepare the default case.
- clauses := s.genCaseClauses(sw.List.Slice())
- sw.List.Set(nil)
- cc := clauses.list
-
- // handle the cases in order
- for len(cc) > 0 {
- run := 1
- if okforcmp[t.Etype] && cc[0].isconst {
- // do binary search on runs of constants
- for ; run < len(cc) && cc[run].isconst; run++ {
+ br := nod(OBREAK, nil, nil)
+ var defaultGoto *Node
+ var body Nodes
+ for _, ncase := range sw.List.Slice() {
+ label := autolabel(".s")
+ jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
+
+ // Process case dispatch.
+ if ncase.List.Len() == 0 {
+ if defaultGoto != nil {
+ Fatalf("duplicate default case not detected during typechecking")
}
- // sort and compile constants
- sort.Sort(caseClauseByConstVal(cc[:run]))
+ defaultGoto = jmp
}
- a := s.walkCases(cc[:run])
- cas = append(cas, a)
- cc = cc[run:]
+ for _, n1 := range ncase.List.Slice() {
+ s.Add(ncase.Pos, n1, jmp)
+ }
+
+ // Process body.
+ body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
+ body.Append(ncase.Nbody.Slice()...)
+ if !hasFall(ncase.Nbody.Slice()) {
+ body.Append(br)
+ }
}
+ sw.List.Set(nil)
- // handle default case
- if nerrors == 0 {
- cas = append(cas, clauses.defjmp)
- sw.Nbody.Prepend(cas...)
- walkstmtlist(sw.Nbody.Slice())
+ if defaultGoto == nil {
+ defaultGoto = br
}
-}
-// walkCases generates an AST implementing the cases in cc.
-func (s *exprSwitch) walkCases(cc []caseClause) *Node {
- if len(cc) < binarySearchMin {
- // linear search
- var cas []*Node
- for _, c := range cc {
- n := c.node
- lno := setlineno(n)
+ s.Emit(&sw.Nbody)
+ sw.Nbody.Append(defaultGoto)
+ sw.Nbody.AppendNodes(&body)
+ walkstmtlist(sw.Nbody.Slice())
+}
- a := nod(OIF, nil, nil)
- if rng := n.List.Slice(); rng != nil {
- // Integer range.
- // exprname is a temp or a constant,
- // so it is safe to evaluate twice.
- // In most cases, this conjunction will be
- // rewritten by walkinrange into a single comparison.
- low := nod(OGE, s.exprname, rng[0])
- high := nod(OLE, s.exprname, rng[1])
- a.Left = nod(OANDAND, low, high)
- } else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
- a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
- } else if s.kind == switchKindTrue {
- a.Left = n.Left // if val
- } else {
- // s.kind == switchKindFalse
- a.Left = nod(ONOT, n.Left, nil) // if !val
- }
- a.Left = typecheck(a.Left, ctxExpr)
- a.Left = defaultlit(a.Left, nil)
- a.Nbody.Set1(n.Right) // goto l
+// An exprSwitch walks an expression switch.
+type exprSwitch struct {
+ exprname *Node // value being switched on
- cas = append(cas, a)
- lineno = lno
- }
- return liststmt(cas)
- }
+ done Nodes
+ clauses []exprClause
+}
- // find the middle and recur
- half := len(cc) / 2
- a := nod(OIF, nil, nil)
- n := cc[half-1].node
- var mid *Node
- if rng := n.List.Slice(); rng != nil {
- mid = rng[1] // high end of range
- } else {
- mid = n.Left
- }
- le := nod(OLE, s.exprname, mid)
- if Isconst(mid, CTSTR) {
- // Search by length and then by value; see caseClauseByConstVal.
- lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
- leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
- a.Left = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
- } else {
- a.Left = le
- }
- a.Left = typecheck(a.Left, ctxExpr)
- a.Left = defaultlit(a.Left, nil)
- a.Nbody.Set1(s.walkCases(cc[:half]))
- a.Rlist.Set1(s.walkCases(cc[half:]))
- return a
+type exprClause struct {
+ pos src.XPos
+ lo, hi *Node
+ jmp *Node
}
-// casebody builds separate lists of statements and cases.
-// It makes labels between cases and statements
-// and deals with fallthrough, break, and unreachable statements.
-func casebody(sw *Node, typeswvar *Node) {
- if sw.List.Len() == 0 {
+func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) {
+ c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
+ if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL {
+ s.clauses = append(s.clauses, c)
return
}
- lno := setlineno(sw)
+ s.flush()
+ s.clauses = append(s.clauses, c)
+ s.flush()
+}
- var cas []*Node // cases
- var stat []*Node // statements
- var def *Node // defaults
- br := nod(OBREAK, nil, nil)
+func (s *exprSwitch) Emit(out *Nodes) {
+ s.flush()
+ out.AppendNodes(&s.done)
+}
- for _, n := range sw.List.Slice() {
- setlineno(n)
- if n.Op != OXCASE {
- Fatalf("casebody %v", n.Op)
- }
- n.Op = OCASE
- needvar := n.List.Len() != 1 || n.List.First().Op == OLITERAL
-
- lbl := autolabel(".s")
- jmp := nodSym(OGOTO, nil, lbl)
- switch n.List.Len() {
- case 0:
- // default
- if def != nil {
- yyerrorl(n.Pos, "more than one default case")
- }
- // reuse original default case
- n.Right = jmp
- def = n
- case 1:
- // one case -- reuse OCASE node
- n.Left = n.List.First()
- n.Right = jmp
- n.List.Set(nil)
- cas = append(cas, n)
- default:
- // Expand multi-valued cases and detect ranges of integer cases.
- if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
- // Can't use integer ranges. Expand each case into a separate node.
- for _, n1 := range n.List.Slice() {
- cas = append(cas, nod(OCASE, n1, jmp))
- }
- break
- }
- // Find integer ranges within runs of constants.
- s := n.List.Slice()
- j := 0
- for j < len(s) {
- // Find a run of constants.
- var run int
- for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
- }
- if run-j >= integerRangeMin {
- // Search for integer ranges in s[j:run].
- // Typechecking is done, so all values are already in an appropriate range.
- search := s[j:run]
- sort.Sort(constIntNodesByVal(search))
- for beg, end := 0, 1; end <= len(search); end++ {
- if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
- continue
- }
- if end-beg >= integerRangeMin {
- // Record range in List.
- c := nod(OCASE, nil, jmp)
- c.List.Set2(search[beg], search[end-1])
- cas = append(cas, c)
- } else {
- // Not large enough for range; record separately.
- for _, n := range search[beg:end] {
- cas = append(cas, nod(OCASE, n, jmp))
- }
- }
- beg = end
- }
- j = run
- }
- // Advance to next constant, adding individual non-constant
- // or as-yet-unhandled constant cases as we go.
- for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
- cas = append(cas, nod(OCASE, s[j], jmp))
- }
+func (s *exprSwitch) flush() {
+ cc := s.clauses
+ s.clauses = nil
+ if len(cc) == 0 {
+ return
+ }
+
+ // Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
+ // The code below is structured to implicitly handle this case
+ // (e.g., sort.Slice doesn't need to invoke the less function
+ // when there's only a single slice element).
+
+ // Sort strings by length and then by value.
+ // It is much cheaper to compare lengths than values,
+ // and all we need here is consistency.
+ // We respect this sorting below.
+ sort.Slice(cc, func(i, j int) bool {
+ vi := cc[i].lo.Val()
+ vj := cc[j].lo.Val()
+
+ if s.exprname.Type.IsString() {
+ si := vi.U.(string)
+ sj := vj.U.(string)
+ if len(si) != len(sj) {
+ return len(si) < len(sj)
}
+ return si < sj
}
- stat = append(stat, nodSym(OLABEL, nil, lbl))
- if typeswvar != nil && needvar && n.Rlist.Len() != 0 {
- l := []*Node{
- nod(ODCL, n.Rlist.First(), nil),
- nod(OAS, n.Rlist.First(), typeswvar),
+ return compareOp(vi, OLT, vj)
+ })
+
+ // Merge consecutive integer cases.
+ if s.exprname.Type.IsInteger() {
+ merged := cc[:1]
+ for _, c := range cc[1:] {
+ last := &merged[len(merged)-1]
+ if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() {
+ last.hi = c.lo
+ } else {
+ merged = append(merged, c)
}
- typecheckslice(l, ctxStmt)
- stat = append(stat, l...)
- }
- stat = append(stat, n.Nbody.Slice()...)
-
- // Search backwards for the index of the fallthrough
- // statement. Do not assume it'll be in the last
- // position, since in some cases (e.g. when the statement
- // list contains autotmp_ variables), one or more OVARKILL
- // nodes will be at the end of the list.
- fallIndex := len(stat) - 1
- for stat[fallIndex].Op == OVARKILL {
- fallIndex--
- }
- last := stat[fallIndex]
- if last.Op != OFALL {
- stat = append(stat, br)
}
+ cc = merged
}
- stat = append(stat, br)
- if def != nil {
- cas = append(cas, def)
- }
+ binarySearch(len(cc), &s.done,
+ func(i int) *Node {
+ mid := cc[i-1].hi
- sw.List.Set(cas)
- sw.Nbody.Set(stat)
- lineno = lno
+ le := nod(OLE, s.exprname, mid)
+ if s.exprname.Type.IsString() {
+ // Compare strings by length and then
+ // by value; see sort.Slice above.
+ lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
+ leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
+ le = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
+ }
+ return le
+ },
+ func(i int, out *Nodes) {
+ c := &cc[i]
+
+ nif := nodl(c.pos, OIF, c.test(s.exprname), nil)
+ nif.Left = typecheck(nif.Left, ctxExpr)
+ nif.Left = defaultlit(nif.Left, nil)
+ nif.Nbody.Set1(c.jmp)
+ out.Append(nif)
+ },
+ )
}
-// genCaseClauses generates the caseClauses value for clauses.
-func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
- var cc caseClauses
- for _, n := range clauses {
- if n.Left == nil && n.List.Len() == 0 {
- // default case
- if cc.defjmp != nil {
- Fatalf("duplicate default case not detected during typechecking")
- }
- cc.defjmp = n.Right
- continue
- }
- c := caseClause{node: n, ordinal: len(cc.list)}
- if n.List.Len() > 0 {
- c.isconst = true
- }
- switch consttype(n.Left) {
- case CTFLT, CTINT, CTRUNE, CTSTR:
- c.isconst = true
- }
- cc.list = append(cc.list, c)
+func (c *exprClause) test(exprname *Node) *Node {
+ // Integer range.
+ if c.hi != c.lo {
+ low := nodl(c.pos, OGE, exprname, c.lo)
+ high := nodl(c.pos, OLE, exprname, c.hi)
+ return nodl(c.pos, OANDAND, low, high)
}
- if cc.defjmp == nil {
- cc.defjmp = nod(OBREAK, nil, nil)
+ // Optimize "switch true { ...}" and "switch false { ... }".
+ if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() {
+ if exprname.Val().U.(bool) {
+ return c.lo
+ } else {
+ return nodl(c.pos, ONOT, c.lo, nil)
+ }
}
- return cc
+
+ return nodl(c.pos, OEQ, exprname, c.lo)
}
-// genCaseClauses generates the caseClauses value for clauses.
-func (s *typeSwitch) genCaseClauses(clauses []*Node) caseClauses {
- var cc caseClauses
- for _, n := range clauses {
- switch {
- case n.Left == nil:
- // default case
- if cc.defjmp != nil {
- Fatalf("duplicate default case not detected during typechecking")
- }
- cc.defjmp = n.Right
- continue
- case n.Left.Op == OLITERAL:
- // nil case in type switch
- if cc.niljmp != nil {
- Fatalf("duplicate nil case not detected during typechecking")
- }
- cc.niljmp = n.Right
- continue
+func allCaseExprsAreSideEffectFree(sw *Node) bool {
+ // In theory, we could be more aggressive, allowing any
+ // side-effect-free expressions in cases, but it's a bit
+ // tricky because some of that information is unavailable due
+ // to the introduction of temporaries during order.
+ // Restricting to constants is simple and probably powerful
+ // enough.
+
+ for _, ncase := range sw.List.Slice() {
+ if ncase.Op != OXCASE {
+ Fatalf("switch string(byteslice) bad op: %v", ncase.Op)
}
-
- // general case
- c := caseClause{
- node: n,
- ordinal: len(cc.list),
- isconst: !n.Left.Type.IsInterface(),
- hash: typehash(n.Left.Type),
+ for _, v := range ncase.List.Slice() {
+ if v.Op != OLITERAL {
+ return false
+ }
}
- cc.list = append(cc.list, c)
- }
-
- if cc.defjmp == nil {
- cc.defjmp = nod(OBREAK, nil, nil)
}
-
- return cc
+ return true
}
-// walk generates an AST that implements sw,
-// where sw is a type switch.
-// The AST is generally of the form of a linear
-// search using if..goto, although binary search
-// is used with long runs of concrete types.
-func (s *typeSwitch) walk(sw *Node) {
- cond := sw.Left
- sw.Left = nil
+// hasFall reports whether stmts ends with a "fallthrough" statement.
+func hasFall(stmts []*Node) bool {
+ // Search backwards for the index of the fallthrough
+ // statement. Do not assume it'll be in the last
+ // position, since in some cases (e.g. when the statement
+ // list contains autotmp_ variables), one or more OVARKILL
+ // nodes will be at the end of the list.
- if cond == nil {
- sw.List.Set(nil)
- return
+ i := len(stmts) - 1
+ for i >= 0 && stmts[i].Op == OVARKILL {
+ i--
}
- if cond.Right == nil {
- yyerrorl(sw.Pos, "type switch must have an assignment")
- return
- }
-
- cond.Right = walkexpr(cond.Right, &sw.Ninit)
- if !cond.Right.Type.IsInterface() {
- yyerrorl(sw.Pos, "type switch must be on an interface")
- return
- }
-
- var cas []*Node
-
- // predeclare temporary variables and the boolean var
- s.facename = temp(cond.Right.Type)
+ return i >= 0 && stmts[i].Op == OFALL
+}
- a := nod(OAS, s.facename, cond.Right)
- a = typecheck(a, ctxStmt)
- cas = append(cas, a)
+// walkTypeSwitch generates an AST that implements sw, where sw is a
+// type switch.
+func walkTypeSwitch(sw *Node) {
+ var s typeSwitch
+ s.facename = sw.Left.Right
+ sw.Left = nil
+ s.facename = walkexpr(s.facename, &sw.Ninit)
+ s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody)
s.okname = temp(types.Types[TBOOL])
- s.okname = typecheck(s.okname, ctxExpr)
-
- s.hashname = temp(types.Types[TUINT32])
- s.hashname = typecheck(s.hashname, ctxExpr)
- // set up labels and jumps
- casebody(sw, s.facename)
-
- clauses := s.genCaseClauses(sw.List.Slice())
- sw.List.Set(nil)
- def := clauses.defjmp
+ // Get interface descriptor word.
+ // For empty interfaces this will be the type.
+ // For non-empty interfaces this will be the itab.
+ itab := nod(OITAB, s.facename, nil)
// For empty interfaces, do:
// if e._type == nil {
// }
// h := e._type.hash
// Use a similar strategy for non-empty interfaces.
-
- // Get interface descriptor word.
- // For empty interfaces this will be the type.
- // For non-empty interfaces this will be the itab.
- itab := nod(OITAB, s.facename, nil)
-
- // Check for nil first.
- i := nod(OIF, nil, nil)
- i.Left = nod(OEQ, itab, nodnil())
- if clauses.niljmp != nil {
- // Do explicit nil case right here.
- i.Nbody.Set1(clauses.niljmp)
- } else {
- // Jump to default case.
- lbl := autolabel(".s")
- i.Nbody.Set1(nodSym(OGOTO, nil, lbl))
- // Wrap default case with label.
- blk := nod(OBLOCK, nil, nil)
- blk.List.Set2(nodSym(OLABEL, nil, lbl), def)
- def = blk
- }
- i.Left = typecheck(i.Left, ctxExpr)
- i.Left = defaultlit(i.Left, nil)
- cas = append(cas, i)
+ ifNil := nod(OIF, nil, nil)
+ ifNil.Left = nod(OEQ, itab, nodnil())
+ ifNil.Left = typecheck(ifNil.Left, ctxExpr)
+ ifNil.Left = defaultlit(ifNil.Left, nil)
+ // ifNil.Nbody assigned at end.
+ sw.Nbody.Append(ifNil)
// Load hash from type or itab.
- h := nodSym(ODOTPTR, itab, nil)
- h.Type = types.Types[TUINT32]
- h.SetTypecheck(1)
- if cond.Right.Type.IsEmptyInterface() {
- h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
+ dotHash := nodSym(ODOTPTR, itab, nil)
+ dotHash.Type = types.Types[TUINT32]
+ dotHash.SetTypecheck(1)
+ if s.facename.Type.IsEmptyInterface() {
+ dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
} else {
- h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
+ dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
}
- h.SetBounded(true) // guaranteed not to fault
- a = nod(OAS, s.hashname, h)
- a = typecheck(a, ctxStmt)
- cas = append(cas, a)
+ dotHash.SetBounded(true) // guaranteed not to fault
+ s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody)
- cc := clauses.list
+ br := nod(OBREAK, nil, nil)
+ var defaultGoto, nilGoto *Node
+ var body Nodes
+ for _, ncase := range sw.List.Slice() {
+ var caseVar *Node
+ if ncase.Rlist.Len() != 0 {
+ caseVar = ncase.Rlist.First()
+ }
- // insert type equality check into each case block
- for _, c := range cc {
- c.node.Right = s.typeone(c.node)
- }
+ // For single-type cases, we initialize the case
+ // variable as part of the type assertion; but in
+ // other cases, we initialize it in the body.
+ singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE
- // generate list of if statements, binary search for constant sequences
- for len(cc) > 0 {
- if !cc[0].isconst {
- n := cc[0].node
- cas = append(cas, n.Right)
- cc = cc[1:]
- continue
- }
+ label := autolabel(".s")
- // identify run of constants
- var run int
- for run = 1; run < len(cc) && cc[run].isconst; run++ {
+ jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
+ if ncase.List.Len() == 0 { // default:
+ if defaultGoto != nil {
+ Fatalf("duplicate default case not detected during typechecking")
+ }
+ defaultGoto = jmp
}
- // sort by hash
- sort.Sort(caseClauseByType(cc[:run]))
+ for _, n1 := range ncase.List.Slice() {
+ if n1.isNil() { // case nil:
+ if nilGoto != nil {
+ Fatalf("duplicate nil case not detected during typechecking")
+ }
+ nilGoto = jmp
+ continue
+ }
- // for debugging: linear search
- if false {
- for i := 0; i < run; i++ {
- n := cc[i].node
- cas = append(cas, n.Right)
+ if singleType {
+ s.Add(n1.Type, caseVar, jmp)
+ } else {
+ s.Add(n1.Type, nil, jmp)
}
- continue
}
- // combine adjacent cases with the same hash
- var batch []caseClause
- for i, j := 0, 0; i < run; i = j {
- hash := []*Node{cc[i].node.Right}
- for j = i + 1; j < run && cc[i].hash == cc[j].hash; j++ {
- hash = append(hash, cc[j].node.Right)
+ body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
+ if caseVar != nil && !singleType {
+ l := []*Node{
+ nodl(ncase.Pos, ODCL, caseVar, nil),
+ nodl(ncase.Pos, OAS, caseVar, s.facename),
}
- cc[i].node.Right = liststmt(hash)
- batch = append(batch, cc[i])
+ typecheckslice(l, ctxStmt)
+ body.Append(l...)
}
+ body.Append(ncase.Nbody.Slice()...)
+ body.Append(br)
+ }
+ sw.List.Set(nil)
- // binary search among cases to narrow by hash
- cas = append(cas, s.walkCases(batch))
- cc = cc[run:]
+ if defaultGoto == nil {
+ defaultGoto = br
}
- // handle default case
- if nerrors == 0 {
- cas = append(cas, def)
- sw.Nbody.Prepend(cas...)
- sw.List.Set(nil)
- walkstmtlist(sw.Nbody.Slice())
+ if nilGoto != nil {
+ ifNil.Nbody.Set1(nilGoto)
+ } else {
+ // TODO(mdempsky): Just use defaultGoto directly.
+
+ // Jump to default case.
+ label := autolabel(".s")
+ ifNil.Nbody.Set1(nodSym(OGOTO, nil, label))
+ // Wrap default case with label.
+ blk := nod(OBLOCK, nil, nil)
+ blk.List.Set2(nodSym(OLABEL, nil, label), defaultGoto)
+ defaultGoto = blk
}
+
+ s.Emit(&sw.Nbody)
+ sw.Nbody.Append(defaultGoto)
+ sw.Nbody.AppendNodes(&body)
+
+ walkstmtlist(sw.Nbody.Slice())
}
-// typeone generates an AST that jumps to the
-// case body if the variable is of type t.
-func (s *typeSwitch) typeone(t *Node) *Node {
- var name *Node
- var init Nodes
- if t.Rlist.Len() == 0 {
- name = nblank
- nblank = typecheck(nblank, ctxExpr|ctxAssign)
- } else {
- name = t.Rlist.First()
- init.Append(nod(ODCL, name, nil))
- a := nod(OAS, name, nil)
- a = typecheck(a, ctxStmt)
- init.Append(a)
- }
-
- a := nod(OAS2, nil, nil)
- a.List.Set2(name, s.okname) // name, ok =
- b := nod(ODOTTYPE, s.facename, nil)
- b.Type = t.Left.Type // interface.(type)
- a.Rlist.Set1(b)
- a = typecheck(a, ctxStmt)
- a = walkexpr(a, &init)
- init.Append(a)
-
- c := nod(OIF, nil, nil)
- c.Left = s.okname
- c.Nbody.Set1(t.Right) // if ok { goto l }
-
- init.Append(c)
- return init.asblock()
+// A typeSwitch walks a type switch.
+type typeSwitch struct {
+ // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
+ facename *Node // value being type-switched on
+ hashname *Node // type hash of the value being type-switched on
+ okname *Node // boolean used for comma-ok type assertions
+
+ done Nodes
+ clauses []typeClause
}
-// walkCases generates an AST implementing the cases in cc.
-func (s *typeSwitch) walkCases(cc []caseClause) *Node {
- if len(cc) < binarySearchMin {
- var cas []*Node
- for _, c := range cc {
- n := c.node
- if !c.isconst {
- Fatalf("typeSwitch walkCases")
- }
- a := nod(OIF, nil, nil)
- a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
- a.Left = typecheck(a.Left, ctxExpr)
- a.Left = defaultlit(a.Left, nil)
- a.Nbody.Set1(n.Right)
- cas = append(cas, a)
- }
- return liststmt(cas)
- }
-
- // find the middle and recur
- half := len(cc) / 2
- a := nod(OIF, nil, nil)
- a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
- a.Left = typecheck(a.Left, ctxExpr)
- a.Left = defaultlit(a.Left, nil)
- a.Nbody.Set1(s.walkCases(cc[:half]))
- a.Rlist.Set1(s.walkCases(cc[half:]))
- return a
+type typeClause struct {
+ hash uint32
+ body Nodes
}
-// caseClauseByConstVal sorts clauses by constant value to enable binary search.
-type caseClauseByConstVal []caseClause
-
-func (x caseClauseByConstVal) Len() int { return len(x) }
-func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x caseClauseByConstVal) Less(i, j int) bool {
- // n1 and n2 might be individual constants or integer ranges.
- // We have checked for duplicates already,
- // so ranges can be safely represented by any value in the range.
- n1 := x[i].node
- var v1 interface{}
- if s := n1.List.Slice(); s != nil {
- v1 = s[0].Val().U
+func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) {
+ var body Nodes
+ if caseVar != nil {
+ l := []*Node{
+ nod(ODCL, caseVar, nil),
+ nod(OAS, caseVar, nil),
+ }
+ typecheckslice(l, ctxStmt)
+ body.Append(l...)
} else {
- v1 = n1.Left.Val().U
+ caseVar = nblank
+ }
+
+ // cv, ok = iface.(type)
+ as := nod(OAS2, nil, nil)
+ as.List.Set2(caseVar, s.okname) // cv, ok =
+ dot := nod(ODOTTYPE, s.facename, nil)
+ dot.Type = typ // iface.(type)
+ as.Rlist.Set1(dot)
+ as = typecheck(as, ctxStmt)
+ as = walkexpr(as, &body)
+ body.Append(as)
+
+ // if ok { goto label }
+ nif := nod(OIF, nil, nil)
+ nif.Left = s.okname
+ nif.Nbody.Set1(jmp)
+ body.Append(nif)
+
+ if !typ.IsInterface() {
+ s.clauses = append(s.clauses, typeClause{
+ hash: typehash(typ),
+ body: body,
+ })
+ return
}
- n2 := x[j].node
- var v2 interface{}
- if s := n2.List.Slice(); s != nil {
- v2 = s[0].Val().U
- } else {
- v2 = n2.Left.Val().U
- }
-
- switch v1 := v1.(type) {
- case *Mpflt:
- return v1.Cmp(v2.(*Mpflt)) < 0
- case *Mpint:
- return v1.Cmp(v2.(*Mpint)) < 0
- case string:
- // Sort strings by length and then by value.
- // It is much cheaper to compare lengths than values,
- // and all we need here is consistency.
- // We respect this sorting in exprSwitch.walkCases.
- a := v1
- b := v2.(string)
- if len(a) != len(b) {
- return len(a) < len(b)
- }
- return a < b
- }
+ s.flush()
+ s.done.AppendNodes(&body)
+}
- Fatalf("caseClauseByConstVal passed bad clauses %v < %v", x[i].node.Left, x[j].node.Left)
- return false
+func (s *typeSwitch) Emit(out *Nodes) {
+ s.flush()
+ out.AppendNodes(&s.done)
}
-type caseClauseByType []caseClause
+func (s *typeSwitch) flush() {
+ cc := s.clauses
+ s.clauses = nil
+ if len(cc) == 0 {
+ return
+ }
+
+ sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
-func (x caseClauseByType) Len() int { return len(x) }
-func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x caseClauseByType) Less(i, j int) bool {
- c1, c2 := x[i], x[j]
- // sort by hash code, then ordinal (for the rare case of hash collisions)
- if c1.hash != c2.hash {
- return c1.hash < c2.hash
+ // Combine adjacent cases with the same hash.
+ merged := cc[:1]
+ for _, c := range cc[1:] {
+ last := &merged[len(merged)-1]
+ if last.hash == c.hash {
+ last.body.AppendNodes(&c.body)
+ } else {
+ merged = append(merged, c)
+ }
}
- return c1.ordinal < c2.ordinal
+ cc = merged
+
+ binarySearch(len(cc), &s.done,
+ func(i int) *Node {
+ return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash)))
+ },
+ func(i int, out *Nodes) {
+ // TODO(mdempsky): Omit hash equality check if
+ // there's only one type.
+ c := cc[i]
+ a := nod(OIF, nil, nil)
+ a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
+ a.Left = typecheck(a.Left, ctxExpr)
+ a.Left = defaultlit(a.Left, nil)
+ a.Nbody.AppendNodes(&c.body)
+ out.Append(a)
+ },
+ )
}
-type constIntNodesByVal []*Node
+// binarySearch constructs a binary search tree for handling n cases,
+// and appends it to out. It's used for efficiently implementing
+// switch statements.
+//
+// less(i) should return a boolean expression. If it evaluates true,
+// then cases [0, i) will be tested; otherwise, cases [i, n).
+//
+// base(i, out) should append statements to out to test the i'th case.
+func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, out *Nodes)) {
+ const binarySearchMin = 4 // minimum number of cases for binary search
+
+ var do func(lo, hi int, out *Nodes)
+ do = func(lo, hi int, out *Nodes) {
+ n := hi - lo
+ if n < binarySearchMin {
+ for i := lo; i < hi; i++ {
+ base(i, out)
+ }
+ return
+ }
+
+ half := lo + n/2
+ nif := nod(OIF, nil, nil)
+ nif.Left = less(half)
+ nif.Left = typecheck(nif.Left, ctxExpr)
+ nif.Left = defaultlit(nif.Left, nil)
+ do(lo, half, &nif.Nbody)
+ do(half, hi, &nif.Rlist)
+ out.Append(nif)
+ }
-func (x constIntNodesByVal) Len() int { return len(x) }
-func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
-func (x constIntNodesByVal) Less(i, j int) bool {
- return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
+ do(0, n, out)
}