]> Cypherpunks repositories - gostls13.git/commitdiff
go/types, types2: use a type lookup by identity in method lookup
authorRobert Findley <rfindley@google.com>
Wed, 4 May 2022 21:41:39 +0000 (17:41 -0400)
committerRobert Findley <rfindley@google.com>
Mon, 9 May 2022 19:54:59 +0000 (19:54 +0000)
Named type identity is no longer canonical. For correctness, named types
need to be compared with types.Identical. Our method set algorithm was
not doing this: it was using a map to de-duplicate named types, relying
on their pointer identity. As a result it was possible to get incorrect
results or even infinite recursion, as encountered in #52715.

To fix this, look up types by identity in NewMethodSet and
LookupFieldOrMethod. This does a linear search among types with equal
origin. Alternatively we could use a *Context to do a hash lookup, but
in practice we will be considering a small number of types, and so
performance is not a concern and a linear lookup is simpler. This also
means we don't have to rely on our type hash being perfect, which we
don't depend on elsewhere.

Also add more tests for NewMethodSet and LookupFieldOrMethod involving
generics.

Fixes #52715
Fixes #51580

Change-Id: I04dfeff54347bc3544d95a30224c640ef448e9b7
Reviewed-on: https://go-review.googlesource.com/c/go/+/404099
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>

src/cmd/compile/internal/types2/api_test.go
src/cmd/compile/internal/types2/lookup.go
src/go/types/api_test.go
src/go/types/lookup.go
src/go/types/methodset.go
src/go/types/methodset_test.go

index fde7291b03ecb6a62fb1310efdd4c5c6bbb36494..f7cdd1d21e3a7500101446d91b280b3b601b7787 100644 (file)
@@ -1618,19 +1618,41 @@ func TestLookupFieldOrMethod(t *testing.T) {
                {"var x T; type T struct{ f int }", true, []int{0}, false},
                {"var x T; type T struct{ a, b, f, c int }", true, []int{2}, false},
 
+               // field lookups on a generic type
+               {"var x T[int]; type T[P any] struct{}", false, nil, false},
+               {"var x T[int]; type T[P any] struct{ f P }", true, []int{0}, false},
+               {"var x T[int]; type T[P any] struct{ a, b, f, c P }", true, []int{2}, false},
+
                // method lookups
                {"var a T; type T struct{}; func (T) f() {}", true, []int{0}, false},
                {"var a *T; type T struct{}; func (T) f() {}", true, []int{0}, true},
                {"var a T; type T struct{}; func (*T) f() {}", true, []int{0}, false},
                {"var a *T; type T struct{}; func (*T) f() {}", true, []int{0}, true}, // TODO(gri) should this report indirect = false?
 
+               // method lookups on a generic type
+               {"var a T[int]; type T[P any] struct{}; func (T[P]) f() {}", true, []int{0}, false},
+               {"var a *T[int]; type T[P any] struct{}; func (T[P]) f() {}", true, []int{0}, true},
+               {"var a T[int]; type T[P any] struct{}; func (*T[P]) f() {}", true, []int{0}, false},
+               {"var a *T[int]; type T[P any] struct{}; func (*T[P]) f() {}", true, []int{0}, true}, // TODO(gri) should this report indirect = false?
+
                // collisions
                {"type ( E1 struct{ f int }; E2 struct{ f int }; x struct{ E1; *E2 })", false, []int{1, 0}, false},
                {"type ( E1 struct{ f int }; E2 struct{}; x struct{ E1; *E2 }); func (E2) f() {}", false, []int{1, 0}, false},
 
+               // collisions on a generic type
+               {"type ( E1[P any] struct{ f P }; E2[P any] struct{ f P }; x struct{ E1[int]; *E2[int] })", false, []int{1, 0}, false},
+               {"type ( E1[P any] struct{ f P }; E2[P any] struct{}; x struct{ E1[int]; *E2[int] }); func (E2[P]) f() {}", false, []int{1, 0}, false},
+
                // outside methodset
                // (*T).f method exists, but value of type T is not addressable
                {"var x T; type T struct{}; func (*T) f() {}", false, nil, true},
+
+               // outside method set of a generic type
+               {"var x T[int]; type T[P any] struct{}; func (*T[P]) f() {}", false, nil, true},
+
+               // recursive generic types; see golang/go#52715
+               {"var a T[int]; type ( T[P any] struct { *N[P] }; N[P any] struct { *T[P] } ); func (N[P]) f() {}", true, []int{0, 0}, true},
+               {"var a T[int]; type ( T[P any] struct { *N[P] }; N[P any] struct { *T[P] } ); func (T[P]) f() {}", true, []int{0}, false},
        }
 
        for _, test := range tests {
@@ -1665,6 +1687,37 @@ func TestLookupFieldOrMethod(t *testing.T) {
        }
 }
 
+// Test for golang/go#52715
+func TestLookupFieldOrMethod_RecursiveGeneric(t *testing.T) {
+       const src = `
+package pkg
+
+type Tree[T any] struct {
+       *Node[T]
+}
+
+func (*Tree[R]) N(r R) R { return r }
+
+type Node[T any] struct {
+       *Tree[T]
+}
+
+type Instance = *Tree[int]
+`
+
+       f, err := parseSrc("foo.go", src)
+       if err != nil {
+               panic(err)
+       }
+       pkg := NewPackage("pkg", f.PkgName.Value)
+       if err := NewChecker(nil, pkg, nil).Files([]*syntax.File{f}); err != nil {
+               panic(err)
+       }
+
+       T := pkg.Scope().Lookup("Instance").Type()
+       _, _, _ = LookupFieldOrMethod(T, false, pkg, "M") // verify that LookupFieldOrMethod terminates
+}
+
 func sameSlice(a, b []int) bool {
        if len(a) != len(b) {
                return false
index 684bbf7a8bce5510f14d11c4078ad00c7730f258..482b6bd8ef9f530b938faf9055721f0d60913fb0 100644 (file)
@@ -81,11 +81,6 @@ func LookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string) (o
        return
 }
 
-// TODO(gri) The named type consolidation and seen maps below must be
-// indexed by unique keys for a given type. Verify that named
-// types always have only one representation (even when imported
-// indirectly via different packages.)
-
 // lookupFieldOrMethod should only be called by LookupFieldOrMethod and missingMethod.
 // If foldCase is true, the lookup for methods will include looking for any method
 // which case-folds to the same as 'name' (used for giving helpful error messages).
@@ -110,14 +105,12 @@ func lookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string, fo
        // Start with typ as single entry at shallowest depth.
        current := []embeddedType{{typ, nil, isPtr, false}}
 
-       // Named types that we have seen already, allocated lazily.
+       // seen tracks named types that we have seen already, allocated lazily.
        // Used to avoid endless searches in case of recursive types.
-       // Since only Named types can be used for recursive types, we
-       // only need to track those.
-       // (If we ever allow type aliases to construct recursive types,
-       // we must use type identity rather than pointer equality for
-       // the map key comparison, as we do in consolidateMultiples.)
-       var seen map[*Named]bool
+       //
+       // We must use a lookup on identity rather than a simple map[*Named]bool as
+       // instantiated types may be identical but not equal.
+       var seen instanceLookup
 
        // search current depth
        for len(current) > 0 {
@@ -130,7 +123,7 @@ func lookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string, fo
                        // If we have a named type, we may have associated methods.
                        // Look for those first.
                        if named, _ := typ.(*Named); named != nil {
-                               if seen[named] {
+                               if alt := seen.lookup(named); alt != nil {
                                        // We have seen this type before, at a more shallow depth
                                        // (note that multiples of this type at the current depth
                                        // were consolidated before). The type at that depth shadows
@@ -138,10 +131,7 @@ func lookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string, fo
                                        // this one.
                                        continue
                                }
-                               if seen == nil {
-                                       seen = make(map[*Named]bool)
-                               }
-                               seen[named] = true
+                               seen.add(named)
 
                                // look for a matching attached method
                                named.resolve(nil)
@@ -271,6 +261,27 @@ func lookupType(m map[Type]int, typ Type) (int, bool) {
        return 0, false
 }
 
+type instanceLookup struct {
+       m map[*Named][]*Named
+}
+
+func (l *instanceLookup) lookup(inst *Named) *Named {
+       for _, t := range l.m[inst.Origin()] {
+               if Identical(inst, t) {
+                       return t
+               }
+       }
+       return nil
+}
+
+func (l *instanceLookup) add(inst *Named) {
+       if l.m == nil {
+               l.m = make(map[*Named][]*Named)
+       }
+       insts := l.m[inst.Origin()]
+       l.m[inst.Origin()] = append(insts, inst)
+}
+
 // MissingMethod returns (nil, false) if V implements T, otherwise it
 // returns a missing method required by T and whether it is missing or
 // just has the wrong type.
index 0ad97c59227d0e70f54a834ea59b1c9f0e7f1847..0daeff7fc0ab836b79999628948ad4136db36c73 100644 (file)
@@ -1630,23 +1630,45 @@ func TestLookupFieldOrMethod(t *testing.T) {
                {"var x T; type T struct{ f int }", true, []int{0}, false},
                {"var x T; type T struct{ a, b, f, c int }", true, []int{2}, false},
 
+               // field lookups on a generic type
+               {"var x T[int]; type T[P any] struct{}", false, nil, false},
+               {"var x T[int]; type T[P any] struct{ f P }", true, []int{0}, false},
+               {"var x T[int]; type T[P any] struct{ a, b, f, c P }", true, []int{2}, false},
+
                // method lookups
                {"var a T; type T struct{}; func (T) f() {}", true, []int{0}, false},
                {"var a *T; type T struct{}; func (T) f() {}", true, []int{0}, true},
                {"var a T; type T struct{}; func (*T) f() {}", true, []int{0}, false},
                {"var a *T; type T struct{}; func (*T) f() {}", true, []int{0}, true}, // TODO(gri) should this report indirect = false?
 
+               // method lookups on a generic type
+               {"var a T[int]; type T[P any] struct{}; func (T[P]) f() {}", true, []int{0}, false},
+               {"var a *T[int]; type T[P any] struct{}; func (T[P]) f() {}", true, []int{0}, true},
+               {"var a T[int]; type T[P any] struct{}; func (*T[P]) f() {}", true, []int{0}, false},
+               {"var a *T[int]; type T[P any] struct{}; func (*T[P]) f() {}", true, []int{0}, true}, // TODO(gri) should this report indirect = false?
+
                // collisions
                {"type ( E1 struct{ f int }; E2 struct{ f int }; x struct{ E1; *E2 })", false, []int{1, 0}, false},
                {"type ( E1 struct{ f int }; E2 struct{}; x struct{ E1; *E2 }); func (E2) f() {}", false, []int{1, 0}, false},
 
+               // collisions on a generic type
+               {"type ( E1[P any] struct{ f P }; E2[P any] struct{ f P }; x struct{ E1[int]; *E2[int] })", false, []int{1, 0}, false},
+               {"type ( E1[P any] struct{ f P }; E2[P any] struct{}; x struct{ E1[int]; *E2[int] }); func (E2[P]) f() {}", false, []int{1, 0}, false},
+
                // outside methodset
                // (*T).f method exists, but value of type T is not addressable
                {"var x T; type T struct{}; func (*T) f() {}", false, nil, true},
+
+               // outside method set of a generic type
+               {"var x T[int]; type T[P any] struct{}; func (*T[P]) f() {}", false, nil, true},
+
+               // recursive generic types; see golang/go#52715
+               {"var a T[int]; type ( T[P any] struct { *N[P] }; N[P any] struct { *T[P] } ); func (N[P]) f() {}", true, []int{0, 0}, true},
+               {"var a T[int]; type ( T[P any] struct { *N[P] }; N[P any] struct { *T[P] } ); func (T[P]) f() {}", true, []int{0}, false},
        }
 
        for _, test := range tests {
-               pkg, err := pkgFor("test", "package p;"+test.src, nil)
+               pkg, err := pkgForMode("test", "package p;"+test.src, nil, 0)
                if err != nil {
                        t.Errorf("%s: incorrect test case: %s", test.src, err)
                        continue
@@ -1677,6 +1699,38 @@ func TestLookupFieldOrMethod(t *testing.T) {
        }
 }
 
+// Test for golang/go#52715
+func TestLookupFieldOrMethod_RecursiveGeneric(t *testing.T) {
+       const src = `
+package pkg
+
+type Tree[T any] struct {
+       *Node[T]
+}
+
+func (*Tree[R]) N(r R) R { return r }
+
+type Node[T any] struct {
+       *Tree[T]
+}
+
+type Instance = *Tree[int]
+`
+
+       fset := token.NewFileSet()
+       f, err := parser.ParseFile(fset, "foo.go", src, 0)
+       if err != nil {
+               panic(err)
+       }
+       pkg := NewPackage("pkg", f.Name.Name)
+       if err := NewChecker(nil, fset, pkg, nil).Files([]*ast.File{f}); err != nil {
+               panic(err)
+       }
+
+       T := pkg.Scope().Lookup("Instance").Type()
+       _, _, _ = LookupFieldOrMethod(T, false, pkg, "M") // verify that LookupFieldOrMethod terminates
+}
+
 func sameSlice(a, b []int) bool {
        if len(a) != len(b) {
                return false
index 70e211d0828c7a63cafe7f24ebd976c199585b88..22a62055d315a27c7b36ead38e7c6b7c57577a75 100644 (file)
@@ -81,11 +81,6 @@ func LookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string) (o
        return
 }
 
-// TODO(gri) The named type consolidation and seen maps below must be
-// indexed by unique keys for a given type. Verify that named
-// types always have only one representation (even when imported
-// indirectly via different packages.)
-
 // lookupFieldOrMethod should only be called by LookupFieldOrMethod and missingMethod.
 // If foldCase is true, the lookup for methods will include looking for any method
 // which case-folds to the same as 'name' (used for giving helpful error messages).
@@ -110,14 +105,12 @@ func lookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string, fo
        // Start with typ as single entry at shallowest depth.
        current := []embeddedType{{typ, nil, isPtr, false}}
 
-       // Named types that we have seen already, allocated lazily.
+       // seen tracks named types that we have seen already, allocated lazily.
        // Used to avoid endless searches in case of recursive types.
-       // Since only Named types can be used for recursive types, we
-       // only need to track those.
-       // (If we ever allow type aliases to construct recursive types,
-       // we must use type identity rather than pointer equality for
-       // the map key comparison, as we do in consolidateMultiples.)
-       var seen map[*Named]bool
+       //
+       // We must use a lookup on identity rather than a simple map[*Named]bool as
+       // instantiated types may be identical but not equal.
+       var seen instanceLookup
 
        // search current depth
        for len(current) > 0 {
@@ -130,7 +123,7 @@ func lookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string, fo
                        // If we have a named type, we may have associated methods.
                        // Look for those first.
                        if named, _ := typ.(*Named); named != nil {
-                               if seen[named] {
+                               if alt := seen.lookup(named); alt != nil {
                                        // We have seen this type before, at a more shallow depth
                                        // (note that multiples of this type at the current depth
                                        // were consolidated before). The type at that depth shadows
@@ -138,10 +131,7 @@ func lookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string, fo
                                        // this one.
                                        continue
                                }
-                               if seen == nil {
-                                       seen = make(map[*Named]bool)
-                               }
-                               seen[named] = true
+                               seen.add(named)
 
                                // look for a matching attached method
                                named.resolve(nil)
@@ -271,6 +261,27 @@ func lookupType(m map[Type]int, typ Type) (int, bool) {
        return 0, false
 }
 
+type instanceLookup struct {
+       m map[*Named][]*Named
+}
+
+func (l *instanceLookup) lookup(inst *Named) *Named {
+       for _, t := range l.m[inst.Origin()] {
+               if Identical(inst, t) {
+                       return t
+               }
+       }
+       return nil
+}
+
+func (l *instanceLookup) add(inst *Named) {
+       if l.m == nil {
+               l.m = make(map[*Named][]*Named)
+       }
+       insts := l.m[inst.Origin()]
+       l.m[inst.Origin()] = append(insts, inst)
+}
+
 // MissingMethod returns (nil, false) if V implements T, otherwise it
 // returns a missing method required by T and whether it is missing or
 // just has the wrong type.
index c1d1e93e593ea38639d64a8eebe3e890db13e6b9..2bf30286153a5d3d79c1d517fcdd8ad61e74c90f 100644 (file)
@@ -89,14 +89,12 @@ func NewMethodSet(T Type) *MethodSet {
        // Start with typ as single entry at shallowest depth.
        current := []embeddedType{{typ, nil, isPtr, false}}
 
-       // Named types that we have seen already, allocated lazily.
+       // seen tracks named types that we have seen already, allocated lazily.
        // Used to avoid endless searches in case of recursive types.
-       // Since only Named types can be used for recursive types, we
-       // only need to track those.
-       // (If we ever allow type aliases to construct recursive types,
-       // we must use type identity rather than pointer equality for
-       // the map key comparison, as we do in consolidateMultiples.)
-       var seen map[*Named]bool
+       //
+       // We must use a lookup on identity rather than a simple map[*Named]bool as
+       // instantiated types may be identical but not equal.
+       var seen instanceLookup
 
        // collect methods at current depth
        for len(current) > 0 {
@@ -112,7 +110,7 @@ func NewMethodSet(T Type) *MethodSet {
                        // If we have a named type, we may have associated methods.
                        // Look for those first.
                        if named, _ := typ.(*Named); named != nil {
-                               if seen[named] {
+                               if alt := seen.lookup(named); alt != nil {
                                        // We have seen this type before, at a more shallow depth
                                        // (note that multiples of this type at the current depth
                                        // were consolidated before). The type at that depth shadows
@@ -120,10 +118,7 @@ func NewMethodSet(T Type) *MethodSet {
                                        // this one.
                                        continue
                                }
-                               if seen == nil {
-                                       seen = make(map[*Named]bool)
-                               }
-                               seen[named] = true
+                               seen.add(named)
 
                                for i := 0; i < named.NumMethods(); i++ {
                                        mset = mset.addOne(named.Method(i), concat(e.index, i), e.indirect, e.multiples)
index 73a8442f21417de6c3c3f01019551b2ae825fc0d..ee3ad0dbebf15fc557ec79cac87d951bdde84b24 100644 (file)
@@ -7,6 +7,9 @@ package types_test
 import (
        "testing"
 
+       "go/ast"
+       "go/parser"
+       "go/token"
        . "go/types"
 )
 
@@ -26,11 +29,22 @@ func TestNewMethodSet(t *testing.T) {
                "var a T; type T struct{}; func (*T) f() {}":  {},
                "var a *T; type T struct{}; func (*T) f() {}": {{"f", []int{0}, true}},
 
+               // Generic named types
+               "var a T[int]; type T[P any] struct{}; func (T[P]) f() {}":   {{"f", []int{0}, false}},
+               "var a *T[int]; type T[P any] struct{}; func (T[P]) f() {}":  {{"f", []int{0}, true}},
+               "var a T[int]; type T[P any] struct{}; func (*T[P]) f() {}":  {},
+               "var a *T[int]; type T[P any] struct{}; func (*T[P]) f() {}": {{"f", []int{0}, true}},
+
                // Interfaces
                "var a T; type T interface{ f() }":                           {{"f", []int{0}, true}},
                "var a T1; type ( T1 T2; T2 interface{ f() } )":              {{"f", []int{0}, true}},
                "var a T1; type ( T1 interface{ T2 }; T2 interface{ f() } )": {{"f", []int{0}, true}},
 
+               // Genric interfaces
+               "var a T[int]; type T[P any] interface{ f() }":                                     {{"f", []int{0}, true}},
+               "var a T1[int]; type ( T1[P any] T2[P]; T2[P any] interface{ f() } )":              {{"f", []int{0}, true}},
+               "var a T1[int]; type ( T1[P any] interface{ T2[P] }; T2[P any] interface{ f() } )": {{"f", []int{0}, true}},
+
                // Embedding
                "var a struct{ E }; type E interface{ f() }":            {{"f", []int{0, 0}, true}},
                "var a *struct{ E }; type E interface{ f() }":           {{"f", []int{0, 0}, true}},
@@ -39,12 +53,24 @@ func TestNewMethodSet(t *testing.T) {
                "var a struct{ E }; type E struct{}; func (*E) f() {}":  {},
                "var a struct{ *E }; type E struct{}; func (*E) f() {}": {{"f", []int{0, 0}, true}},
 
+               // Embedding of generic types
+               "var a struct{ E[int] }; type E[P any] interface{ f() }":               {{"f", []int{0, 0}, true}},
+               "var a *struct{ E[int] }; type E[P any] interface{ f() }":              {{"f", []int{0, 0}, true}},
+               "var a struct{ E[int] }; type E[P any] struct{}; func (E[P]) f() {}":   {{"f", []int{0, 0}, false}},
+               "var a struct{ *E[int] }; type E[P any] struct{}; func (E[P]) f() {}":  {{"f", []int{0, 0}, true}},
+               "var a struct{ E[int] }; type E[P any] struct{}; func (*E[P]) f() {}":  {},
+               "var a struct{ *E[int] }; type E[P any] struct{}; func (*E[P]) f() {}": {{"f", []int{0, 0}, true}},
+
                // collisions
                "var a struct{ E1; *E2 }; type ( E1 interface{ f() }; E2 struct{ f int })":            {},
                "var a struct{ E1; *E2 }; type ( E1 struct{ f int }; E2 struct{} ); func (E2) f() {}": {},
+
+               // recursive generic types; see golang/go#52715
+               "var a T[int]; type ( T[P any] struct { *N[P] }; N[P any] struct { *T[P] } ); func (N[P]) m() {}": {{"m", []int{0, 0}, true}},
+               "var a T[int]; type ( T[P any] struct { *N[P] }; N[P any] struct { *T[P] } ); func (T[P]) m() {}": {{"m", []int{0}, false}},
        }
 
-       genericTests := map[string][]method{
+       tParamTests := map[string][]method{
                // By convention, look up a in the scope of "g"
                "type C interface{ f() }; func g[T C](a T){}":               {{"f", []int{0}, true}},
                "type C interface{ f() }; func g[T C]() { var a T; _ = a }": {{"f", []int{0}, true}},
@@ -58,12 +84,7 @@ func TestNewMethodSet(t *testing.T) {
        }
 
        check := func(src string, methods []method, generic bool) {
-               pkgName := "p"
-               if generic {
-                       // The generic_ prefix causes pkgFor to allow generic code.
-                       pkgName = "generic_p"
-               }
-               pkg, err := pkgFor("test", "package "+pkgName+";"+src, nil)
+               pkg, err := pkgForMode("test", "package p;"+src, nil, 0)
                if err != nil {
                        t.Errorf("%s: incorrect test case: %s", src, err)
                        return
@@ -103,7 +124,37 @@ func TestNewMethodSet(t *testing.T) {
                check(src, methods, false)
        }
 
-       for src, methods := range genericTests {
+       for src, methods := range tParamTests {
                check(src, methods, true)
        }
 }
+
+// Test for golang/go#52715
+func TestNewMethodSet_RecursiveGeneric(t *testing.T) {
+       const src = `
+package pkg
+
+type Tree[T any] struct {
+       *Node[T]
+}
+
+type Node[T any] struct {
+       *Tree[T]
+}
+
+type Instance = *Tree[int]
+`
+
+       fset := token.NewFileSet()
+       f, err := parser.ParseFile(fset, "foo.go", src, 0)
+       if err != nil {
+               panic(err)
+       }
+       pkg := NewPackage("pkg", f.Name.Name)
+       if err := NewChecker(nil, fset, pkg, nil).Files([]*ast.File{f}); err != nil {
+               panic(err)
+       }
+
+       T := pkg.Scope().Lookup("Instance").Type()
+       _ = NewMethodSet(T) // verify that NewMethodSet terminates
+}