]> Cypherpunks repositories - gostls13.git/commitdiff
gc: inlining (disabled without -l)
authorLuuk van Dijk <lvd@golang.org>
Wed, 14 Dec 2011 14:05:33 +0000 (15:05 +0100)
committerLuuk van Dijk <lvd@golang.org>
Wed, 14 Dec 2011 14:05:33 +0000 (15:05 +0100)
Cross- and intra package inlining of single assignments or return <expression>.
Minus some hairy cases, currently including other calls, expressions with closures and ... arguments.

R=rsc, rogpeppe, adg, gri
CC=golang-dev
https://golang.org/cl/5400043

src/cmd/gc/Makefile
src/cmd/gc/export.c
src/cmd/gc/go.h
src/cmd/gc/go.y
src/cmd/gc/inl.c [new file with mode: 0644]
src/cmd/gc/lex.c

index e6e3a74f448e8a457b6a848e3c16e27fa540c40e..623add4a7fdbe4fa1efcb5890b89f746c4a3e38c 100644 (file)
@@ -27,6 +27,7 @@ OFILES=\
        fmt.$O\
        gen.$O\
        init.$O\
+       inl.$O\
        lex.$O\
        md5.$O\
        mparith1.$O\
index 5951aa2e8815c2cd9a58bb55c0effd668851065b..00bbaf31f1bac7b334175c8c8ac18f4d997fd79c 100644 (file)
@@ -7,7 +7,7 @@
 #include       "go.h"
 #include       "y.tab.h"
 
-static void    dumpexporttype(Type*);
+static void    dumpexporttype(Type *t);
 
 // Mark n's symbol as exported
 void
@@ -78,7 +78,7 @@ dumppkg(Pkg *p)
 {
        char *suffix;
 
-       if(p == nil || p == localpkg || p->exported)
+       if(p == nil || p == localpkg || p->exported || p == builtinpkg)
                return;
        p->exported = 1;
        suffix = "";
@@ -87,6 +87,68 @@ dumppkg(Pkg *p)
        Bprint(bout, "\timport %s \"%Z\"%s\n", p->name, p->path, suffix);
 }
 
+// Look for anything we need for the inline body
+static void reexportdep(Node *n);
+static void
+reexportdeplist(NodeList *ll)
+{
+       for(; ll ;ll=ll->next)
+               reexportdep(ll->n);
+}
+
+static void
+reexportdep(Node *n)
+{
+       Type *t;
+
+       if(!n)
+               return;
+
+       switch(n->op) {
+       case ONAME:
+               switch(n->class&~PHEAP) {
+               case PFUNC:
+               case PEXTERN:
+                       if (n->sym && n->sym->pkg != localpkg && n->sym->pkg != builtinpkg)
+                               exportlist = list(exportlist, n);
+               }
+               break;
+
+       case OTYPE:
+       case OLITERAL:
+               if (n->sym && n->sym->pkg != localpkg && n->sym->pkg != builtinpkg)
+                       exportlist = list(exportlist, n);
+               break;
+
+       // for operations that need a type when rendered, put the type on the export list.
+       case OCONV:
+       case OCONVIFACE:
+       case OCONVNOP:
+       case ODOTTYPE:
+       case OSTRUCTLIT:
+       case OPTRLIT:
+               t = n->type;
+               if(!t->sym && t->type)
+                       t = t->type;
+               if (t && t->sym && t->sym->def && t->sym->pkg != localpkg  && t->sym->pkg != builtinpkg) {
+//                     print("reexport convnop %+hN\n", t->sym->def);
+                       exportlist = list(exportlist, t->sym->def);
+               }
+               break;
+       }
+
+       reexportdep(n->left);
+       reexportdep(n->right);
+       reexportdeplist(n->list);
+       reexportdeplist(n->rlist);
+       reexportdeplist(n->ninit);
+       reexportdep(n->ntest);
+       reexportdep(n->nincr);
+       reexportdeplist(n->nbody);
+       reexportdeplist(n->nelse);
+}
+
+
 static void
 dumpexportconst(Sym *s)
 {
@@ -123,9 +185,13 @@ dumpexportvar(Sym *s)
        t = n->type;
        dumpexporttype(t);
 
-       if(t->etype == TFUNC && n->class == PFUNC)
-               Bprint(bout, "\tfunc %#S%#hT\n", s, t);
-       else
+       if(t->etype == TFUNC && n->class == PFUNC) {
+               if (n->inl) {
+                       Bprint(bout, "\tfunc %#S%#hT { %#H }\n", s, t, n->inl);
+                       reexportdeplist(n->inl);
+               } else
+                       Bprint(bout, "\tfunc %#S%#hT\n", s, t);
+       } else
                Bprint(bout, "\tvar %#S %#T\n", s, t);
 }
 
