]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.regabi] cmd/compile: use []*CaseStmt in {Select,Switch}Stmt
authorMatthew Dempsky <mdempsky@google.com>
Sun, 27 Dec 2020 06:23:45 +0000 (22:23 -0800)
committerMatthew Dempsky <mdempsky@google.com>
Mon, 28 Dec 2020 07:45:09 +0000 (07:45 +0000)
Select and switch statements only ever contain case statements, so
change their Cases fields from Nodes to []*CaseStmt. This allows
removing a bunch of type assertions throughout the compiler.

CaseStmt should be renamed to CaseClause, and SelectStmt should
probably have its own CommClause type instead (like in go/ast and
cmd/compile/internal/syntax), but this is a good start.

Passes toolstash -cmp.

Change-Id: I2d41d616d44512c2be421e1e2ff13d0ee8b238ad
Reviewed-on: https://go-review.googlesource.com/c/go/+/280442
Trust: Matthew Dempsky <mdempsky@google.com>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
13 files changed:
src/cmd/compile/internal/escape/escape.go
src/cmd/compile/internal/ir/mknode.go
src/cmd/compile/internal/ir/node_gen.go
src/cmd/compile/internal/ir/stmt.go
src/cmd/compile/internal/ir/visit.go
src/cmd/compile/internal/noder/noder.go
src/cmd/compile/internal/typecheck/iexport.go
src/cmd/compile/internal/typecheck/iimport.go
src/cmd/compile/internal/typecheck/stmt.go
src/cmd/compile/internal/typecheck/typecheck.go
src/cmd/compile/internal/walk/order.go
src/cmd/compile/internal/walk/select.go
src/cmd/compile/internal/walk/switch.go

