From 2ecf52b841cd48e76df1fe721d29a972c22bf93f Mon Sep 17 00:00:00 2001 From: Matthew Dempsky Date: Sat, 26 Dec 2020 22:42:17 -0800 Subject: [PATCH] [dev.regabi] cmd/compile: separate CommStmt from CaseStmt Like go/ast and cmd/compile/internal/syntax before it, package ir now has separate concrete representations for switch-case clauses and select-communication clauses. Passes toolstash -cmp. Change-Id: I32667cbae251fe7881be0f434388478433b2414f Reviewed-on: https://go-review.googlesource.com/c/go/+/280443 Trust: Matthew Dempsky Run-TryBot: Matthew Dempsky TryBot-Result: Go Bot Reviewed-by: Cuong Manh Le --- src/cmd/compile/internal/ir/mknode.go | 7 +++ src/cmd/compile/internal/ir/node_gen.go | 31 ++++++++-- src/cmd/compile/internal/ir/stmt.go | 62 +++++++++++++++---- src/cmd/compile/internal/noder/noder.go | 6 +- src/cmd/compile/internal/typecheck/iexport.go | 11 +++- src/cmd/compile/internal/typecheck/iimport.go | 10 ++- src/cmd/compile/internal/walk/select.go | 8 +-- 7 files changed, 109 insertions(+), 26 deletions(-) diff --git a/src/cmd/compile/internal/ir/mknode.go b/src/cmd/compile/internal/ir/mknode.go index edf3ee501c..bc6fa3cd30 100644 --- a/src/cmd/compile/internal/ir/mknode.go +++ b/src/cmd/compile/internal/ir/mknode.go @@ -38,6 +38,7 @@ func main() { ntypeType := lookup("Ntype") nodesType := lookup("Nodes") slicePtrCaseStmtType := types.NewSlice(types.NewPointer(lookup("CaseStmt"))) + slicePtrCommStmtType := types.NewSlice(types.NewPointer(lookup("CommStmt"))) ptrFieldType := types.NewPointer(lookup("Field")) slicePtrFieldType := types.NewSlice(ptrFieldType) ptrIdentType := types.NewPointer(lookup("Ident")) @@ -79,6 +80,8 @@ func main() { fmt.Fprintf(&buf, "c.%s = c.%s.Copy()\n", name, name) case is(slicePtrCaseStmtType): fmt.Fprintf(&buf, "c.%s = copyCases(c.%s)\n", name, name) + case is(slicePtrCommStmtType): + fmt.Fprintf(&buf, "c.%s = copyComms(c.%s)\n", name, name) case is(ptrFieldType): fmt.Fprintf(&buf, "if c.%s != nil { c.%s = c.%s.copy() }\n", name, name, name) case is(slicePtrFieldType): @@ -99,6 +102,8 @@ func main() { fmt.Fprintf(&buf, "err = maybeDoList(n.%s, err, do)\n", name) case is(slicePtrCaseStmtType): fmt.Fprintf(&buf, "err = maybeDoCases(n.%s, err, do)\n", name) + case is(slicePtrCommStmtType): + fmt.Fprintf(&buf, "err = maybeDoComms(n.%s, err, do)\n", name) case is(ptrFieldType): fmt.Fprintf(&buf, "err = maybeDoField(n.%s, err, do)\n", name) case is(slicePtrFieldType): @@ -120,6 +125,8 @@ func main() { fmt.Fprintf(&buf, "editList(n.%s, edit)\n", name) case is(slicePtrCaseStmtType): fmt.Fprintf(&buf, "editCases(n.%s, edit)\n", name) + case is(slicePtrCommStmtType): + fmt.Fprintf(&buf, "editComms(n.%s, edit)\n", name) case is(ptrFieldType): fmt.Fprintf(&buf, "editField(n.%s, edit)\n", name) case is(slicePtrFieldType): diff --git a/src/cmd/compile/internal/ir/node_gen.go b/src/cmd/compile/internal/ir/node_gen.go index 041855bbe9..5796544b48 100644 --- a/src/cmd/compile/internal/ir/node_gen.go +++ b/src/cmd/compile/internal/ir/node_gen.go @@ -239,7 +239,6 @@ func (n *CaseStmt) doChildren(do func(Node) error) error { err = maybeDoList(n.init, err, do) err = maybeDo(n.Var, err, do) err = maybeDoList(n.List, err, do) - err = maybeDo(n.Comm, err, do) err = maybeDoList(n.Body, err, do) return err } @@ -247,7 +246,6 @@ func (n *CaseStmt) editChildren(edit func(Node) Node) { editList(n.init, edit) n.Var = maybeEdit(n.Var, edit) editList(n.List, edit) - n.Comm = maybeEdit(n.Comm, edit) editList(n.Body, edit) } @@ -295,6 +293,29 @@ func (n *ClosureReadExpr) editChildren(edit func(Node) Node) { editList(n.init, edit) } +func (n *CommStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } +func (n *CommStmt) copy() Node { + c := *n + c.init = c.init.Copy() + c.List = c.List.Copy() + c.Body = c.Body.Copy() + return &c +} +func (n *CommStmt) doChildren(do func(Node) error) error { + var err error + err = maybeDoList(n.init, err, do) + err = maybeDoList(n.List, err, do) + err = maybeDo(n.Comm, err, do) + err = maybeDoList(n.Body, err, do) + return err +} +func (n *CommStmt) editChildren(edit func(Node) Node) { + editList(n.init, edit) + editList(n.List, edit) + n.Comm = maybeEdit(n.Comm, edit) + editList(n.Body, edit) +} + func (n *CompLitExpr) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } func (n *CompLitExpr) copy() Node { c := *n @@ -781,20 +802,20 @@ func (n *SelectStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } func (n *SelectStmt) copy() Node { c := *n c.init = c.init.Copy() - c.Cases = copyCases(c.Cases) + c.Cases = copyComms(c.Cases) c.Compiled = c.Compiled.Copy() return &c } func (n *SelectStmt) doChildren(do func(Node) error) error { var err error err = maybeDoList(n.init, err, do) - err = maybeDoCases(n.Cases, err, do) + err = maybeDoComms(n.Cases, err, do) err = maybeDoList(n.Compiled, err, do) return err } func (n *SelectStmt) editChildren(edit func(Node) Node) { editList(n.init, edit) - editCases(n.Cases, edit) + editComms(n.Cases, edit) editList(n.Compiled, edit) } diff --git a/src/cmd/compile/internal/ir/stmt.go b/src/cmd/compile/internal/ir/stmt.go index ce775a8529..181a0fd582 100644 --- a/src/cmd/compile/internal/ir/stmt.go +++ b/src/cmd/compile/internal/ir/stmt.go @@ -178,19 +178,17 @@ type CaseStmt struct { miniStmt Var Node // declared variable for this case in type switch List Nodes // list of expressions for switch, early select - Comm Node // communication case (Exprs[0]) after select is type-checked Body Nodes } func NewCaseStmt(pos src.XPos, list, body []Node) *CaseStmt { - n := &CaseStmt{} + n := &CaseStmt{List: list, Body: body} n.pos = pos n.op = OCASE - n.List.Set(list) - n.Body.Set(body) return n } +// TODO(mdempsky): Generate these with mknode.go. func copyCases(list []*CaseStmt) []*CaseStmt { if list == nil { return nil @@ -199,7 +197,6 @@ func copyCases(list []*CaseStmt) []*CaseStmt { copy(c, list) return c } - func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error { if err != nil { return err @@ -213,7 +210,6 @@ func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error { } return nil } - func editCases(list []*CaseStmt, edit func(Node) Node) { for i, x := range list { if x != nil { @@ -222,6 +218,50 @@ func editCases(list []*CaseStmt, edit func(Node) Node) { } } +type CommStmt struct { + miniStmt + List Nodes // list of expressions for switch, early select + Comm Node // communication case (Exprs[0]) after select is type-checked + Body Nodes +} + +func NewCommStmt(pos src.XPos, list, body []Node) *CommStmt { + n := &CommStmt{List: list, Body: body} + n.pos = pos + n.op = OCASE + return n +} + +// TODO(mdempsky): Generate these with mknode.go. +func copyComms(list []*CommStmt) []*CommStmt { + if list == nil { + return nil + } + c := make([]*CommStmt, len(list)) + copy(c, list) + return c +} +func maybeDoComms(list []*CommStmt, err error, do func(Node) error) error { + if err != nil { + return err + } + for _, x := range list { + if x != nil { + if err := do(x); err != nil { + return err + } + } + } + return nil +} +func editComms(list []*CommStmt, edit func(Node) Node) { + for i, x := range list { + if x != nil { + list[i] = edit(x).(*CommStmt) + } + } +} + // A ForStmt is a non-range for loop: for Init; Cond; Post { Body } // Op can be OFOR or OFORUNTIL (!Cond). type ForStmt struct { @@ -365,18 +405,17 @@ func (n *ReturnStmt) SetOrig(x Node) { n.orig = x } type SelectStmt struct { miniStmt Label *types.Sym - Cases []*CaseStmt + Cases []*CommStmt HasBreak bool // TODO(rsc): Instead of recording here, replace with a block? Compiled Nodes // compiled form, after walkswitch } -func NewSelectStmt(pos src.XPos, cases []*CaseStmt) *SelectStmt { - n := &SelectStmt{} +func NewSelectStmt(pos src.XPos, cases []*CommStmt) *SelectStmt { + n := &SelectStmt{Cases: cases} n.pos = pos n.op = OSELECT - n.Cases = cases return n } @@ -407,10 +446,9 @@ type SwitchStmt struct { } func NewSwitchStmt(pos src.XPos, tag Node, cases []*CaseStmt) *SwitchStmt { - n := &SwitchStmt{Tag: tag} + n := &SwitchStmt{Tag: tag, Cases: cases} n.pos = pos n.op = OSWITCH - n.Cases = cases return n } diff --git a/src/cmd/compile/internal/noder/noder.go b/src/cmd/compile/internal/noder/noder.go index b974448338..ff699cd54d 100644 --- a/src/cmd/compile/internal/noder/noder.go +++ b/src/cmd/compile/internal/noder/noder.go @@ -1266,8 +1266,8 @@ func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node { return []ir.Node{p.stmt(stmt)} } -func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CaseStmt { - nodes := make([]*ir.CaseStmt, len(clauses)) +func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CommStmt { + nodes := make([]*ir.CommStmt, len(clauses)) for i, clause := range clauses { p.setlineno(clause) if i > 0 { @@ -1275,7 +1275,7 @@ func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []* } p.openScope(clause.Pos()) - nodes[i] = ir.NewCaseStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body)) + nodes[i] = ir.NewCommStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body)) } if len(clauses) > 0 { p.closeScope(rbrace) diff --git a/src/cmd/compile/internal/typecheck/iexport.go b/src/cmd/compile/internal/typecheck/iexport.go index 19437a069e..ef2c4527a9 100644 --- a/src/cmd/compile/internal/typecheck/iexport.go +++ b/src/cmd/compile/internal/typecheck/iexport.go @@ -1144,7 +1144,7 @@ func (w *exportWriter) stmt(n ir.Node) { w.op(n.Op()) w.pos(n.Pos()) w.stmtList(n.Init()) - w.caseList(n.Cases, false) + w.commList(n.Cases) case ir.OSWITCH: n := n.(*ir.SwitchStmt) @@ -1193,6 +1193,15 @@ func (w *exportWriter) caseList(cases []*ir.CaseStmt, namedTypeSwitch bool) { } } +func (w *exportWriter) commList(cases []*ir.CommStmt) { + w.uint64(uint64(len(cases))) + for _, cas := range cases { + w.pos(cas.Pos()) + w.stmtList(cas.List) + w.stmtList(cas.Body) + } +} + func (w *exportWriter) exprList(list ir.Nodes) { for _, n := range list { w.expr(n) diff --git a/src/cmd/compile/internal/typecheck/iimport.go b/src/cmd/compile/internal/typecheck/iimport.go index fd8314b662..ba7ea2f156 100644 --- a/src/cmd/compile/internal/typecheck/iimport.go +++ b/src/cmd/compile/internal/typecheck/iimport.go @@ -789,6 +789,14 @@ func (r *importReader) caseList(switchExpr ir.Node) []*ir.CaseStmt { return cases } +func (r *importReader) commList() []*ir.CommStmt { + cases := make([]*ir.CommStmt, r.uint64()) + for i := range cases { + cases[i] = ir.NewCommStmt(r.pos(), r.stmtList(), r.stmtList()) + } + return cases +} + func (r *importReader) exprList() []ir.Node { var list []ir.Node for { @@ -1035,7 +1043,7 @@ func (r *importReader) node() ir.Node { case ir.OSELECT: pos := r.pos() init := r.stmtList() - n := ir.NewSelectStmt(pos, r.caseList(nil)) + n := ir.NewSelectStmt(pos, r.commList()) n.PtrInit().Set(init) return n diff --git a/src/cmd/compile/internal/walk/select.go b/src/cmd/compile/internal/walk/select.go index 0b7e7e99fb..f51684c9b6 100644 --- a/src/cmd/compile/internal/walk/select.go +++ b/src/cmd/compile/internal/walk/select.go @@ -29,7 +29,7 @@ func walkSelect(sel *ir.SelectStmt) { base.Pos = lno } -func walkSelectCases(cases []*ir.CaseStmt) []ir.Node { +func walkSelectCases(cases []*ir.CommStmt) []ir.Node { ncas := len(cases) sellineno := base.Pos @@ -73,7 +73,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node { // convert case value arguments to addresses. // this rewrite is used by both the general code and the next optimization. - var dflt *ir.CaseStmt + var dflt *ir.CommStmt for _, cas := range cases { ir.SetPos(cas) n := cas.Comm @@ -146,7 +146,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node { if dflt != nil { ncas-- } - casorder := make([]*ir.CaseStmt, ncas) + casorder := make([]*ir.CommStmt, ncas) nsends, nrecvs := 0, 0 var init []ir.Node @@ -242,7 +242,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node { } // dispatch cases - dispatch := func(cond ir.Node, cas *ir.CaseStmt) { + dispatch := func(cond ir.Node, cas *ir.CommStmt) { cond = typecheck.Expr(cond) cond = typecheck.DefaultLit(cond, nil) -- 2.48.1