]> Cypherpunks repositories - gostls13.git/commitdiff
go/types: combine all type inference in a single function
authorRob Findley <rfindley@google.com>
Wed, 21 Apr 2021 02:59:59 +0000 (22:59 -0400)
committerRobert Findley <rfindley@google.com>
Wed, 21 Apr 2021 23:21:55 +0000 (23:21 +0000)
This is a port of CL 306170 to go/types, adjusted for the different
positioning API.

Some of the error positions in tests had to be adjusted, but I think the
new locations are better.

Change-Id: Ib157fbb47d7483e3c6302bd57f5070bd74602a36
Reviewed-on: https://go-review.googlesource.com/c/go/+/312191
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/cmd/compile/internal/types2/fixedbugs/issue40056.go2
src/go/types/call.go
src/go/types/fixedbugs/issue39634.go2
src/go/types/fixedbugs/issue40056.go2
src/go/types/infer.go
src/go/types/testdata/typeparams.go2

index 0c78c3f289a0d78c2b023c1a630fde8dc85b4b7e..98ded7c49a186f8adc2d9851f6e6169db3229090 100644 (file)
@@ -13,4 +13,4 @@ type S struct {}
 
 func NewS[T any]() *S
 
-func (_ *S /* ERROR S is not a generic type */ [T]) M()
\ No newline at end of file
+func (_ *S /* ERROR S is not a generic type */ [T]) M()
index 8fd0f2dd2b1236a90032da07b9b29fba8a8f4dac..4834bd02c12a94c2e2722e5d089e681884acdd62 100644 (file)
@@ -26,54 +26,39 @@ func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) {
        }
        assert(len(targs) == len(xlist))
 
-       // check number of type arguments
-       n := len(targs)
+       // check number of type arguments (got) vs number of type parameters (want)
        sig := x.typ.(*Signature)
