]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.regabi] cmd/compile: separate CommStmt from CaseStmt
authorMatthew Dempsky <mdempsky@google.com>
Sun, 27 Dec 2020 06:42:17 +0000 (22:42 -0800)
committerMatthew Dempsky <mdempsky@google.com>
Mon, 28 Dec 2020 08:01:27 +0000 (08:01 +0000)
Like go/ast and cmd/compile/internal/syntax before it, package ir now
has separate concrete representations for switch-case clauses and
select-communication clauses.

Passes toolstash -cmp.

Change-Id: I32667cbae251fe7881be0f434388478433b2414f
Reviewed-on: https://go-review.googlesource.com/c/go/+/280443
Trust: Matthew Dempsky <mdempsky@google.com>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
src/cmd/compile/internal/ir/mknode.go
src/cmd/compile/internal/ir/node_gen.go
src/cmd/compile/internal/ir/stmt.go
src/cmd/compile/internal/noder/noder.go
src/cmd/compile/internal/typecheck/iexport.go
src/cmd/compile/internal/typecheck/iimport.go
src/cmd/compile/internal/walk/select.go

index edf3ee501c1ae78fcb29d948e0c4b309552653c3..bc6fa3cd305dfd3c9db2747e864d0669b0ae38ef 100644 (file)
@@ -38,6 +38,7 @@ func main() {
        ntypeType := lookup("Ntype")
        nodesType := lookup("Nodes")
        slicePtrCaseStmtType := types.NewSlice(types.NewPointer(lookup("CaseStmt")))
+       slicePtrCommStmtType := types.NewSlice(types.NewPointer(lookup("CommStmt")))
        ptrFieldType := types.NewPointer(lookup("Field"))
        slicePtrFieldType := types.NewSlice(ptrFieldType)
        ptrIdentType := types.NewPointer(lookup("Ident"))
@@ -79,6 +80,8 @@ func main() {
                                        fmt.Fprintf(&buf, "c.%s = c.%s.Copy()\n", name, name)
                                case is(slicePtrCaseStmtType):
                                        fmt.Fprintf(&buf, "c.%s = copyCases(c.%s)\n", name, name)
+                               case is(slicePtrCommStmtType):
+                                       fmt.Fprintf(&buf, "c.%s = copyComms(c.%s)\n", name, name)
                                case is(ptrFieldType):
                                        fmt.Fprintf(&buf, "if c.%s != nil { c.%s = c.%s.copy() }\n", name, name, name)
                                case is(slicePtrFieldType):
@@ -99,6 +102,8 @@ func main() {
                                fmt.Fprintf(&buf, "err = maybeDoList(n.%s, err, do)\n", name)
                        case is(slicePtrCaseStmtType):
                                fmt.Fprintf(&buf, "err = maybeDoCases(n.%s, err, do)\n", name)
+                       case is(slicePtrCommStmtType):
+                               fmt.Fprintf(&buf, "err = maybeDoComms(n.%s, err, do)\n", name)
                        case is(ptrFieldType):
                                fmt.Fprintf(&buf, "err = maybeDoField(n.%s, err, do)\n", name)
                        case is(slicePtrFieldType):
@@ -120,6 +125,8 @@ func main() {
                                fmt.Fprintf(&buf, "editList(n.%s, edit)\n", name)
                        case is(slicePtrCaseStmtType):
                                fmt.Fprintf(&buf, "editCases(n.%s, edit)\n", name)
+                       case is(slicePtrCommStmtType):
+                               fmt.Fprintf(&buf, "editComms(n.%s, edit)\n", name)
                        case is(ptrFieldType):
                                fmt.Fprintf(&buf, "editField(n.%s, edit)\n", name)
                        case is(slicePtrFieldType):
index 041855bbe93044b8b4f647a39f47aebcc1aa2583..5796544b4843fe9d4e955c2f7318d061f462e6db 100644 (file)
@@ -239,7 +239,6 @@ func (n *CaseStmt) doChildren(do func(Node) error) error {
        err = maybeDoList(n.init, err, do)
        err = maybeDo(n.Var, err, do)
        err = maybeDoList(n.List, err, do)
-       err = maybeDo(n.Comm, err, do)
        err = maybeDoList(n.Body, err, do)
        return err
 }
@@ -247,7 +246,6 @@ func (n *CaseStmt) editChildren(edit func(Node) Node) {
        editList(n.init, edit)
        n.Var = maybeEdit(n.Var, edit)
        editList(n.List, edit)
-       n.Comm = maybeEdit(n.Comm, edit)
        editList(n.Body, edit)
 }
 
@@ -295,6 +293,29 @@ func (n *ClosureReadExpr) editChildren(edit func(Node) Node) {
        editList(n.init, edit)
 }
 
+func (n *CommStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
+func (n *CommStmt) copy() Node {
+       c := *n
+       c.init = c.init.Copy()
+       c.List = c.List.Copy()
+       c.Body = c.Body.Copy()
+       return &c
+}
+func (n *CommStmt) doChildren(do func(Node) error) error {
+       var err error
+       err = maybeDoList(n.init, err, do)
+       err = maybeDoList(n.List, err, do)
+       err = maybeDo(n.Comm, err, do)
+       err = maybeDoList(n.Body, err, do)
+       return err
+}
+func (n *CommStmt) editChildren(edit func(Node) Node) {
+       editList(n.init, edit)
+       editList(n.List, edit)
+       n.Comm = maybeEdit(n.Comm, edit)
+       editList(n.Body, edit)
+}
+
 func (n *CompLitExpr) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
 func (n *CompLitExpr) copy() Node {
        c := *n
@@ -781,20 +802,20 @@ func (n *SelectStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
 func (n *SelectStmt) copy() Node {
        c := *n
        c.init = c.init.Copy()
-       c.Cases = copyCases(c.Cases)
+       c.Cases = copyComms(c.Cases)
        c.Compiled = c.Compiled.Copy()
        return &c
 }
 func (n *SelectStmt) doChildren(do func(Node) error) error {
        var err error
        err = maybeDoList(n.init, err, do)
-       err = maybeDoCases(n.Cases, err, do)
+       err = maybeDoComms(n.Cases, err, do)
        err = maybeDoList(n.Compiled, err, do)
        return err
 }
 func (n *SelectStmt) editChildren(edit func(Node) Node) {
        editList(n.init, edit)
-       editCases(n.Cases, edit)
+       editComms(n.Cases, edit)
        editList(n.Compiled, edit)
 }
 
index ce775a8529ef36497591365c9d78d5742677f622..181a0fd582a911c4a6eb5d3df273b6f4777bf56b 100644 (file)
@@ -178,19 +178,17 @@ type CaseStmt struct {
        miniStmt
        Var  Node  // declared variable for this case in type switch
        List Nodes // list of expressions for switch, early select
-       Comm Node  // communication case (Exprs[0]) after select is type-checked
        Body Nodes
 }
 
 func NewCaseStmt(pos src.XPos, list, body []Node) *CaseStmt {
-       n := &CaseStmt{}
+       n := &CaseStmt{List: list, Body: body}
        n.pos = pos
        n.op = OCASE
-       n.List.Set(list)
-       n.Body.Set(body)
        return n
 }
 
+// TODO(mdempsky): Generate these with mknode.go.
 func copyCases(list []*CaseStmt) []*CaseStmt {
        if list == nil {
                return nil
@@ -199,7 +197,6 @@ func copyCases(list []*CaseStmt) []*CaseStmt {
        copy(c, list)
        return c
 }
-
 func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error {
        if err != nil {
                return err
@@ -213,7 +210,6 @@ func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error {
        }
        return nil
 }
-
 func editCases(list []*CaseStmt, edit func(Node) Node) {
        for i, x := range list {
                if x != nil {
@@ -222,6 +218,50 @@ func editCases(list []*CaseStmt, edit func(Node) Node) {
        }
 }
 
+type CommStmt struct {
+       miniStmt
+       List Nodes // list of expressions for switch, early select
+       Comm Node  // communication case (Exprs[0]) after select is type-checked
+       Body Nodes
+}
+
+func NewCommStmt(pos src.XPos, list, body []Node) *CommStmt {
+       n := &CommStmt{List: list, Body: body}
+       n.pos = pos
+       n.op = OCASE
+       return n
+}
+
+// TODO(mdempsky): Generate these with mknode.go.
+func copyComms(list []*CommStmt) []*CommStmt {
+       if list == nil {
+               return nil
+       }
+       c := make([]*CommStmt, len(list))
+       copy(c, list)
+       return c
+}
+func maybeDoComms(list []*CommStmt, err error, do func(Node) error) error {
+       if err != nil {
+               return err
+       }
+       for _, x := range list {
+               if x != nil {
+                       if err := do(x); err != nil {
+                               return err
+                       }
+               }
+       }
+       return nil
+}
+func editComms(list []*CommStmt, edit func(Node) Node) {
+       for i, x := range list {
+               if x != nil {
+                       list[i] = edit(x).(*CommStmt)
+               }
+       }
+}
+
 // A ForStmt is a non-range for loop: for Init; Cond; Post { Body }
 // Op can be OFOR or OFORUNTIL (!Cond).
 type ForStmt struct {
@@ -365,18 +405,17 @@ func (n *ReturnStmt) SetOrig(x Node) { n.orig = x }
 type SelectStmt struct {
        miniStmt
        Label    *types.Sym
-       Cases    []*CaseStmt
+       Cases    []*CommStmt
        HasBreak bool
 
        // TODO(rsc): Instead of recording here, replace with a block?
        Compiled Nodes // compiled form, after walkswitch
 }
 
-func NewSelectStmt(pos src.XPos, cases []*CaseStmt) *SelectStmt {
-       n := &SelectStmt{}
+func NewSelectStmt(pos src.XPos, cases []*CommStmt) *SelectStmt {
+       n := &SelectStmt{Cases: cases}
        n.pos = pos
        n.op = OSELECT
-       n.Cases = cases
        return n
 }
 
@@ -407,10 +446,9 @@ type SwitchStmt struct {
 }
 
 func NewSwitchStmt(pos src.XPos, tag Node, cases []*CaseStmt) *SwitchStmt {
-       n := &SwitchStmt{Tag: tag}
+       n := &SwitchStmt{Tag: tag, Cases: cases}
        n.pos = pos
        n.op = OSWITCH
-       n.Cases = cases
        return n
 }
 
index b974448338f3add4d399b0a162ef7fc5fde48d6c..ff699cd54d4a605af1ebbace036ee277112885c7 100644 (file)
@@ -1266,8 +1266,8 @@ func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node {
        return []ir.Node{p.stmt(stmt)}
 }
 
-func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CaseStmt {
-       nodes := make([]*ir.CaseStmt, len(clauses))
+func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CommStmt {
+       nodes := make([]*ir.CommStmt, len(clauses))
        for i, clause := range clauses {
                p.setlineno(clause)
                if i > 0 {
@@ -1275,7 +1275,7 @@ func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*
                }
                p.openScope(clause.Pos())
 
-               nodes[i] = ir.NewCaseStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body))
+               nodes[i] = ir.NewCommStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body))
        }
        if len(clauses) > 0 {
                p.closeScope(rbrace)
index 19437a069e11f3cb685e7fceb5cb5954c67f2747..ef2c4527a934ea176e148422bc708ca4886bf96a 100644 (file)
@@ -1144,7 +1144,7 @@ func (w *exportWriter) stmt(n ir.Node) {
                w.op(n.Op())
                w.pos(n.Pos())
                w.stmtList(n.Init())
-               w.caseList(n.Cases, false)
+               w.commList(n.Cases)
 
        case ir.OSWITCH:
                n := n.(*ir.SwitchStmt)
@@ -1193,6 +1193,15 @@ func (w *exportWriter) caseList(cases []*ir.CaseStmt, namedTypeSwitch bool) {
        }
 }
 
+func (w *exportWriter) commList(cases []*ir.CommStmt) {
+       w.uint64(uint64(len(cases)))
+       for _, cas := range cases {
+               w.pos(cas.Pos())
+               w.stmtList(cas.List)
+               w.stmtList(cas.Body)
+       }
+}
+
 func (w *exportWriter) exprList(list ir.Nodes) {
        for _, n := range list {
                w.expr(n)
index fd8314b66217acaedcec1a43bb2ed16f4fed9390..ba7ea2f156afaedae206a74dd49e057eff3a7a94 100644 (file)
@@ -789,6 +789,14 @@ func (r *importReader) caseList(switchExpr ir.Node) []*ir.CaseStmt {
        return cases
 }
 
+func (r *importReader) commList() []*ir.CommStmt {
+       cases := make([]*ir.CommStmt, r.uint64())
+       for i := range cases {
+               cases[i] = ir.NewCommStmt(r.pos(), r.stmtList(), r.stmtList())
+       }
+       return cases
+}
+
 func (r *importReader) exprList() []ir.Node {
        var list []ir.Node
        for {
@@ -1035,7 +1043,7 @@ func (r *importReader) node() ir.Node {
        case ir.OSELECT:
                pos := r.pos()
                init := r.stmtList()
-               n := ir.NewSelectStmt(pos, r.caseList(nil))
+               n := ir.NewSelectStmt(pos, r.commList())
                n.PtrInit().Set(init)
                return n
 
index 0b7e7e99fbfb2cd40786aceacef694d8586320a8..f51684c9b6681716c5bc846dd96dd6a616b81750 100644 (file)
@@ -29,7 +29,7 @@ func walkSelect(sel *ir.SelectStmt) {
        base.Pos = lno
 }
 
-func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
+func walkSelectCases(cases []*ir.CommStmt) []ir.Node {
        ncas := len(cases)
        sellineno := base.Pos
 
@@ -73,7 +73,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
 
        // convert case value arguments to addresses.
        // this rewrite is used by both the general code and the next optimization.
-       var dflt *ir.CaseStmt
+       var dflt *ir.CommStmt
        for _, cas := range cases {
                ir.SetPos(cas)
                n := cas.Comm
@@ -146,7 +146,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
        if dflt != nil {
                ncas--
        }
-       casorder := make([]*ir.CaseStmt, ncas)
+       casorder := make([]*ir.CommStmt, ncas)
        nsends, nrecvs := 0, 0
 
        var init []ir.Node
@@ -242,7 +242,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
        }
 
        // dispatch cases
-       dispatch := func(cond ir.Node, cas *ir.CaseStmt) {
+       dispatch := func(cond ir.Node, cas *ir.CommStmt) {
                cond = typecheck.Expr(cond)
                cond = typecheck.DefaultLit(cond, nil)