]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.typeparams] cmd/compile: always generate (*T).M wrappers for instantiated methods
authorDan Scales <danscales@google.com>
Sun, 16 May 2021 21:48:05 +0000 (14:48 -0700)
committerDan Scales <danscales@google.com>
Wed, 26 May 2021 21:39:54 +0000 (21:39 +0000)
Always generate (*T).M wrappers for instantiated methods, even when the
instantiated method is being generated for another package (its source
package)

Added new function t.IsInstantiated() to check for fully-instantiated
types (generic type instantiated with concrete types, hence concrete
themselves). This function helps hide the representation of instantiated
types outside of the types package.

Added new export/import test setsimp.go that needs this change.

Change-Id: Ifb700db8c9494e1684c93735edb20f4709be5f7f
Reviewed-on: https://go-review.googlesource.com/c/go/+/322193
Trust: Dan Scales <danscales@google.com>
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/cmd/compile/internal/reflectdata/reflect.go
src/cmd/compile/internal/types/type.go
test/typeparam/setsimp.dir/a.go [new file with mode: 0644]
test/typeparam/setsimp.dir/main.go [new file with mode: 0644]
test/typeparam/setsimp.go [new file with mode: 0644]

index 3576a23db9f34f07c57acec1cfd4af698bdfb936..d452d4f19420cfb146d950e25e9aa680484c8e91 100644 (file)
@@ -956,7 +956,7 @@ func writeType(t *types.Type) *obj.LSym {
                // in the local package, even if they may be marked as part of
                // another package (the package of their base generic type).
                if tbase.Sym() != nil && tbase.Sym().Pkg != types.LocalPkg &&
-                       len(tbase.RParams()) == 0 {
+                       !tbase.IsInstantiated() {
                        if i := typecheck.BaseTypeIndex(t); i >= 0 {
                                lsym.Pkg = tbase.Sym().Pkg.Prefix
                                lsym.SymIdx = int32(i)
@@ -1777,9 +1777,11 @@ func methodWrapper(rcvr *types.Type, method *types.Field) *obj.LSym {
                return lsym
        }
 
-       // Only generate (*T).M wrappers for T.M in T's own package.
+       // Only generate (*T).M wrappers for T.M in T's own package, except for
+       // instantiated methods.
        if rcvr.IsPtr() && rcvr.Elem() == method.Type.Recv().Type &&
-               rcvr.Elem().Sym() != nil && rcvr.Elem().Sym().Pkg != types.LocalPkg {
+               rcvr.Elem().Sym() != nil && rcvr.Elem().Sym().Pkg != types.LocalPkg &&
+               !rcvr.Elem().IsInstantiated() {
                return lsym
        }
 
index e7831121bff1e19fe4b52b3fb802a59b7ef6393e..08855f518ca3c9dbde86fd1aab497f3a5c06cd3b 100644 (file)
@@ -279,6 +279,13 @@ func (t *Type) SetRParams(rparams []*Type) {
        }
 }
 
+// IsInstantiated reports whether t is a fully instantiated generic type; i.e. an
+// instantiated generic type where all type arguments are non-generic or fully
+// instantiated generic types.
+func (t *Type) IsInstantiated() bool {
+       return len(t.RParams()) > 0 && !t.HasTParam()
+}
+
 // NoPkg is a nil *Pkg value for clarity.
 // It's intended for use when constructing types that aren't exported
 // and thus don't need to be associated with any package.
diff --git a/test/typeparam/setsimp.dir/a.go b/test/typeparam/setsimp.dir/a.go
new file mode 100644 (file)
index 0000000..92449ce
--- /dev/null
@@ -0,0 +1,128 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package a
+
+// SliceEqual reports whether two slices are equal: the same length and all
+// elements equal. All floating point NaNs are considered equal.
+func SliceEqual[Elem comparable](s1, s2 []Elem) bool {
+       if len(s1) != len(s2) {
+               return false
+       }
+       for i, v1 := range s1 {
+               v2 := s2[i]
+               if v1 != v2 {
+                       isNaN := func(f Elem) bool { return f != f }
+                       if !isNaN(v1) || !isNaN(v2) {
+                               return false
+                       }
+               }
+       }
+       return true
+}
+
+// A Set is a set of elements of some type.
+type Set[Elem comparable] struct {
+       m map[Elem]struct{}
+}
+
+// Make makes a new set.
+func Make[Elem comparable]() Set[Elem] {
+       return Set[Elem]{m: make(map[Elem]struct{})}
+}
+
+// Add adds an element to a set.
+func (s Set[Elem]) Add(v Elem) {
+       s.m[v] = struct{}{}
+}
+
+// Delete removes an element from a set. If the element is not present
+// in the set, this does nothing.
+func (s Set[Elem]) Delete(v Elem) {
+       delete(s.m, v)
+}
+
+// Contains reports whether v is in the set.
+func (s Set[Elem]) Contains(v Elem) bool {
+       _, ok := s.m[v]
+       return ok
+}
+
+// Len returns the number of elements in the set.
+func (s Set[Elem]) Len() int {
+       return len(s.m)
+}
+
+// Values returns the values in the set.
+// The values will be in an indeterminate order.
+func (s Set[Elem]) Values() []Elem {
+       r := make([]Elem, 0, len(s.m))
+       for v := range s.m {
+               r = append(r, v)
+       }
+       return r
+}
+
+// Equal reports whether two sets contain the same elements.
+func Equal[Elem comparable](s1, s2 Set[Elem]) bool {
+       if len(s1.m) != len(s2.m) {
+               return false
+       }
+       for v1 := range s1.m {
+               if !s2.Contains(v1) {
+                       return false
+               }
+       }
+       return true
+}
+
+// Copy returns a copy of s.
+func (s Set[Elem]) Copy() Set[Elem] {
+       r := Set[Elem]{m: make(map[Elem]struct{}, len(s.m))}
+       for v := range s.m {
+               r.m[v] = struct{}{}
+       }
+       return r
+}
+
+// AddSet adds all the elements of s2 to s.
+func (s Set[Elem]) AddSet(s2 Set[Elem]) {
+       for v := range s2.m {
+               s.m[v] = struct{}{}
+       }
+}
+
+// SubSet removes all elements in s2 from s.
+// Values in s2 that are not in s are ignored.
+func (s Set[Elem]) SubSet(s2 Set[Elem]) {
+       for v := range s2.m {
+               delete(s.m, v)
+       }
+}
+
+// Intersect removes all elements from s that are not present in s2.
+// Values in s2 that are not in s are ignored.
+func (s Set[Elem]) Intersect(s2 Set[Elem]) {
+       for v := range s.m {
+               if !s2.Contains(v) {
+                       delete(s.m, v)
+               }
+       }
+}
+
+// Iterate calls f on every element in the set.
+func (s Set[Elem]) Iterate(f func(Elem)) {
+       for v := range s.m {
+               f(v)
+       }
+}
+
+// Filter deletes any elements from s for which f returns false.
+func (s Set[Elem]) Filter(f func(Elem) bool) {
+       for v := range s.m {
+               if !f(v) {
+                       delete(s.m, v)
+               }
+       }
+}
diff --git a/test/typeparam/setsimp.dir/main.go b/test/typeparam/setsimp.dir/main.go
new file mode 100644 (file)
index 0000000..8fd1657
--- /dev/null
@@ -0,0 +1,156 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "a"
+       "fmt"
+       "sort"
+)
+
+func TestSet() {
+       s1 := a.Make[int]()
+       if got := s1.Len(); got != 0 {
+               panic(fmt.Sprintf("Len of empty set = %d, want 0", got))
+       }
+       s1.Add(1)
+       s1.Add(1)
+       s1.Add(1)
+       if got := s1.Len(); got != 1 {
+               panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got))
+       }
+       s1.Add(2)
+       s1.Add(3)
+       s1.Add(4)
+       if got := s1.Len(); got != 4 {
+               panic(fmt.Sprintf("(%v).Len() == %d, want 4", s1, got))
+       }
+       if !s1.Contains(1) {
+               panic(fmt.Sprintf("(%v).Contains(1) == false, want true", s1))
+       }
+       if s1.Contains(5) {
+               panic(fmt.Sprintf("(%v).Contains(5) == true, want false", s1))
+       }
+       vals := s1.Values()
+       sort.Ints(vals)
+       w1 := []int{1, 2, 3, 4}
+       if !a.SliceEqual(vals, w1) {
+               panic(fmt.Sprintf("(%v).Values() == %v, want %v", s1, vals, w1))
+       }
+}
+
+func TestEqual() {
+       s1 := a.Make[string]()
+       s2 := a.Make[string]()
+       if !a.Equal(s1, s2) {
+               panic(fmt.Sprintf("a.Equal(%v, %v) = false, want true", s1, s2))
+       }
+       s1.Add("hello")
+       s1.Add("world")
+       if got := s1.Len(); got != 2 {
+               panic(fmt.Sprintf("(%v).Len() == %d, want 2", s1, got))
+       }
+       if a.Equal(s1, s2) {
+               panic(fmt.Sprintf("a.Equal(%v, %v) = true, want false", s1, s2))
+       }
+}
+
+func TestCopy() {
+       s1 := a.Make[float64]()
+       s1.Add(0)
+       s2 := s1.Copy()
+       if !a.Equal(s1, s2) {
+               panic(fmt.Sprintf("a.Equal(%v, %v) = false, want true", s1, s2))
+       }
+       s1.Add(1)
+       if a.Equal(s1, s2) {
+               panic(fmt.Sprintf("a.Equal(%v, %v) = true, want false", s1, s2))
+       }
+}
+
+func TestAddSet() {
+       s1 := a.Make[int]()
+       s1.Add(1)
+       s1.Add(2)
+       s2 := a.Make[int]()
+       s2.Add(2)
+       s2.Add(3)
+       s1.AddSet(s2)
+       if got := s1.Len(); got != 3 {
+               panic(fmt.Sprintf("(%v).Len() == %d, want 3", s1, got))
+       }
+       s2.Add(1)
+       if !a.Equal(s1, s2) {
+               panic(fmt.Sprintf("a.Equal(%v, %v) = false, want true", s1, s2))
+       }
+}
+
+func TestSubSet() {
+       s1 := a.Make[int]()
+       s1.Add(1)
+       s1.Add(2)
+       s2 := a.Make[int]()
+       s2.Add(2)
+       s2.Add(3)
+       s1.SubSet(s2)
+       if got := s1.Len(); got != 1 {
+               panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got))
+       }
+       if vals, want := s1.Values(), []int{1}; !a.SliceEqual(vals, want) {
+               panic(fmt.Sprintf("after SubSet got %v, want %v", vals, want))
+       }
+}
+
+func TestIntersect() {
+       s1 := a.Make[int]()
+       s1.Add(1)
+       s1.Add(2)
+       s2 := a.Make[int]()
+       s2.Add(2)
+       s2.Add(3)
+       s1.Intersect(s2)
+       if got := s1.Len(); got != 1 {
+               panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got))
+       }
+       if vals, want := s1.Values(), []int{2}; !a.SliceEqual(vals, want) {
+               panic(fmt.Sprintf("after Intersect got %v, want %v", vals, want))
+       }
+}
+
+func TestIterate() {
+       s1 := a.Make[int]()
+       s1.Add(1)
+       s1.Add(2)
+       s1.Add(3)
+       s1.Add(4)
+       tot := 0
+       s1.Iterate(func(i int) { tot += i })
+       if tot != 10 {
+               panic(fmt.Sprintf("total of %v == %d, want 10", s1, tot))
+       }
+}
+
+func TestFilter() {
+       s1 := a.Make[int]()
+       s1.Add(1)
+       s1.Add(2)
+       s1.Add(3)
+       s1.Filter(func(v int) bool { return v%2 == 0 })
+       if vals, want := s1.Values(), []int{2}; !a.SliceEqual(vals, want) {
+               panic(fmt.Sprintf("after Filter got %v, want %v", vals, want))
+       }
+
+}
+
+func main() {
+       TestSet()
+       TestEqual()
+       TestCopy()
+       TestAddSet()
+       TestSubSet()
+       TestIntersect()
+       TestIterate()
+       TestFilter()
+}
diff --git a/test/typeparam/setsimp.go b/test/typeparam/setsimp.go
new file mode 100644 (file)
index 0000000..76930e5
--- /dev/null
@@ -0,0 +1,7 @@
+// rundir -G=3
+
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ignored