// If DisableUnusedImportCheck is set, packages are not checked
// for unused imports.
DisableUnusedImportCheck bool
+
+ // If EnableReverseTypeInference is set, uninstantiated and
+ // partially instantiated generic functions may be assigned
+ // (incl. returned) to variables of function type and type
+ // inference will attempt to infer the missing type arguments.
+ // Experimental. Needs a proposal.
+ EnableReverseTypeInference bool
}
func srcimporter_setUsesCgo(conf *Config) {
}
var x operand
- check.expr(&x, lhs)
+ check.expr(nil, &x, lhs)
if v != nil {
v.used = v_used // restore v.used
default:
if sel, ok := x.expr.(*syntax.SelectorExpr); ok {
var op operand
- check.expr(&op, sel.X)
+ check.expr(nil, &op, sel.X)
if op.mode == mapindex {
check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", syntax.String(x.expr))
return Typ[Invalid]
return x.typ
}
-// assignVar checks the assignment lhs = x.
-func (check *Checker) assignVar(lhs syntax.Expr, x *operand) {
- if x.mode == invalid {
- check.useLHS(lhs)
+// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil).
+// If x != nil, it must be the evaluation of rhs (and rhs will be ignored).
+func (check *Checker) assignVar(lhs, rhs syntax.Expr, x *operand) {
+ T := check.lhsVar(lhs) // nil if lhs is _
+ if T == Typ[Invalid] {
+ check.use(rhs)
return
}
- T := check.lhsVar(lhs) // nil if lhs is _
- if T == Typ[Invalid] {
+ if x == nil {
+ x = new(operand)
+ check.expr(T, x, rhs)
+ }
+ if x.mode == invalid {
return
}
if l == r && !isCall {
var x operand
for i, lhs := range lhs {
- check.expr(&x, orig_rhs[i])
+ check.expr(lhs.typ, &x, orig_rhs[i])
check.initVar(lhs, &x, context)
}
return
// each value can be assigned to its corresponding variable.
if l == r && !isCall {
for i, lhs := range lhs {
- var x operand
- check.expr(&x, orig_rhs[i])
- check.assignVar(lhs, &x)
+ check.assignVar(lhs, orig_rhs[i], nil)
}
return
}
r = len(rhs)
if l == r {
for i, lhs := range lhs {
- check.assignVar(lhs, rhs[i])
+ check.assignVar(lhs, nil, rhs[i])
}
if commaOk {
check.recordCommaOkTypes(orig_rhs[0], rhs)
return
}
- check.expr(x, selx.X)
+ check.expr(nil, x, selx.X)
if x.mode == invalid {
return
}
var t operand
x1 := x
for _, arg := range call.ArgList {
- check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
+ check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.dump("%v: %s", posFor(x1), x1)
x1 = &t // use incoming x only for first argument
}
import (
"cmd/compile/internal/syntax"
+ "fmt"
. "internal/types/errors"
"strings"
"unicode"
)
-// funcInst type-checks a function instantiation inst and returns the result in x.
-// The operand x must be the evaluation of inst.X and its type must be a signature.
-func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
+// funcInst type-checks a function instantiation and returns the result in x.
+// The incoming x must be an uninstantiated generic function. If inst != 0,
+// it provides (some or all of) the type arguments (inst.Index) for the
+// instantiation. If the target type T != nil and is a (non-generic) function
+// signature, the signature's parameter types are used to infer additional
+// missing type arguments of x, if any.
+// At least one of inst or T must be provided.
+func (check *Checker) funcInst(T Type, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) {
if !check.allowVersion(check.pkg, 1, 18) {
check.versionErrorf(inst.Pos(), "go1.18", "function instantiation")
}
- xlist := unpackExpr(inst.Index)
- targs := check.typeList(xlist)
- if targs == nil {
- x.mode = invalid
- x.expr = inst
- return
+ // tsig is the (assignment) target function signature, or nil.
+ // TODO(gri) refactor and pass in tsig to funcInst instead
+ var tsig *Signature
+ if check.conf.EnableReverseTypeInference && T != nil {
+ tsig, _ = under(T).(*Signature)
}
- assert(len(targs) == len(xlist))
- // check number of type arguments (got) vs number of type parameters (want)
+ // targs and xlist are the type arguments and corresponding type expressions, or nil.
+ var targs []Type
+ var xlist []syntax.Expr
+ if inst != nil {
+ xlist = unpackExpr(inst.Index)
+ targs = check.typeList(xlist)
+ if targs == nil {
+ x.mode = invalid
+ x.expr = inst
+ return
+ }
+ assert(len(targs) == len(xlist))
+ }
+
+ assert(tsig != nil || targs != nil)
+
+ // Check the number of type arguments (got) vs number of type parameters (want).
+ // Note that x is a function value, not a type expression, so we don't need to
+ // call under below.
sig := x.typ.(*Signature)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
+ // Providing too many type arguments is always an error.
check.errorf(xlist[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = inst
}
if got < want {
- targs = check.infer(inst.Pos(), sig.TypeParams().list(), targs, nil, nil)
+ // If the uninstantiated or partially instantiated function x is used in an
+ // assignment (tsig != nil), use the respective function parameter and result
+ // types to infer additional type arguments.
+ var args []*operand
+ var params []*Var
+ if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
+ // x is a generic function and the signature arity matches the target function.
+ // To infer x's missing type arguments, treat the function assignment as a call
+ // of a synthetic function f where f's parameters are the parameters and results
+ // of x and where the arguments to the call of f are values of the parameter and
+ // result types of x.
+ n := tsig.params.Len()
+ m := tsig.results.Len()
+ args = make([]*operand, n+m)
+ params = make([]*Var, n+m)
+ for i := 0; i < n; i++ {
+ lvar := tsig.params.At(i)
+ lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "parameter"))
+ args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+ params[i] = sig.params.At(i)
+ }
+ for i := 0; i < m; i++ {
+ lvar := tsig.results.At(i)
+ lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "result parameter"))
+ args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+ params[n+i] = sig.results.At(i)
+ }
+ }
+
+ // Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
+ targs = check.infer(pos, sig.TypeParams().list(), targs, NewTuple(params...), args)
if targs == nil {
// error was already reported
x.mode = invalid
// instantiate function signature
sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
- check.recordInstance(inst.X, targs, sig)
+
x.typ = sig
x.mode = value
- x.expr = inst
+ // If we don't have an index expression, keep the existing expression of x.
+ if inst != nil {
+ x.expr = inst
+ }
+ check.recordInstance(x.expr, targs, sig)
+}
+
+func paramName(name string, i int, kind string) string {
+ if name != "" {
+ return name
+ }
+ return nth(i+1) + " " + kind
+}
+
+func nth(n int) string {
+ switch n {
+ case 1:
+ return "1st"
+ case 2:
+ return "2nd"
+ case 3:
+ return "3rd"
+ }
+ return fmt.Sprintf("%dth", n)
}
func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs []Type, xlist []syntax.Expr) (res *Signature) {
case typexpr:
// conversion
- check.nonGeneric(x)
+ check.nonGeneric(nil, x)
if x.mode == invalid {
return conversion
}
case 0:
check.errorf(call, WrongArgCount, "missing argument in conversion to %s", T)
case 1:
- check.expr(x, call.ArgList[0])
+ check.expr(nil, x, call.ArgList[0])
if x.mode != invalid {
if t, _ := under(T).(*Interface); t != nil && !isTypeParam(T) {
if !t.IsMethodSet() {
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
- check.expr(&x, e)
+ check.expr(nil, &x, e)
xlist[i] = &x
}
}
}
}
}
- check.rawExpr(&x, n, nil, true)
+ check.rawExpr(nil, &x, n, nil, true)
if v != nil {
v.used = v_used // restore v.used
}
case *syntax.ListExpr:
return check.useN(n.ElemList, lhs)
default:
- check.rawExpr(&x, e, nil, true)
+ check.rawExpr(nil, &x, e, nil, true)
}
return x.mode != invalid
}
flags := flag.NewFlagSet("", flag.PanicOnError)
flags.StringVar(&conf.GoVersion, "lang", "", "")
flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "")
+ flags.BoolVar(&conf.EnableReverseTypeInference, "reverseTypeInference", false, "")
if err := parseFlags(filenames[0], nil, flags); err != nil {
t.Fatal(err)
}
// (see issues go.dev/issue/42991, go.dev/issue/42992).
check.errpos = obj.pos
}
- check.expr(&x, init)
+ check.expr(nil, &x, init)
}
check.initConst(obj, &x)
}
if lhs == nil || len(lhs) == 1 {
assert(lhs == nil || lhs[0] == obj)
var x operand
- check.expr(&x, init)
+ check.expr(obj.typ, &x, init)
check.initVar(obj, &x, "variable declaration")
return
}
}
func (check *Checker) unary(x *operand, e *syntax.Operation) {
- check.expr(x, e.X)
+ check.expr(nil, x, e.X)
if x.mode == invalid {
return
}
func (check *Checker) binary(x *operand, e syntax.Expr, lhs, rhs syntax.Expr, op syntax.Operator) {
var y operand
- check.expr(x, lhs)
- check.expr(&y, rhs)
+ check.expr(nil, x, lhs)
+ check.expr(nil, &y, rhs)
if x.mode == invalid {
return
statement
)
+// TODO(gri) In rawExpr below, consider using T instead of hint and
+// some sort of "operation mode" instead of allowGeneric.
+// May be clearer and less error-prone.
+
// rawExpr typechecks expression e and initializes x with the expression
// value or type. If an error occurred, x.mode is set to invalid.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
// If hint != nil, it is the type of a composite literal element.
// If allowGeneric is set, the operand type may be an uninstantiated
// parameterized type or function value.
-func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind {
+func (check *Checker) rawExpr(T Type, x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind {
if check.conf.Trace {
check.trace(e.Pos(), "-- expr %s", e)
check.indent++
}()
}
- kind := check.exprInternal(x, e, hint)
+ kind := check.exprInternal(T, x, e, hint)
if !allowGeneric {
- check.nonGeneric(x)
+ check.nonGeneric(T, x)
}
check.record(x)
return kind
}
-// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
+// If x is a generic type, or a generic function whose type arguments cannot be inferred
+// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ.
// Otherwise it leaves x alone.
-func (check *Checker) nonGeneric(x *operand) {
+func (check *Checker) nonGeneric(T Type, x *operand) {
if x.mode == invalid || x.mode == novalue {
return
}
}
case *Signature:
if t.tparams != nil {
+ if check.conf.EnableReverseTypeInference && T != nil {
+ if _, ok := under(T).(*Signature); ok {
+ check.funcInst(T, x.Pos(), x, nil)
+ return
+ }
+ }
what = "function"
}
}
// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
-func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKind {
+// (See rawExpr for an explanation of the parameters.)
+func (check *Checker) exprInternal(T Type, x *operand, e syntax.Expr, hint Type) exprKind {
// make sure x has a valid state in case of bailout
// (was go.dev/issue/5770)
x.mode = invalid
key, _ := kv.Key.(*syntax.Name)
// do all possible checks early (before exiting due to errors)
// so we don't drop information on the floor
- check.expr(x, kv.Value)
+ check.expr(nil, x, kv.Value)
if key == nil {
check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
continue
check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
continue
}
- check.expr(x, e)
+ check.expr(nil, x, e)
if i >= len(fields) {
check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
break // cannot continue
x.typ = typ
case *syntax.ParenExpr:
- kind := check.rawExpr(x, e.X, nil, false)
+ // type inference doesn't go past parentheses (targe type T = nil)
+ kind := check.rawExpr(nil, x, e.X, nil, false)
x.expr = e
return kind
case *syntax.IndexExpr:
if check.indexExpr(x, e) {
- check.funcInst(x, e)
+ check.funcInst(T, e.Pos(), x, e)
}
if x.mode == invalid {
goto Error
}
case *syntax.AssertExpr:
- check.expr(x, e.X)
+ check.expr(nil, x, e.X)
if x.mode == invalid {
goto Error
}
check.error(e, InvalidSyntaxTree, "invalid use of AssertExpr")
goto Error
}
- // TODO(gri) we may want to permit type assertions on type parameter values at some point
if isTypeParam(x.typ) {
check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
goto Error
}
// expr typechecks expression e and initializes x with the expression value.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
// The result must be a single value.
// If an error occurred, x.mode is set to invalid.
-func (check *Checker) expr(x *operand, e syntax.Expr) {
- check.rawExpr(x, e, nil, false)
+func (check *Checker) expr(T Type, x *operand, e syntax.Expr) {
+ check.rawExpr(T, x, e, nil, false)
+ check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
+ check.singleValue(x)
+}
+
+// genericExpr is like expr but the result may also be generic.
+func (check *Checker) genericExpr(x *operand, e syntax.Expr) {
+ check.rawExpr(nil, x, e, nil, true)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// If an error occurred, list[0] is not valid.
func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*operand, commaOk bool) {
var x operand
- check.rawExpr(&x, e, nil, false)
+ check.rawExpr(nil, &x, e, nil, false)
check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprWithHint(x *operand, e syntax.Expr, hint Type) {
assert(hint != nil)
- check.rawExpr(x, e, hint, false)
+ check.rawExpr(nil, x, e, hint, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprOrType(x *operand, e syntax.Expr, allowGeneric bool) {
- check.rawExpr(x, e, nil, allowGeneric)
+ check.rawExpr(nil, x, e, nil, allowGeneric)
check.exclude(x, 1<<novalue)
check.singleValue(x)
}
}
// x should not be generic at this point, but be safe and check
- check.nonGeneric(x)
+ check.nonGeneric(nil, x)
if x.mode == invalid {
return false
}
return false
}
var key operand
- check.expr(&key, index)
+ check.expr(nil, &key, index)
check.assignment(&key, typ.key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
return false
}
var k operand
- check.expr(&k, index)
+ check.expr(nil, &k, index)
check.assignment(&k, key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
}
func (check *Checker) sliceExpr(x *operand, e *syntax.SliceExpr) {
- check.expr(x, e.X)
+ check.expr(nil, x, e.X)
if x.mode == invalid {
check.use(e.Index[:]...)
return
val = -1
var x operand
- check.expr(&x, index)
+ check.expr(nil, &x, index)
if !check.isValidIndex(&x, InvalidIndex, "index", false) {
return
}
var x operand
var msg string
- switch check.rawExpr(&x, call, nil, false) {
+ switch check.rawExpr(nil, &x, call, nil, false) {
case conversion:
msg = "requires function call, not conversion"
case expression:
L:
for _, e := range values {
var v operand
- check.expr(&v, e)
+ check.expr(nil, &v, e)
if x.mode == invalid || v.mode == invalid {
continue L
}
// The spec allows the value nil instead of a type.
if check.isNil(e) {
T = nil
- check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+ check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
} else {
T = check.varType(e)
if T == Typ[Invalid] {
// // The spec allows the value nil instead of a type.
// var hash string
// if check.isNil(e) {
-// check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+// check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
// T = nil
// hash = "<nil>" // avoid collision with a type named nil
// } else {
// function and method calls and receive operations can appear
// in statement context. Such statements may be parenthesized."
var x operand
- kind := check.rawExpr(&x, s.X, nil, false)
+ kind := check.rawExpr(nil, &x, s.X, nil, false)
var msg string
var code Code
switch x.mode {
case *syntax.SendStmt:
var ch, val operand
- check.expr(&ch, s.Chan)
- check.expr(&val, s.Value)
+ check.expr(nil, &ch, s.Chan)
+ check.expr(nil, &val, s.Value)
if ch.mode == invalid || val.mode == invalid {
return
}
// x++ or x--
// (no need to call unpackExpr as s.Lhs must be single-valued)
var x operand
- check.expr(&x, s.Lhs)
+ check.expr(nil, &x, s.Lhs)
if x.mode == invalid {
return
}
check.errorf(s.Lhs, NonNumericIncDec, invalidOp+"%s%s%s (non-numeric type %s)", s.Lhs, s.Op, s.Op, x.typ)
return
}
- check.assignVar(s.Lhs, &x)
+ check.assignVar(s.Lhs, nil, &x)
return
}
var x operand
check.binary(&x, nil, lhs[0], rhs[0], s.Op)
- check.assignVar(lhs[0], &x)
+ check.assignVar(lhs[0], nil, &x)
case *syntax.CallStmt:
kind := "go"
check.simpleStmt(s.Init)
var x operand
- check.expr(&x, s.Cond)
+ check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
}
check.simpleStmt(s.Init)
if s.Cond != nil {
var x operand
- check.expr(&x, s.Cond)
+ check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
}
var x operand
if s.Tag != nil {
- check.expr(&x, s.Tag)
+ check.expr(nil, &x, s.Tag)
// By checking assignment of x to an invisible temporary
// (as a compiler would), we get all the relevant checks.
check.assignment(&x, nil, "switch expression")
// check rhs
var x operand
- check.expr(&x, guard.X)
+ check.expr(nil, &x, guard.X)
if x.mode == invalid {
return
}
// check expression to iterate over
var x operand
- check.expr(&x, rclause.X)
+ check.expr(nil, &x, rclause.X)
// determine key/value types
var key, val Type
x.mode = value
x.expr = lhs // we don't have a better rhs expression to use here
x.typ = typ
- check.assignVar(lhs, &x)
+ check.assignVar(lhs, nil, &x)
}
}
}
}
var x operand
- check.expr(&x, e)
+ check.expr(nil, &x, e)
if x.mode != constant_ {
if x.mode != invalid {
check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)
// If DisableUnusedImportCheck is set, packages are not checked
// for unused imports.
DisableUnusedImportCheck bool
+
+ // If _EnableReverseTypeInference is set, uninstantiated and
+ // partially instantiated generic functions may be assigned
+ // (incl. returned) to variables of function type and type
+ // inference will attempt to infer the missing type arguments.
+ // Experimental. Needs a proposal.
+ _EnableReverseTypeInference bool
}
func srcimporter_setUsesCgo(conf *Config) {
}
var x operand
- check.expr(&x, lhs)
+ check.expr(nil, &x, lhs)
if v != nil {
v.used = v_used // restore v.used
default:
if sel, ok := x.expr.(*ast.SelectorExpr); ok {
var op operand
- check.expr(&op, sel.X)
+ check.expr(nil, &op, sel.X)
if op.mode == mapindex {
check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", ExprString(x.expr))
return Typ[Invalid]
return x.typ
}
-// assignVar checks the assignment lhs = x.
-func (check *Checker) assignVar(lhs ast.Expr, x *operand) {
- if x.mode == invalid {
- check.useLHS(lhs)
+// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil).
+// If x != nil, it must be the evaluation of rhs (and rhs will be ignored).
+func (check *Checker) assignVar(lhs, rhs ast.Expr, x *operand) {
+ T := check.lhsVar(lhs) // nil if lhs is _
+ if T == Typ[Invalid] {
+ check.use(rhs)
return
}
- T := check.lhsVar(lhs) // nil if lhs is _
- if T == Typ[Invalid] {
+ if x == nil {
+ x = new(operand)
+ check.expr(T, x, rhs)
+ }
+ if x.mode == invalid {
return
}
if l == r && !isCall {
var x operand
for i, lhs := range lhs {
- check.expr(&x, orig_rhs[i])
+ check.expr(lhs.typ, &x, orig_rhs[i])
check.initVar(lhs, &x, context)
}
return
// each value can be assigned to its corresponding variable.
if l == r && !isCall {
for i, lhs := range lhs {
- var x operand
- check.expr(&x, orig_rhs[i])
- check.assignVar(lhs, &x)
+ check.assignVar(lhs, orig_rhs[i], nil)
}
return
}
r = len(rhs)
if l == r {
for i, lhs := range lhs {
- check.assignVar(lhs, rhs[i])
+ check.assignVar(lhs, nil, rhs[i])
}
if commaOk {
check.recordCommaOkTypes(orig_rhs[0], rhs)
return
}
- check.expr(x, selx.X)
+ check.expr(nil, x, selx.X)
if x.mode == invalid {
return
}
var t operand
x1 := x
for _, arg := range call.Args {
- check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
+ check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.dump("%v: %s", x1.Pos(), x1)
x1 = &t // use incoming x only for first argument
}
package types
import (
+ "fmt"
"go/ast"
"go/internal/typeparams"
"go/token"
"unicode"
)
-// funcInst type-checks a function instantiation inst and returns the result in x.
-// The operand x must be the evaluation of inst.X and its type must be a signature.
-func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
+// funcInst type-checks a function instantiation and returns the result in x.
+// The incoming x must be an uninstantiated generic function. If ix != 0,
+// it provides (some or all of) the type arguments (ix.Indices) for the
+// instantiation. If the target type T != nil and is a (non-generic) function
+// signature, the signature's parameter types are used to infer additional
+// missing type arguments of x, if any.
+// At least one of inst or T must be provided.
+func (check *Checker) funcInst(T Type, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
if !check.allowVersion(check.pkg, 1, 18) {
check.softErrorf(inNode(ix.Orig, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
}
- targs := check.typeList(ix.Indices)
- if targs == nil {
- x.mode = invalid
- x.expr = ix.Orig
- return
+ // tsig is the (assignment) target function signature, or nil.
+ // TODO(gri) refactor and pass in tsig to funcInst instead
+ var tsig *Signature
+ if check.conf._EnableReverseTypeInference && T != nil {
+ tsig, _ = under(T).(*Signature)
+ }
+
+ // targs and xlist are the type arguments and corresponding type expressions, or nil.
+ var targs []Type
+ var xlist []ast.Expr
+ if ix != nil {
+ xlist = ix.Indices
+ targs = check.typeList(xlist)
+ if targs == nil {
+ x.mode = invalid
+ x.expr = ix
+ return
+ }
+ assert(len(targs) == len(xlist))
}
- assert(len(targs) == len(ix.Indices))
- // check number of type arguments (got) vs number of type parameters (want)
+ assert(tsig != nil || targs != nil)
+
+ // Check the number of type arguments (got) vs number of type parameters (want).
+ // Note that x is a function value, not a type expression, so we don't need to
+ // call under below.
sig := x.typ.(*Signature)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
+ // Providing too many type arguments is always an error.
check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = ix.Orig
}
if got < want {
- targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil)
+ // If the uninstantiated or partially instantiated function x is used in an
+ // assignment (tsig != nil), use the respective function parameter and result
+ // types to infer additional type arguments.
+ var args []*operand
+ var params []*Var
+ if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
+ // x is a generic function and the signature arity matches the target function.
+ // To infer x's missing type arguments, treat the function assignment as a call
+ // of a synthetic function f where f's parameters are the parameters and results
+ // of x and where the arguments to the call of f are values of the parameter and
+ // result types of x.
+ n := tsig.params.Len()
+ m := tsig.results.Len()
+ args = make([]*operand, n+m)
+ params = make([]*Var, n+m)
+ for i := 0; i < n; i++ {
+ lvar := tsig.params.At(i)
+ lname := ast.NewIdent(paramName(lvar.name, i, "parameter"))
+ lname.NamePos = x.Pos() // correct position
+ args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+ params[i] = sig.params.At(i)
+ }
+ for i := 0; i < m; i++ {
+ lvar := tsig.results.At(i)
+ lname := ast.NewIdent(paramName(lvar.name, i, "result parameter"))
+ lname.NamePos = x.Pos() // correct position
+ args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+ params[n+i] = sig.results.At(i)
+ }
+ }
+
+ // Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
+ targs = check.infer(atPos(pos), sig.TypeParams().list(), targs, NewTuple(params...), args)
if targs == nil {
// error was already reported
x.mode = invalid
- x.expr = ix.Orig
+ x.expr = ix // TODO(gri) is this correct?
return
}
got = len(targs)
assert(got == want)
// instantiate function signature
- sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices)
+ sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
- check.recordInstance(ix.Orig, targs, sig)
+
x.typ = sig
x.mode = value
- x.expr = ix.Orig
+ // If we don't have an index expression, keep the existing expression of x.
+ if ix != nil {
+ x.expr = ix.Orig
+ }
+ check.recordInstance(x.expr, targs, sig)
+}
+
+func paramName(name string, i int, kind string) string {
+ if name != "" {
+ return name
+ }
+ return nth(i+1) + " " + kind
+}
+
+func nth(n int) string {
+ switch n {
+ case 1:
+ return "1st"
+ case 2:
+ return "2nd"
+ case 3:
+ return "3rd"
+ }
+ return fmt.Sprintf("%dth", n)
}
func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) {
case typexpr:
// conversion
- check.nonGeneric(x)
+ check.nonGeneric(nil, x)
if x.mode == invalid {
return conversion
}
case 0:
check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T)
case 1:
- check.expr(x, call.Args[0])
+ check.expr(nil, x, call.Args[0])
if x.mode != invalid {
if call.Ellipsis.IsValid() {
check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T)
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
- check.expr(&x, e)
+ check.expr(nil, &x, e)
xlist[i] = &x
}
}
}
}
}
- check.rawExpr(&x, n, nil, true)
+ check.rawExpr(nil, &x, n, nil, true)
if v != nil {
v.used = v_used // restore v.used
}
default:
- check.rawExpr(&x, e, nil, true)
+ check.rawExpr(nil, &x, e, nil, true)
}
return x.mode != invalid
}
flags := flag.NewFlagSet("", flag.PanicOnError)
flags.StringVar(&conf.GoVersion, "lang", "", "")
flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "")
+ flags.BoolVar(boolFieldAddr(&conf, "_EnableReverseTypeInference"), "reverseTypeInference", false, "")
if err := parseFlags(filenames[0], srcs[0], flags); err != nil {
t.Fatal(err)
}
// (see issues go.dev/issue/42991, go.dev/issue/42992).
check.errpos = atPos(obj.pos)
}
- check.expr(&x, init)
+ check.expr(nil, &x, init)
}
check.initConst(obj, &x)
}
if lhs == nil || len(lhs) == 1 {
assert(lhs == nil || lhs[0] == obj)
var x operand
- check.expr(&x, init)
+ check.expr(obj.typ, &x, init)
check.initVar(obj, &x, "variable declaration")
return
}
// evaluate node
var x operand
- check.rawExpr(&x, expr, nil, true) // allow generic expressions
- check.processDelayed(0) // incl. all functions
+ check.rawExpr(nil, &x, expr, nil, true) // allow generic expressions
+ check.processDelayed(0) // incl. all functions
check.recordUntyped()
return nil
// The unary expression e may be nil. It's passed in for better error messages only.
func (check *Checker) unary(x *operand, e *ast.UnaryExpr) {
- check.expr(x, e.X)
+ check.expr(nil, x, e.X)
if x.mode == invalid {
return
}
func (check *Checker) binary(x *operand, e ast.Expr, lhs, rhs ast.Expr, op token.Token, opPos token.Pos) {
var y operand
- check.expr(x, lhs)
- check.expr(&y, rhs)
+ check.expr(nil, x, lhs)
+ check.expr(nil, &y, rhs)
if x.mode == invalid {
return
statement
)
+// TODO(gri) In rawExpr below, consider using T instead of hint and
+// some sort of "operation mode" instead of allowGeneric.
+// May be clearer and less error-prone.
+
// rawExpr typechecks expression e and initializes x with the expression
// value or type. If an error occurred, x.mode is set to invalid.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
// If hint != nil, it is the type of a composite literal element.
// If allowGeneric is set, the operand type may be an uninstantiated
// parameterized type or function value.
-func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
+func (check *Checker) rawExpr(T Type, x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
if check.conf._Trace {
check.trace(e.Pos(), "-- expr %s", e)
check.indent++
}()
}
- kind := check.exprInternal(x, e, hint)
+ kind := check.exprInternal(T, x, e, hint)
if !allowGeneric {
- check.nonGeneric(x)
+ check.nonGeneric(T, x)
}
check.record(x)
return kind
}
-// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
+// If x is a generic type, or a generic function whose type arguments cannot be inferred
+// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ.
// Otherwise it leaves x alone.
-func (check *Checker) nonGeneric(x *operand) {
+func (check *Checker) nonGeneric(T Type, x *operand) {
if x.mode == invalid || x.mode == novalue {
return
}
}
case *Signature:
if t.tparams != nil {
+ if check.conf._EnableReverseTypeInference && T != nil {
+ if _, ok := under(T).(*Signature); ok {
+ check.funcInst(T, x.Pos(), x, nil)
+ return
+ }
+ }
what = "function"
}
}
// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
-func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
+// (See rawExpr for an explanation of the parameters.)
+func (check *Checker) exprInternal(T Type, x *operand, e ast.Expr, hint Type) exprKind {
// make sure x has a valid state in case of bailout
// (was go.dev/issue/5770)
x.mode = invalid
key, _ := kv.Key.(*ast.Ident)
// do all possible checks early (before exiting due to errors)
// so we don't drop information on the floor
- check.expr(x, kv.Value)
+ check.expr(nil, x, kv.Value)
if key == nil {
check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
continue
check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
continue
}
- check.expr(x, e)
+ check.expr(nil, x, e)
if i >= len(fields) {
check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
break // cannot continue
x.typ = typ
case *ast.ParenExpr:
- kind := check.rawExpr(x, e.X, nil, false)
+ // type inference doesn't go past parentheses (targe type T = nil)
+ kind := check.rawExpr(nil, x, e.X, nil, false)
x.expr = e
return kind
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(e)
if check.indexExpr(x, ix) {
- check.funcInst(x, ix)
+ check.funcInst(T, e.Pos(), x, ix)
}
if x.mode == invalid {
goto Error
}
case *ast.TypeAssertExpr:
- check.expr(x, e.X)
+ check.expr(nil, x, e.X)
if x.mode == invalid {
goto Error
}
check.error(e, BadTypeKeyword, "use of .(type) outside type switch")
goto Error
}
- // TODO(gri) we may want to permit type assertions on type parameter values at some point
if isTypeParam(x.typ) {
check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
goto Error
}
// expr typechecks expression e and initializes x with the expression value.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
// The result must be a single value.
// If an error occurred, x.mode is set to invalid.
-func (check *Checker) expr(x *operand, e ast.Expr) {
- check.rawExpr(x, e, nil, false)
+func (check *Checker) expr(T Type, x *operand, e ast.Expr) {
+ check.rawExpr(T, x, e, nil, false)
+ check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
+ check.singleValue(x)
+}
+
+// genericExpr is like expr but the result may also be generic.
+func (check *Checker) genericExpr(x *operand, e ast.Expr) {
+ check.rawExpr(nil, x, e, nil, true)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// If an error occurred, list[0] is not valid.
func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand, commaOk bool) {
var x operand
- check.rawExpr(&x, e, nil, false)
+ check.rawExpr(nil, &x, e, nil, false)
check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprWithHint(x *operand, e ast.Expr, hint Type) {
assert(hint != nil)
- check.rawExpr(x, e, hint, false)
+ check.rawExpr(nil, x, e, hint, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprOrType(x *operand, e ast.Expr, allowGeneric bool) {
- check.rawExpr(x, e, nil, allowGeneric)
+ check.rawExpr(nil, x, e, nil, allowGeneric)
check.exclude(x, 1<<novalue)
check.singleValue(x)
}
}
// x should not be generic at this point, but be safe and check
- check.nonGeneric(x)
+ check.nonGeneric(nil, x)
if x.mode == invalid {
return false
}
return false
}
var key operand
- check.expr(&key, index)
+ check.expr(nil, &key, index)
check.assignment(&key, typ.key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
return false
}
var k operand
- check.expr(&k, index)
+ check.expr(nil, &k, index)
check.assignment(&k, key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
}
func (check *Checker) sliceExpr(x *operand, e *ast.SliceExpr) {
- check.expr(x, e.X)
+ check.expr(nil, x, e.X)
if x.mode == invalid {
check.use(e.Low, e.High, e.Max)
return
val = -1
var x operand
- check.expr(&x, index)
+ check.expr(nil, &x, index)
if !check.isValidIndex(&x, InvalidIndex, "index", false) {
return
}
var x operand
var msg string
var code Code
- switch check.rawExpr(&x, call, nil, false) {
+ switch check.rawExpr(nil, &x, call, nil, false) {
case conversion:
msg = "requires function call, not conversion"
code = InvalidDefer
L:
for _, e := range values {
var v operand
- check.expr(&v, e)
+ check.expr(nil, &v, e)
if x.mode == invalid || v.mode == invalid {
continue L
}
// The spec allows the value nil instead of a type.
if check.isNil(e) {
T = nil
- check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+ check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
} else {
T = check.varType(e)
if T == Typ[Invalid] {
// // The spec allows the value nil instead of a type.
// var hash string
// if check.isNil(e) {
-// check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+// check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
// T = nil
// hash = "<nil>" // avoid collision with a type named nil
// } else {
// function and method calls and receive operations can appear
// in statement context. Such statements may be parenthesized."
var x operand
- kind := check.rawExpr(&x, s.X, nil, false)
+ kind := check.rawExpr(nil, &x, s.X, nil, false)
var msg string
var code Code
switch x.mode {
case *ast.SendStmt:
var ch, val operand
- check.expr(&ch, s.Chan)
- check.expr(&val, s.Value)
+ check.expr(nil, &ch, s.Chan)
+ check.expr(nil, &val, s.Value)
if ch.mode == invalid || val.mode == invalid {
return
}
}
var x operand
- check.expr(&x, s.X)
+ check.expr(nil, &x, s.X)
if x.mode == invalid {
return
}
if x.mode == invalid {
return
}
- check.assignVar(s.X, &x)
+ check.assignVar(s.X, nil, &x)
case *ast.AssignStmt:
switch s.Tok {
if x.mode == invalid {
return
}
- check.assignVar(s.Lhs[0], &x)
+ check.assignVar(s.Lhs[0], nil, &x)
}
case *ast.GoStmt:
check.simpleStmt(s.Init)
var x operand
- check.expr(&x, s.Cond)
+ check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
}
check.simpleStmt(s.Init)
var x operand
if s.Tag != nil {
- check.expr(&x, s.Tag)
+ check.expr(nil, &x, s.Tag)
// By checking assignment of x to an invisible temporary
// (as a compiler would), we get all the relevant checks.
check.assignment(&x, nil, "switch expression")
return
}
var x operand
- check.expr(&x, expr.X)
+ check.expr(nil, &x, expr.X)
if x.mode == invalid {
return
}
check.simpleStmt(s.Init)
if s.Cond != nil {
var x operand
- check.expr(&x, s.Cond)
+ check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
}
// check expression to iterate over
var x operand
- check.expr(&x, s.X)
+ check.expr(nil, &x, s.X)
// determine key/value types
var key, val Type
x.mode = value
x.expr = lhs // we don't have a better rhs expression to use here
x.typ = typ
- check.assignVar(lhs, &x)
+ check.assignVar(lhs, nil, &x)
}
}
}
}
var x operand
- check.expr(&x, e)
+ check.expr(nil, &x, e)
if x.mode != constant_ {
if x.mode != invalid {
check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)
wantsMethods /* ERROR "any does not satisfy interface{m1(Q); m2() R} (missing method m1)" */ (any(nil))
wantsMethods /* ERROR "hasMethods4 does not satisfy interface{m1(Q); m2() R} (wrong type for method m1)" */ (hasMethods4(nil))
}
+
+// "Reverse" type inference is not yet permitted.
+
+func f[P any](P) {}
+
+// This must not crash.
+var _ func(int) = f // ERROR "cannot use generic function f without instantiation"
--- /dev/null
+// -reverseTypeInference
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file shows some examples of "reverse" type inference
+// where the type arguments for generic functions are determined
+// from assigning the functions.
+
+package p
+
+func f1[P any](P) {}
+func f2[P any]() P { var x P; return x }
+func f3[P, Q any](P) Q { var x Q; return x }
+func f4[P any](P, P) {}
+func f5[P any](P) []P { return nil }
+
+// initialization expressions
+var (
+ v1 = f1 // ERROR "cannot use generic function f1 without instantiation"
+ v2 func(int) = f2 // ERROR "cannot infer P"
+
+ v3 func(int) = f1
+ v4 func() int = f2
+ v5 func(int) int = f3
+ _ func(int) int = f3[int]
+
+ v6 func(int, int) = f4
+ v7 func(int, string) = f4 // ERROR "type string of 2nd parameter does not match inferred type int for P"
+ v8 func(int) []int = f5
+ v9 func(string) []int = f5 // ERROR "type []int of 1st result parameter does not match inferred type []string for []P"
+
+ _, _ func(int) = f1, f1
+ _, _ func(int) = f1, f2 // ERROR "cannot infer P"
+)
+
+// Regular assignments
+func _() {
+ v1 = f1 // no error here because v1 is invalid (we don't know its type) due to the error above
+ var v1_ func() int
+ _ = v1_
+ v1_ = f1 // ERROR "cannot infer P"
+ v2 = f2 // ERROR "cannot infer P"
+
+ v3 = f1
+ v4 = f2
+ v5 = f3
+ v5 = f3[int]
+
+ v6 = f4
+ v7 = f4 // ERROR "type string of 2nd parameter does not match inferred type int for P"
+ v8 = f5
+ v9 = f5 // ERROR "type []int of 1st result parameter does not match inferred type []string for []P"
+}
+
+// Return statements
+func _() func(int) { return f1 }
+func _() func() int { return f2 }
+func _() func(int) int { return f3 }
+func _() func(int) int { return f3[int] }
+
+func _() func(int, int) { return f4 }
+func _() func(int, string) {
+ return f4 /* ERROR "type string of 2nd parameter does not match inferred type int for P" */
+}
+func _() func(int) []int { return f5 }
+func _() func(string) []int {
+ return f5 /* ERROR "type []int of 1st result parameter does not match inferred type []string for []P" */
+}
+
+func _() (_, _ func(int)) { return f1, f1 }
+func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ }