]> Cypherpunks repositories - gostls13.git/commitdiff
go/types, types2: preserve parent scope when substituting receivers
authorRobert Findley <rfindley@google.com>
Thu, 24 Mar 2022 17:29:03 +0000 (13:29 -0400)
committerRobert Findley <rfindley@google.com>
Tue, 29 Mar 2022 10:35:56 +0000 (10:35 +0000)
Fixes #51920

Change-Id: I29e44a52dabee5c09e1761f9ec8fb2e8795f8901
Reviewed-on: https://go-review.googlesource.com/c/go/+/395538
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
Reviewed-by: Robert Griesemer <gri@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/cmd/compile/internal/types2/api_test.go
src/cmd/compile/internal/types2/named.go
src/cmd/compile/internal/types2/subst.go
src/go/types/api_test.go
src/go/types/named.go
src/go/types/subst.go

index d433ed1bdf7de389b9cd3acebc62e6949de114a3..d8f7fb5eda1848c660dddceeb3fd29a776ae8353 100644 (file)
@@ -2305,6 +2305,103 @@ func TestInstanceIdentity(t *testing.T) {
        }
 }
 
+// TestInstantiatedObjects verifies properties of instantiated objects.
+func TestInstantiatedObjects(t *testing.T) {
+       const src = `
+package p
+
+type T[P any] struct {
+       field P
+}
+
+func (recv *T[Q]) concreteMethod() {}
+
+type FT[P any] func(ftp P) (ftrp P)
+
+func F[P any](fp P) (frp P){ return }
+
+type I[P any] interface {
+       interfaceMethod(P)
+}
+
+var (
+       t T[int]
+       ft FT[int]
+       f = F[int]
+       i I[int]
+)
+`
+       info := &Info{
+               Defs: make(map[*syntax.Name]Object),
+       }
+       f, err := parseSrc("p.go", src)
+       if err != nil {
+               t.Fatal(err)
+       }
+       conf := Config{}
+       pkg, err := conf.Check(f.PkgName.Value, []*syntax.File{f}, info)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       lookup := func(name string) Type { return pkg.Scope().Lookup(name).Type() }
+       tests := []struct {
+               ident string
+               obj   Object
+       }{
+               {"field", lookup("t").Underlying().(*Struct).Field(0)},
+               {"concreteMethod", lookup("t").(*Named).Method(0)},
+               {"recv", lookup("t").(*Named).Method(0).Type().(*Signature).Recv()},
+               {"ftp", lookup("ft").Underlying().(*Signature).Params().At(0)},
+               {"ftrp", lookup("ft").Underlying().(*Signature).Results().At(0)},
+               {"fp", lookup("f").(*Signature).Params().At(0)},
+               {"frp", lookup("f").(*Signature).Results().At(0)},
+               {"interfaceMethod", lookup("i").Underlying().(*Interface).Method(0)},
+       }
+
+       // Collect all identifiers by name.
+       idents := make(map[string][]*syntax.Name)
+       syntax.Inspect(f, func(n syntax.Node) bool {
+               if id, ok := n.(*syntax.Name); ok {
+                       idents[id.Value] = append(idents[id.Value], id)
+               }
+               return true
+       })
+
+       for _, test := range tests {
+               test := test
+               t.Run(test.ident, func(t *testing.T) {
+                       if got := len(idents[test.ident]); got != 1 {
+                               t.Fatalf("found %d identifiers named %s, want 1", got, test.ident)
+                       }
+                       ident := idents[test.ident][0]
+                       def := info.Defs[ident]
+                       if def == test.obj {
+                               t.Fatalf("info.Defs[%s] contains the test object", test.ident)
+                       }
+                       if def.Pkg() != test.obj.Pkg() {
+                               t.Errorf("Pkg() = %v, want %v", def.Pkg(), test.obj.Pkg())
+                       }
+                       if def.Name() != test.obj.Name() {
+                               t.Errorf("Name() = %v, want %v", def.Name(), test.obj.Name())
+                       }
+                       if def.Pos() != test.obj.Pos() {
+                               t.Errorf("Pos() = %v, want %v", def.Pos(), test.obj.Pos())
+                       }
+                       if def.Parent() != test.obj.Parent() {
+                               t.Fatalf("Parent() = %v, want %v", def.Parent(), test.obj.Parent())
+                       }
+                       if def.Exported() != test.obj.Exported() {
+                               t.Fatalf("Exported() = %v, want %v", def.Exported(), test.obj.Exported())
+                       }
+                       if def.Id() != test.obj.Id() {
+                               t.Fatalf("Id() = %v, want %v", def.Id(), test.obj.Id())
+                       }
+                       // String and Type are expected to differ.
+               })
+       }
+}
+
 func TestImplements(t *testing.T) {
        const src = `
 package p
index 89d24d4e0b996d6e2ab0c66e1b4b7c7cc94db839..584ee51a136ab121ed03669437e0673ff6a28e3f 100644 (file)
@@ -191,7 +191,7 @@ func (t *Named) instantiateMethod(i int) *Func {
                rtyp = t
        }
 
-       sig.recv = NewParam(origSig.recv.pos, origSig.recv.pkg, origSig.recv.name, rtyp)
+       sig.recv = substVar(origSig.recv, rtyp)
        return NewFunc(origm.pos, origm.pkg, origm.name, sig)
 }
 
index 037f04797b1a56f4bc700bf177be38da2dbd2eeb..c9e8f9676d399cab22686d82c6c787896d9e5b70 100644 (file)
@@ -290,14 +290,18 @@ func (subst *subster) typOrNil(typ Type) Type {
 func (subst *subster) var_(v *Var) *Var {
        if v != nil {
                if typ := subst.typ(v.typ); typ != v.typ {
-                       copy := *v
-                       copy.typ = typ
-                       return &copy
+                       return substVar(v, typ)
                }
        }
        return v
 }
 
+func substVar(v *Var, typ Type) *Var {
+       copy := *v
+       copy.typ = typ
+       return &copy
+}
+
 func (subst *subster) tuple(t *Tuple) *Tuple {
        if t != nil {
                if vars, copied := subst.varList(t.vars); copied {
@@ -410,9 +414,8 @@ func replaceRecvType(in []*Func, old, new Type) (out []*Func, copied bool) {
                                copied = true
                        }
                        newsig := *sig
-                       sig = &newsig
-                       sig.recv = NewVar(sig.recv.pos, sig.recv.pkg, "", new)
-                       out[i] = NewFunc(method.pos, method.pkg, method.name, sig)
+                       newsig.recv = substVar(sig.recv, new)
+                       out[i] = NewFunc(method.pos, method.pkg, method.name, &newsig)
                }
        }
        return
index d25e6116cf4cc062bf755f8c8531e68b64fa9d95..9ed4633b6f00520800730782a815b9ed00c5ab57 100644 (file)
@@ -2305,6 +2305,104 @@ func TestInstanceIdentity(t *testing.T) {
        }
 }
 
+// TestInstantiatedObjects verifies properties of instantiated objects.
+func TestInstantiatedObjects(t *testing.T) {
+       const src = `
+package p
+
+type T[P any] struct {
+       field P
+}
+
+func (recv *T[Q]) concreteMethod() {}
+
+type FT[P any] func(ftp P) (ftrp P)
+
+func F[P any](fp P) (frp P){ return }
+
+type I[P any] interface {
+       interfaceMethod(P)
+}
+
+var (
+       t T[int]
+       ft FT[int]
+       f = F[int]
+       i I[int]
+)
+`
+       info := &Info{
+               Defs: make(map[*ast.Ident]Object),
+       }
+       fset := token.NewFileSet()
+       f, err := parser.ParseFile(fset, "p.go", src, 0)
+       if err != nil {
+               t.Fatal(err)
+       }
+       conf := Config{}
+       pkg, err := conf.Check(f.Name.Name, fset, []*ast.File{f}, info)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       lookup := func(name string) Type { return pkg.Scope().Lookup(name).Type() }
+       tests := []struct {
+               name string
+               obj  Object
+       }{
+               {"field", lookup("t").Underlying().(*Struct).Field(0)},
+               {"concreteMethod", lookup("t").(*Named).Method(0)},
+               {"recv", lookup("t").(*Named).Method(0).Type().(*Signature).Recv()},
+               {"ftp", lookup("ft").Underlying().(*Signature).Params().At(0)},
+               {"ftrp", lookup("ft").Underlying().(*Signature).Results().At(0)},
+               {"fp", lookup("f").(*Signature).Params().At(0)},
+               {"frp", lookup("f").(*Signature).Results().At(0)},
+               {"interfaceMethod", lookup("i").Underlying().(*Interface).Method(0)},
+       }
+
+       // Collect all identifiers by name.
+       idents := make(map[string][]*ast.Ident)
+       ast.Inspect(f, func(n ast.Node) bool {
+               if id, ok := n.(*ast.Ident); ok {
+                       idents[id.Name] = append(idents[id.Name], id)
+               }
+               return true
+       })
+
+       for _, test := range tests {
+               test := test
+               t.Run(test.name, func(t *testing.T) {
+                       if got := len(idents[test.name]); got != 1 {
+                               t.Fatalf("found %d identifiers named %s, want 1", got, test.name)
+                       }
+                       ident := idents[test.name][0]
+                       def := info.Defs[ident]
+                       if def == test.obj {
+                               t.Fatalf("info.Defs[%s] contains the test object", test.name)
+                       }
+                       if def.Pkg() != test.obj.Pkg() {
+                               t.Errorf("Pkg() = %v, want %v", def.Pkg(), test.obj.Pkg())
+                       }
+                       if def.Name() != test.obj.Name() {
+                               t.Errorf("Name() = %v, want %v", def.Name(), test.obj.Name())
+                       }
+                       if def.Pos() != test.obj.Pos() {
+                               t.Errorf("Pos() = %v, want %v", def.Pos(), test.obj.Pos())
+                       }
+                       if def.Parent() != test.obj.Parent() {
+                               t.Fatalf("Parent() = %v, want %v", def.Parent(), test.obj.Parent())
+                       }
+                       if def.Exported() != test.obj.Exported() {
+                               t.Fatalf("Exported() = %v, want %v", def.Exported(), test.obj.Exported())
+                       }
+                       if def.Id() != test.obj.Id() {
+                               t.Fatalf("Id() = %v, want %v", def.Id(), test.obj.Id())
+                       }
+                       // String and Type are expected to differ.
+               })
+       }
+}
+
 func TestImplements(t *testing.T) {
        const src = `
 package p
index a0b94818f5a1057053126dfce75a8253a39afe3d..e4fd96ab646f391672daac3f2a128c62cbf2a7a8 100644 (file)
@@ -193,7 +193,7 @@ func (t *Named) instantiateMethod(i int) *Func {
                rtyp = t
        }
 
-       sig.recv = NewParam(origSig.recv.pos, origSig.recv.pkg, origSig.recv.name, rtyp)
+       sig.recv = substVar(origSig.recv, rtyp)
        return NewFunc(origm.pos, origm.pkg, origm.name, sig)
 }
 