-       if n > len(sig.tparams) {
-               check.errorf(xlist[n-1], _Todo, "got %d type arguments but want %d", n, len(sig.tparams))
+       got, want := len(targs), len(sig.tparams)
+       if got > want {
+               check.errorf(xlist[got-1], _Todo, "got %d type arguments but want %d", got, want)
                x.mode = invalid
                x.expr = inst
                return
        }
 
-       // determine argument positions (for error reporting)
-       // TODO(rFindley) use a positioner here? instantiate would need to be
-       //                updated accordingly.
-       poslist := make([]token.Pos, n)
-       for i, x := range xlist {
-               poslist[i] = x.Pos()
-       }
+       // if we don't have enough type arguments, try type inference
+       inferred := false
 
-       // if we don't have enough type arguments, use constraint type inference
-       var inferred bool
-       if n < len(sig.tparams) {
-               var failed int
-               targs, failed = check.inferB(sig.tparams, targs)
+       if got < want {
+               targs = check.infer(inst, sig.tparams, targs, nil, nil, true)
                if targs == nil {
                        // error was already reported
                        x.mode = invalid
                        x.expr = inst
                        return
                }
-               if failed >= 0 {
-                       // at least one type argument couldn't be inferred
-                       assert(targs[failed] == nil)
-                       tpar := sig.tparams[failed]
-                       check.errorf(inNode(inst, inst.Rbrack), 0, "cannot infer %s (%v) (%s)", tpar.name, tpar.pos, targs)
-                       x.mode = invalid
-                       x.expr = inst
-                       return
-               }
-               // all type arguments were inferred successfully
-               if debug {
-                       for _, targ := range targs {
-                               assert(targ != nil)
-                       }
-               }
-               n = len(targs)
+               got = len(targs)
                inferred = true
        }
-       assert(n == len(sig.tparams))
+       assert(got == want)
+
+       // determine argument positions (for error reporting)
+       // TODO(rFindley) use a positioner here? instantiate would need to be
+       //                updated accordingly.
+       poslist := make([]token.Pos, len(xlist))
+       for i, x := range xlist {
+               poslist[i] = x.Pos()
+       }
 
        // instantiate function signature
        res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature)
@@ -307,32 +292,10 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, args []*oper
        if len(sig.tparams) > 0 {
                // TODO(gri) provide position information for targs so we can feed
                //           it to the instantiate call for better error reporting
-               targs, failed := check.infer(sig.tparams, sigParams, args)
+               targs := check.infer(call, sig.tparams, nil, sigParams, args, true)
                if targs == nil {
                        return // error already reported
                }
-               if failed >= 0 {
-                       // Some type arguments couldn't be inferred. Use
-                       // bounds type inference to try to make progress.
-                       targs, failed = check.inferB(sig.tparams, targs)
-                       if targs == nil {
-                               return // error already reported
-                       }
-                       if failed >= 0 {
-                               // at least one type argument couldn't be inferred
-                               assert(targs[failed] == nil)
-                               tpar := sig.tparams[failed]
-                               ppos := check.fset.Position(tpar.pos).String()
-                               check.errorf(inNode(call, call.Rparen), _Todo, "cannot infer %s (%s) (%s)", tpar.name, ppos, targs)
-                               return
-                       }
-               }
-               // all type arguments were inferred successfully
-               if debug {
-                       for _, targ := range targs {
-                               assert(targ != nil)
-                       }
-               }
 
                // compute result signature
                rsig = check.instantiate(call.Pos(), sig, targs, nil).(*Signature)
@@ -544,13 +507,13 @@ func (check *Checker) selector(x *operand, e *ast.SelectorExpr) {
                                        recv = recv.(*Pointer).base
                                }
                        }
+                       // Disable reporting of errors during inference below. If we're unable to infer
+                       // the receiver type arguments here, the receiver must be be otherwise invalid
+                       // and an error has been reported elsewhere.
                        arg := operand{mode: variable, expr: x.expr, typ: recv}
-                       targs, failed := check.infer(sig.rparams, NewTuple(sig.recv), []*operand{&arg})
-                       if failed >= 0 {
+                       targs := check.infer(m, sig.rparams, nil, NewTuple(sig.recv), []*operand{&arg}, false /* no error reporting */)
+                       if targs == nil {
                                // We may reach here if there were other errors (see issue #40056).
-                               // check.infer will report a follow-up error.
-                               // TODO(gri) avoid the follow-up error as it is confusing
-                               //           (there's no inference in the source code)
                                goto Error
                        }
                        // Don't modify m. Instead - for now - make a copy of m and use that instead.
index f8585755c99b7ca5d176d97f86448a1a5a8bcdf1..af1f1e44c5a8a525c27fdedcd8967642e3c4cfb8 100644 (file)
@@ -89,4 +89,4 @@ func F26[Z any]() T26 { return F26[] /* ERROR operand */ }
 
 // crash 27
 func e27[T any]() interface{ x27 /* ERROR not a type */ }
-func x27() { e27() /* ERROR cannot infer T */ }
+func x27() { e27 /* ERROR cannot infer T */ () }
index 71074be67e20d1324e6dd80cb7da986192990f79..f587691e3d746a58bc22a2bfd6d9382230859d3c 100644 (file)
@@ -5,11 +5,11 @@
 package p
 
 func _() {
-       NewS() /* ERROR cannot infer T */ .M()
+       NewS /* ERROR cannot infer T */ ().M()
 }
 
 type S struct {}
 
 func NewS[T any]() *S
 
-func (_ *S /* ERROR S is not a generic type */ [T]) M()
\ No newline at end of file
+func (_ *S /* ERROR S is not a generic type */ [T]) M()
index 4b20836f885267f67e73f677fb4aa85194fef62b..9aae1bb2484f10ce1425105ce5358b6173a0566a 100644 (file)
@@ -12,22 +12,104 @@ import (
        "strings"
 )
 
-// infer returns the list of actual type arguments for the given list of type parameters tparams
-// by inferring them from the actual arguments args for the parameters params. If type inference
-// is impossible because unification fails, an error is reported and the resulting types list is
-// nil, and index is 0. Otherwise, types is the list of inferred type arguments, and index is
-// the index of the first type argument in that list that couldn't be inferred (and thus is nil).
-// If all type arguments were inferred successfully, index is < 0.
-func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand) (types []Type, index int) {
+// infer attempts to infer the complete set of type arguments for generic function instantiation/call
+// based on the given type parameters tparams, type arguments targs, function parameters params, and
+// function arguments args, if any. There must be at least one type parameter, no more type arguments
+// than type parameters, and params and args must match in number (incl. zero).
+// If successful, infer returns the complete list of type arguments, one for each type parameter.
+// Otherwise the result is nil and appropriate errors will be reported unless report is set to false.
+//
+// Inference proceeds in 3 steps:
+//
+//   1) Start with given type arguments.
+//   2) Infer type arguments from typed function arguments.
+//   3) Infer type arguments from untyped function arguments.
+//
+// Constraint type inference is used after each step to expand the set of type arguments.
+//
+func (check *Checker) infer(posn positioner, tparams []*TypeName, targs []Type, params *Tuple, args []*operand, report bool) (result []Type) {
+       if debug {
+               defer func() {
+                       assert(result == nil || len(result) == len(tparams))
+                       for _, targ := range result {
+                               assert(targ != nil)
+                       }
+                       //check.dump("### inferred targs = %s", result)
+               }()
+       }
+
+       // There must be at least one type parameter, and no more type arguments than type parameters.
+       n := len(tparams)
+       assert(n > 0 && len(targs) <= n)
+
+       // Function parameters and arguments must match in number.
        assert(params.Len() == len(args))
 
+       // --- 0 ---
+       // If we already have all type arguments, we're done.
+       if len(targs) == n {
+               return targs
+       }
+       // len(targs) < n
+
+       // --- 1 ---
+       // Explicitly provided type arguments take precedence over any inferred types;
+       // and types inferred via constraint type inference take precedence over types
+       // inferred from function arguments.
+       // If we have type arguments, see how far we get with constraint type inference.
+       if len(targs) > 0 {
+               var index int
+               targs, index = check.inferB(tparams, targs, report)
+               if targs == nil || index < 0 {
+                       return targs
+               }
+       }
+
+       // Continue with the type arguments we have now. Avoid matching generic
+       // parameters that already have type arguments against function arguments:
+       // It may fail because matching uses type identity while parameter passing
+       // uses assignment rules. Instantiate the parameter list with the type
+       // arguments we have, and continue with that parameter list.
+
+       // First, make sure we have a "full" list of type arguments, so of which
+       // may be nil (unknown).
+       if len(targs) < n {
+               targs2 := make([]Type, n)
+               copy(targs2, targs)
+               targs = targs2
+       }
+       // len(targs) == n
+
+       // Substitute type arguments for their respective type parameters in params,
+       // if any. Note that nil targs entries are ignored by check.subst.
+       // TODO(gri) Can we avoid this (we're setting known type argumemts below,
+       //           but that doesn't impact the isParameterized check for now).
+       if params.Len() > 0 {
+               smap := makeSubstMap(tparams, targs)
+               params = check.subst(token.NoPos, params, smap).(*Tuple)
+       }
+
+       // --- 2 ---
+       // Unify parameter and argument types for generic parameters with typed arguments
+       // and collect the indices of generic parameters with untyped arguments.
+       // Terminology: generic parameter = function parameter with a type-parameterized type
        u := newUnifier(check, false)
        u.x.init(tparams)
 
+       // Set the type arguments which we know already.
+       for i, targ := range targs {
+               if targ != nil {
+                       u.x.set(i, targ)
+               }
+       }
+
        errorf := func(kind string, tpar, targ Type, arg *operand) {
+               if !report {
+                       return
+               }
                // provide a better error message if we can
-               targs, failed := u.x.types()
-               if failed == 0 {
+               targs, index := u.x.types()
+               if index == 0 {
                        // The first type parameter couldn't be inferred.
                        // If none of them could be inferred, don't try
                        // to provide the inferred type in the error msg.
@@ -53,16 +135,13 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
                }
        }
 
-       // Terminology: generic parameter = function parameter with a type-parameterized type
-
-       // 1st pass: Unify parameter and argument types for generic parameters with typed arguments
-       //           and collect the indices of generic parameters with untyped arguments.
+       // indices of the generic parameters with untyped arguments - save for later
        var indices []int
        for i, arg := range args {
                par := params.At(i)
                // If we permit bidirectional unification, this conditional code needs to be
                // executed even if par.typ is not parameterized since the argument may be a
-               // generic function (for which we want to infer // its type arguments).
+               // generic function (for which we want to infer its type arguments).
                if isParameterized(tparams, par.typ) {
                        if arg.mode == invalid {
                                // An error was reported earlier. Ignore this targ
@@ -76,7 +155,7 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
                                // the respective type parameters of targ.
                                if !u.unify(par.typ, targ) {
                                        errorf("type", par.typ, targ, arg)
-                                       return nil, 0
+                                       return nil
                                }
                        } else {
                                indices = append(indices, i)
@@ -84,12 +163,25 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
                }
        }
 
-       // Some generic parameters with untyped arguments may have been given a type
-       // indirectly through another generic parameter with a typed argument; we can
-       // ignore those now. (This only means that we know the types for those generic
-       // parameters; it doesn't mean untyped arguments can be passed safely. We still
-       // need to verify that assignment of those arguments is valid when we check
-       // function parameter passing external to infer.)
+       // If we've got all type arguments, we're done.
+       var index int
+       targs, index = u.x.types()
+       if index < 0 {
+               return targs
+       }
+
+       // See how far we get with constraint type inference.
+       // Note that even if we don't have any type arguments, constraint type inference
+       // may produce results for constraints that explicitly specify a type.
+       targs, index = check.inferB(tparams, targs, report)
+       if targs == nil || index < 0 {
+               return targs
+       }
+
+       // --- 3 ---
+       // Use any untyped arguments to infer additional type arguments.
+       // Some generic parameters with untyped arguments may have been given
+       // a type by now, we can ignore them.
        j := 0
        for _, i := range indices {
                par := params.At(i)
@@ -98,14 +190,15 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
                // only parameter type it can possibly match against is a *TypeParam.
                // Thus, only keep the indices of generic parameters that are not of
                // composite types and which don't have a type inferred yet.
-               if tpar, _ := par.typ.(*_TypeParam); tpar != nil && u.x.at(tpar.index) == nil {
+               if tpar, _ := par.typ.(*_TypeParam); tpar != nil && targs[tpar.index] == nil {
                        indices[j] = i
                        j++
                }
        }
        indices = indices[:j]
 
-       // 2nd pass: Unify parameter and default argument types for remaining generic parameters.
+       // Unify parameter and default argument types for remaining generic parameters.
+       // TODO(gri) Rather than iterating again, combine this code with the loop above.
        for _, i := range indices {
                par := params.At(i)
                arg := args[i]
@@ -115,11 +208,29 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
                // nil by making sure all default argument types are typed.
                if isTyped(targ) && !u.unify(par.typ, targ) {
                        errorf("default type", par.typ, targ, arg)
-                       return nil, 0
+                       return nil
                }
        }
 
-       return u.x.types()
+       // If we've got all type arguments, we're done.
+       targs, index = u.x.types()
+       if index < 0 {
+               return targs
+       }
+
+       // Again, follow up with constraint type inference.
+       targs, index = check.inferB(tparams, targs, report)
+       if targs == nil || index < 0 {
+               return targs
+       }
+
+       // At least one type argument couldn't be inferred.
+       assert(targs != nil && index >= 0 && targs[index] == nil)
+       tpar := tparams[index]
+       if report {
+               check.errorf(posn, _Todo, "cannot infer %s (%v) (%v)", tpar.name, tpar.pos, targs)
+       }
+       return nil
 }
 
 // typeNamesString produces a string containing all the
@@ -268,12 +379,13 @@ func (w *tpWalker) isParameterizedList(list []Type) bool {
 
 // inferB returns the list of actual type arguments inferred from the type parameters'
 // bounds and an initial set of type arguments. If type inference is impossible because
-// unification fails, an error is reported, the resulting types list is nil, and index is 0.
+// unification fails, an error is reported if report is set to true, the resulting types
+// list is nil, and index is 0.
 // Otherwise, types is the list of inferred type arguments, and index is the index of the
 // first type argument in that list that couldn't be inferred (and thus is nil). If all
-// type arguments where inferred successfully, index is < 0. The number of type arguments
+// type arguments were inferred successfully, index is < 0. The number of type arguments
 // provided may be less than the number of type parameters, but there must be at least one.
-func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, index int) {
+func (check *Checker) inferB(tparams []*TypeName, targs []Type, report bool) (types []Type, index int) {
        assert(len(tparams) >= len(targs) && len(targs) > 0)
 
        // Setup bidirectional unification between those structural bounds
@@ -295,7 +407,9 @@ func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, i
                sbound := check.structuralType(typ.bound)
                if sbound != nil {
                        if !u.unify(typ, sbound) {
-                               check.errorf(tpar, 0, "%s does not match %s", tpar, sbound)
+                               if report {
+                                       check.errorf(tpar, 0, "%s does not match %s", tpar, sbound)
+                               }
                                return nil, 0
                        }
                }
index 1577ad6f8da62519ce7593b6e955cef88bc4ffa1..d95e02e44322443a3f31acc951991b71e8909f76 100644 (file)
@@ -178,17 +178,17 @@ func _[T interface{ type string, chan<-int }](x T) {
 
 // type inference checks
 
-var _ = new() /* ERROR cannot infer T */
+var _ = new /* ERROR cannot infer T */ ()
 
 func f4[A, B, C any](A, B) C
 
-var _ = f4(1, 2) /* ERROR cannot infer C */
+var _ = f4 /* ERROR cannot infer C */ (1, 2)
 var _ = f4[int, float32, complex128](1, 2)
 
 func f5[A, B, C any](A, []*B, struct{f []C}) int
 
 var _ = f5[int, float32, complex128](0, nil, struct{f []complex128}{})
-var _ = f5(0, nil, struct{f []complex128}{}) // ERROR cannot infer
+var _ = f5 /* ERROR cannot infer */ (0, nil, struct{f []complex128}{})
 var _ = f5(0, []*float32{new[float32]()}, struct{f []complex128}{})
 
 func f6[A any](A, []A) int
@@ -197,13 +197,13 @@ var _ = f6(0, nil)
 
 func f6nil[A any](A) int
 
-var _ = f6nil(nil) // ERROR cannot infer
+var _ = f6nil /* ERROR cannot infer */ (nil)
 
 // type inference with variadic functions
 
 func f7[T any](...T) T
 
-var _ int = f7() /* ERROR cannot infer T */
+var _ int = f7 /* ERROR cannot infer T */ ()
 var _ int = f7(1)
 var _ int = f7(1, 2)
 var _ int = f7([]int{}...)