@@ -148,7 +214,6 @@ dumpexporttype(Type *t)
 
        if(t == T)
                return;
-
        if(t->printed || t == types[t->etype] || t == bytetype || t == runetype || t == errortype)
                return;
        t->printed = 1;
@@ -177,7 +242,11 @@ dumpexporttype(Type *t)
        Bprint(bout, "\ttype %#S %#lT\n", t->sym, t);
        for(i=0; i<n; i++) {
                f = m[i];
-               Bprint(bout, "\tfunc (%#T) %#hhS%#hT\n", getthisx(f->type)->type, f->sym, f->type);
+               if (f->type->nname && f->type->nname->inl) { // nname was set by caninl
+                       Bprint(bout, "\tfunc (%#T) %#hhS%#hT { %#H }\n", getthisx(f->type)->type, f->sym, f->type, f->type->nname->inl);
+                       reexportdeplist(f->type->nname->inl);
+               } else
+                       Bprint(bout, "\tfunc (%#T) %#hhS%#hT\n", getthisx(f->type)->type, f->sym, f->type);
        }
 }
 
index 6d16ea739b34acfac6ef75ea6f0edd29c9c4ca38..10441a5c3f72c608367126dbdeb11d293d98207d 100644 (file)
@@ -141,7 +141,7 @@ struct      Type
        uchar   local;          // created in this file
        uchar   deferwidth;
        uchar   broke;
-       uchar   isddd;  // TFIELD is ... argument
+       uchar   isddd;          // TFIELD is ... argument
        uchar   align;
 
        Node*   nod;            // canonical OTYPE node
@@ -260,6 +260,7 @@ struct      Node
        NodeList*       exit;
        NodeList*       cvars;  // closure params
        NodeList*       dcl;    // autodcl for this func/closure
+       NodeList*       inl;    // copy of the body for use in inlining
 
        // OLITERAL/OREGISTER
        Val     val;
@@ -280,6 +281,9 @@ struct      Node
        Node*   outer;  // outer PPARAMREF in nested closure
        Node*   closure;        // ONAME/PHEAP <-> ONAME/PPARAMREF
 
+       // ONAME substitute while inlining
+       Node* inlvar;
+
        // OPACK
        Pkg*    pkg;
        
@@ -473,6 +477,7 @@ enum
        // misc
        ODDD,
        ODDDARG,
+       OINLCALL,       // intermediary representation of an inlined call
 
        // for back ends
        OCMP, ODEC, OEXTEND, OINC, OREGISTER, OINDREG,
@@ -986,6 +991,12 @@ Node*      temp(Type*);
 void   fninit(NodeList *n);
 Sym*   renameinit(void);
 
+/*
+ *     inl.c
+ */
+void   caninl(Node *fn);
+void   inlcalls(Node *fn);
+
 /*
  *     lex.c
  */
index f71658920ade2bb66d809499cd0a774d48c63bed..d64a3f82b43d97557afae8afa70dfb28508ec306 100644 (file)
@@ -1295,6 +1295,12 @@ hidden_fndcl:
                checkwidth($$->type);
                addmethod($4, $$->type, 0);
                funchdr($$);
+               
+               // inl.c's inlnode in on a dotmeth node expects to find the inlineable body as
+               // (dotmeth's type)->nname->inl, and dotmeth's type has been pulled
+               // out by typecheck's lookdot as this $$->ttype.  So by providing
+               // this back link here we avoid special casing there.
+               $$->type->nname = $$;
        }
 
 fntype:
@@ -1790,11 +1796,15 @@ hidden_import:
                if($2 == N)
                        break;
 
+               $2->inl = $3;
+
                funcbody($2);
                importlist = list(importlist, $2);
 
                if(debug['E']) {
-                       print("import [%Z] func %lN \n", $2->sym->pkg->path, $2);
+                       print("import [%Z] func %lN \n", importpkg->path, $2);
+                       if(debug['l'] > 2 && $2->inl)
+                               print("inl body:%+H\n", $2->inl);
                }
        }
 
diff --git a/src/cmd/gc/inl.c b/src/cmd/gc/inl.c
new file mode 100644 (file)
index 0000000..196a6ef
--- /dev/null
@@ -0,0 +1,688 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// The inlining facility makes 2 passes: first caninl determines which
+// functions are suitable for inlining, and for those that are it
+// saves a copy of the body. Then inlcalls walks each function body to
+// expand calls to inlinable functions.
+//
+
+#include <u.h>
+#include <libc.h>
+#include "go.h"
+
+// Used by caninl.
+static Node*   inlcopy(Node *n);
+static NodeList* inlcopylist(NodeList *ll);
+static int     ishairy(Node *n);
+static int     ishairylist(NodeList *ll); 
+
+// Used by inlcalls
+static void    inlnodelist(NodeList *l);
+static void    inlnode(Node **np);
+static void    mkinlcall(Node **np, Node *fn);
+static Node*   inlvar(Node *n);
+static Node*   retvar(Type *n, int i);
+static Node*   newlabel(void);
+static Node*   inlsubst(Node *n);
+static NodeList* inlsubstlist(NodeList *ll);
+
+// Used during inlsubst[list]
+static Node *inlfn;            // function currently being inlined
+static Node *inlretlabel;      // target of the goto substituted in place of a return
+static NodeList *inlretvars;   // temp out variables
+
+
+// Caninl determines whether fn is inlineable. Currently that means:
+// fn is exactly 1 statement, either a return or an assignment, and
+// some temporary constraints marked TODO.  If fn is inlineable, saves
+// fn->nbody in fn->inl and substitutes it with a copy.
+void
+caninl(Node *fn)
+{
+       Node *savefn;
+       Type *t;
+
+       if(fn->op != ODCLFUNC)
+               fatal("caninl %N", fn);
+       if(!fn->nname)
+               fatal("caninl no nname %+N", fn);
+
+       // exactly 1 statement
+       if(fn->nbody == nil || fn->nbody->next != nil)
+               return;
+
+       // the single statement should be a return or an assignment.
+       switch(fn->nbody->n->op) {
+       default:
+               return;
+       case ORETURN:
+       case OAS:
+       case OAS2:
+         // case OEMPTY:  // TODO
+               break;
+       }
+
+       // can't handle ... args yet
+       for(t=fn->type->type->down->down->type; t; t=t->down)
+               if(t->isddd)
+                       return;
+
+       // TODO Anything non-trivial
+       if(ishairy(fn))
+               return;
+
+       savefn = curfn;
+       curfn = fn;
+
+       fn->nname->inl = fn->nbody;
+       fn->nbody = inlcopylist(fn->nname->inl);
+
+       // hack, TODO, check for better way to link method nodes back to the thing with the ->inl
+       // this is so export can find the body of a method
+       fn->type->nname = fn->nname;
+
+       if(debug['l']>1)
+               print("%L: can inline %#N as: %#T { %#H }\n", fn->lineno, fn->nname, fn->type, fn->nname->inl);
+
+       curfn = savefn;
+}
+
+// Look for anything we want to punt on.
+static int
+ishairylist(NodeList *ll)
+{
+       for(;ll;ll=ll->next)
+               if(ishairy(ll->n))
+                       return 1;
+       return 0;
+}
+
+static int
+ishairy(Node *n)
+{
+       if(!n)
+               return 0;
+
+       switch(n->op) {
+       case OPROC:
+       case ODEFER:
+       case OCALL:
+       case OCALLFUNC:
+       case OCALLINTER:
+       case OCALLMETH:
+       case OCLOSURE:
+               return 1;
+       }
+
+       return  ishairy(n->left) ||
+               ishairy(n->right) ||
+               ishairylist(n->list) ||
+               ishairylist(n->rlist) ||
+               ishairylist(n->ninit) ||
+               ishairy(n->ntest) ||
+               ishairy(n->nincr) ||
+               ishairylist(n->nbody) ||
+               ishairylist(n->nelse);
+}
+
+// Inlcopy and inlcopylist recursively copy the body of a function.
+// Any name-like node of non-local class is marked for re-export by adding it to
+// the exportlist.
+static NodeList*
+inlcopylist(NodeList *ll)
+{
+       NodeList *l;
+
+       l = nil;
+       for(; ll; ll=ll->next)
+               l = list(l, inlcopy(ll->n));
+       return l;
+}
+
+static Node*
+inlcopy(Node *n)
+{
+       Node *m;
+
+       if(n == N)
+               return N;
+
+       switch(n->op) {
+       case ONAME:
+       case OTYPE:
+       case OLITERAL:
+               return n;
+       }
+
+       m = nod(OXXX, N, N);
+       *m = *n;
+       m->inl = nil;
+       m->left   = inlcopy(n->left);
+       m->right  = inlcopy(n->right);
+       m->list   = inlcopylist(n->list);
+       m->rlist  = inlcopylist(n->rlist);
+       m->ninit  = inlcopylist(n->ninit);
+       m->ntest  = inlcopy(n->ntest);
+       m->nincr  = inlcopy(n->nincr);
+       m->nbody  = inlcopylist(n->nbody);
+       m->nelse  = inlcopylist(n->nelse);
+
+       return m;
+}
+
+
+// Inlcalls/nodelist/node walks fn's statements and expressions and substitutes any
+// calls made to inlineable functions.  This is the external entry point.
+void
+inlcalls(Node *fn)
+{
+       Node *savefn;
+
+       savefn = curfn;
+       curfn = fn;
+       inlnode(&fn);
+       if(fn != curfn)
+               fatal("inlnode replaced curfn");
+       curfn = savefn;
+}
+
+// Turn an OINLCALL into a statement.
+static void
+inlconv2stmt(Node *n)
+{
+       n->op = OBLOCK;
+       n->list = n->nbody;
+       n->nbody = nil;
+       n->rlist = nil;
+}
+
+// Turn an OINLCALL into a single valued expression.
+static void
+inlconv2expr(Node *n)
+{
+       n->op = OCONVNOP;
+       n->left = n->rlist->n;
+       n->rlist = nil;
+       n->ninit = concat(n->ninit, n->nbody);
+       n->nbody = nil;
+}
+
+// Turn the OINLCALL in n->list into an expression list on n.
+// Used in return and call statements.
+static void
+inlgluelist(Node *n)
+{
+       Node *c;
+
+       c = n->list->n;
+       n->ninit = concat(n->ninit, c->ninit);
+       n->ninit = concat(n->ninit, c->nbody);
+       n->list  = c->rlist;
+} 
+
+// Turn the OINLCALL in n->rlist->n into an expression list on n.
+// Used in OAS2FUNC.
+static void
+inlgluerlist(Node *n)
+{
+       Node *c;
+
+       c = n->rlist->n;
+       n->ninit = concat(n->ninit, c->ninit);
+       n->ninit = concat(n->ninit, c->nbody);
+       n->rlist = c->rlist;
+} 
+
+static void
+inlnodelist(NodeList *l)
+{
+       for(; l; l=l->next)
+               inlnode(&l->n);
+}
+
+// inlnode recurses over the tree to find inlineable calls, which will
+// be turned into OINLCALLs by mkinlcall.  When the recursion comes
+// back up will examine left, right, list, rlist, ninit, ntest, nincr,
+// nbody and nelse and use one of the 4 inlconv/glue functions above
+// to turn the OINLCALL into an expression, a statement, or patch it
+// in to this nodes list or rlist as appropriate.
+// NOTE it makes no sense to pass the glue functions down the recursion to the level where the OINLCALL gets created because they have to edit /this/ n,
+// so you'd have to push that one down as well, but then you may as well do it here.  so this is cleaner and shorter and less complicated.
+static void
+inlnode(Node **np)
+{
+       Node *n;
+       NodeList *l;
+
+       if(*np == nil)
+               return;
+
+       n = *np;
+
+       switch(n->op) {
+       case ODEFER:
+       case OPROC:
+               // inhibit inlining of their argument
+               switch(n->left->op) {
+               case OCALLFUNC:
+               case OCALLMETH:
+                       n->left->etype = n->op;
+               }
+
+       case OCLOSURE:
+               // TODO.  do them here rather than in lex.c phase 6b
+               return;
+       }
+
+       inlnodelist(n->ninit);
+       for(l=n->ninit; l; l=l->next)
+               if(l->n->op == OINLCALL)
+                       inlconv2stmt(l->n);
+
+       inlnode(&n->left);
+       if(n->left && n->left->op == OINLCALL)
+               inlconv2expr(n->left);
+
+       inlnode(&n->right);
+       if(n->right && n->right->op == OINLCALL)
+               inlconv2expr(n->right);
+
+       inlnodelist(n->list);
+       switch(n->op) {
+       case OBLOCK:
+               for(l=n->list; l; l=l->next)
+                       if(l->n->op == OINLCALL)
+                               inlconv2stmt(l->n);
+               break;
+
+       case ORETURN:
+               if(count(n->list) == 1 && curfn->type->outtuple > 1 && n->list->n->op == OINLCALL) {
+                       inlgluelist(n);
+                       break;
+               }
+               
+               goto list_dflt;
+
+       case OCALLMETH:
+       case OCALLINTER:
+       case OCALLFUNC:
+               // if we just replaced arg in f(arg()) with an inlined call
+               // and arg returns multiple values, glue as list
+               if(count(n->list) == 1 && n->list->n->op == OINLCALL && count(n->list->n->rlist) > 1) {
+                       inlgluelist(n);
+                       break;
+               }
+
+               // fallthrough
+       default:
+       list_dflt:
+               for(l=n->list; l; l=l->next)
+                       if(l->n->op == OINLCALL)
+                               inlconv2expr(l->n);
+       }
+
+       inlnodelist(n->rlist);
+       switch(n->op) {
+       case OAS2FUNC:
+               if(n->rlist->n->op == OINLCALL) {
+                       inlgluerlist(n);
+                       n->op = OAS2;
+                       n->typecheck = 0;
+                       typecheck(np, Etop);
+                       break;
+               }
+
+               // fallthrough
+       default:
+               for(l=n->rlist; l; l=l->next)
+                       if(l->n->op == OINLCALL)
+                               inlconv2expr(l->n);
+
+       }
+
+       inlnode(&n->ntest);
+       if(n->ntest && n->ntest->op == OINLCALL)
+               inlconv2expr(n->ntest);
+
+       inlnode(&n->nincr);
+       if(n->nincr && n->nincr->op == OINLCALL)
+               inlconv2stmt(n->nincr);
+
+       inlnodelist(n->nbody);
+       for(l=n->nbody; l; l=l->next)
+               if(l->n->op == OINLCALL)
+                       inlconv2stmt(l->n);
+
+       inlnodelist(n->nelse);
+       for(l=n->nelse; l; l=l->next)
+               if(l->n->op == OINLCALL)
+                       inlconv2stmt(l->n);
+
+       // with all the branches out of the way, it is now time to
+       // transmogrify this node itself unless inhibited by the
+       // switch at the top of this function.
+       switch(n->op) {
+       case OCALLFUNC:
+       case OCALLMETH:
+               if (n->etype == OPROC || n->etype == ODEFER)
+                       return;
+       }
+
+       switch(n->op) {
+       case OCALLFUNC:
+               if(debug['l']>3)
+                       print("%L:call to func %lN\n", n->lineno, n->left);
+               mkinlcall(np, n->left);
+               break;
+
+       case OCALLMETH:
+               if(debug['l']>3)
+                       print("%L:call to meth %lN\n", n->lineno, n->left->right);
+               // typecheck resolved ODOTMETH->type, whose nname points to the actual function.
+               if(n->left->type->nname) 
+                       mkinlcall(np, n->left->type->nname);
+               else
+                       fatal("no function definition for [%p] %+T\n", n->left->type, n->left->type);
+               break;
+       }
+}
+
+// if *np is a call, and fn is a function with an inlinable body, substitute *np with an OINLCALL.
+// On return ninit has the parameter assignments, the nbody is the
+// inlined function body and list, rlist contain the input, output
+// parameters.
+static void
+mkinlcall(Node **np, Node *fn)
+{
+       int i;
+       Node *n, *call, *saveinlfn, *as;
+       NodeList *dcl, *ll, *ninit, *body;
+       Type *t;
+
+       if (fn->inl == nil)
+               return;
+
+       n = *np;
+
+       // Bingo, we have a function node, and it has an inlineable body
+       if(debug['l']>1)
+               print("%L: inlining call to %S %#T { %#H }\n", n->lineno, fn->sym, fn->type, fn->inl);
+
+       if(debug['l']>2)
+               print("%L: Before inlining: %+N\n", n->lineno, n);
+
+       saveinlfn = inlfn;
+       inlfn = fn;
+
+       ninit = n->ninit;
+
+       if (fn->defn) // local function
+               dcl = fn->defn->dcl;
+       else // imported function
+               dcl = fn->dcl;
+
+       // Make temp names to use instead of the originals for anything but the outparams
+       for(ll = dcl; ll; ll=ll->next)
+               if(ll->n->op == ONAME && ll->n->class != PPARAMOUT) {
+                       ll->n->inlvar = inlvar(ll->n);
+                       ninit = list(ninit, nod(ODCL, ll->n->inlvar, N));  // otherwise gen won't emit the allocations for heapallocs
+               }
+
+       // assign arguments to the parameters' temp names
+       if(fn->type->thistuple) {
+               if (!n->left->op == ODOTMETH || !n->left->left)
+                       fatal("method call without receiver: %+N", n);
+               t = getthisx(fn->type)->type;
+               if(t != T && t->nname) {
+                       if(!t->nname->inlvar)
+                               fatal("missing inlvar for %N\n", t->nname);
+                       as = nod(OAS, t->nname->inlvar, n->left->left);
+                       typecheck(&as, Etop);
+                       ninit = list(ninit, as);
+               } // else if !ONAME add to init anyway?
+       }
+
+       as = nod(OAS2, N, N);
+       if(fn->type->intuple > 1 && n->list && !n->list->next) {
+               // TODO check that n->list->n is a call?
+               as->rlist = n->list;
+               for(t = getinargx(fn->type)->type; t; t=t->down) {
+                       if(t->nname) {
+                               if(!t->nname->inlvar)
+                                       fatal("missing inlvar for %N\n", t->nname);
+                               as->list = list(as->list, t->nname->inlvar);
+                       } else {
+                               as->list = list(as->list, temp(t->type));
+                       }
+               }               
+       } else {
+               ll = n->list;
+               for(t = getinargx(fn->type)->type; t && ll; t=t->down) {
+                       if(t->nname) {
+                               if(!t->nname->inlvar)
+                                       fatal("missing inlvar for %N\n", t->nname);
+                               as->list = list(as->list, t->nname->inlvar);
+                               as->rlist = list(as->rlist, ll->n);
+                       }
+                       ll=ll->next;
+               }
+               if(ll || t)
+                       fatal("arg count mismatch: %#T  vs %,H\n",  getinargx(fn->type), n->list);
+       }
+
+       if (as->rlist) {
+               typecheck(&as, Etop);
+               ninit = list(ninit, as);
+       }
+
+       // make the outparams.  No need to declare because currently they'll only be used in the assignment that replaces returns.
+       inlretvars = nil;
+       i = 0;
+       for(t = getoutargx(fn->type)->type; t; t = t->down)
+               inlretvars = list(inlretvars, retvar(t, i++));
+       
+       inlretlabel = newlabel();
+       body = inlsubstlist(fn->inl);
+
+       body = list(body, nod(OGOTO, inlretlabel, N));  // avoid 'not used' when function doesnt have return
+       body = list(body, nod(OLABEL, inlretlabel, N));
+
+       typechecklist(body, Etop);
+
+       call = nod(OINLCALL, N, N);
+       call->ninit = ninit;
+       call->nbody = body;
+       call->rlist = inlretvars;
+       call->type = n->type;
+       call->lineno = n->lineno;
+       call->typecheck = 1;
+
+       *np = call;
+
+       inlfn = saveinlfn;
+       if(debug['l']>2)
+               print("%L: After inlining %+N\n\n", n->lineno, *np);
+
+}
+
+// Every time we expand a function we generate a new set of tmpnames,
+// PAUTO's in the calling functions, and link them off of the
+// PPARAM's, PAUTOS and PPARAMOUTs of the called function. 
+static Node*
+inlvar(Node *var)
+{
+       Node *n;
+
+       if(debug['l']>3)
+               print("inlvar %+N\n", var);
+
+       n = newname(var->sym);
+       n->type = var->type;
+       n->class = PAUTO;
+       n->used = 1;
+       n->curfn = curfn;   // the calling function, not the called one
+       curfn->dcl = list(curfn->dcl, n);
+       return n;
+}
+
+// Make a new pparamref
+static Node*
+inlref(Node *var)
+{
+       Node *n;
+
+       if (!var->closure)
+               fatal("No ->closure: %N", var);
+
+       if (!var->closure->inlvar)
+               fatal("No ->closure->inlref: %N", var);
+
+       n = nod(OXXX, N, N);
+       *n = *var;
+
+//     if(debug['l']>1)
+//             print("inlref: %N -> %N\n", var, var->closure->inlvar);
+
+       var = var->closure->inlvar;
+
+       return n;
+}
+
+// Synthesize a variable to store the inlined function's results in.
+static Node*
+retvar(Type *t, int i)
+{
+       Node *n;
+
+       snprint(namebuf, sizeof(namebuf), ".r%d", i);
+       n = newname(lookup(namebuf));
+       n->type = t->type;
+       n->class = PAUTO;
+       n->used = 1;
+       n->curfn = curfn;   // the calling function, not the called one
+       curfn->dcl = list(curfn->dcl, n);
+       return n;
+}
+
+static Node*
+newlabel(void)
+{
+       Node *n;
+       static int label;
+       
+       label++;
+       snprint(namebuf, sizeof(namebuf), ".inlret%.6d", label);
+       n = newname(lookup(namebuf));
+       n->etype = 1;  // flag 'safe' for escape analysis (no backjumps)
+       return n;
+}
+
+// inlsubst and inlsubstlist recursively copy the body of the saved
+// pristine ->inl body of the function while substituting references
+// to input/output parameters with ones to the tmpnames, and
+// substituting returns with assignments to the output.
+static NodeList*
+inlsubstlist(NodeList *ll)
+{
+       NodeList *l;
+
+       l = nil;
+       for(; ll; ll=ll->next)
+               l = list(l, inlsubst(ll->n));
+       return l;
+}
+
+static int closuredepth;
+
+static Node*
+inlsubst(Node *n)
+{
+       Node *m, *as;
+       NodeList *ll;
+
+       if(n == N)
+               return N;
+
+       switch(n->op) {
+       case ONAME:
+               if(n->inlvar) { // These will be set during inlnode
+                       if (debug['l']>2)
+                               print ("substituting name %N  ->  %N\n", n, n->inlvar);
+                       return n->inlvar;
+               }
+               if (debug['l']>2)
+                       print ("not substituting name %N\n", n);
+               return n;
+
+       case OLITERAL:
+       case OTYPE:
+               return n;
+
+       case ORETURN:
+               // only rewrite returns belonging to this function, not nested ones.
+               if (closuredepth > 0)
+                       break;
+               
+//             dump("Return before substitution", n);
+               m = nod(OGOTO, inlretlabel, N);
+               m->ninit  = inlsubstlist(n->ninit);
+
+               // rewrite naked return for function with return values to return PPARAMOUTs
+               if(count(n->list) == 0 && inlfn->type->outtuple > 0) {
+                       for(ll = inlfn->dcl; ll; ll=ll->next)
+                               if(ll->n->op == ONAME && ll->n->class == PPARAMOUT)
+                                       n->list = list(n->list, ll->n);
+
+//                     dump("Return naked -> dressed ", n);
+               }
+
+               if(inlretvars && n->list) {
+                       as = nod(OAS2, N, N);
+                       as->list = inlretvars;
+                       as->rlist = inlsubstlist(n->list);
+                       typecheck(&as, Etop);
+                       m->ninit = list(m->ninit, as);
+               }
+
+               typechecklist(m->ninit, Etop);
+               typecheck(&m, Etop);
+//             dump("Return after substitution", m);
+               return m;
+       }
+
+
+       m = nod(OXXX, N, N);
+       *m = *n;
+       m->ninit = nil;
+       
+       if(n->op == OCLOSURE) {
+               closuredepth++;
+
+               for(ll = m->dcl; ll; ll=ll->next)
+                       if(ll->n->op == ONAME) {
+                               ll->n->inlvar = inlvar(ll->n);
+                               m->ninit = list(m->ninit, nod(ODCL, ll->n->inlvar, N));  // otherwise gen won't emit the allocations for heapallocs
+                       }
+               
+               for (ll=m->cvars; ll; ll=ll->next)
+                       if (ll->n->op == ONAME)
+                               ll->n->cvars = list(ll->n->cvars, inlref(ll->n));
+       }
+       
+       m->left   = inlsubst(n->left);
+       m->right  = inlsubst(n->right);
+       m->list   = inlsubstlist(n->list);
+       m->rlist  = inlsubstlist(n->rlist);
+       m->ninit  = concat(m->ninit, inlsubstlist(n->ninit));
+       m->ntest  = inlsubst(n->ntest);
+       m->nincr  = inlsubst(n->nincr);
+       m->nbody  = inlsubstlist(n->nbody);
+       m->nelse  = inlsubstlist(n->nelse);
+
+       if(n->op == OCLOSURE)
+               closuredepth--;
+
+       return m;
+}
index 8c544f6b92f99be3b5f42d96531c65ee76306a21..b582ab5c4f298812f75842ed0f7077b877170048 100644 (file)
@@ -336,11 +336,43 @@ main(int argc, char *argv[])
        if(nsavederrors+nerrors)
                errorexit();
 