index 4b4a0f4ad670945acea046f3e8f2de1790b68891..ef5593b71f40be94da51fe5308795e6fac1690bf 100644 (file)
@@ -290,14 +290,18 @@ func (subst *subster) typOrNil(typ Type) Type {
 func (subst *subster) var_(v *Var) *Var {
        if v != nil {
                if typ := subst.typ(v.typ); typ != v.typ {
-                       copy := *v
-                       copy.typ = typ
-                       return &copy
+                       return substVar(v, typ)
                }
        }
        return v
 }
 
+func substVar(v *Var, typ Type) *Var {
+       copy := *v
+       copy.typ = typ
+       return &copy
+}
+
 func (subst *subster) tuple(t *Tuple) *Tuple {
        if t != nil {
                if vars, copied := subst.varList(t.vars); copied {
@@ -410,9 +414,8 @@ func replaceRecvType(in []*Func, old, new Type) (out []*Func, copied bool) {
                                copied = true
                        }
                        newsig := *sig
-                       sig = &newsig
-                       sig.recv = NewVar(sig.recv.pos, sig.recv.pkg, "", new)
-                       out[i] = NewFunc(method.pos, method.pkg, method.name, sig)
+                       newsig.recv = substVar(sig.recv, new)
+                       out[i] = NewFunc(method.pos, method.pkg, method.name, &newsig)
                }
        }
        return