]> Cypherpunks repositories - gostls13.git/commitdiff
go/types: add check that code is monomorphizable
authorMatthew Dempsky <mdempsky@google.com>
Wed, 20 Oct 2021 20:17:07 +0000 (13:17 -0700)
committerMatthew Dempsky <mdempsky@google.com>
Tue, 2 Nov 2021 17:59:53 +0000 (17:59 +0000)
This CL adds a check to ensure that generic Go code doesn't involve
any unbounded recursive instantiation, which are incompatible with an
implementation that uses static instantiation (i.e., monomorphization
or compile-time dictionary construction).

Updates #48098.

Change-Id: I9d051f0f9369ab881592a361a5d0e2a716788a6b
Reviewed-on: https://go-review.googlesource.com/c/go/+/357449
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
Trust: Matthew Dempsky <mdempsky@google.com>

src/go/types/call.go
src/go/types/check.go
src/go/types/errorcodes.go
src/go/types/mono.go [new file with mode: 0644]
src/go/types/mono_test.go [new file with mode: 0644]
src/go/types/signature.go
src/go/types/testdata/fixedbugs/issue48974.go2
src/go/types/typexpr.go

index 6894f1c1821e0184ccbf3888e3848decd9f6380e..36086891b5f60dd075e5d03049faf9de6acbc381 100644 (file)
@@ -91,6 +91,8 @@ func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs
                        pos = posList[i]
                }
                check.softErrorf(atPos(pos), _Todo, err.Error())
+       } else {
+               check.mono.recordInstance(check.pkg, pos, tparams, targs, posList)
        }
 
        return inst