-       // Phase 4: escape analysis.
+       // Phase 4: Inlining
+       if (debug['l']) {               // TODO only if debug['l'] > 1, otherwise lazily when used.
+               // Typecheck imported function bodies
+               for(l=importlist; l; l=l->next) {
+                       if (l->n->inl == nil)
+                               continue;
+                       curfn = l->n;
+                       saveerrors();
+                       importpkg = l->n->sym->pkg;
+                       if (debug['l']>2)
+                               print("typecheck import [%S] %lN { %#H }\n", l->n->sym, l->n, l->n->inl);
+                       typechecklist(l->n->inl, Etop);
+                       importpkg = nil;
+               }
+               curfn = nil;
+               
+               if(nsavederrors+nerrors)
+                       errorexit();
+       }
+
+       if (debug['l']) {
+               // Find functions that can be inlined and clone them before walk expands them.
+               for(l=xtop; l; l=l->next)
+                       if(l->n->op == ODCLFUNC)
+                               caninl(l->n);
+               
+               // Expand inlineable calls in all functions
+               for(l=xtop; l; l=l->next)
+                       if(l->n->op == ODCLFUNC)
+                               inlcalls(l->n);
+       }
+
+       // Phase 5: escape analysis.
        if(!debug['N'])
                escapes();
 
-       // Phase 5: Compile top level functions.
+       // Phase 6: Compile top level functions.
        for(l=xtop; l; l=l->next)
                if(l->n->op == ODCLFUNC)
                        funccompile(l->n, 0);
@@ -348,7 +380,7 @@ main(int argc, char *argv[])
        if(nsavederrors+nerrors == 0)
                fninit(xtop);
 
-       // Phase 5b: Compile all closures.
+       // Phase 6b: Compile all closures.
        while(closures) {
                l = closures;
                closures = nil;
@@ -356,7 +388,7 @@ main(int argc, char *argv[])
                        funccompile(l->n, 1);
        }
 
-       // Phase 6: check external declarations.
+       // Phase 7: check external declarations.
        for(l=externdcl; l; l=l->next)
                if(l->n->op == ONAME)
                        typecheck(&l->n, Erv);