]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.regabi] cmd/compile: cleanup for concrete types - inl
authorRuss Cox <rsc@golang.org>
Thu, 10 Dec 2020 23:46:45 +0000 (18:46 -0500)
committerRuss Cox <rsc@golang.org>
Thu, 17 Dec 2020 04:43:21 +0000 (04:43 +0000)
An automated rewrite will add concrete type assertions after
a test of n.Op(), when n can be safely type-asserted
(meaning, n is not reassigned a different type, n is not reassigned
and then used outside the scope of the type assertion,
and so on).

This sequence of CLs handles the code that the automated
rewrite does not: adding specific types to function arguments,
adjusting code not to call n.Left() etc when n may have multiple
representations, and so on.

This CL focuses on inl.go.

Passes buildall w/ toolstash -cmp.

Change-Id: Iaaee7664cd43e264d9e49d252e3afa7cf719939b
Reviewed-on: https://go-review.googlesource.com/c/go/+/277926
Trust: Russ Cox <rsc@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/cmd/compile/internal/gc/inl.go

index 3a19efd3252418304542a0cf1c5a14885ed7855e..e940e416fd3eb748d0f588ae4634f283bbada136 100644 (file)
@@ -320,22 +320,26 @@ func (v *hairyVisitor) doNode(n ir.Node) error {
        switch n.Op() {
        // Call is okay if inlinable and we have the budget for the body.
        case ir.OCALLFUNC:
+               n := n.(*ir.CallExpr)
                // Functions that call runtime.getcaller{pc,sp} can not be inlined
                // because getcaller{pc,sp} expect a pointer to the caller's first argument.
                //
                // runtime.throw is a "cheap call" like panic in normal code.
-               if n.Left().Op() == ir.ONAME && n.Left().Class() == ir.PFUNC && isRuntimePkg(n.Left().Sym().Pkg) {
-                       fn := n.Left().Sym().Name
-                       if fn == "getcallerpc" || fn == "getcallersp" {
-                               return errors.New("call to " + fn)
-                       }
-                       if fn == "throw" {
-                               v.budget -= inlineExtraThrowCost
-                               break
+               if n.Left().Op() == ir.ONAME {
+                       name := n.Left().(*ir.Name)
+                       if name.Class() == ir.PFUNC && isRuntimePkg(name.Sym().Pkg) {
+                               fn := name.Sym().Name
+                               if fn == "getcallerpc" || fn == "getcallersp" {
+                                       return errors.New("call to " + fn)
+                               }
+                               if fn == "throw" {
+                                       v.budget -= inlineExtraThrowCost
+                                       break
+                               }
                        }
                }
 
-               if isIntrinsicCall(n.(*ir.CallExpr)) {
+               if isIntrinsicCall(n) {
                        // Treat like any other node.
                        break
                }
@@ -401,11 +405,15 @@ func (v *hairyVisitor) doNode(n ir.Node) error {
                // These nodes don't produce code; omit from inlining budget.
                return nil
 
-       case ir.OFOR, ir.OFORUNTIL, ir.OSWITCH:
-               // ORANGE, OSELECT in "unhandled" above
+       case ir.OFOR, ir.OFORUNTIL:
                if n.Sym() != nil {
                        return errors.New("labeled control")
                }
+       case ir.OSWITCH:
+               if n.Sym() != nil {
+                       return errors.New("labeled control")
+               }
+       // case ir.ORANGE, ir.OSELECT in "unhandled" above
 
        case ir.OBREAK, ir.OCONTINUE:
                if n.Sym() != nil {
@@ -488,7 +496,7 @@ func inlcalls(fn *ir.Func) {
 }
 
 // Turn an OINLCALL into a statement.
-func inlconv2stmt(inlcall ir.Node) ir.Node {
+func inlconv2stmt(inlcall *ir.InlinedCallExpr) ir.Node {
        n := ir.NodAt(inlcall.Pos(), ir.OBLOCK, nil, nil)
        n.SetList(inlcall.Init())
        n.PtrList().AppendNodes(inlcall.PtrBody())
@@ -498,7 +506,7 @@ func inlconv2stmt(inlcall ir.Node) ir.Node {
 // Turn an OINLCALL into a single valued expression.
 // The result of inlconv2expr MUST be assigned back to n, e.g.
 //     n.Left = inlconv2expr(n.Left)
-func inlconv2expr(n ir.Node) ir.Node {
+func inlconv2expr(n *ir.InlinedCallExpr) ir.Node {
        r := n.Rlist().First()
        return initExpr(append(n.Init().Slice(), n.Body().Slice()...), r)
 }
@@ -508,7 +516,7 @@ func inlconv2expr(n ir.Node) ir.Node {
 // containing the inlined statements on the first list element so
 // order will be preserved. Used in return, oas2func and call
 // statements.
-func inlconv2list(n ir.Node) []ir.Node {
+func inlconv2list(n *ir.InlinedCallExpr) []ir.Node {
        if n.Op() != ir.OINLCALL || n.Rlist().Len() == 0 {
                base.Fatalf("inlconv2list %+v\n", n)
        }
@@ -538,9 +546,9 @@ func inlnode(n ir.Node, maxCost int32, inlMap map[*ir.Func]bool, edit func(ir.No
 
        switch n.Op() {
        case ir.ODEFER, ir.OGO:
-               switch n.Left().Op() {
+               switch call := n.Left(); call.Op() {
                case ir.OCALLFUNC, ir.OCALLMETH:
-                       n.Left().SetNoInline(true)
+                       call.SetNoInline(true)
                }
 
        // TODO do them here (or earlier),
@@ -559,11 +567,13 @@ func inlnode(n ir.Node, maxCost int32, inlMap map[*ir.Func]bool, edit func(ir.No
 
        ir.EditChildren(n, edit)
 
-       if n.Op() == ir.OAS2FUNC && n.Rlist().First().Op() == ir.OINLCALL {
-               n.PtrRlist().Set(inlconv2list(n.Rlist().First()))
-               n.SetOp(ir.OAS2)
-               n.SetTypecheck(0)
-               n = typecheck(n, ctxStmt)
+       if as := n; as.Op() == ir.OAS2FUNC {
+               if as.Rlist().First().Op() == ir.OINLCALL {
+                       as.PtrRlist().Set(inlconv2list(as.Rlist().First().(*ir.InlinedCallExpr)))
+                       as.SetOp(ir.OAS2)
+                       as.SetTypecheck(0)
+                       n = typecheck(as, ctxStmt)
+               }
        }
 
        // with all the branches out of the way, it is now time to
@@ -576,45 +586,46 @@ func inlnode(n ir.Node, maxCost int32, inlMap map[*ir.Func]bool, edit func(ir.No
                }
        }
 
-       var call ir.Node
+       var call *ir.CallExpr
        switch n.Op() {
        case ir.OCALLFUNC:
-               call = n
+               call = n.(*ir.CallExpr)
                if base.Flag.LowerM > 3 {
-                       fmt.Printf("%v:call to func %+v\n", ir.Line(n), n.Left())
+                       fmt.Printf("%v:call to func %+v\n", ir.Line(n), call.Left())
                }
-               if isIntrinsicCall(n.(*ir.CallExpr)) {
+               if isIntrinsicCall(call) {
                        break
                }
-               if fn := inlCallee(n.Left()); fn != nil && fn.Inl != nil {
-                       n = mkinlcall(n, fn, maxCost, inlMap, edit)
+               if fn := inlCallee(call.Left()); fn != nil && fn.Inl != nil {
+                       n = mkinlcall(call, fn, maxCost, inlMap, edit)
                }
 
        case ir.OCALLMETH:
-               call = n
+               call = n.(*ir.CallExpr)
                if base.Flag.LowerM > 3 {
-                       fmt.Printf("%v:call to meth %L\n", ir.Line(n), n.Left().Right())
+                       fmt.Printf("%v:call to meth %v\n", ir.Line(n), call.Left().(*ir.SelectorExpr).Sel)
                }
 
                // typecheck should have resolved ODOTMETH->type, whose nname points to the actual function.
-               if n.Left().Type() == nil {
-                       base.Fatalf("no function type for [%p] %+v\n", n.Left(), n.Left())
+               if call.Left().Type() == nil {
+                       base.Fatalf("no function type for [%p] %+v\n", call.Left(), call.Left())
                }
 
-               n = mkinlcall(n, methodExprName(n.Left()).Func(), maxCost, inlMap, edit)
+               n = mkinlcall(call, methodExprName(call.Left()).Func(), maxCost, inlMap, edit)
        }
 
        base.Pos = lno
 
        if n.Op() == ir.OINLCALL {
-               switch call.(*ir.CallExpr).Use {
+               ic := n.(*ir.InlinedCallExpr)
+               switch call.Use {
                default:
                        ir.Dump("call", call)
                        base.Fatalf("call missing use")
                case ir.CallUseExpr:
-                       n = inlconv2expr(n)
+                       n = inlconv2expr(ic)
                case ir.CallUseStmt:
-                       n = inlconv2stmt(n)
+                       n = inlconv2stmt(ic)
                case ir.CallUseList:
                        // leave for caller to convert
                }
@@ -627,8 +638,8 @@ func inlnode(n ir.Node, maxCost int32, inlMap map[*ir.Func]bool, edit func(ir.No
 // that it refers to if statically known. Otherwise, it returns nil.
 func inlCallee(fn ir.Node) *ir.Func {
        fn = staticValue(fn)
-       switch {
-       case fn.Op() == ir.OMETHEXPR:
+       switch fn.Op() {
+       case ir.OMETHEXPR:
                n := methodExprName(fn)
                // Check that receiver type matches fn.Left.
                // TODO(mdempsky): Handle implicit dereference
@@ -637,9 +648,11 @@ func inlCallee(fn ir.Node) *ir.Func {
                        return nil
                }
                return n.Func()
-       case fn.Op() == ir.ONAME && fn.Class() == ir.PFUNC:
-               return fn.Func()
-       case fn.Op() == ir.OCLOSURE:
+       case ir.ONAME:
+               if fn.Class() == ir.PFUNC {
+                       return fn.Func()
+               }
+       case ir.OCLOSURE:
                c := fn.Func()
                caninl(c)
                return c
@@ -650,7 +663,7 @@ func inlCallee(fn ir.Node) *ir.Func {
 func staticValue(n ir.Node) ir.Node {
        for {
                if n.Op() == ir.OCONVNOP {
-                       n = n.Left()
+                       n = n.(*ir.ConvExpr).Left()
                        continue
                }
 
@@ -665,8 +678,12 @@ func staticValue(n ir.Node) ir.Node {
 // staticValue1 implements a simple SSA-like optimization. If n is a local variable
 // that is initialized and never reassigned, staticValue1 returns the initializer
 // expression. Otherwise, it returns nil.
-func staticValue1(n ir.Node) ir.Node {
-       if n.Op() != ir.ONAME || n.Class() != ir.PAUTO || n.Name().Addrtaken() {
+func staticValue1(nn ir.Node) ir.Node {
+       if nn.Op() != ir.ONAME {
+               return nil
+       }
+       n := nn.(*ir.Name)
+       if n.Class() != ir.PAUTO || n.Name().Addrtaken() {
                return nil
        }
 
@@ -695,7 +712,7 @@ FindRHS:
                base.Fatalf("RHS is nil: %v", defn)
        }
 
-       if reassigned(n.(*ir.Name)) {
+       if reassigned(n) {
                return nil
        }
 
@@ -757,7 +774,7 @@ var inlgen int
 // parameters.
 // The result of mkinlcall MUST be assigned back to n, e.g.
 //     n.Left = mkinlcall(n.Left, fn, isddd)
-func mkinlcall(n ir.Node, fn *ir.Func, maxCost int32, inlMap map[*ir.Func]bool, edit func(ir.Node) ir.Node) ir.Node {
+func mkinlcall(n *ir.CallExpr, fn *ir.Func, maxCost int32, inlMap map[*ir.Func]bool, edit func(ir.Node) ir.Node) ir.Node {
        if fn.Inl == nil {
                if logopt.Enabled() {
                        logopt.LogOpt(n.Pos(), "cannotInlineCall", "inline", ir.FuncName(Curfn),
@@ -830,8 +847,9 @@ func mkinlcall(n ir.Node, fn *ir.Func, maxCost int32, inlMap map[*ir.Func]bool,
        if n.Op() == ir.OCALLFUNC {
                callee := n.Left()
                for callee.Op() == ir.OCONVNOP {
-                       ninit.AppendNodes(callee.PtrInit())
-                       callee = callee.Left()
+                       conv := callee.(*ir.ConvExpr)
+                       ninit.AppendNodes(conv.PtrInit())
+                       callee = conv.Left()
                }
                if callee.Op() != ir.ONAME && callee.Op() != ir.OCLOSURE && callee.Op() != ir.OMETHEXPR {
                        base.Fatalf("unexpected callee expression: %v", callee)
@@ -952,16 +970,17 @@ func mkinlcall(n ir.Node, fn *ir.Func, maxCost int32, inlMap map[*ir.Func]bool,
        as := ir.Nod(ir.OAS2, nil, nil)
        as.SetColas(true)
        if n.Op() == ir.OCALLMETH {
-               if n.Left().Left() == nil {
+               sel := n.Left().(*ir.SelectorExpr)
+               if sel.Left() == nil {
                        base.Fatalf("method call without receiver: %+v", n)
                }
-               as.PtrRlist().Append(n.Left().Left())
+               as.PtrRlist().Append(sel.Left())
        }
        as.PtrRlist().Append(n.List().Slice()...)
 
        // For non-dotted calls to variadic functions, we assign the
        // variadic parameter's temp name separately.
-       var vas ir.Node
+       var vas *ir.AssignStmt
 
        if recv := fn.Type().Recv(); recv != nil {
                as.PtrList().Append(inlParam(recv, as, inlvars))
@@ -984,14 +1003,15 @@ func mkinlcall(n ir.Node, fn *ir.Func, maxCost int32, inlMap map[*ir.Func]bool,
                }
                varargs := as.List().Slice()[x:]
 
-               vas = ir.Nod(ir.OAS, nil, nil)
+               vas = ir.NewAssignStmt(base.Pos, nil, nil)
                vas.SetLeft(inlParam(param, vas, inlvars))
                if len(varargs) == 0 {
                        vas.SetRight(nodnil())
                        vas.Right().SetType(param.Type)
                } else {
-                       vas.SetRight(ir.Nod(ir.OCOMPLIT, nil, ir.TypeNode(param.Type)))
-                       vas.Right().PtrList().Set(varargs)
+                       lit := ir.Nod(ir.OCOMPLIT, nil, ir.TypeNode(param.Type))
+                       lit.PtrList().Set(varargs)
+                       vas.SetRight(lit)
                }
        }
 
@@ -1229,13 +1249,20 @@ func (subst *inlsubst) node(n ir.Node) ir.Node {
                typecheckslice(init, ctxStmt)
                return ir.NewBlockStmt(base.Pos, init)
 
-       case ir.OGOTO, ir.OLABEL:
-               m := ir.Copy(n)
+       case ir.OGOTO:
+               m := ir.Copy(n).(*ir.BranchStmt)
                m.SetPos(subst.updatedPos(m.Pos()))
                m.PtrInit().Set(nil)
                p := fmt.Sprintf("%s·%d", n.Sym().Name, inlgen)
                m.SetSym(lookup(p))
+               return m
 
+       case ir.OLABEL:
+               m := ir.Copy(n).(*ir.LabelStmt)
+               m.SetPos(subst.updatedPos(m.Pos()))
+               m.PtrInit().Set(nil)
+               p := fmt.Sprintf("%s·%d", n.Sym().Name, inlgen)
+               m.SetSym(lookup(p))
                return m
        }
 
@@ -1280,36 +1307,38 @@ func devirtualize(fn *ir.Func) {
        Curfn = fn
        ir.VisitList(fn.Body(), func(n ir.Node) {
                if n.Op() == ir.OCALLINTER {
-                       devirtualizeCall(n)
+                       devirtualizeCall(n.(*ir.CallExpr))
                }
        })
 }
 
-func devirtualizeCall(call ir.Node) {
-       recv := staticValue(call.Left().Left())
-       if recv.Op() != ir.OCONVIFACE {
+func devirtualizeCall(call *ir.CallExpr) {
+       sel := call.Left().(*ir.SelectorExpr)
+       r := staticValue(sel.Left())
+       if r.Op() != ir.OCONVIFACE {
                return
        }
+       recv := r.(*ir.ConvExpr)
 
        typ := recv.Left().Type()
        if typ.IsInterface() {
                return
        }
 
-       dt := ir.NodAt(call.Left().Pos(), ir.ODOTTYPE, call.Left().Left(), nil)
+       dt := ir.NodAt(sel.Pos(), ir.ODOTTYPE, sel.Left(), nil)
        dt.SetType(typ)
-       x := typecheck(nodlSym(call.Left().Pos(), ir.OXDOT, dt, call.Left().Sym()), ctxExpr|ctxCallee)
+       x := typecheck(nodlSym(sel.Pos(), ir.OXDOT, dt, sel.Sym()), ctxExpr|ctxCallee)
        switch x.Op() {
        case ir.ODOTMETH:
                if base.Flag.LowerM != 0 {
-                       base.WarnfAt(call.Pos(), "devirtualizing %v to %v", call.Left(), typ)
+                       base.WarnfAt(call.Pos(), "devirtualizing %v to %v", sel, typ)
                }
                call.SetOp(ir.OCALLMETH)
                call.SetLeft(x)
        case ir.ODOTINTER:
                // Promoted method from embedded interface-typed field (#42279).
                if base.Flag.LowerM != 0 {
-                       base.WarnfAt(call.Pos(), "partially devirtualizing %v to %v", call.Left(), typ)
+                       base.WarnfAt(call.Pos(), "partially devirtualizing %v to %v", sel, typ)
                }
                call.SetOp(ir.OCALLINTER)
                call.SetLeft(x)