index 2b8ef9f061321371c1a6f5a712e7a96a67777d31..3a0e4a6a234ee564e9ff1af1b0fa4686391836d0 100644 (file)
@@ -129,6 +129,7 @@ type Checker struct {
        imports       []*PkgName                // list of imported packages
        dotImportMap  map[dotImportKey]*PkgName // maps dot-imported objects to the package they were dot-imported through
        recvTParamMap map[*ast.Ident]*TypeParam // maps blank receiver type parameters to their type
+       mono          monoGraph                 // graph for detecting non-monomorphizable instantiation loops
 
        firstErr error                 // first error encountered
        methods  map[*TypeName][]*Func // maps package scope type names to associated non-blank (non-interface) methods
@@ -306,6 +307,11 @@ func (check *Checker) checkFiles(files []*ast.File) (err error) {
 
        check.recordUntyped()
 
+       if check.firstErr == nil {
+               // TODO(mdempsky): Ensure monomorph is safe when errors exist.
+               check.monomorph()
+       }
+
        check.pkg.complete = true
 
        // no longer needed - release memory
index 49c6a74c20d5487d7e398a4a6b278f0573a2428f..88dd0fda2f8aefbd03f42729592aeac3bdb1204c 100644 (file)
@@ -1301,6 +1301,13 @@ const (
        //  var _ = unsafe.Slice(&x, uint64(1) << 63)
        _InvalidUnsafeSlice
 
+       // _InvalidInstanceCycle occurs when an invalid cycle is detected
+       // within the instantiation graph.
+       //
+       // Example:
+       //  func f[T any]() { f[*T]() }
+       _InvalidInstanceCycle
+
        // _Todo is a placeholder for error codes that have not been decided.
        // TODO(rFindley) remove this error code after deciding on errors for generics code.
        _Todo
diff --git a/src/go/types/mono.go b/src/go/types/mono.go
new file mode 100644 (file)
index 0000000..fb1127e
--- /dev/null
@@ -0,0 +1,331 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package types
+
+import (
+       "go/token"
+)
+
+// This file implements a check to validate that a Go package doesn't
+// have unbounded recursive instantiation, which is not compatible
+// with compilers using static instantiation (such as
+// monomorphization).
+//
+// It implements a sort of "type flow" analysis by detecting which
+// type parameters are instantiated with other type parameters (or
+// types derived thereof). A package cannot be statically instantiated
+// if the graph has any cycles involving at least one derived type.
+//
+// Concretely, we construct a directed, weighted graph. Vertices are
+// used to represent type parameters as well as some defined
+// types. Edges are used to represent how types depend on each other:
+//
+// * Everywhere a type-parameterized function or type is instantiated,
+//   we add edges to each type parameter from the vertices (if any)
+//   representing each type parameter or defined type referenced by
+//   the type argument. If the type argument is just the referenced
+//   type itself, then the edge has weight 0, otherwise 1.
+//
+// * For every defined type declared within a type-parameterized
+//   function or method, we add an edge of weight 1 to the defined
+//   type from each ambient type parameter.
+//
+// For example, given:
+//
+//     func f[A, B any]() {
+//             type T int
+//             f[T, map[A]B]()
+//     }
+//
+// we construct vertices representing types A, B, and T. Because of
+// declaration "type T int", we construct edges T<-A and T<-B with
+// weight 1; and because of instantiation "f[T, map[A]B]" we construct
+// edges A<-T with weight 0, and B<-A and B<-B with weight 1.
+//
+// Finally, we look for any positive-weight cycles. Zero-weight cycles
+// are allowed because static instantiation will reach a fixed point.
+
+type monoGraph struct {
+       vertices []monoVertex
+       edges    []monoEdge
+
+       // canon maps method receiver type parameters to their respective
+       // receiver type's type parameters.
+       canon map[*TypeParam]*TypeParam
+
+       // nameIdx maps a defined type or (canonical) type parameter to its
+       // vertex index.
+       nameIdx map[*TypeName]int
+}
+
+type monoVertex struct {
+       weight int // weight of heaviest known path to this vertex
+       pre    int // previous edge (if any) in the above path
+       len    int // length of the above path
+
+       // obj is the defined type or type parameter represented by this
+       // vertex.
+       obj *TypeName
+}
+
+type monoEdge struct {
+       dst    int
+       src    int
+       weight int
+
+       // report emits an error describing why this edge exists.
+       //
+       // TODO(mdempsky): Avoid requiring a function closure for each edge.
+       report func(check *Checker)
+}
+
+func (check *Checker) monomorph() {
+       // We detect unbounded instantiation cycles using a variant of
+       // Bellman-Ford's algorithm. Namely, instead of always running |V|
+       // iterations, we run until we either reach a fixed point or we've
+       // found a path of length |V|. This allows us to terminate earlier
+       // when there are no cycles, which should be the common case.
+
+       again := true
+       for again {
+               again = false
+
+               for i, edge := range check.mono.edges {
+                       src := &check.mono.vertices[edge.src]
+                       dst := &check.mono.vertices[edge.dst]
+
+                       // N.B., we're looking for the greatest weight paths, unlike
+                       // typical Bellman-Ford.
+                       w := src.weight + edge.weight
+                       if w <= dst.weight {
+                               continue
+                       }
+
+                       dst.pre = i
+                       dst.len = src.len + 1
+                       if dst.len == len(check.mono.vertices) {
+                               check.reportInstanceLoop(edge.dst)
+                               return
+                       }
+
+                       dst.weight = w
+                       again = true
+               }
+       }
+}
+
+func (check *Checker) reportInstanceLoop(v int) {
+       var stack []int
+       seen := make([]bool, len(check.mono.vertices))
+
+       // We have a path that contains a cycle and ends at v, but v may
+       // only be reachable from the cycle, not on the cycle itself. We
+       // start by walking backwards along the path until we find a vertex
+       // that appears twice.
+       for !seen[v] {
+               stack = append(stack, v)
+               seen[v] = true
+               v = check.mono.edges[check.mono.vertices[v].pre].src
+       }
+
+       // Trim any vertices we visited before visiting v the first
+       // time. Since v is the first vertex we found within the cycle, any
+       // vertices we visited earlier cannot be part of the cycle.
+       for stack[0] != v {
+               stack = stack[1:]
+       }
+
+       // TODO(mdempsky): Pivot stack so we report the cycle from the top?
+
+       obj := check.mono.vertices[v].obj
+       check.errorf(obj, _InvalidInstanceCycle, "instantiation cycle:")
+
+       for _, v := range stack {
+               edge := check.mono.edges[check.mono.vertices[v].pre]
+               edge.report(check)
+       }
+}
+
+// recordCanon records that tpar is the canonical type parameter
+// corresponding to method type parameter mpar.
+func (w *monoGraph) recordCanon(mpar, tpar *TypeParam) {
+       if w.canon == nil {
+               w.canon = make(map[*TypeParam]*TypeParam)
+       }
+       w.canon[mpar] = tpar
+}
+
+// recordInstance records that the given type parameters were
+// instantiated with the corresponding type arguments.
+func (w *monoGraph) recordInstance(pkg *Package, pos token.Pos, tparams []*TypeParam, targs []Type, posList []token.Pos) {
+       for i, tpar := range tparams {
+               pos := pos
+               if i < len(posList) {
+                       pos = posList[i]
+               }
+               w.assign(pkg, pos, tpar, targs[i])
+       }
+}
+
+// assign records that tpar was instantiated as targ at pos.
+func (w *monoGraph) assign(pkg *Package, pos token.Pos, tpar *TypeParam, targ Type) {
+       // Go generics do not have an analog to C++`s template-templates,
+       // where a template parameter can itself be an instantiable
+       // template. So any instantiation cycles must occur within a single
+       // package. Accordingly, we can ignore instantiations of imported
+       // type parameters.
+       //
+       // TODO(mdempsky): Push this check up into recordInstance? All type
+       // parameters in a list will appear in the same package.
+       if tpar.Obj().Pkg() != pkg {
+               return
+       }
+
+       // flow adds an edge from vertex src representing that typ flows to tpar.
+       flow := func(src int, typ Type) {
+               weight := 1
+               if typ == targ {
+                       weight = 0
+               }
+
+               w.addEdge(w.typeParamVertex(tpar), src, weight, func(check *Checker) {
+                       qf := RelativeTo(check.pkg)
+                       check.errorf(atPos(pos), _InvalidInstanceCycle, "\t%s instantiated as %s", tpar.Obj().Name(), TypeString(targ, qf)) // secondary error, \t indented
+               })
+       }
+
+       // Recursively walk the type argument to find any defined types or
+       // type parameters.
+       var do func(typ Type)
+       do = func(typ Type) {
+               switch typ := typ.(type) {
+               default:
+                       panic("unexpected type")
+
+               case *TypeParam:
+                       assert(typ.Obj().Pkg() == pkg)
+                       flow(w.typeParamVertex(typ), typ)
+
+               case *Named:
+                       if src := w.localNamedVertex(pkg, typ.Origin()); src >= 0 {
+                               flow(src, typ)
+                       }
+
+                       targs := typ.TypeArgs()
+                       for i := 0; i < targs.Len(); i++ {
+                               do(targs.At(i))
+                       }
+
+               case *Array:
+                       do(typ.Elem())
+               case *Basic:
+                       // ok
+               case *Chan:
+                       do(typ.Elem())
+               case *Map:
+                       do(typ.Key())
+                       do(typ.Elem())
+               case *Pointer:
+                       do(typ.Elem())
+               case *Slice:
+                       do(typ.Elem())
+
+               case *Interface:
+                       for i := 0; i < typ.NumMethods(); i++ {
+                               do(typ.Method(i).Type())
+                       }
+               case *Signature:
+                       tuple := func(tup *Tuple) {
+                               for i := 0; i < tup.Len(); i++ {
+                                       do(tup.At(i).Type())
+                               }
+                       }
+                       tuple(typ.Params())
+                       tuple(typ.Results())
+               case *Struct:
+                       for i := 0; i < typ.NumFields(); i++ {
+                               do(typ.Field(i).Type())
+                       }
+               }
+       }
+       do(targ)
+}
+
+// localNamedVertex returns the index of the vertex representing
+// named, or -1 if named doesn't need representation.
+func (w *monoGraph) localNamedVertex(pkg *Package, named *Named) int {
+       obj := named.Obj()
+       if obj.Pkg() != pkg {
+               return -1 // imported type
+       }
+
+       root := pkg.Scope()
+       if obj.Parent() == root {
+               return -1 // package scope, no ambient type parameters
+       }
+
+       if idx, ok := w.nameIdx[obj]; ok {
+               return idx
+       }
+
+       idx := -1
+
+       // Walk the type definition's scope to find any ambient type
+       // parameters that it's implicitly parameterized by.
+       for scope := obj.Parent(); scope != root; scope = scope.Parent() {
+               for _, elem := range scope.elems {
+                       if elem, ok := elem.(*TypeName); ok && !elem.IsAlias() && elem.Pos() < obj.Pos() {
+                               if tpar, ok := elem.Type().(*TypeParam); ok {
+                                       if idx < 0 {
+                                               idx = len(w.vertices)
+                                               w.vertices = append(w.vertices, monoVertex{obj: obj})
+                                       }
+
+                                       w.addEdge(idx, w.typeParamVertex(tpar), 1, func(check *Checker) {
+                                               check.errorf(obj, _InvalidInstanceCycle, "\t%s implicitly parameterized by %s", obj.Name(), elem.Name())
+                                       })
+                               }
+                       }
+               }
+       }
+
+       if w.nameIdx == nil {
+               w.nameIdx = make(map[*TypeName]int)
+       }
+       w.nameIdx[obj] = idx
+       return idx
+}
+
+// typeParamVertex returns the index of the vertex representing tpar.
+func (w *monoGraph) typeParamVertex(tpar *TypeParam) int {
+       if x, ok := w.canon[tpar]; ok {
+               tpar = x
+       }
+
+       obj := tpar.Obj()
+
+       if idx, ok := w.nameIdx[obj]; ok {
+               return idx
+       }
+
+       if w.nameIdx == nil {
+               w.nameIdx = make(map[*TypeName]int)
+       }
+
+       idx := len(w.vertices)
+       w.vertices = append(w.vertices, monoVertex{obj: obj})
+       w.nameIdx[obj] = idx
+       return idx
+}
+
+func (w *monoGraph) addEdge(dst, src, weight int, report func(check *Checker)) {
+       // TODO(mdempsky): Deduplicate redundant edges?
+       w.edges = append(w.edges, monoEdge{
+               dst:    dst,
+               src:    src,
+               weight: weight,
+               report: report,
+       })
+}
diff --git a/src/go/types/mono_test.go b/src/go/types/mono_test.go
new file mode 100644 (file)
index 0000000..c4c5282
--- /dev/null
@@ -0,0 +1,92 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package types_test
+
+import (
+       "bytes"
+       "errors"
+       "fmt"
+       "go/ast"
+       "go/importer"
+       "go/parser"
+       "go/token"
+       "go/types"
+       "strings"
+       "testing"
+)
+
+func checkMono(t *testing.T, body string) error {
+       fset := token.NewFileSet()
+       file, err := parser.ParseFile(fset, "x.go", "package x; import `unsafe`; var _ unsafe.Pointer;\n"+body, 0)
+       if err != nil {
+               t.Fatal(err)
+       }
+       files := []*ast.File{file}
+
+       var buf bytes.Buffer
+       conf := types.Config{
+               Error:    func(err error) { fmt.Fprintln(&buf, err) },
+               Importer: importer.Default(),
+       }
+       conf.Check("x", fset, files, nil)
+       if buf.Len() == 0 {
+               return nil
+       }
+       return errors.New(strings.TrimRight(buf.String(), "\n"))
+}
+
+func TestMonoGood(t *testing.T) {
+       for i, good := range goods {
+               if err := checkMono(t, good); err != nil {
+                       t.Errorf("%d: unexpected failure: %v", i, err)
+               }
+       }
+}
+
+func TestMonoBad(t *testing.T) {
+       for i, bad := range bads {
+               if err := checkMono(t, bad); err == nil {
+                       t.Errorf("%d: unexpected success", i)
+               } else {
+                       t.Log(err)
+               }
+       }
+}
+
+var goods = []string{
+       "func F[T any](x T) { F(x) }",
+       "func F[T, U, V any]() { F[U, V, T](); F[V, T, U]() }",
+       "type Ring[A, B, C any] struct { L *Ring[B, C, A]; R *Ring[C, A, B] }",
+       "func F[T any]() { type U[T any] [unsafe.Sizeof(F[*T])]byte }",
+       "func F[T any]() { type U[T any] [unsafe.Sizeof(F[*T])]byte; var _ U[int] }",
+       "type U[T any] [unsafe.Sizeof(F[*T])]byte; func F[T any]() { var _ U[U[int]] }",
+       "func F[T any]() { type A = int; F[A]() }",
+}
+
+// TODO(mdempsky): Validate specific error messages and positioning.
+
+var bads = []string{
+       "func F[T any](x T) { F(&x) }",
+       "func F[T any]() { F[*T]() }",
+       "func F[T any]() { F[[]T]() }",
+       "func F[T any]() { F[[1]T]() }",
+       "func F[T any]() { F[chan T]() }",
+       "func F[T any]() { F[map[*T]int]() }",
+       "func F[T any]() { F[map[error]T]() }",
+       "func F[T any]() { F[func(T)]() }",
+       "func F[T any]() { F[func() T]() }",
+       "func F[T any]() { F[struct{ t T }]() }",
+       "func F[T any]() { F[interface{ t() T }]() }",
+       "type U[_ any] int; func F[T any]() { F[U[T]]() }",
+       "func F[T any]() { type U int; F[U]() }",
+       "func F[T any]() { type U int; F[*U]() }",
+       "type U[T any] int; func (U[T]) m() { var _ U[*T] }",
+       "type U[T any] int; func (*U[T]) m() { var _ U[*T] }",
+       "type U[T any] [unsafe.Sizeof(F[*T])]byte; func F[T any]() { var _ U[T] }",
+       "func F[A, B, C, D, E any]() { F[B, C, D, E, *A]() }",
+       "type U[_ any] int; const X = unsafe.Sizeof(func() { type A[T any] U[A[*T]] })",
+       "func F[T any]() { type A = *T; F[A]() }",
+       "type A[T any] struct { _ A[*T] }",
+}
index c83bf090320ca1ed735d70a5e1525a73e5283bbe..ad69c95d1288fd88ab104fc369bd9de827da69c8 100644 (file)
@@ -148,6 +148,7 @@ func (check *Checker) funcType(sig *Signature, recvPar *ast.FieldList, ftyp *ast
                                list := make([]Type, sig.RecvTypeParams().Len())
                                for i, t := range sig.RecvTypeParams().list() {
                                        list[i] = t
+                                       check.mono.recordCanon(t, recvTParams[i])
                                }
                                smap := makeSubstMap(recvTParams, list)
                                for i, tpar := range sig.RecvTypeParams().list() {
index ca4b6d932163b45eada0e2eb291961dfb56cfce8..d8ff7c8cf46d1b957f72bd8c1a856f693a30c9c3 100644 (file)
@@ -8,7 +8,7 @@ type Fooer interface {
        Foo()
 }
 
-type Fooable[F Fooer] struct {
+type Fooable[F /* ERROR instantiation cycle */ Fooer] struct {
        ptr F
 }
 
index 092e355b38794eaf71b87e1a5495a6cae867c55b..3636c8556a1f113c17fae381a69dd6dee77cdf62 100644 (file)
@@ -457,6 +457,8 @@ func (check *Checker) instantiatedType(x ast.Expr, targsx []ast.Expr, def *Named
                                        pos = posList[i]
                                }
                                check.softErrorf(atPos(pos), _Todo, err.Error())
+                       } else {
+                               check.mono.recordInstance(check.pkg, x.Pos(), inst.tparams.list(), inst.targs.list(), posList)
                        }
                }