]> Cypherpunks repositories - gostls13.git/commitdiff
go/types, types2: ensure that named types never expand infinitely
authorRobert Findley <rfindley@google.com>
Sun, 8 May 2022 01:22:17 +0000 (21:22 -0400)
committerRobert Findley <rfindley@google.com>
Mon, 6 Jun 2022 15:42:19 +0000 (15:42 +0000)
During type-checking, newly created instances share a type checking
Context which de-duplicates identical instances. However, when
unexpanded types escape the type-checking pass or are created via calls
to Instantiate, they lack this shared context. As reported in #52728,
this may lead to infinitely many identical but distinct types that are
reachable via the API.

This CL introduces a new invariant that ensures we don't create such
infinitely expanding chains: instances created during expansion share a
context with the type that led to their creation. During expansion, the
expanding type passes its Context to any newly created instances.

This ensures that cycles will eventually terminate with a previously
seen instance. For example, if we have an instantiation chain
T1[P]->T2[P]->T3[P]->T1[P], by virtue of this Context passing the
expansion of T3[P] will find the instantiation T1[P].

In general, storing a Context in a Named type could lead to pinning
types in memory unnecessarily, but in this case the Context pins only
those types that are reachable from the original instance. This seems
like a reasonable compromise between lazy and eager expansion.

Our treatment of Context was a little haphazard: Checker.bestContext
made it easy to get a context at any point, but made it harder to reason
about which context is being used. To fix this, replace bestContext with
Checker.context, which returns the type-checking context and panics on a
nil receiver. Update all call-sites to verify that the Checker is
non-nil when context is called.

Also make it a panic to call subst with a nil context. Instead, update
subst to explicitly accept a local (=instance) context along with a
global context, and require that one of them is non-nil. Thread this
through to the call to Checker.instance, and handle context updating
there.

Fixes #52728

Change-Id: Ib7f26eb8c406290325bc3212fda25421a37a1e8e
Reviewed-on: https://go-review.googlesource.com/c/go/+/404885
Reviewed-by: Robert Griesemer <gri@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Robert Findley <rfindley@google.com>

18 files changed:
src/cmd/compile/internal/types2/call.go
src/cmd/compile/internal/types2/infer.go
src/cmd/compile/internal/types2/instantiate.go
src/cmd/compile/internal/types2/named.go
src/cmd/compile/internal/types2/named_test.go
src/cmd/compile/internal/types2/predicates.go
src/cmd/compile/internal/types2/signature.go
src/cmd/compile/internal/types2/subst.go
src/cmd/compile/internal/types2/typexpr.go
src/go/types/call.go
src/go/types/infer.go
src/go/types/instantiate.go
src/go/types/named.go
src/go/types/named_test.go
src/go/types/predicates.go
src/go/types/signature.go
src/go/types/subst.go
src/go/types/typexpr.go

index 3ade147dfe2e4cdc5da61d62e30221b4aeeb0fbe..b1ea6917fb1a21a235e8f0b48e707033653b2995 100644 (file)
@@ -72,13 +72,13 @@ func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs
                }()
        }
 
-       inst := check.instance(pos, typ, targs, check.bestContext(nil)).(*Signature)
+       inst := check.instance(pos, typ, targs, nil, check.context()).(*Signature)
        assert(len(xlist) <= len(targs))
 
        // verify instantiation lazily (was issue #50450)
        check.later(func() {
                tparams := typ.TypeParams().list()
-               if i, err := check.verify(pos, tparams, targs); err != nil {
+               if i, err := check.verify(pos, tparams, targs, check.context()); err != nil {
                        // best position for error reporting
                        pos := pos
                        if i < len(xlist) {
@@ -395,7 +395,7 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
                // need to compute it from the adjusted list; otherwise we can
                // simply use the result signature's parameter list.
                if adjusted {
-                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(sig.TypeParams().list(), targs), nil).(*Tuple)
+                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(sig.TypeParams().list(), targs), nil, check.context()).(*Tuple)
                } else {
                        sigParams = rsig.params
                }
index 8ab568be596109e942c5c756c9f7ee43973c4d58..b0c6a4fceac275c08e5c26fb26854183c4d9a8b2 100644 (file)
@@ -110,11 +110,11 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
 
                        renameMap := makeRenameMap(tparams, tparams2)
                        for i, tparam := range tparams {
-                               tparams2[i].bound = check.subst(pos, tparam.bound, renameMap, nil)
+                               tparams2[i].bound = check.subst(pos, tparam.bound, renameMap, nil, check.context())
                        }
 
                        tparams = tparams2
-                       params = check.subst(pos, params, renameMap, nil).(*Tuple)
+                       params = check.subst(pos, params, renameMap, nil, check.context()).(*Tuple)
                }
        }
 
@@ -188,7 +188,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
        //           but that doesn't impact the isParameterized check for now).
        if params.Len() > 0 {
                smap := makeSubstMap(tparams, targs)
-               params = check.subst(nopos, params, smap, nil).(*Tuple)
+               params = check.subst(nopos, params, smap, nil, check.context()).(*Tuple)
        }
 
        // Unify parameter and argument types for generic parameters with typed arguments
@@ -224,7 +224,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
                        }
                }
                smap := makeSubstMap(tparams, targs)
-               inferred := check.subst(arg.Pos(), tpar, smap, nil)
+               inferred := check.subst(arg.Pos(), tpar, smap, nil, check.context())
                if inferred != tpar {
                        check.errorf(arg, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar)
                } else {
@@ -626,7 +626,7 @@ func (check *Checker) inferB(pos syntax.Pos, tparams []*TypeParam, targs []Type)
                n := 0
                for _, index := range dirty {
                        t0 := types[index]
-                       if t1 := check.subst(nopos, t0, smap, nil); t1 != t0 {
+                       if t1 := check.subst(nopos, t0, smap, nil, check.context()); t1 != t0 {
                                types[index] = t1
                                dirty[n] = index
                                n++
index ed6206e150fe0179bf9f39a336bb231f026f2df0..f338e28d2e5875af10e0734179622134b1ac35a8 100644 (file)
@@ -40,6 +40,9 @@ import (
 // count is incorrect; for *Named types, a panic may occur later inside the
 // *Named API.
 func Instantiate(ctxt *Context, orig Type, targs []Type, validate bool) (Type, error) {
+       if ctxt == nil {
+               ctxt = NewContext()
+       }
        if validate {
                var tparams []*TypeParam
                switch t := orig.(type) {
@@ -51,34 +54,71 @@ func Instantiate(ctxt *Context, orig Type, targs []Type, validate bool) (Type, e
                if len(targs) != len(tparams) {
                        return nil, fmt.Errorf("got %d type arguments but %s has %d type parameters", len(targs), orig, len(tparams))
                }
-               if i, err := (*Checker)(nil).verify(nopos, tparams, targs); err != nil {
+               if i, err := (*Checker)(nil).verify(nopos, tparams, targs, ctxt); err != nil {
                        return nil, &ArgumentError{i, err}
                }
        }
 
-       inst := (*Checker)(nil).instance(nopos, orig, targs, ctxt)
+       inst := (*Checker)(nil).instance(nopos, orig, targs, nil, ctxt)
        return inst, nil
 }
 
-// instance creates a type or function instance using the given original type
-// typ and arguments targs. For Named types the resulting instance will be
-// unexpanded. check may be nil.
-func (check *Checker) instance(pos syntax.Pos, orig Type, targs []Type, ctxt *Context) (res Type) {
-       var h string
-       if ctxt != nil {
-               h = ctxt.instanceHash(orig, targs)
-               // typ may already have been instantiated with identical type arguments. In
-               // that case, re-use the existing instance.
-               if inst := ctxt.lookup(h, orig, targs); inst != nil {
-                       return inst
+// instance resolves a type or function instance for the given original type
+// and type arguments. It looks for an existing identical instance in the given
+// contexts, creating a new instance if none is found.
+//
+// If local is non-nil, it is the context associated with a Named instance
+// type currently being expanded. If global is non-nil, it is the context
+// associated with the current type-checking pass or call to Instantiate. At
+// least one of local or global must be non-nil.
+//
+// For Named types the resulting instance may be unexpanded.
+func (check *Checker) instance(pos syntax.Pos, orig Type, targs []Type, local, global *Context) (res Type) {
+       // The order of the contexts below matters: we always prefer instances in
+       // local in order to preserve reference cycles.
+       //
+       // Invariant: if local != nil, the returned instance will be the instance
+       // recorded in local.
+       var ctxts []*Context
+       if local != nil {
+               ctxts = append(ctxts, local)
+       }
+       if global != nil {
+               ctxts = append(ctxts, global)
+       }
+       assert(len(ctxts) > 0)
+
+       // Compute all hashes; hashes may differ across contexts due to different
+       // unique IDs for Named types within the hasher.
+       hashes := make([]string, len(ctxts))
+       for i, ctxt := range ctxts {
+               hashes[i] = ctxt.instanceHash(orig, targs)
+       }
+
+       // If local is non-nil, updateContexts return the type recorded in
+       // local.
+       updateContexts := func(res Type) Type {
+               for i := len(ctxts) - 1; i >= 0; i-- {
+                       res = ctxts[i].update(hashes[i], orig, targs, res)
+               }
+               return res
+       }
+
+       // typ may already have been instantiated with identical type arguments. In
+       // that case, re-use the existing instance.
+       for i, ctxt := range ctxts {
+               if inst := ctxt.lookup(hashes[i], orig, targs); inst != nil {
+                       return updateContexts(inst)
                }
        }
 
        switch orig := orig.(type) {
        case *Named:
-               res = check.newNamedInstance(pos, orig, targs)
+               res = check.newNamedInstance(pos, orig, targs, local) // substituted lazily
 
        case *Signature:
+               assert(local == nil) // function instances cannot be reached from Named types
+
                tparams := orig.TypeParams()
                if !check.validateTArgLen(pos, tparams.Len(), len(targs)) {
                        return Typ[Invalid]
@@ -86,7 +126,7 @@ func (check *Checker) instance(pos syntax.Pos, orig Type, targs []Type, ctxt *Co
                if tparams.Len() == 0 {
                        return orig // nothing to do (minor optimization)
                }
-               sig := check.subst(pos, orig, makeSubstMap(tparams.list(), targs), ctxt).(*Signature)
+               sig := check.subst(pos, orig, makeSubstMap(tparams.list(), targs), nil, global).(*Signature)
                // If the signature doesn't use its type parameters, subst
                // will not make a copy. In that case, make a copy now (so
                // we can set tparams to nil w/o causing side-effects).
@@ -104,13 +144,8 @@ func (check *Checker) instance(pos syntax.Pos, orig Type, targs []Type, ctxt *Co
                panic(fmt.Sprintf("%v: cannot instantiate %v", pos, orig))
        }
 
-       if ctxt != nil {
-               // It's possible that we've lost a race to add named to the context.
-               // In this case, use whichever instance is recorded in the context.
-               res = ctxt.update(h, orig, targs, res)
-       }
-
-       return res
+       // Update all contexts; it's possible that we've lost a race.
+       return updateContexts(res)
 }
 
 // validateTArgLen verifies that the length of targs and tparams matches,
@@ -128,7 +163,7 @@ func (check *Checker) validateTArgLen(pos syntax.Pos, ntparams, ntargs int) bool
        return true
 }
 
-func (check *Checker) verify(pos syntax.Pos, tparams []*TypeParam, targs []Type) (int, error) {
+func (check *Checker) verify(pos syntax.Pos, tparams []*TypeParam, targs []Type, ctxt *Context) (int, error) {
        smap := makeSubstMap(tparams, targs)
        for i, tpar := range tparams {
                // Ensure that we have a (possibly implicit) interface as type bound (issue #51048).
@@ -137,7 +172,7 @@ func (check *Checker) verify(pos syntax.Pos, tparams []*TypeParam, targs []Type)
                // as the instantiated type; before we can use it for bounds checking we
                // need to instantiate it with the type arguments with which we instantiated
                // the parameterized type.
-               bound := check.subst(pos, tpar.bound, smap, nil)
+               bound := check.subst(pos, tpar.bound, smap, nil, ctxt)
                if err := check.implements(targs[i], bound); err != nil {
                        return i, err
                }
index 133fb6fa88e0caf4fe47b10e00164d900fca7018..720e500cd50bb6880b38404509f96c93ca07504c 100644 (file)
@@ -79,6 +79,16 @@ import (
 // Identical to compare them. For instantiated named types, their obj is a
 // synthetic placeholder that records their position of the corresponding
 // instantiation in the source (if they were constructed during type checking).
+//
+// To prevent infinite expansion of named instances that are created outside of
+// type-checking, instances share a Context with other instances created during
+// their expansion. Via the pidgeonhole principle, this guarantees that in the
+// presence of a cycle of named types, expansion will eventually find an
+// existing instance in the Context and short-circuit the expansion.
+//
+// Once an instance is complete, we can nil out this shared Context to unpin
+// memory, though this Context may still be held by other incomplete instances
+// in its "lineage".
 
 // A Named represents a named (defined) type.
 type Named struct {
@@ -115,6 +125,7 @@ type instance struct {
        orig            *Named    // original, uninstantiated type
        targs           *TypeList // type arguments
        expandedMethods int       // number of expanded methods; expandedMethods <= len(orig.methods)
+       ctxt            *Context  // local Context; set to nil after full expansion
 }
 
 // namedState represents the possible states that a named type may assume.
@@ -143,7 +154,7 @@ func NewNamed(obj *TypeName, underlying Type, methods []*Func) *Named {
 // After resolution, the type parameters, methods, and underlying type of n are
 // accessible; but if n is an instantiated type, its methods may still be
 // unexpanded.
-func (n *Named) resolve(ctxt *Context) *Named {
+func (n *Named) resolve() *Named {
        if n.state() >= resolved { // avoid locking below
                return n
        }
@@ -162,8 +173,8 @@ func (n *Named) resolve(ctxt *Context) *Named {
                assert(n.loader == nil)     // instances are created by instantiation, in which case n.loader is nil
 
                orig := n.inst.orig
-               orig.resolve(ctxt)
-               underlying := n.expandUnderlying(ctxt)
+               orig.resolve()
+               underlying := n.expandUnderlying()
 
                n.tparams = orig.tparams
                n.underlying = underlying
@@ -171,6 +182,7 @@ func (n *Named) resolve(ctxt *Context) *Named {
 
                if len(orig.methods) == 0 {
                        n.setState(complete) // nothing further to do
+                       n.inst.ctxt = nil
                } else {
                        n.setState(resolved)
                }
@@ -225,11 +237,11 @@ func (check *Checker) newNamed(obj *TypeName, underlying Type, methods []*Func)
        return typ
 }
 
-func (check *Checker) newNamedInstance(pos syntax.Pos, orig *Named, targs []Type) *Named {
+func (check *Checker) newNamedInstance(pos syntax.Pos, orig *Named, targs []Type, local *Context) *Named {
        assert(len(targs) > 0)
 
        obj := NewTypeName(pos, orig.obj.pkg, orig.obj.name, nil)
-       inst := &instance{orig: orig, targs: newTypeList(targs)}
+       inst := &instance{orig: orig, targs: newTypeList(targs), ctxt: local}
        typ := &Named{check: check, obj: obj, inst: inst}
        obj.typ = typ
        // Ensure that typ is always expanded and sanity-checked.
@@ -280,13 +292,13 @@ func (t *Named) Origin() *Named {
 
 // TypeParams returns the type parameters of the named type t, or nil.
 // The result is non-nil for an (originally) generic type even if it is instantiated.
-func (t *Named) TypeParams() *TypeParamList { return t.resolve(nil).tparams }
+func (t *Named) TypeParams() *TypeParamList { return t.resolve().tparams }
 
 // SetTypeParams sets the type parameters of the named type t.
 // t must not have type arguments.
 func (t *Named) SetTypeParams(tparams []*TypeParam) {
        assert(t.inst == nil)
-       t.resolve(nil).tparams = bindTParams(tparams)
+       t.resolve().tparams = bindTParams(tparams)
 }
 
 // TypeArgs returns the type arguments used to instantiate the named type t.
@@ -298,17 +310,17 @@ func (t *Named) TypeArgs() *TypeList {
 }
 
 // NumMethods returns the number of explicit methods defined for t.
-//
-// For an ordinary or instantiated type t, the receiver base type of these
-// methods will be the named type t. For an uninstantiated generic type t, each
-// method receiver will be instantiated with its receiver type parameters.
 func (t *Named) NumMethods() int {
-       return len(t.Origin().resolve(nil).methods)
+       return len(t.Origin().resolve().methods)
 }
 
 // Method returns the i'th method of named type t for 0 <= i < t.NumMethods().
+//
+// For an ordinary or instantiated type t, the receiver base type of this
+// method is the named type t. For an uninstantiated generic type t, each
+// method receiver is instantiated with its receiver type parameters.
 func (t *Named) Method(i int) *Func {
-       t.resolve(nil)
+       t.resolve()
 
        if t.state() >= complete {
                return t.methods[i]
@@ -326,6 +338,7 @@ func (t *Named) Method(i int) *Func {
        }
 
        if t.methods[i] == nil {
+               assert(t.inst.ctxt != nil) // we should still have a context remaining from the resolution phase
                t.methods[i] = t.expandMethod(i)
                t.inst.expandedMethods++
 
@@ -333,6 +346,7 @@ func (t *Named) Method(i int) *Func {
                // type as fully expanded.
                if t.inst.expandedMethods == len(orig.methods) {
                        t.setState(complete)
+                       t.inst.ctxt = nil // no need for a context anymore
                }
        }
 
@@ -372,9 +386,12 @@ func (t *Named) expandMethod(i int) *Func {
        // and type parameters. This check is necessary in the presence of invalid
        // code.
        if origSig.RecvTypeParams().Len() == t.inst.targs.Len() {
-               ctxt := check.bestContext(nil)
                smap := makeSubstMap(origSig.RecvTypeParams().list(), t.inst.targs.list())
-               sig = check.subst(origm.pos, origSig, smap, ctxt).(*Signature)
+               var global *Context
+               if check != nil {
+                       global = check.context()
+               }
+               sig = check.subst(origm.pos, origSig, smap, t.inst.ctxt, global).(*Signature)
        }
 
        if sig == origSig {
@@ -405,7 +422,7 @@ func (t *Named) SetUnderlying(underlying Type) {
        if _, ok := underlying.(*Named); ok {
                panic("underlying type must not be *Named")
        }
-       t.resolve(nil).underlying = underlying
+       t.resolve().underlying = underlying
        if t.fromRHS == nil {
                t.fromRHS = underlying // for cycle detection
        }
@@ -415,17 +432,20 @@ func (t *Named) SetUnderlying(underlying Type) {
 // t must not have type arguments.
 func (t *Named) AddMethod(m *Func) {
        assert(t.inst == nil)
-       t.resolve(nil)
+       t.resolve()
        if i, _ := lookupMethod(t.methods, m.pkg, m.name, false); i < 0 {
                t.methods = append(t.methods, m)
        }
 }
 
-func (t *Named) Underlying() Type { return t.resolve(nil).underlying }
+func (t *Named) Underlying() Type { return t.resolve().underlying }
 func (t *Named) String() string   { return TypeString(t, nil) }
 
 // ----------------------------------------------------------------------------
 // Implementation
+//
+// TODO(rfindley): reorganize the loading and expansion methods under this
+// heading.
 
 // under returns the expanded underlying type of n0; possibly by following
 // forward chains of named types. If an underlying type is found, resolve
@@ -522,7 +542,7 @@ func (n *Named) setUnderlying(typ Type) {
 }
 
 func (n *Named) lookupMethod(pkg *Package, name string, foldCase bool) (int, *Func) {
-       n.resolve(nil)
+       n.resolve()
        // If n is an instance, we may not have yet instantiated all of its methods.
        // Look up the method index in orig, and only instantiate method at the
        // matching index (if any).
@@ -534,26 +554,17 @@ func (n *Named) lookupMethod(pkg *Package, name string, foldCase bool) (int, *Fu
        return i, n.Method(i)
 }
 
-// bestContext returns the best available context. In order of preference:
-// - the given ctxt, if non-nil
-// - check.ctxt, if check is non-nil
-// - a new Context
-func (check *Checker) bestContext(ctxt *Context) *Context {
-       if ctxt != nil {
-               return ctxt
+// context returns the type-checker context.
+func (check *Checker) context() *Context {
+       if check.ctxt == nil {
+               check.ctxt = NewContext()
        }
-       if check != nil {
-               if check.ctxt == nil {
-                       check.ctxt = NewContext()
-               }
-               return check.ctxt
-       }
-       return NewContext()
+       return check.ctxt
 }
 
 // expandUnderlying substitutes type arguments in the underlying type n.orig,
 // returning the result. Returns Typ[Invalid] if there was an error.
-func (n *Named) expandUnderlying(ctxt *Context) Type {
+func (n *Named) expandUnderlying() Type {
        check := n.check
        if check != nil && check.conf.Trace {
                check.trace(n.obj.pos, "-- Named.expandUnderlying %s", n)
@@ -565,6 +576,9 @@ func (n *Named) expandUnderlying(ctxt *Context) Type {
        }
 
        assert(n.inst.orig.underlying != nil)
+       if n.inst.ctxt == nil {
+               n.inst.ctxt = NewContext()
+       }
 
        orig := n.inst.orig
        targs := n.inst.targs
@@ -580,16 +594,20 @@ func (n *Named) expandUnderlying(ctxt *Context) Type {
                return Typ[Invalid]
        }
 
-       // We must always have a context, to avoid infinite recursion.
-       ctxt = check.bestContext(ctxt)
-       h := ctxt.instanceHash(orig, targs.list())
-       // ensure that an instance is recorded for h to avoid infinite recursion.
-       ctxt.update(h, orig, targs.list(), n)
+       // Ensure that an instance is recorded before substituting, so that we
+       // resolve n for any recursive references.
+       h := n.inst.ctxt.instanceHash(orig, targs.list())
+       n2 := n.inst.ctxt.update(h, orig, n.TypeArgs().list(), n)
+       assert(n == n2)
 
        smap := makeSubstMap(orig.tparams.list(), targs.list())
-       underlying := n.check.subst(n.obj.pos, orig.underlying, smap, ctxt)
-       // If the underlying type of n is an interface, we need to set the receiver
-       // of its methods accurately -- we set the receiver of interface methods on
+       var global *Context
+       if check != nil {
+               global = check.context()
+       }
+       underlying := n.check.subst(n.obj.pos, orig.underlying, smap, n.inst.ctxt, global)
+       // If the underlying type of n is an interface, we need to set the receiver of
+       // its methods accurately -- we set the receiver of interface methods on
        // the RHS of a type declaration to the defined type.
        if iface, _ := underlying.(*Interface); iface != nil {
                if methods, copied := replaceRecvType(iface.methods, orig, n); copied {
index 14a982048a2bcc28d14d0b5dd66b89a81dd329d1..e5e8eddb054ef66b8ffc2a581d1d222045eec3a4 100644 (file)
@@ -7,6 +7,7 @@ package types2_test
 import (
        "testing"
 
+       "cmd/compile/internal/syntax"
        . "cmd/compile/internal/types2"
 )
 
@@ -73,3 +74,47 @@ func mustInstantiate(tb testing.TB, orig Type, targs ...Type) Type {
        }
        return inst
 }
+
+// Test that types do not expand infinitely, as in golang/go#52715.
+func TestFiniteTypeExpansion(t *testing.T) {
+       const src = `
+package p
+
+type Tree[T any] struct {
+       *Node[T]
+}
+
+func (*Tree[R]) N(r R) R { return r }
+
+type Node[T any] struct {
+       *Tree[T]
+}
+
+func (Node[Q]) M(Q) {}
+
+type Inst = *Tree[int]
+`
+
+       f, err := parseSrc("foo.go", src)
+       if err != nil {
+               t.Fatal(err)
+       }
+       pkg := NewPackage("p", f.PkgName.Value)
+       if err := NewChecker(nil, pkg, nil).Files([]*syntax.File{f}); err != nil {
+               t.Fatal(err)
+       }
+
+       firstFieldType := func(n *Named) *Named {
+               return n.Underlying().(*Struct).Field(0).Type().(*Pointer).Elem().(*Named)
+       }
+
+       Inst := pkg.Scope().Lookup("Inst").Type().(*Pointer).Elem().(*Named)
+       Node := firstFieldType(Inst)
+       Tree := firstFieldType(Node)
+       if !Identical(Inst, Tree) {
+               t.Fatalf("Not a cycle: got %v, want %v", Tree, Inst)
+       }
+       if Inst != Tree {
+               t.Errorf("Duplicate instances in cycle: %s (%p) -> %s (%p) -> %s (%p)", Inst, Inst, Node, Node, Tree, Tree)
+       }
+}
index 6b6c21c780195618c7c8c3e2ba21629a1dfdbd5f..f7b5b16204e5c823a94f818d0018595f1437136b 100644 (file)
@@ -283,18 +283,19 @@ func identical(x, y Type, cmpTags bool, p *ifacePair) bool {
                        }
                        smap := makeSubstMap(ytparams, targs)
 
-                       var check *Checker // ok to call subst on a nil *Checker
+                       var check *Checker   // ok to call subst on a nil *Checker
+                       ctxt := NewContext() // need a non-nil Context for the substitution below
 
                        // Constraints must be pair-wise identical, after substitution.
                        for i, xtparam := range xtparams {
-                               ybound := check.subst(nopos, ytparams[i].bound, smap, nil)
+                               ybound := check.subst(nopos, ytparams[i].bound, smap, nil, ctxt)
                                if !identical(xtparam.bound, ybound, cmpTags, p) {
                                        return false
                                }
                        }
 
-                       yparams = check.subst(nopos, y.params, smap, nil).(*Tuple)
-                       yresults = check.subst(nopos, y.results, smap, nil).(*Tuple)
+                       yparams = check.subst(nopos, y.params, smap, nil, ctxt).(*Tuple)
+                       yresults = check.subst(nopos, y.results, smap, nil, ctxt).(*Tuple)
                }
 
                return x.variadic == y.variadic &&
index 92d3aadf888b4f841951ee141d1f05530aa2e258..1b61b368d2a3fb65fa5069414870b2df6185db85 100644 (file)
@@ -143,7 +143,7 @@ func (check *Checker) funcType(sig *Signature, recvPar *syntax.Field, tparams []
                                        // recvTPar.bound is (possibly) parameterized in the context of the
                                        // receiver type declaration. Substitute parameters for the current
                                        // context.
-                                       tpar.bound = check.subst(tpar.obj.pos, recvTPar.bound, smap, nil)
+                                       tpar.bound = check.subst(tpar.obj.pos, recvTPar.bound, smap, nil, check.context())
                                }
                        } else if len(tparams) < len(recvTParams) {
                                // Reporting an error here is a stop-gap measure to avoid crashes in the
index 9af1a71cfebccf16f3da9ea1b06e564c23a15a7d..4a4c8f960a607c4d8b586020eb768080d48e3f67 100644 (file)
@@ -49,7 +49,9 @@ func (m substMap) lookup(tpar *TypeParam) Type {
 // from the incoming type.
 //
 // If the given context is non-nil, it is used in lieu of check.Config.Context.
-func (check *Checker) subst(pos syntax.Pos, typ Type, smap substMap, ctxt *Context) Type {
+func (check *Checker) subst(pos syntax.Pos, typ Type, smap substMap, local, global *Context) Type {
+       assert(local != nil || global != nil)
+
        if smap.empty() {
                return typ
        }
@@ -64,19 +66,20 @@ func (check *Checker) subst(pos syntax.Pos, typ Type, smap substMap, ctxt *Conte
 
        // general case
        subst := subster{
-               pos:   pos,
-               smap:  smap,
-               check: check,
-               ctxt:  check.bestContext(ctxt),
+               pos:    pos,
+               smap:   smap,
+               check:  check,
+               local:  local,
+               global: global,
        }
        return subst.typ(typ)
 }
 
 type subster struct {
-       pos   syntax.Pos
-       smap  substMap
-       check *Checker // nil if called via Instantiate
-       ctxt  *Context
+       pos           syntax.Pos
+       smap          substMap
+       check         *Checker // nil if called via Instantiate
+       local, global *Context
 }
 
 func (subst *subster) typ(typ Type) Type {
@@ -247,25 +250,11 @@ func (subst *subster) typ(typ Type) Type {
                        return t // nothing to substitute
                }
 
-               // before creating a new named type, check if we have this one already
-               h := subst.ctxt.instanceHash(orig, newTArgs)
-               dump(">>> new type hash: %s", h)
-               if named := subst.ctxt.lookup(h, orig, newTArgs); named != nil {
-                       dump(">>> found %s", named)
-                       return named
-               }
-
                // Create a new instance and populate the context to avoid endless
                // recursion. The position used here is irrelevant because validation only
                // occurs on t (we don't call validType on named), but we use subst.pos to
                // help with debugging.
-               return subst.check.instance(subst.pos, orig, newTArgs, subst.ctxt)
-
-               // Note that if we were to expose substitution more generally (not just in
-               // the context of a declaration), we'd have to substitute in
-               // named.underlying as well.
-               //
-               // But this is unnecessary for now.
+               return subst.check.instance(subst.pos, orig, newTArgs, subst.local, subst.global)
 
        case *TypeParam:
                return subst.smap.lookup(t)
index ea13eb622d82e529a217ac99303c860c3acbf20e..f0cd236050fd215c5698b6b382f90bfb1281292d 100644 (file)
@@ -433,8 +433,7 @@ func (check *Checker) instantiatedType(x syntax.Expr, xlist []syntax.Expr, def *
        }
 
        // create the instance
-       ctxt := check.bestContext(nil)
-       inst := check.instance(x.Pos(), orig, targs, ctxt).(*Named)
+       inst := check.instance(x.Pos(), orig, targs, nil, check.context()).(*Named)
        def.setUnderlying(inst)
 
        // orig.tparams may not be set up, so we need to do expansion later.
@@ -445,7 +444,7 @@ func (check *Checker) instantiatedType(x syntax.Expr, xlist []syntax.Expr, def *
                check.recordInstance(x, inst.TypeArgs().list(), inst)
 
                if check.validateTArgLen(x.Pos(), inst.TypeParams().Len(), inst.TypeArgs().Len()) {
-                       if i, err := check.verify(x.Pos(), inst.TypeParams().list(), inst.TypeArgs().list()); err != nil {
+                       if i, err := check.verify(x.Pos(), inst.TypeParams().list(), inst.TypeArgs().list(), check.context()); err != nil {
                                // best position for error reporting
                                pos := x.Pos()
                                if i < len(xlist) {
index 3c7c3226f622e3e0afdac08908d3d321a26c6c80..c580885a5ab16bec67eb1b4a722b15cf12aba476 100644 (file)
@@ -73,13 +73,13 @@ func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs
                }()
        }
 
-       inst := check.instance(pos, typ, targs, check.bestContext(nil)).(*Signature)
+       inst := check.instance(pos, typ, targs, nil, check.context()).(*Signature)
        assert(len(xlist) <= len(targs))
 
        // verify instantiation lazily (was issue #50450)
        check.later(func() {
                tparams := typ.TypeParams().list()
-               if i, err := check.verify(pos, tparams, targs); err != nil {
+               if i, err := check.verify(pos, tparams, targs, check.context()); err != nil {
                        // best position for error reporting
                        pos := pos
                        if i < len(xlist) {
@@ -400,7 +400,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
                // need to compute it from the adjusted list; otherwise we can
                // simply use the result signature's parameter list.
                if adjusted {
-                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(sig.TypeParams().list(), targs), nil).(*Tuple)
+                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(sig.TypeParams().list(), targs), nil, check.context()).(*Tuple)
                } else {
                        sigParams = rsig.params
                }
index ebe6d8ced7f308db563bf92f256f3dfbc5bd5443..1aa26126387ad3fc21e55bd21695846001827c02 100644 (file)
@@ -110,11 +110,11 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
 
                        renameMap := makeRenameMap(tparams, tparams2)
                        for i, tparam := range tparams {
-                               tparams2[i].bound = check.subst(posn.Pos(), tparam.bound, renameMap, nil)
+                               tparams2[i].bound = check.subst(posn.Pos(), tparam.bound, renameMap, nil, check.context())
                        }
 
                        tparams = tparams2
-                       params = check.subst(posn.Pos(), params, renameMap, nil).(*Tuple)
+                       params = check.subst(posn.Pos(), params, renameMap, nil, check.context()).(*Tuple)
                }
        }
 
@@ -188,7 +188,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
        //           but that doesn't impact the isParameterized check for now).
        if params.Len() > 0 {
                smap := makeSubstMap(tparams, targs)
-               params = check.subst(token.NoPos, params, smap, nil).(*Tuple)
+               params = check.subst(token.NoPos, params, smap, nil, check.context()).(*Tuple)
        }
 
        // Unify parameter and argument types for generic parameters with typed arguments
@@ -225,7 +225,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
                }
                smap := makeSubstMap(tparams, targs)
                // TODO(rFindley): pass a positioner here, rather than arg.Pos().
-               inferred := check.subst(arg.Pos(), tpar, smap, nil)
+               inferred := check.subst(arg.Pos(), tpar, smap, nil, check.context())
                // _CannotInferTypeArgs indicates a failure of inference, though the actual
                // error may be better attributed to a user-provided type argument (hence
                // _InvalidTypeArg). We can't differentiate these cases, so fall back on
@@ -626,7 +626,7 @@ func (check *Checker) inferB(posn positioner, tparams []*TypeParam, targs []Type
                n := 0
                for _, index := range dirty {
                        t0 := types[index]
-                       if t1 := check.subst(token.NoPos, t0, smap, nil); t1 != t0 {
+                       if t1 := check.subst(token.NoPos, t0, smap, nil, check.context()); t1 != t0 {
                                types[index] = t1
                                dirty[n] = index
                                n++
index d420a615728e7904f1a9c88949e6157ce6a7386d..6091b0b38147a779ce4177063c1b9239da09763a 100644 (file)
@@ -40,6 +40,9 @@ import (
 // count is incorrect; for *Named types, a panic may occur later inside the
 // *Named API.
 func Instantiate(ctxt *Context, orig Type, targs []Type, validate bool) (Type, error) {
+       if ctxt == nil {
+               ctxt = NewContext()
+       }
        if validate {
                var tparams []*TypeParam
                switch t := orig.(type) {
@@ -51,34 +54,71 @@ func Instantiate(ctxt *Context, orig Type, targs []Type, validate bool) (Type, e
                if len(targs) != len(tparams) {
                        return nil, fmt.Errorf("got %d type arguments but %s has %d type parameters", len(targs), orig, len(tparams))
                }
-               if i, err := (*Checker)(nil).verify(token.NoPos, tparams, targs); err != nil {
+               if i, err := (*Checker)(nil).verify(token.NoPos, tparams, targs, ctxt); err != nil {
                        return nil, &ArgumentError{i, err}
                }
        }
 
-       inst := (*Checker)(nil).instance(token.NoPos, orig, targs, ctxt)
+       inst := (*Checker)(nil).instance(token.NoPos, orig, targs, nil, ctxt)
        return inst, nil
 }
 
-// instance creates a type or function instance using the given original type
-// typ and arguments targs. For Named types the resulting instance will be
-// unexpanded. check may be nil.
-func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, ctxt *Context) (res Type) {
-       var h string
-       if ctxt != nil {
-               h = ctxt.instanceHash(orig, targs)
-               // typ may already have been instantiated with identical type arguments. In
-               // that case, re-use the existing instance.
-               if inst := ctxt.lookup(h, orig, targs); inst != nil {
-                       return inst
+// instance resolves a type or function instance for the given original type
+// and type arguments. It looks for an existing identical instance in the given
+// contexts, creating a new instance if none is found.
+//
+// If local is non-nil, it is the context associated with a Named instance
+// type currently being expanded. If global is non-nil, it is the context
+// associated with the current type-checking pass or call to Instantiate. At
+// least one of local or global must be non-nil.
+//
+// For Named types the resulting instance may be unexpanded.
+func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, local, global *Context) (res Type) {
+       // The order of the contexts below matters: we always prefer instances in
+       // local in order to preserve reference cycles.
+       //
+       // Invariant: if local != nil, the returned instance will be the instance
+       // recorded in local.
+       var ctxts []*Context
+       if local != nil {
+               ctxts = append(ctxts, local)
+       }
+       if global != nil {
+               ctxts = append(ctxts, global)
+       }
+       assert(len(ctxts) > 0)
+
+       // Compute all hashes; hashes may differ across contexts due to different
+       // unique IDs for Named types within the hasher.
+       hashes := make([]string, len(ctxts))
+       for i, ctxt := range ctxts {
+               hashes[i] = ctxt.instanceHash(orig, targs)
+       }
+
+       // If local is non-nil, updateContexts return the type recorded in
+       // local.
+       updateContexts := func(res Type) Type {
+               for i := len(ctxts) - 1; i >= 0; i-- {
+                       res = ctxts[i].update(hashes[i], orig, targs, res)
+               }
+               return res
+       }
+
+       // typ may already have been instantiated with identical type arguments. In
+       // that case, re-use the existing instance.
+       for i, ctxt := range ctxts {
+               if inst := ctxt.lookup(hashes[i], orig, targs); inst != nil {
+                       return updateContexts(inst)
                }
        }
 
        switch orig := orig.(type) {
        case *Named:
-               res = check.newNamedInstance(pos, orig, targs)
+               res = check.newNamedInstance(pos, orig, targs, local) // substituted lazily
 
        case *Signature:
+               assert(local == nil) // function instances cannot be reached from Named types
+
                tparams := orig.TypeParams()
                if !check.validateTArgLen(pos, tparams.Len(), len(targs)) {
                        return Typ[Invalid]
@@ -86,7 +126,7 @@ func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, ctxt *Con
                if tparams.Len() == 0 {
                        return orig // nothing to do (minor optimization)
                }
-               sig := check.subst(pos, orig, makeSubstMap(tparams.list(), targs), ctxt).(*Signature)
+               sig := check.subst(pos, orig, makeSubstMap(tparams.list(), targs), nil, global).(*Signature)
                // If the signature doesn't use its type parameters, subst
                // will not make a copy. In that case, make a copy now (so
                // we can set tparams to nil w/o causing side-effects).
@@ -104,13 +144,8 @@ func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, ctxt *Con
                panic(fmt.Sprintf("%v: cannot instantiate %v", pos, orig))
        }
 
-       if ctxt != nil {
-               // It's possible that we've lost a race to add named to the context.
-               // In this case, use whichever instance is recorded in the context.
-               res = ctxt.update(h, orig, targs, res)
-       }
-
-       return res
+       // Update all contexts; it's possible that we've lost a race.
+       return updateContexts(res)
 }
 
 // validateTArgLen verifies that the length of targs and tparams matches,
@@ -128,7 +163,7 @@ func (check *Checker) validateTArgLen(pos token.Pos, ntparams, ntargs int) bool
        return true
 }
 
-func (check *Checker) verify(pos token.Pos, tparams []*TypeParam, targs []Type) (int, error) {
+func (check *Checker) verify(pos token.Pos, tparams []*TypeParam, targs []Type, ctxt *Context) (int, error) {
        smap := makeSubstMap(tparams, targs)
        for i, tpar := range tparams {
                // Ensure that we have a (possibly implicit) interface as type bound (issue #51048).
@@ -137,7 +172,7 @@ func (check *Checker) verify(pos token.Pos, tparams []*TypeParam, targs []Type)
                // as the instantiated type; before we can use it for bounds checking we
                // need to instantiate it with the type arguments with which we instantiated
                // the parameterized type.
-               bound := check.subst(pos, tpar.bound, smap, nil)
+               bound := check.subst(pos, tpar.bound, smap, nil, ctxt)
                if err := check.implements(targs[i], bound); err != nil {
                        return i, err
                }
index 71a26f96a16c79fac4a04a60dfc4dde7007439af..63f0a22323b7a6fcfe0652f7c883da3c56691881 100644 (file)
@@ -79,6 +79,16 @@ import (
 // Identical to compare them. For instantiated named types, their obj is a
 // synthetic placeholder that records their position of the corresponding
 // instantiation in the source (if they were constructed during type checking).
+//
+// To prevent infinite expansion of named instances that are created outside of
+// type-checking, instances share a Context with other instances created during
+// their expansion. Via the pidgeonhole principle, this guarantees that in the
+// presence of a cycle of named types, expansion will eventually find an
+// existing instance in the Context and short-circuit the expansion.
+//
+// Once an instance is complete, we can nil out this shared Context to unpin
+// memory, though this Context may still be held by other incomplete instances
+// in its "lineage".
 
 // A Named represents a named (defined) type.
 type Named struct {
@@ -115,6 +125,7 @@ type instance struct {
        orig            *Named    // original, uninstantiated type
        targs           *TypeList // type arguments
        expandedMethods int       // number of expanded methods; expandedMethods <= len(orig.methods)
+       ctxt            *Context  // local Context; set to nil after full expansion
 }
 
 // namedState represents the possible states that a named type may assume.
@@ -143,7 +154,7 @@ func NewNamed(obj *TypeName, underlying Type, methods []*Func) *Named {
 // After resolution, the type parameters, methods, and underlying type of n are
 // accessible; but if n is an instantiated type, its methods may still be
 // unexpanded.
-func (n *Named) resolve(ctxt *Context) *Named {
+func (n *Named) resolve() *Named {
        if n.state() >= resolved { // avoid locking below
                return n
        }
@@ -162,8 +173,8 @@ func (n *Named) resolve(ctxt *Context) *Named {
                assert(n.loader == nil)     // instances are created by instantiation, in which case n.loader is nil
 
                orig := n.inst.orig
-               orig.resolve(ctxt)
-               underlying := n.expandUnderlying(ctxt)
+               orig.resolve()
+               underlying := n.expandUnderlying()
 
                n.tparams = orig.tparams
                n.underlying = underlying
@@ -171,6 +182,7 @@ func (n *Named) resolve(ctxt *Context) *Named {
 
                if len(orig.methods) == 0 {
                        n.setState(complete) // nothing further to do
+                       n.inst.ctxt = nil
                } else {
                        n.setState(resolved)
                }
@@ -225,11 +237,11 @@ func (check *Checker) newNamed(obj *TypeName, underlying Type, methods []*Func)
        return typ
 }
 
-func (check *Checker) newNamedInstance(pos token.Pos, orig *Named, targs []Type) *Named {
+func (check *Checker) newNamedInstance(pos token.Pos, orig *Named, targs []Type, local *Context) *Named {
        assert(len(targs) > 0)
 
        obj := NewTypeName(pos, orig.obj.pkg, orig.obj.name, nil)
-       inst := &instance{orig: orig, targs: newTypeList(targs)}
+       inst := &instance{orig: orig, targs: newTypeList(targs), ctxt: local}
        typ := &Named{check: check, obj: obj, inst: inst}
        obj.typ = typ
        // Ensure that typ is always expanded and sanity-checked.
@@ -280,13 +292,13 @@ func (t *Named) Origin() *Named {
 
 // TypeParams returns the type parameters of the named type t, or nil.
 // The result is non-nil for an (originally) generic type even if it is instantiated.
-func (t *Named) TypeParams() *TypeParamList { return t.resolve(nil).tparams }
+func (t *Named) TypeParams() *TypeParamList { return t.resolve().tparams }
 
 // SetTypeParams sets the type parameters of the named type t.
 // t must not have type arguments.
 func (t *Named) SetTypeParams(tparams []*TypeParam) {
        assert(t.inst == nil)
-       t.resolve(nil).tparams = bindTParams(tparams)
+       t.resolve().tparams = bindTParams(tparams)
 }
 
 // TypeArgs returns the type arguments used to instantiate the named type t.
@@ -298,17 +310,17 @@ func (t *Named) TypeArgs() *TypeList {
 }
 
 // NumMethods returns the number of explicit methods defined for t.
-//
-// For an ordinary or instantiated type t, the receiver base type of these
-// methods will be the named type t. For an uninstantiated generic type t, each
-// method receiver will be instantiated with its receiver type parameters.
 func (t *Named) NumMethods() int {
-       return len(t.Origin().resolve(nil).methods)
+       return len(t.Origin().resolve().methods)
 }
 
 // Method returns the i'th method of named type t for 0 <= i < t.NumMethods().
+//
+// For an ordinary or instantiated type t, the receiver base type of this
+// method is the named type t. For an uninstantiated generic type t, each
+// method receiver is instantiated with its receiver type parameters.
 func (t *Named) Method(i int) *Func {
-       t.resolve(nil)
+       t.resolve()
 
        if t.state() >= complete {
                return t.methods[i]
@@ -326,6 +338,7 @@ func (t *Named) Method(i int) *Func {
        }
 
        if t.methods[i] == nil {
+               assert(t.inst.ctxt != nil) // we should still have a context remaining from the resolution phase
                t.methods[i] = t.expandMethod(i)
                t.inst.expandedMethods++
 
@@ -333,6 +346,7 @@ func (t *Named) Method(i int) *Func {
                // type as fully expanded.
                if t.inst.expandedMethods == len(orig.methods) {
                        t.setState(complete)
+                       t.inst.ctxt = nil // no need for a context anymore
                }
        }
 
@@ -372,9 +386,12 @@ func (t *Named) expandMethod(i int) *Func {
        // and type parameters. This check is necessary in the presence of invalid
        // code.
        if origSig.RecvTypeParams().Len() == t.inst.targs.Len() {
-               ctxt := check.bestContext(nil)
                smap := makeSubstMap(origSig.RecvTypeParams().list(), t.inst.targs.list())
-               sig = check.subst(origm.pos, origSig, smap, ctxt).(*Signature)
+               var global *Context
+               if check != nil {
+                       global = check.context()
+               }
+               sig = check.subst(origm.pos, origSig, smap, t.inst.ctxt, global).(*Signature)
        }
 
        if sig == origSig {
@@ -405,7 +422,7 @@ func (t *Named) SetUnderlying(underlying Type) {
        if _, ok := underlying.(*Named); ok {
                panic("underlying type must not be *Named")
        }
-       t.resolve(nil).underlying = underlying
+       t.resolve().underlying = underlying
        if t.fromRHS == nil {
                t.fromRHS = underlying // for cycle detection
        }
@@ -415,17 +432,20 @@ func (t *Named) SetUnderlying(underlying Type) {
 // t must not have type arguments.
 func (t *Named) AddMethod(m *Func) {
        assert(t.inst == nil)
-       t.resolve(nil)
+       t.resolve()
        if i, _ := lookupMethod(t.methods, m.pkg, m.name, false); i < 0 {
                t.methods = append(t.methods, m)
        }
 }
 
-func (t *Named) Underlying() Type { return t.resolve(nil).underlying }
+func (t *Named) Underlying() Type { return t.resolve().underlying }
 func (t *Named) String() string   { return TypeString(t, nil) }
 
 // ----------------------------------------------------------------------------
 // Implementation
+//
+// TODO(rfindley): reorganize the loading and expansion methods under this
+// heading.
 
 // under returns the expanded underlying type of n0; possibly by following
 // forward chains of named types. If an underlying type is found, resolve
@@ -522,7 +542,7 @@ func (n *Named) setUnderlying(typ Type) {
 }
 
 func (n *Named) lookupMethod(pkg *Package, name string, foldCase bool) (int, *Func) {
-       n.resolve(nil)
+       n.resolve()
        // If n is an instance, we may not have yet instantiated all of its methods.
        // Look up the method index in orig, and only instantiate method at the
        // matching index (if any).
@@ -534,26 +554,17 @@ func (n *Named) lookupMethod(pkg *Package, name string, foldCase bool) (int, *Fu
        return i, n.Method(i)
 }
 
-// bestContext returns the best available context. In order of preference:
-// - the given ctxt, if non-nil
-// - check.ctxt, if check is non-nil
-// - a new Context
-func (check *Checker) bestContext(ctxt *Context) *Context {
-       if ctxt != nil {
-               return ctxt
+// context returns the type-checker context.
+func (check *Checker) context() *Context {
+       if check.ctxt == nil {
+               check.ctxt = NewContext()
        }
-       if check != nil {
-               if check.ctxt == nil {
-                       check.ctxt = NewContext()
-               }
-               return check.ctxt
-       }
-       return NewContext()
+       return check.ctxt
 }
 
 // expandUnderlying substitutes type arguments in the underlying type n.orig,
 // returning the result. Returns Typ[Invalid] if there was an error.
-func (n *Named) expandUnderlying(ctxt *Context) Type {
+func (n *Named) expandUnderlying() Type {
        check := n.check
        if check != nil && trace {
                check.trace(n.obj.pos, "-- Named.expandUnderlying %s", n)
@@ -565,6 +576,9 @@ func (n *Named) expandUnderlying(ctxt *Context) Type {
        }
 
        assert(n.inst.orig.underlying != nil)
+       if n.inst.ctxt == nil {
+               n.inst.ctxt = NewContext()
+       }
 
        orig := n.inst.orig
        targs := n.inst.targs
@@ -580,16 +594,20 @@ func (n *Named) expandUnderlying(ctxt *Context) Type {
                return Typ[Invalid]
        }
 
-       // We must always have a context, to avoid infinite recursion.
-       ctxt = check.bestContext(ctxt)
-       h := ctxt.instanceHash(orig, targs.list())
-       // ensure that an instance is recorded for h to avoid infinite recursion.
-       ctxt.update(h, orig, targs.list(), n)
+       // Ensure that an instance is recorded before substituting, so that we
+       // resolve n for any recursive references.
+       h := n.inst.ctxt.instanceHash(orig, targs.list())
+       n2 := n.inst.ctxt.update(h, orig, n.TypeArgs().list(), n)
+       assert(n == n2)
 
        smap := makeSubstMap(orig.tparams.list(), targs.list())
-       underlying := n.check.subst(n.obj.pos, orig.underlying, smap, ctxt)
-       // If the underlying type of n is an interface, we need to set the receiver
-       // of its methods accurately -- we set the receiver of interface methods on
+       var global *Context
+       if check != nil {
+               global = check.context()
+       }
+       underlying := n.check.subst(n.obj.pos, orig.underlying, smap, n.inst.ctxt, global)
+       // If the underlying type of n is an interface, we need to set the receiver of
+       // its methods accurately -- we set the receiver of interface methods on
        // the RHS of a type declaration to the defined type.
        if iface, _ := underlying.(*Interface); iface != nil {
                if methods, copied := replaceRecvType(iface.methods, orig, n); copied {
index 74cdb4888967e9b7860801745e9ca81c8a5ec0c3..0fe17418f400e96640f348e2c73f6185df62e597 100644 (file)
@@ -5,6 +5,9 @@
 package types_test
 
 import (
+       "go/ast"
+       "go/parser"
+       "go/token"
        "testing"
 
        . "go/types"
@@ -86,3 +89,48 @@ func mustInstantiate(tb testing.TB, orig Type, targs ...Type) Type {
        }
        return inst
 }
+
+// Test that types do not expand infinitely, as in golang/go#52715.
+func TestFiniteTypeExpansion(t *testing.T) {
+       const src = `
+package p
+
+type Tree[T any] struct {
+       *Node[T]
+}
+
+func (*Tree[R]) N(r R) R { return r }
+
+type Node[T any] struct {
+       *Tree[T]
+}
+
+func (Node[Q]) M(Q) {}
+
+type Inst = *Tree[int]
+`
+
+       fset := token.NewFileSet()
+       f, err := parser.ParseFile(fset, "foo.go", src, 0)
+       if err != nil {
+               t.Fatal(err)
+       }
+       pkg := NewPackage("p", f.Name.Name)
+       if err := NewChecker(nil, fset, pkg, nil).Files([]*ast.File{f}); err != nil {
+               t.Fatal(err)
+       }
+
+       firstFieldType := func(n *Named) *Named {
+               return n.Underlying().(*Struct).Field(0).Type().(*Pointer).Elem().(*Named)
+       }
+
+       Inst := pkg.Scope().Lookup("Inst").Type().(*Pointer).Elem().(*Named)
+       Node := firstFieldType(Inst)
+       Tree := firstFieldType(Node)
+       if !Identical(Inst, Tree) {
+               t.Fatalf("Not a cycle: got %v, want %v", Tree, Inst)
+       }
+       if Inst != Tree {
+               t.Errorf("Duplicate instances in cycle: %s (%p) -> %s (%p) -> %s (%p)", Inst, Inst, Node, Node, Tree, Tree)
+       }
+}
index 6e08b76e406029acdd0a204a6c47b7c21add317a..25db4acf4a7e1197f579f1ae1670f2921add92ba 100644 (file)
@@ -285,18 +285,19 @@ func identical(x, y Type, cmpTags bool, p *ifacePair) bool {
                        }
                        smap := makeSubstMap(ytparams, targs)
 
-                       var check *Checker // ok to call subst on a nil *Checker
+                       var check *Checker   // ok to call subst on a nil *Checker
+                       ctxt := NewContext() // need a non-nil Context for the substitution below
 
                        // Constraints must be pair-wise identical, after substitution.
                        for i, xtparam := range xtparams {
-                               ybound := check.subst(token.NoPos, ytparams[i].bound, smap, nil)
+                               ybound := check.subst(token.NoPos, ytparams[i].bound, smap, nil, ctxt)
                                if !identical(xtparam.bound, ybound, cmpTags, p) {
                                        return false
                                }
                        }
 
-                       yparams = check.subst(token.NoPos, y.params, smap, nil).(*Tuple)
-                       yresults = check.subst(token.NoPos, y.results, smap, nil).(*Tuple)
+                       yparams = check.subst(token.NoPos, y.params, smap, nil, ctxt).(*Tuple)
+                       yresults = check.subst(token.NoPos, y.results, smap, nil, ctxt).(*Tuple)
                }
 
                return x.variadic == y.variadic &&
index 4b63f0e6f0b42daf24789ed91028560e30c27e35..82177a1c58d5a12767cbb2fdfb779f117bd4d6a6 100644 (file)
@@ -150,7 +150,7 @@ func (check *Checker) funcType(sig *Signature, recvPar *ast.FieldList, ftyp *ast
                                        // recvTPar.bound is (possibly) parameterized in the context of the
                                        // receiver type declaration. Substitute parameters for the current
                                        // context.
-                                       tpar.bound = check.subst(tpar.obj.pos, recvTPar.bound, smap, nil)
+                                       tpar.bound = check.subst(tpar.obj.pos, recvTPar.bound, smap, nil, check.context())
                                }
                        } else if len(tparams) < len(recvTParams) {
                                // Reporting an error here is a stop-gap measure to avoid crashes in the
index 110298cbaeea3e5482f63ebb634855fb506709b4..36987a4c95300343b71c867004ac1fced4453663 100644 (file)
@@ -49,7 +49,9 @@ func (m substMap) lookup(tpar *TypeParam) Type {
 // result type is different from the incoming type.
 //
 // If the given context is non-nil, it is used in lieu of check.Config.Context
-func (check *Checker) subst(pos token.Pos, typ Type, smap substMap, ctxt *Context) Type {
+func (check *Checker) subst(pos token.Pos, typ Type, smap substMap, local, global *Context) Type {
+       assert(local != nil || global != nil)
+
        if smap.empty() {
                return typ
        }
@@ -64,19 +66,20 @@ func (check *Checker) subst(pos token.Pos, typ Type, smap substMap, ctxt *Contex
 
        // general case
        subst := subster{
-               pos:   pos,
-               smap:  smap,
-               check: check,
-               ctxt:  check.bestContext(ctxt),
+               pos:    pos,
+               smap:   smap,
+               check:  check,
+               local:  local,
+               global: global,
        }
        return subst.typ(typ)
 }
 
 type subster struct {
-       pos   token.Pos
-       smap  substMap
-       check *Checker // nil if called via Instantiate
-       ctxt  *Context
+       pos           token.Pos
+       smap          substMap
+       check         *Checker // nil if called via Instantiate
+       local, global *Context
 }
 
 func (subst *subster) typ(typ Type) Type {
@@ -247,25 +250,11 @@ func (subst *subster) typ(typ Type) Type {
                        return t // nothing to substitute
                }
 
-               // before creating a new named type, check if we have this one already
-               h := subst.ctxt.instanceHash(orig, newTArgs)
-               dump(">>> new type hash: %s", h)
-               if named := subst.ctxt.lookup(h, orig, newTArgs); named != nil {
-                       dump(">>> found %s", named)
-                       return named
-               }
-
                // Create a new instance and populate the context to avoid endless
                // recursion. The position used here is irrelevant because validation only
                // occurs on t (we don't call validType on named), but we use subst.pos to
                // help with debugging.
-               return subst.check.instance(subst.pos, orig, newTArgs, subst.ctxt)
-
-               // Note that if we were to expose substitution more generally (not just in
-               // the context of a declaration), we'd have to substitute in
-               // named.underlying as well.
-               //
-               // But this is unnecessary for now.
+               return subst.check.instance(subst.pos, orig, newTArgs, subst.local, subst.global)
 
        case *TypeParam:
                return subst.smap.lookup(t)
index 05bd51a82b2b4b8628f537c72c9fbd082e38ddf0..a881d33654c4ffd5aabaf8f7aac74dce7ffcc516 100644 (file)
@@ -417,8 +417,7 @@ func (check *Checker) instantiatedType(ix *typeparams.IndexExpr, def *Named) (re
        }
 
        // create the instance
-       ctxt := check.bestContext(nil)
-       inst := check.instance(ix.Pos(), orig, targs, ctxt).(*Named)
+       inst := check.instance(ix.Pos(), orig, targs, nil, check.context()).(*Named)
        def.setUnderlying(inst)
 
        // orig.tparams may not be set up, so we need to do expansion later.
@@ -429,7 +428,7 @@ func (check *Checker) instantiatedType(ix *typeparams.IndexExpr, def *Named) (re
                check.recordInstance(ix.Orig, inst.TypeArgs().list(), inst)
 
                if check.validateTArgLen(ix.Pos(), inst.TypeParams().Len(), inst.TypeArgs().Len()) {
-                       if i, err := check.verify(ix.Pos(), inst.TypeParams().list(), inst.TypeArgs().list()); err != nil {
+                       if i, err := check.verify(ix.Pos(), inst.TypeParams().list(), inst.TypeArgs().list(), check.context()); err != nil {
                                // best position for error reporting
                                pos := ix.Pos()
                                if i < len(ix.Indices) {