]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: deal with closures in generic functions and instantiated function values
authorDan Scales <danscales@google.com>
Sun, 21 Feb 2021 18:54:38 +0000 (10:54 -0800)
committerDan Scales <danscales@google.com>
Fri, 26 Feb 2021 18:57:20 +0000 (18:57 +0000)
 - Deal with closures in generic functions by fixing the stenciling code

 - Deal with instantiated function values (instantiated generic
   functions that are not immediately called) during stenciling. This
   requires changing the OFUNCINST node to an ONAME node for the
   appropriately instantiated function. We do this in a second pass,
   since this is uncommon, but requires editing the tree at multiple
   levels.

 - Check global assignments (as well as functions) for generic function
   instantiations.

 - Fix a bug in (*subst).typ where a generic type in a generic function
   may definitely not use all the type args of the function, so we need
   to translate the rparams of the type based on the tparams/targs of
   the function.

 - Added new test combine.go that tests out closures in generic
   functions and instantiated function values.

 - Added one new variant to the settable test.

 - Enabling inlining functions with closures for -G=3. (For now, set
   Ntype on closures in -G=3 mode to keep compatibility with later parts
   of compiler, and allow inlining of functions with closures.)

Change-Id: Iea63d5704c322e42e2f750a83adc8b44f911d4ec
Reviewed-on: https://go-review.googlesource.com/c/go/+/296269
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Dan Scales <danscales@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Dan Scales <danscales@google.com>

src/cmd/compile/internal/inline/inl.go
src/cmd/compile/internal/noder/expr.go
src/cmd/compile/internal/noder/stencil.go
test/typeparam/combine.go [new file with mode: 0644]
test/typeparam/settable.go