index 31d157b1651fc7cfd43963495ac2951d4836801c..d8f0111d2de693248263b992c3c29e27aa8821cf 100644 (file)
@@ -369,7 +369,6 @@ func (e *escape) stmt(n ir.Node) {
 
                var ks []hole
                for _, cas := range n.Cases { // cases
-                       cas := cas.(*ir.CaseStmt)
                        if typesw && n.Tag.(*ir.TypeSwitchGuard).Tag != nil {
                                cv := cas.Var
                                k := e.dcl(cv) // type switch variables have no ODCL.
@@ -391,7 +390,6 @@ func (e *escape) stmt(n ir.Node) {
        case ir.OSELECT:
                n := n.(*ir.SelectStmt)
                for _, cas := range n.Cases {
-                       cas := cas.(*ir.CaseStmt)
                        e.stmt(cas.Comm)
                        e.block(cas.Body)
                }
index f5dacee622f9d659d9375b5f219addf2054a239e..edf3ee501c1ae78fcb29d948e0c4b309552653c3 100644 (file)
@@ -37,6 +37,7 @@ func main() {
        nodeType := lookup("Node")
        ntypeType := lookup("Ntype")
        nodesType := lookup("Nodes")
+       slicePtrCaseStmtType := types.NewSlice(types.NewPointer(lookup("CaseStmt")))
        ptrFieldType := types.NewPointer(lookup("Field"))
        slicePtrFieldType := types.NewSlice(ptrFieldType)
        ptrIdentType := types.NewPointer(lookup("Ident"))
@@ -76,6 +77,8 @@ func main() {
                                switch {
                                case is(nodesType):
                                        fmt.Fprintf(&buf, "c.%s = c.%s.Copy()\n", name, name)
+                               case is(slicePtrCaseStmtType):
+                                       fmt.Fprintf(&buf, "c.%s = copyCases(c.%s)\n", name, name)
                                case is(ptrFieldType):
                                        fmt.Fprintf(&buf, "if c.%s != nil { c.%s = c.%s.copy() }\n", name, name, name)
                                case is(slicePtrFieldType):
@@ -94,6 +97,8 @@ func main() {
                                fmt.Fprintf(&buf, "err = maybeDo(n.%s, err, do)\n", name)
                        case is(nodesType):
                                fmt.Fprintf(&buf, "err = maybeDoList(n.%s, err, do)\n", name)
+                       case is(slicePtrCaseStmtType):
+                               fmt.Fprintf(&buf, "err = maybeDoCases(n.%s, err, do)\n", name)
                        case is(ptrFieldType):
                                fmt.Fprintf(&buf, "err = maybeDoField(n.%s, err, do)\n", name)
                        case is(slicePtrFieldType):
@@ -113,6 +118,8 @@ func main() {
                                fmt.Fprintf(&buf, "n.%s = toNtype(maybeEdit(n.%s, edit))\n", name, name)
                        case is(nodesType):
                                fmt.Fprintf(&buf, "editList(n.%s, edit)\n", name)
+                       case is(slicePtrCaseStmtType):
+                               fmt.Fprintf(&buf, "editCases(n.%s, edit)\n", name)
                        case is(ptrFieldType):
                                fmt.Fprintf(&buf, "editField(n.%s, edit)\n", name)
                        case is(slicePtrFieldType):
index ecb39563c46f3f118f295750a79a08f97605206b..041855bbe93044b8b4f647a39f47aebcc1aa2583 100644 (file)
@@ -781,20 +781,20 @@ func (n *SelectStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
 func (n *SelectStmt) copy() Node {
        c := *n
        c.init = c.init.Copy()
-       c.Cases = c.Cases.Copy()
+       c.Cases = copyCases(c.Cases)
        c.Compiled = c.Compiled.Copy()
        return &c
 }
 func (n *SelectStmt) doChildren(do func(Node) error) error {
        var err error
        err = maybeDoList(n.init, err, do)
-       err = maybeDoList(n.Cases, err, do)
+       err = maybeDoCases(n.Cases, err, do)
        err = maybeDoList(n.Compiled, err, do)
        return err
 }
 func (n *SelectStmt) editChildren(edit func(Node) Node) {
        editList(n.init, edit)
-       editList(n.Cases, edit)
+       editCases(n.Cases, edit)
        editList(n.Compiled, edit)
 }
 
@@ -945,7 +945,7 @@ func (n *SwitchStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
 func (n *SwitchStmt) copy() Node {
        c := *n
        c.init = c.init.Copy()
-       c.Cases = c.Cases.Copy()
+       c.Cases = copyCases(c.Cases)
        c.Compiled = c.Compiled.Copy()
        return &c
 }
@@ -953,14 +953,14 @@ func (n *SwitchStmt) doChildren(do func(Node) error) error {
        var err error
        err = maybeDoList(n.init, err, do)
        err = maybeDo(n.Tag, err, do)
-       err = maybeDoList(n.Cases, err, do)
+       err = maybeDoCases(n.Cases, err, do)
        err = maybeDoList(n.Compiled, err, do)
        return err
 }
 func (n *SwitchStmt) editChildren(edit func(Node) Node) {
        editList(n.init, edit)
        n.Tag = maybeEdit(n.Tag, edit)
-       editList(n.Cases, edit)
+       editCases(n.Cases, edit)
        editList(n.Compiled, edit)
 }
 
index cfda6fd234dabd5ed65d1558880c8a76aaa21da3..ce775a8529ef36497591365c9d78d5742677f622 100644 (file)
@@ -191,6 +191,37 @@ func NewCaseStmt(pos src.XPos, list, body []Node) *CaseStmt {
        return n
 }
 
+func copyCases(list []*CaseStmt) []*CaseStmt {
+       if list == nil {
+               return nil
+       }
+       c := make([]*CaseStmt, len(list))
+       copy(c, list)
+       return c
+}
+
+func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error {
+       if err != nil {
+               return err
+       }
+       for _, x := range list {
+               if x != nil {
+                       if err := do(x); err != nil {
+                               return err
+                       }
+               }
+       }
+       return nil
+}
+
+func editCases(list []*CaseStmt, edit func(Node) Node) {
+       for i, x := range list {
+               if x != nil {
+                       list[i] = edit(x).(*CaseStmt)
+               }
+       }
+}
+
 // A ForStmt is a non-range for loop: for Init; Cond; Post { Body }
 // Op can be OFOR or OFORUNTIL (!Cond).
 type ForStmt struct {
@@ -334,18 +365,18 @@ func (n *ReturnStmt) SetOrig(x Node) { n.orig = x }
 type SelectStmt struct {
        miniStmt
        Label    *types.Sym
-       Cases    Nodes
+       Cases    []*CaseStmt
        HasBreak bool
 
        // TODO(rsc): Instead of recording here, replace with a block?
        Compiled Nodes // compiled form, after walkswitch
 }
 
-func NewSelectStmt(pos src.XPos, cases []Node) *SelectStmt {
+func NewSelectStmt(pos src.XPos, cases []*CaseStmt) *SelectStmt {
        n := &SelectStmt{}
        n.pos = pos
        n.op = OSELECT
-       n.Cases.Set(cases)
+       n.Cases = cases
        return n
 }
 
@@ -367,7 +398,7 @@ func NewSendStmt(pos src.XPos, ch, value Node) *SendStmt {
 type SwitchStmt struct {
        miniStmt
        Tag      Node
-       Cases    Nodes // list of *CaseStmt
+       Cases    []*CaseStmt
        Label    *types.Sym
        HasBreak bool
 
@@ -375,11 +406,11 @@ type SwitchStmt struct {
        Compiled Nodes // compiled form, after walkswitch
 }
 
-func NewSwitchStmt(pos src.XPos, tag Node, cases []Node) *SwitchStmt {
+func NewSwitchStmt(pos src.XPos, tag Node, cases []*CaseStmt) *SwitchStmt {
        n := &SwitchStmt{Tag: tag}
        n.pos = pos
        n.op = OSWITCH
-       n.Cases.Set(cases)
+       n.Cases = cases
        return n
 }
 
index a1c345968f703acda48cb953c67a045b2156d549..8839e1664d3b581b614fb299ce66b7813811c1f8 100644 (file)
@@ -217,10 +217,9 @@ func EditChildren(n Node, edit func(Node) Node) {
 // Note that editList only calls edit on the nodes in the list, not their children.
 // If x's children should be processed, edit(x) must call EditChildren(x, edit) itself.
 func editList(list Nodes, edit func(Node) Node) {
-       s := list
        for i, x := range list {
                if x != nil {
-                       s[i] = edit(x)
+                       list[i] = edit(x)
                }
        }
 }
index ad66b6c8509ca18ad90a4b4394050709440e8e38..b974448338f3add4d399b0a162ef7fc5fde48d6c 100644 (file)
@@ -1202,14 +1202,14 @@ func (p *noder) switchStmt(stmt *syntax.SwitchStmt) ir.Node {
        if l := n.Tag; l != nil && l.Op() == ir.OTYPESW {
                tswitch = l.(*ir.TypeSwitchGuard)
        }
-       n.Cases.Set(p.caseClauses(stmt.Body, tswitch, stmt.Rbrace))
+       n.Cases = p.caseClauses(stmt.Body, tswitch, stmt.Rbrace)
 
        p.closeScope(stmt.Rbrace)
        return n
 }
 
-func (p *noder) caseClauses(clauses []*syntax.CaseClause, tswitch *ir.TypeSwitchGuard, rbrace syntax.Pos) []ir.Node {
-       nodes := make([]ir.Node, 0, len(clauses))
+func (p *noder) caseClauses(clauses []*syntax.CaseClause, tswitch *ir.TypeSwitchGuard, rbrace syntax.Pos) []*ir.CaseStmt {
+       nodes := make([]*ir.CaseStmt, 0, len(clauses))
        for i, clause := range clauses {
                p.setlineno(clause)
                if i > 0 {
@@ -1266,8 +1266,8 @@ func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node {
        return []ir.Node{p.stmt(stmt)}
 }
 
-func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []ir.Node {
-       nodes := make([]ir.Node, len(clauses))
+func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CaseStmt {
+       nodes := make([]*ir.CaseStmt, len(clauses))
        for i, clause := range clauses {
                p.setlineno(clause)
                if i > 0 {
index 0c813a71ef41a160e8159137f664f283a469caa1..19437a069e11f3cb685e7fceb5cb5954c67f2747 100644 (file)
@@ -1181,10 +1181,9 @@ func isNamedTypeSwitch(x ir.Node) bool {
        return ok && guard.Tag != nil
 }
 
-func (w *exportWriter) caseList(cases []ir.Node, namedTypeSwitch bool) {
+func (w *exportWriter) caseList(cases []*ir.CaseStmt, namedTypeSwitch bool) {
        w.uint64(uint64(len(cases)))
        for _, cas := range cases {
-               cas := cas.(*ir.CaseStmt)
                w.pos(cas.Pos())
                w.stmtList(cas.List)
                if namedTypeSwitch {
index 8285c418e9f8a820e5f867d6b5205b5b5965a1b9..fd8314b66217acaedcec1a43bb2ed16f4fed9390 100644 (file)
@@ -767,10 +767,10 @@ func (r *importReader) stmtList() []ir.Node {
        return list
 }
 
-func (r *importReader) caseList(switchExpr ir.Node) []ir.Node {
+func (r *importReader) caseList(switchExpr ir.Node) []*ir.CaseStmt {
        namedTypeSwitch := isNamedTypeSwitch(switchExpr)
 
-       cases := make([]ir.Node, r.uint64())
+       cases := make([]*ir.CaseStmt, r.uint64())
        for i := range cases {
                cas := ir.NewCaseStmt(r.pos(), nil, nil)
                cas.List.Set(r.stmtList())
index 7e74b730bc117a56c34848a7f483c86c5c921e0d..03c3e399eb456eaf511cf341801ec223f9104fdb 100644 (file)
@@ -364,8 +364,6 @@ func tcSelect(sel *ir.SelectStmt) {
        lno := ir.SetPos(sel)
        Stmts(sel.Init())
        for _, ncase := range sel.Cases {
-               ncase := ncase.(*ir.CaseStmt)
-
                if len(ncase.List) == 0 {
                        // default
                        if def != nil {
@@ -508,7 +506,6 @@ func tcSwitchExpr(n *ir.SwitchStmt) {
        var defCase ir.Node
        var cs constSet
        for _, ncase := range n.Cases {
-               ncase := ncase.(*ir.CaseStmt)
                ls := ncase.List
                if len(ls) == 0 { // default:
                        if defCase != nil {
@@ -577,7 +574,6 @@ func tcSwitchType(n *ir.SwitchStmt) {
        var defCase, nilCase ir.Node
        var ts typeSet
        for _, ncase := range n.Cases {
-               ncase := ncase.(*ir.CaseStmt)
                ls := ncase.List
                if len(ls) == 0 { // default:
                        if defCase != nil {
index b779f9ceb0f1403bbca22700f8742e624c21f610..dabfee3bf9ca7dbcf8355dca01bf41f70dcd3cf3 100644 (file)
@@ -2103,7 +2103,6 @@ func isTermNode(n ir.Node) bool {
                }
                def := false
                for _, cas := range n.Cases {
-                       cas := cas.(*ir.CaseStmt)
                        if !isTermNodes(cas.Body) {
                                return false
                        }
@@ -2119,7 +2118,6 @@ func isTermNode(n ir.Node) bool {
                        return false
                }
                for _, cas := range n.Cases {
-                       cas := cas.(*ir.CaseStmt)
                        if !isTermNodes(cas.Body) {
                                return false
                        }
@@ -2218,9 +2216,6 @@ func deadcodeslice(nn *ir.Nodes) {
                case ir.OBLOCK:
                        n := n.(*ir.BlockStmt)
                        deadcodeslice(&n.List)
-               case ir.OCASE:
-                       n := n.(*ir.CaseStmt)
-                       deadcodeslice(&n.Body)
                case ir.OFOR:
                        n := n.(*ir.ForStmt)
                        deadcodeslice(&n.Body)
@@ -2233,10 +2228,14 @@ func deadcodeslice(nn *ir.Nodes) {
                        deadcodeslice(&n.Body)
                case ir.OSELECT:
                        n := n.(*ir.SelectStmt)
-                       deadcodeslice(&n.Cases)
+                       for _, cas := range n.Cases {
+                               deadcodeslice(&cas.Body)
+                       }
                case ir.OSWITCH:
                        n := n.(*ir.SwitchStmt)
-                       deadcodeslice(&n.Cases)
+                       for _, cas := range n.Cases {
+                               deadcodeslice(&cas.Body)
+                       }
                }
 
                if cut {
index 1e41cfc6aaf9ab1dd95255a585cfc0bc3136addd..ebbd4675701451a9c5f603e7556032542c04cee7 100644 (file)
@@ -914,7 +914,6 @@ func (o *orderState) stmt(n ir.Node) {
                n := n.(*ir.SelectStmt)
                t := o.markTemp()
                for _, ncas := range n.Cases {
-                       ncas := ncas.(*ir.CaseStmt)
                        r := ncas.Comm
                        ir.SetPos(ncas)
 
@@ -996,7 +995,6 @@ func (o *orderState) stmt(n ir.Node) {
                // Also insert any ninit queued during the previous loop.
                // (The temporary cleaning must follow that ninit work.)
                for _, cas := range n.Cases {
-                       cas := cas.(*ir.CaseStmt)
                        orderBlock(&cas.Body, o.free)
                        cas.Body.Prepend(o.cleanTempNoPop(t)...)
 
@@ -1036,13 +1034,12 @@ func (o *orderState) stmt(n ir.Node) {
                n := n.(*ir.SwitchStmt)
                if base.Debug.Libfuzzer != 0 && !hasDefaultCase(n) {
                        // Add empty "default:" case for instrumentation.
-                       n.Cases.Append(ir.NewCaseStmt(base.Pos, nil, nil))
+                       n.Cases = append(n.Cases, ir.NewCaseStmt(base.Pos, nil, nil))
                }
 
                t := o.markTemp()
                n.Tag = o.expr(n.Tag, nil)
                for _, ncas := range n.Cases {
-                       ncas := ncas.(*ir.CaseStmt)
                        o.exprListInPlace(ncas.List)
                        orderBlock(&ncas.Body, o.free)
                }
@@ -1056,7 +1053,6 @@ func (o *orderState) stmt(n ir.Node) {
 
 func hasDefaultCase(n *ir.SwitchStmt) bool {
        for _, ncas := range n.Cases {
-               ncas := ncas.(*ir.CaseStmt)
                if len(ncas.List) == 0 {
                        return true
                }
index 5e03732169f6680f677b2e28983d68b9f1492bc7..0b7e7e99fbfb2cd40786aceacef694d8586320a8 100644 (file)
@@ -21,7 +21,7 @@ func walkSelect(sel *ir.SelectStmt) {
        sel.PtrInit().Set(nil)
 
        init = append(init, walkSelectCases(sel.Cases)...)
-       sel.Cases = ir.Nodes{}
+       sel.Cases = nil
 
        sel.Compiled.Set(init)
        walkStmtList(sel.Compiled)
@@ -29,7 +29,7 @@ func walkSelect(sel *ir.SelectStmt) {
        base.Pos = lno
 }
 
-func walkSelectCases(cases ir.Nodes) []ir.Node {
+func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
        ncas := len(cases)
        sellineno := base.Pos
 
@@ -40,7 +40,7 @@ func walkSelectCases(cases ir.Nodes) []ir.Node {
 
        // optimization: one-case select: single op.
        if ncas == 1 {
-               cas := cases[0].(*ir.CaseStmt)
+               cas := cases[0]
                ir.SetPos(cas)
                l := cas.Init()
                if cas.Comm != nil { // not default:
@@ -75,7 +75,6 @@ func walkSelectCases(cases ir.Nodes) []ir.Node {
        // this rewrite is used by both the general code and the next optimization.
        var dflt *ir.CaseStmt
        for _, cas := range cases {
-               cas := cas.(*ir.CaseStmt)
                ir.SetPos(cas)
                n := cas.Comm
                if n == nil {
@@ -99,9 +98,9 @@ func walkSelectCases(cases ir.Nodes) []ir.Node {
 
        // optimization: two-case select but one is default: single non-blocking op.
        if ncas == 2 && dflt != nil {
-               cas := cases[0].(*ir.CaseStmt)
+               cas := cases[0]
                if cas == dflt {
-                       cas = cases[1].(*ir.CaseStmt)
+                       cas = cases[1]
                }
 
                n := cas.Comm
@@ -170,7 +169,6 @@ func walkSelectCases(cases ir.Nodes) []ir.Node {
 
        // register cases
        for _, cas := range cases {
-               cas := cas.(*ir.CaseStmt)
                ir.SetPos(cas)
 
                init = append(init, cas.Init()...)
index 141d2e5e053ffb772d7aab5cbceae61d1729e870..de0b471b34e490366dc83a80a5fbf20e95b4195f 100644 (file)
@@ -71,7 +71,6 @@ func walkSwitchExpr(sw *ir.SwitchStmt) {
        var defaultGoto ir.Node
        var body ir.Nodes
        for _, ncase := range sw.Cases {
-               ncase := ncase.(*ir.CaseStmt)
                label := typecheck.AutoLabel(".s")
                jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
 
@@ -96,7 +95,7 @@ func walkSwitchExpr(sw *ir.SwitchStmt) {
                        body.Append(br)
                }
        }
-       sw.Cases.Set(nil)
+       sw.Cases = nil
 
        if defaultGoto == nil {
                br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
@@ -259,7 +258,6 @@ func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
        // enough.
 
        for _, ncase := range sw.Cases {
-               ncase := ncase.(*ir.CaseStmt)
                for _, v := range ncase.List {
                        if v.Op() != ir.OLITERAL {
                                return false
@@ -325,7 +323,6 @@ func walkSwitchType(sw *ir.SwitchStmt) {
        var defaultGoto, nilGoto ir.Node
        var body ir.Nodes
        for _, ncase := range sw.Cases {
-               ncase := ncase.(*ir.CaseStmt)
                caseVar := ncase.Var
 
                // For single-type cases with an interface type,
@@ -384,7 +381,7 @@ func walkSwitchType(sw *ir.SwitchStmt) {
                body.Append(ncase.Body...)
                body.Append(br)
        }
-       sw.Cases.Set(nil)
+       sw.Cases = nil
 
        if defaultGoto == nil {
                defaultGoto = br