]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile/internal/inline: refactor fixpoint algorithm
authorMatthew Dempsky <mdempsky@google.com>
Wed, 28 Feb 2024 09:28:43 +0000 (01:28 -0800)
committerGopher Robot <gobot@golang.org>
Wed, 28 Feb 2024 17:41:01 +0000 (17:41 +0000)
This CL refactors the interleaved fixpoint algorithm so that calls can
be inlined in any order. This has no immediate effect, but it will
allow a subsequent CL to prioritize calls by inlheur score.

Change-Id: I11a84d228e9c94732ee75f0d3c99bc90d83fea09
Reviewed-on: https://go-review.googlesource.com/c/go/+/567695
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
src/cmd/compile/internal/inline/interleaved/interleaved.go
src/cmd/compile/internal/ir/expr.go

index c5334d0300f2a3122b3bd8ba444f9910b4d7c742..e55b0f1aeeb7b71a67fd4520b27c0a896ec34138 100644 (file)
@@ -83,39 +83,108 @@ func DevirtualizeAndInlineFunc(fn *ir.Func, profile *pgo.Profile) {
                        fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
                }
 
-               // Walk fn's body and apply devirtualization and inlining.
-               var inlCalls []*ir.InlinedCallExpr
-               var edit func(ir.Node) ir.Node
-               edit = func(n ir.Node) ir.Node {
+               match := func(n ir.Node) bool {
                        switch n := n.(type) {
+                       case *ir.CallExpr:
+                               return true
                        case *ir.TailCallStmt:
                                n.Call.NoInline = true // can't inline yet
                        }
+                       return false
+               }
+
+               edit := func(n ir.Node) ir.Node {
+                       call, ok := n.(*ir.CallExpr)
+                       if !ok { // previously inlined
+                               return nil
+                       }
+
+                       devirtualize.StaticCall(call)
+                       if inlCall := inline.TryInlineCall(fn, call, bigCaller, profile); inlCall != nil {
+                               return inlCall
+                       }
+                       return nil
+               }
+
+               fixpoint(fn, match, edit)
+       })
+}
+
+// fixpoint repeatedly edits a function until it stabilizes.
+//
+// First, fixpoint applies match to every node n within fn. Then it
+// iteratively applies edit to each node satisfying match(n).
+//
+// If edit(n) returns nil, no change is made. Otherwise, the result
+// replaces n in fn's body, and fixpoint iterates at least once more.
+//
+// After an iteration where all edit calls return nil, fixpoint
+// returns.
+func fixpoint(fn *ir.Func, match func(ir.Node) bool, edit func(ir.Node) ir.Node) {
+       // Consider the expression "f(g())". We want to be able to replace
+       // "g()" in-place with its inlined representation. But if we first
+       // replace "f(...)" with its inlined representation, then "g()" will
+       // instead appear somewhere within this new AST.
+       //
+       // To mitigate this, each matched node n is wrapped in a ParenExpr,
+       // so we can reliably replace n in-place by assigning ParenExpr.X.
+       // It's safe to use ParenExpr here, because typecheck already
+       // removed them all.
+
+       var parens []*ir.ParenExpr
+       var mark func(ir.Node) ir.Node
+       mark = func(n ir.Node) ir.Node {
+               if _, ok := n.(*ir.ParenExpr); ok {
+                       return n // already visited n.X before wrapping
+               }
 
-                       ir.EditChildren(n, edit)
+               ok := match(n)
 
-                       if call, ok := n.(*ir.CallExpr); ok {
-                               devirtualize.StaticCall(call)
+               ir.EditChildren(n, mark)
 
-                               if inlCall := inline.TryInlineCall(fn, call, bigCaller, profile); inlCall != nil {
-                                       inlCalls = append(inlCalls, inlCall)
-                                       n = inlCall
-                               }
+               if ok {
+                       paren := ir.NewParenExpr(n.Pos(), n)
+                       paren.SetType(n.Type())
+                       paren.SetTypecheck(n.Typecheck())
+
+                       parens = append(parens, paren)
+                       n = paren
+               }
+
+               return n
+       }
+       ir.EditChildren(fn, mark)
+
+       // Edit until stable.
+       for {
+               done := true
+
+               for i := 0; i < len(parens); i++ { // can't use "range parens" here
+                       paren := parens[i]
+                       if new := edit(paren.X); new != nil {
+                               // Update AST and recursively mark nodes.
+                               paren.X = new
+                               ir.EditChildren(new, mark) // mark may append to parens
+                               done = false
                        }
+               }
 
-                       return n
+               if done {
+                       break
                }
-               ir.EditChildren(fn, edit)
-
-               // If we inlined any calls, we want to recursively visit their
-               // bodies for further devirtualization and inlining. However, we
-               // need to wait until *after* the original function body has been
-               // expanded, or else inlCallee can have false positives (e.g.,
-               // #54632).
-               for len(inlCalls) > 0 {
-                       call := inlCalls[0]
-                       inlCalls = inlCalls[1:]
-                       ir.EditChildren(call, edit)
+       }
+
+       // Finally, remove any parens we inserted.
+       if len(parens) == 0 {
+               return // short circuit
+       }
+       var unparen func(ir.Node) ir.Node
+       unparen = func(n ir.Node) ir.Node {
+               if paren, ok := n.(*ir.ParenExpr); ok {
+                       n = paren.X
                }
-       })
+               ir.EditChildren(n, unparen)
+               return n
+       }
+       ir.EditChildren(fn, unparen)
 }
index da5b437f99e6a45333217353e80a177ac717b4fa..345828c1638a8979badbe22f7237658d6cc52411 100644 (file)
@@ -856,13 +856,19 @@ func IsAddressable(n Node) bool {
 // "g()" expression.
 func StaticValue(n Node) Node {
        for {
-               if n.Op() == OCONVNOP {
-                       n = n.(*ConvExpr).X
-                       continue
-               }
-
-               if n.Op() == OINLCALL {
-                       n = n.(*InlinedCallExpr).SingleResult()
+               switch n1 := n.(type) {
+               case *ConvExpr:
+                       if n1.Op() == OCONVNOP {
+                               n = n1.X
+                               continue
+                       }
+               case *InlinedCallExpr:
+                       if n1.Op() == OINLCALL {
+                               n = n1.SingleResult()
+                               continue
+                       }
+               case *ParenExpr:
+                       n = n1.X
                        continue
                }