index e961b10844808ab545465d9a347e2c7ccd3624d5..1d049298d724b2090e89e4e30b26d490ddbb28fc 100644 (file)
@@ -354,7 +354,7 @@ func (v *hairyVisitor) doNode(n ir.Node) bool {
                return true
 
        case ir.OCLOSURE:
-               if base.Debug.InlFuncsWithClosures == 0 || base.Flag.G > 0 {
+               if base.Debug.InlFuncsWithClosures == 0 {
                        v.reason = "not inlining functions with closures"
                        return true
                }
index b166d34eadc9f00b69bb960011ca5e9ce69b7792..3fded144dcdaf30a28177c5dbeca37a81da63a80 100644 (file)
@@ -325,19 +325,22 @@ func (g *irgen) compLit(typ types2.Type, lit *syntax.CompositeLit) ir.Node {
        return typecheck.Expr(ir.NewCompLitExpr(g.pos(lit), ir.OCOMPLIT, ir.TypeNode(g.typ(typ)), exprs))
 }
 
-func (g *irgen) funcLit(typ types2.Type, expr *syntax.FuncLit) ir.Node {
+func (g *irgen) funcLit(typ2 types2.Type, expr *syntax.FuncLit) ir.Node {
        fn := ir.NewFunc(g.pos(expr))
        fn.SetIsHiddenClosure(ir.CurFunc != nil)
 
        fn.Nname = ir.NewNameAt(g.pos(expr), typecheck.ClosureName(ir.CurFunc))
        ir.MarkFunc(fn.Nname)
-       fn.Nname.SetType(g.typ(typ))
+       typ := g.typ(typ2)
        fn.Nname.Func = fn
        fn.Nname.Defn = fn
+       // Set Ntype for now to be compatible with later parts of compile, remove later.
+       fn.Nname.Ntype = ir.TypeNode(typ)
+       typed(typ, fn.Nname)
+       fn.SetTypecheck(1)
 
        fn.OClosure = ir.NewClosureExpr(g.pos(expr), fn)
-       fn.OClosure.SetType(fn.Nname.Type())
-       fn.OClosure.SetTypecheck(1)
+       typed(typ, fn.OClosure)
 
        g.funcBody(fn, nil, expr.Type, expr.Body)
 
index 69461a819055cd997bf6921d5ed11e72bfca4c4d..fb1bbfedc8b428910bdbcdaa9706e5f4c1258ace 100644 (file)
@@ -27,19 +27,37 @@ func (g *irgen) stencil() {
        // functions calling other generic functions.
        for i := 0; i < len(g.target.Decls); i++ {
                decl := g.target.Decls[i]
-               if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 {
-                       // Skip any non-function declarations and skip generic functions
+
+               // Look for function instantiations in bodies of non-generic
+               // functions or in global assignments (ignore global type and
+               // constant declarations).
+               switch decl.Op() {
+               case ir.ODCLFUNC:
+                       if decl.Type().HasTParam() {
+                               // Skip any generic functions
+                               continue
+                       }
+
+               case ir.OAS:
+
+               case ir.OAS2:
+
+               default:
                        continue
                }
 
-               // For each non-generic function, search for any function calls using
-               // generic function instantiations. (We don't yet handle generic
-               // function instantiations that are not immediately called.)
-               // Then create the needed instantiated function if it hasn't been
-               // created yet, and change to calling that function directly.
-               f := decl.(*ir.Func)
+               // For all non-generic code, search for any function calls using
+               // generic function instantiations. Then create the needed
+               // instantiated function if it hasn't been created yet, and change
+               // to calling that function directly.
                modified := false
-               ir.VisitList(f.Body, func(n ir.Node) {
+               foundFuncInst := false
+               ir.Visit(decl, func(n ir.Node) {
+                       if n.Op() == ir.OFUNCINST {
+                               // We found a function instantiation that is not
+                               // immediately called.
+                               foundFuncInst = true
+                       }
                        if n.Op() != ir.OCALLFUNC || n.(*ir.CallExpr).X.Op() != ir.OFUNCINST {
                                return
                        }
@@ -47,19 +65,7 @@ func (g *irgen) stencil() {
                        // instantiation.
                        call := n.(*ir.CallExpr)
                        inst := call.X.(*ir.InstExpr)
-                       sym := makeInstName(inst)
-                       //fmt.Printf("Found generic func call in %v to %v\n", f, s)
-                       st := g.target.Stencils[sym]
-                       if st == nil {
-                               // If instantiation doesn't exist yet, create it and add
-                               // to the list of decls.
-                               st = genericSubst(sym, inst)
-                               g.target.Stencils[sym] = st
-                               g.target.Decls = append(g.target.Decls, st)
-                               if base.Flag.W > 1 {
-                                       ir.Dump(fmt.Sprintf("\nstenciled %v", st), st)
-                               }
-                       }
+                       st := g.getInstantiation(inst)
                        // Replace the OFUNCINST with a direct reference to the
                        // new stenciled function
                        call.X = st.Nname
@@ -76,6 +82,26 @@ func (g *irgen) stencil() {
                        }
                        modified = true
                })
+
+               // If we found an OFUNCINST without a corresponding call in the
+               // above decl, then traverse the nodes of decl again (with
+               // EditChildren rather than Visit), where we actually change the
+               // OFUNCINST node to an ONAME for the instantiated function.
+               // EditChildren is more expensive than Visit, so we only do this
+               // in the infrequent case of an OFUNCINSt without a corresponding
+               // call.
+               if foundFuncInst {
+                       var edit func(ir.Node) ir.Node
+                       edit = func(x ir.Node) ir.Node {
+                               if x.Op() == ir.OFUNCINST {
+                                       st := g.getInstantiation(x.(*ir.InstExpr))
+                                       return st.Nname
+                               }
+                               ir.EditChildren(x, edit)
+                               return x
+                       }
+                       edit(decl)
+               }
                if base.Flag.W > 1 && modified {
                        ir.Dump(fmt.Sprintf("\nmodified %v", decl), decl)
                }
@@ -83,18 +109,39 @@ func (g *irgen) stencil() {
 
 }
 
-// makeInstName makes the unique name for a stenciled generic function, based on
-// the name of the function and the types of the type params.
-func makeInstName(inst *ir.InstExpr) *types.Sym {
-       b := bytes.NewBufferString("#")
+// getInstantiation gets the instantiated function corresponding to inst. If the
+// instantiated function is not already cached, then it calls genericStub to
+// create the new instantiation.
+func (g *irgen) getInstantiation(inst *ir.InstExpr) *ir.Func {
+       var sym *types.Sym
        if meth, ok := inst.X.(*ir.SelectorExpr); ok {
                // Write the name of the generic method, including receiver type
-               b.WriteString(meth.Selection.Nname.Sym().Name)
+               sym = makeInstName(meth.Selection.Nname.Sym(), inst.Targs)
        } else {
-               b.WriteString(inst.X.(*ir.Name).Name().Sym().Name)
+               sym = makeInstName(inst.X.(*ir.Name).Name().Sym(), inst.Targs)
        }
+       //fmt.Printf("Found generic func call in %v to %v\n", f, s)
+       st := g.target.Stencils[sym]
+       if st == nil {
+               // If instantiation doesn't exist yet, create it and add
+               // to the list of decls.
+               st = g.genericSubst(sym, inst)
+               g.target.Stencils[sym] = st
+               g.target.Decls = append(g.target.Decls, st)
+               if base.Flag.W > 1 {
+                       ir.Dump(fmt.Sprintf("\nstenciled %v", st), st)
+               }
+       }
+       return st
+}
+
+// makeInstName makes the unique name for a stenciled generic function, based on
+// the name of the function and the targs.
+func makeInstName(fnsym *types.Sym, targs []ir.Node) *types.Sym {
+       b := bytes.NewBufferString("#")
+       b.WriteString(fnsym.Name)
        b.WriteString("[")
-       for i, targ := range inst.Targs {
+       for i, targ := range targs {
                if i > 0 {
                        b.WriteString(",")
                }
@@ -107,6 +154,7 @@ func makeInstName(inst *ir.InstExpr) *types.Sym {
 // Struct containing info needed for doing the substitution as we create the
 // instantiation of a generic function with specified type arguments.
 type subster struct {
+       g       *irgen
        newf    *ir.Func // Func node for the new stenciled function
        tparams []*types.Field
        targs   []ir.Node
@@ -121,7 +169,7 @@ type subster struct {
 // inst. For a method with a generic receiver, it returns an instantiated function
 // type where the receiver becomes the first parameter. Otherwise the instantiated
 // method would still need to be transformed by later compiler phases.
-func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func {
+func (g *irgen) genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func {
        var nameNode *ir.Name
        var tparams []*types.Field
        if selExpr, ok := inst.X.(*ir.SelectorExpr); ok {
@@ -148,6 +196,7 @@ func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func {
        name.Def = newf.Nname
 
        subst := &subster{
+               g:       g,
                newf:    newf,
                tparams: tparams,
                targs:   inst.Targs,
@@ -198,6 +247,9 @@ func (subst *subster) node(n ir.Node) ir.Node {
                                return v
                        }
                        m := ir.NewNameAt(name.Pos(), name.Sym())
+                       if name.IsClosureVar() {
+                               m.SetIsClosureVar(true)
+                       }
                        t := x.Type()
                        newt := subst.typ(t)
                        m.SetType(newt)
@@ -219,10 +271,12 @@ func (subst *subster) node(n ir.Node) ir.Node {
                                // t can be nil only if this is a call that has no
                                // return values, so allow that and otherwise give
                                // an error.
-                               if _, isCallExpr := m.(*ir.CallExpr); !isCallExpr {
+                               _, isCallExpr := m.(*ir.CallExpr)
+                               _, isStructKeyExpr := m.(*ir.StructKeyExpr)
+                               if !isCallExpr && !isStructKeyExpr {
                                        base.Fatalf(fmt.Sprintf("Nil type for %v", x))
                                }
-                       } else {
+                       } else if x.Op() != ir.OCLOSURE {
                                m.SetType(subst.typ(x.Type()))
                        }
                }
@@ -270,14 +324,27 @@ func (subst *subster) node(n ir.Node) ir.Node {
                        if oldfn.ClosureCalled() {
                                newfn.SetClosureCalled(true)
                        }
+                       newfn.SetIsHiddenClosure(true)
                        m.(*ir.ClosureExpr).Func = newfn
-                       newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), oldfn.Nname.Sym())
-                       newfn.Nname.SetType(oldfn.Nname.Type())
-                       newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype)
+                       newsym := makeInstName(oldfn.Nname.Sym(), subst.targs)
+                       newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), newsym)
+                       newfn.Nname.Func = newfn
+                       newfn.Nname.Defn = newfn
+                       ir.MarkFunc(newfn.Nname)
+                       newfn.OClosure = m.(*ir.ClosureExpr)
+
+                       saveNewf := subst.newf
+                       subst.newf = newfn
+                       newfn.Dcl = subst.namelist(oldfn.Dcl)
+                       newfn.ClosureVars = subst.namelist(oldfn.ClosureVars)
                        newfn.Body = subst.list(oldfn.Body)
-                       // Make shallow copy of the Dcl and ClosureVar slices
-                       newfn.Dcl = append([]*ir.Name(nil), oldfn.Dcl...)
-                       newfn.ClosureVars = append([]*ir.Name(nil), oldfn.ClosureVars...)
+                       subst.newf = saveNewf
+
+                       // Set Ntype for now to be compatible with later parts of compiler
+                       newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype)
+                       typed(subst.typ(oldfn.Nname.Type()), newfn.Nname)
+                       newfn.SetTypecheck(1)
+                       subst.g.target.Decls = append(subst.g.target.Decls, newfn)
                }
                return m
        }
@@ -285,6 +352,20 @@ func (subst *subster) node(n ir.Node) ir.Node {
        return edit(n)
 }
 
+func (subst *subster) namelist(l []*ir.Name) []*ir.Name {
+       s := make([]*ir.Name, len(l))
+       for i, n := range l {
+               s[i] = subst.node(n).(*ir.Name)
+               if n.Defn != nil {
+                       s[i].Defn = subst.node(n.Defn)
+               }
+               if n.Outer != nil {
+                       s[i].Outer = subst.node(n.Outer).(*ir.Name)
+               }
+       }
+       return s
+}
+
 func (subst *subster) list(l []ir.Node) []ir.Node {
        s := make([]ir.Node, len(l))
        for i, n := range l {
@@ -293,7 +374,9 @@ func (subst *subster) list(l []ir.Node) []ir.Node {
        return s
 }
 
-// tstruct substitutes type params in a structure type
+// tstruct substitutes type params in types of the fields of a structure type. For
+// each field, if Nname is set, tstruct also translates the Nname using subst.vars, if
+// Nname is in subst.vars.
 func (subst *subster) tstruct(t *types.Type) *types.Type {
        if t.NumFields() == 0 {
                return t
@@ -301,7 +384,7 @@ func (subst *subster) tstruct(t *types.Type) *types.Type {
        var newfields []*types.Field
        for i, f := range t.Fields().Slice() {
                t2 := subst.typ(f.Type)
-               if t2 != f.Type && newfields == nil {
+               if (t2 != f.Type || f.Nname != nil) && newfields == nil {
                        newfields = make([]*types.Field, t.NumFields())
                        for j := 0; j < i; j++ {
                                newfields[j] = t.Field(j)
@@ -309,6 +392,12 @@ func (subst *subster) tstruct(t *types.Type) *types.Type {
                }
                if newfields != nil {
                        newfields[i] = types.NewField(f.Pos, f.Sym, t2)
+                       if f.Nname != nil {
+                               // f.Nname may not be in subst.vars[] if this is
+                               // a function name or a function instantiation type
+                               // that we are translating
+                               newfields[i].Nname = subst.vars[f.Nname.(*ir.Name)]
+                       }
                }
        }
        if newfields != nil {
@@ -319,14 +408,14 @@ func (subst *subster) tstruct(t *types.Type) *types.Type {
 }
 
 // instTypeName creates a name for an instantiated type, based on the type args
-func instTypeName(name string, targs []ir.Node) string {
+func instTypeName(name string, targs []*types.Type) string {
        b := bytes.NewBufferString(name)
        b.WriteByte('[')
        for i, targ := range targs {
                if i > 0 {
                        b.WriteByte(',')
                }
-               b.WriteString(targ.Type().String())
+               b.WriteString(targ.String())
        }
        b.WriteByte(']')
        return b.String()
@@ -415,10 +504,17 @@ func (subst *subster) typ(t *types.Type) *types.Type {
                        // Since we've substituted types, we also need to change
                        // the defined name of the type, by removing the old types
                        // (in brackets) from the name, and adding the new types.
+
+                       // Translate the type params for this type according to
+                       // the tparam/targs mapping of the function.
+                       neededTargs := make([]*types.Type, len(t.RParams))
+                       for i, rparam := range t.RParams {
+                               neededTargs[i] = subst.typ(rparam)
+                       }
                        oldname := t.Sym().Name
                        i := strings.Index(oldname, "[")
                        oldname = oldname[:i]
-                       sym := t.Sym().Pkg.Lookup(instTypeName(oldname, subst.targs))
+                       sym := t.Sym().Pkg.Lookup(instTypeName(oldname, neededTargs))
                        if sym.Def != nil {
                                // We've already created this instantiated defined type.
                                return sym.Def.Type()
diff --git a/test/typeparam/combine.go b/test/typeparam/combine.go
new file mode 100644 (file)
index 0000000..d4a2988
--- /dev/null
@@ -0,0 +1,65 @@
+// run -gcflags=-G=3
+
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "fmt"
+)
+
+type _Gen[A any] func() (A, bool)
+
+func combine[T1, T2, T any](g1 _Gen[T1], g2 _Gen[T2], join func(T1, T2) T) _Gen[T] {
+    return func() (T, bool) {
+        var t T
+        t1, ok := g1()
+        if !ok {
+            return t, false
+        }
+        t2, ok := g2()
+        if !ok {
+            return t, false
+        }
+        return join(t1, t2), true
+    }
+}
+
+type _Pair[A, B any] struct {
+       A A
+       B B
+}
+
+func _NewPair[A, B any](a A, b B) _Pair[A, B] {
+       return _Pair[A, B]{a, b}
+}
+
+func _Combine2[A, B any](ga _Gen[A], gb _Gen[B]) _Gen[_Pair[A, B]] {
+    return combine(ga, gb, _NewPair[A, B])
+}
+
+func main() {
+       var g1 _Gen[int] = func() (int, bool) { return 3, true }
+       var g2 _Gen[string] = func() (string, bool) { return "x", false }
+       var g3 _Gen[string] = func() (string, bool) { return "y", true }
+
+       gc := combine(g1, g2, _NewPair[int, string])
+       if got, ok := gc(); ok {
+               panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok))
+       }
+       gc2 := _Combine2(g1, g2)
+       if got, ok := gc2(); ok {
+               panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok))
+       }
+
+       gc3 := combine(g1, g3, _NewPair[int, string])
+       if got, ok := gc3(); !ok || got.A != 3 || got.B != "y" {
+               panic(fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok))
+       }
+       gc4 := _Combine2(g1, g3)
+       if got, ok := gc4(); !ok || got.A != 3 || got.B != "y" {
+               panic (fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok))
+       }
+}
index 3bd141f78427548c7ac359877689535926409054..7532953a773f00e692bfa544bd1d34dee439a703 100644 (file)
@@ -11,7 +11,24 @@ import (
        "strconv"
 )
 
-func fromStrings3[T any](s []string, set func(*T, string)) []T {
+type Setter[B any] interface {
+        Set(string)
+       type *B
+}
+
+func fromStrings1[T any, PT Setter[T]](s []string) []T {
+        result := make([]T, len(s))
+        for i, v := range s {
+                // The type of &result[i] is *T which is in the type list
+                // of Setter, so we can convert it to PT.
+                p := PT(&result[i])
+                // PT has a Set method.
+                p.Set(v)
+        }
+        return result
+}
+
+func fromStrings2[T any](s []string, set func(*T, string)) []T {
         results := make([]T, len(s))
         for i, v := range s {
                 set(&results[i], v)
@@ -30,8 +47,12 @@ func (p *Settable) Set(s string) {
 }
 
 func main() {
-        s := fromStrings3([]string{"1"},
-                func(p *Settable, s string) { p.Set(s) })
+        s := fromStrings1[Settable, *Settable]([]string{"1"})
+        if len(s) != 1 || s[0] != 1 {
+                panic(fmt.Sprintf("got %v, want %v", s, []int{1}))
+        }
+
+        s = fromStrings2([]string{"1"}, func(p *Settable, s string) { p.Set(s) })
         if len(s) != 1 || s[0] != 1 {
                 panic(fmt.Sprintf("got %v, want %v", s, []int{1}))
         }