]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.regabi] cmd/compile: cleanup noder
authorMatthew Dempsky <mdempsky@google.com>
Wed, 23 Dec 2020 10:00:39 +0000 (02:00 -0800)
committerMatthew Dempsky <mdempsky@google.com>
Wed, 23 Dec 2020 10:46:45 +0000 (10:46 +0000)
Similar to previous CL: take advantage of better constructor APIs for
translating ASTs from syntax to ir.

Passes toolstash -cmp.

Change-Id: I40970775e7dd5afe2a0b7593ce3bd73237562457
Reviewed-on: https://go-review.googlesource.com/c/go/+/279972
Trust: Matthew Dempsky <mdempsky@google.com>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
src/cmd/compile/internal/noder/noder.go

index a684673c8f490b56061219b41bb24ec0305d9733..c73e2d7fc53843a03963278db744c95a01a63847 100644 (file)
@@ -377,11 +377,7 @@ func (p *noder) importDecl(imp *syntax.ImportDecl) {
 func (p *noder) varDecl(decl *syntax.VarDecl) []ir.Node {
        names := p.declNames(ir.ONAME, decl.NameList)
        typ := p.typeExprOrNil(decl.Type)
-
-       var exprs []ir.Node
-       if decl.Values != nil {
-               exprs = p.exprList(decl.Values)
-       }
+       exprs := p.exprList(decl.Values)
 
        if pragma, ok := decl.Pragma.(*pragmas); ok {
                if len(pragma.Embeds) > 0 {
@@ -620,10 +616,14 @@ func (p *noder) param(param *syntax.Field, dddOk, final bool) *ir.Field {
 }
 
 func (p *noder) exprList(expr syntax.Expr) []ir.Node {
-       if list, ok := expr.(*syntax.ListExpr); ok {
-               return p.exprs(list.ElemList)
+       switch expr := expr.(type) {
+       case nil:
+               return nil
+       case *syntax.ListExpr:
+               return p.exprs(expr.ElemList)
+       default:
+               return []ir.Node{p.expr(expr)}
        }
-       return []ir.Node{p.expr(expr)}
 }
 
 func (p *noder) exprs(exprs []syntax.Expr) []ir.Node {
@@ -642,17 +642,14 @@ func (p *noder) expr(expr syntax.Expr) ir.Node {
        case *syntax.Name:
                return p.mkname(expr)
        case *syntax.BasicLit:
-               n := ir.NewLiteral(p.basicLit(expr))
+               n := ir.NewBasicLit(p.pos(expr), p.basicLit(expr))
                if expr.Kind == syntax.RuneLit {
                        n.SetType(types.UntypedRune)
                }
                n.SetDiag(expr.Bad) // avoid follow-on errors if there was a syntax error
                return n
        case *syntax.CompositeLit:
-               n := ir.NewCompLitExpr(p.pos(expr), ir.OCOMPLIT, nil, nil)
-               if expr.Type != nil {
-                       n.Ntype = ir.Node(p.expr(expr.Type)).(ir.Ntype)
-               }
+               n := ir.NewCompLitExpr(p.pos(expr), ir.OCOMPLIT, p.typeExpr(expr.Type), nil)
                l := p.exprs(expr.ElemList)
                for i, e := range l {
                        l[i] = p.wrapname(expr.ElemList[i], e)
@@ -695,7 +692,7 @@ func (p *noder) expr(expr syntax.Expr) ir.Node {
                n.SetSliceBounds(index[0], index[1], index[2])
                return n
        case *syntax.AssertExpr:
-               return ir.NewTypeAssertExpr(p.pos(expr), p.expr(expr.X), p.typeExpr(expr.Type).(ir.Ntype))
+               return ir.NewTypeAssertExpr(p.pos(expr), p.expr(expr.X), p.typeExpr(expr.Type))
        case *syntax.Operation:
                if expr.Op == syntax.Add && expr.Y != nil {
                        return p.sum(expr)
@@ -719,8 +716,7 @@ func (p *noder) expr(expr syntax.Expr) ir.Node {
                }
                return ir.NewBinaryExpr(pos, op, x, y)
        case *syntax.CallExpr:
-               n := ir.NewCallExpr(p.pos(expr), ir.OCALL, p.expr(expr.Fun), nil)
-               n.Args.Set(p.exprs(expr.ArgList))
+               n := ir.NewCallExpr(p.pos(expr), ir.OCALL, p.expr(expr.Fun), p.exprs(expr.ArgList))
                n.IsDDD = expr.HasDots
                return n
 
@@ -987,7 +983,7 @@ func (p *noder) stmt(stmt syntax.Stmt) ir.Node {
 func (p *noder) stmtFall(stmt syntax.Stmt, fallOK bool) ir.Node {
        p.setlineno(stmt)
        switch stmt := stmt.(type) {
-       case *syntax.EmptyStmt:
+       case nil, *syntax.EmptyStmt:
                return nil
        case *syntax.LabeledStmt:
                return p.labeledStmt(stmt, fallOK)
@@ -1060,12 +1056,7 @@ func (p *noder) stmtFall(stmt syntax.Stmt, fallOK bool) ir.Node {
                }
                return ir.NewGoDeferStmt(p.pos(stmt), op, p.expr(stmt.Call))
        case *syntax.ReturnStmt:
-               var results []ir.Node
-               if stmt.Results != nil {
-                       results = p.exprList(stmt.Results)
-               }
-               n := ir.NewReturnStmt(p.pos(stmt), nil)
-               n.Results.Set(results)
+               n := ir.NewReturnStmt(p.pos(stmt), p.exprList(stmt.Results))
                if len(n.Results) == 0 && ir.CurFunc != nil {
                        for _, ln := range ir.CurFunc.Dcl {
                                if ln.Class_ == ir.PPARAM {
@@ -1159,14 +1150,9 @@ func (p *noder) blockStmt(stmt *syntax.BlockStmt) []ir.Node {
 
 func (p *noder) ifStmt(stmt *syntax.IfStmt) ir.Node {
        p.openScope(stmt.Pos())
-       n := ir.NewIfStmt(p.pos(stmt), nil, nil, nil)
-       if stmt.Init != nil {
-               *n.PtrInit() = []ir.Node{p.stmt(stmt.Init)}
-       }
-       if stmt.Cond != nil {
-               n.Cond = p.expr(stmt.Cond)
-       }
-       n.Body.Set(p.blockStmt(stmt.Then))
+       init := p.simpleStmt(stmt.Init)
+       n := ir.NewIfStmt(p.pos(stmt), p.expr(stmt.Cond), p.blockStmt(stmt.Then), nil)
+       *n.PtrInit() = init
        if stmt.Else != nil {
                e := p.stmt(stmt.Else)
                if e.Op() == ir.OBLOCK {
@@ -1197,30 +1183,17 @@ func (p *noder) forStmt(stmt *syntax.ForStmt) ir.Node {
                return n
        }
 
-       n := ir.NewForStmt(p.pos(stmt), nil, nil, nil, nil)
-       if stmt.Init != nil {
-               *n.PtrInit() = []ir.Node{p.stmt(stmt.Init)}
-       }
-       if stmt.Cond != nil {
-               n.Cond = p.expr(stmt.Cond)
-       }
-       if stmt.Post != nil {
-               n.Post = p.stmt(stmt.Post)
-       }
-       n.Body.Set(p.blockStmt(stmt.Body))
+       n := ir.NewForStmt(p.pos(stmt), p.simpleStmt(stmt.Init), p.expr(stmt.Cond), p.stmt(stmt.Post), p.blockStmt(stmt.Body))
        p.closeAnotherScope()
        return n
 }
 
 func (p *noder) switchStmt(stmt *syntax.SwitchStmt) ir.Node {
        p.openScope(stmt.Pos())
-       n := ir.NewSwitchStmt(p.pos(stmt), nil, nil)
-       if stmt.Init != nil {
-               *n.PtrInit() = []ir.Node{p.stmt(stmt.Init)}
-       }
-       if stmt.Tag != nil {
-               n.Tag = p.expr(stmt.Tag)
-       }
+
+       init := p.simpleStmt(stmt.Init)
+       n := ir.NewSwitchStmt(p.pos(stmt), p.expr(stmt.Tag), nil)
+       *n.PtrInit() = init
 
        var tswitch *ir.TypeSwitchGuard
        if l := n.Tag; l != nil && l.Op() == ir.OTYPESW {
@@ -1241,10 +1214,7 @@ func (p *noder) caseClauses(clauses []*syntax.CaseClause, tswitch *ir.TypeSwitch
                }
                p.openScope(clause.Pos())
 
-               n := ir.NewCaseStmt(p.pos(clause), nil, nil)
-               if clause.Cases != nil {
-                       n.List.Set(p.exprList(clause.Cases))
-               }
+               n := ir.NewCaseStmt(p.pos(clause), p.exprList(clause.Cases), nil)
                if tswitch != nil && tswitch.Tag != nil {
                        nn := typecheck.NewName(tswitch.Tag.Sym())
                        typecheck.Declare(nn, typecheck.DeclContext)
@@ -1283,13 +1253,18 @@ func (p *noder) caseClauses(clauses []*syntax.CaseClause, tswitch *ir.TypeSwitch
 }
 
 func (p *noder) selectStmt(stmt *syntax.SelectStmt) ir.Node {
-       n := ir.NewSelectStmt(p.pos(stmt), nil)
-       n.Cases.Set(p.commClauses(stmt.Body, stmt.Rbrace))
-       return n
+       return ir.NewSelectStmt(p.pos(stmt), p.commClauses(stmt.Body, stmt.Rbrace))
+}
+
+func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node {
+       if stmt == nil {
+               return nil
+       }
+       return []ir.Node{p.stmt(stmt)}
 }
 
 func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []ir.Node {
-       nodes := make([]ir.Node, 0, len(clauses))
+       nodes := make([]ir.Node, len(clauses))
        for i, clause := range clauses {
                p.setlineno(clause)
                if i > 0 {
@@ -1297,12 +1272,7 @@ func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []i
                }
                p.openScope(clause.Pos())
 
-               n := ir.NewCaseStmt(p.pos(clause), nil, nil)
-               if clause.Comm != nil {
-                       n.List = []ir.Node{p.stmt(clause.Comm)}
-               }
-               n.Body.Set(p.stmts(clause.Body))
-               nodes = append(nodes, n)
+               nodes[i] = ir.NewCaseStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body))
        }
        if len(clauses) > 0 {
                p.closeScope(rbrace)