]> Cypherpunks repositories - gostls13.git/commitdiff
go/types, types2: reverse inference of function type arguments
authorRobert Griesemer <gri@golang.org>
Mon, 13 Mar 2023 23:38:14 +0000 (16:38 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 29 Mar 2023 20:53:08 +0000 (20:53 +0000)
This CL implements type inference for generic functions used in
assignments: variable init expressions, regular assignments, and
return statements, but (not yet) function arguments passed to
functions. For instance, given a generic function

        func f[P any](x P)

and a variable of function type

        var v func(x int)

the assignment

        v = f

is valid w/o explicit instantiation of f, and the missing type
argument for f is inferred from the type of v. More generally,
the function f may have multiple type arguments, and it may be
partially instantiated.

This new form of inference is not enabled by default (it needs
to go through the proposal process first). It can be enabled
by setting Config.EnableReverseTypeInference.

The mechanism is implemented as follows:

- The various expression evaluation functions take an additional
  (first) argument T, which is the target type for the expression.
  If not nil, it is the type of the LHS in an assignment.

- The method Checker.funcInst is changed such that it uses both,
  provided type arguments (if any), and a target type (if any)
  to augment type inference.

Change-Id: Idfde61078e1ee4f22abcca894a4c84d681734ff6
Reviewed-on: https://go-review.googlesource.com/c/go/+/476075
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Robert Griesemer <gri@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
Reviewed-by: Robert Griesemer <gri@google.com>
Run-TryBot: Robert Griesemer <gri@google.com>

23 files changed:
src/cmd/compile/internal/types2/api.go
src/cmd/compile/internal/types2/assignments.go
src/cmd/compile/internal/types2/builtins.go
src/cmd/compile/internal/types2/call.go
src/cmd/compile/internal/types2/check_test.go
src/cmd/compile/internal/types2/decl.go
src/cmd/compile/internal/types2/expr.go
src/cmd/compile/internal/types2/index.go
src/cmd/compile/internal/types2/stmt.go
src/cmd/compile/internal/types2/typexpr.go
src/go/types/api.go
src/go/types/assignments.go
src/go/types/builtins.go
src/go/types/call.go
src/go/types/check_test.go
src/go/types/decl.go
src/go/types/eval.go
src/go/types/expr.go
src/go/types/index.go
src/go/types/stmt.go
src/go/types/typexpr.go
src/internal/types/testdata/examples/inference.go
src/internal/types/testdata/examples/inference2.go [new file with mode: 0644]

index 56fb5789431433703d631a251b4d6e26a3be00b6..e027b9a7e2405eecfd14355d40b4ca1ba8d79f66 100644 (file)
@@ -169,6 +169,13 @@ type Config struct {
        // If DisableUnusedImportCheck is set, packages are not checked
        // for unused imports.
        DisableUnusedImportCheck bool
+
+       // If EnableReverseTypeInference is set, uninstantiated and
+       // partially instantiated generic functions may be assigned
+       // (incl. returned) to variables of function type and type
+       // inference will attempt to infer the missing type arguments.
+       // Experimental. Needs a proposal.
+       EnableReverseTypeInference bool
 }
 
 func srcimporter_setUsesCgo(conf *Config) {
index 3ca6bebd3169655efbe66e1fba82118847b33084..5a51b3de1e2bb29bcd4ef90f4aa0356b1d1ca6a1 100644 (file)
@@ -189,7 +189,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type {
        }
 
        var x operand
-       check.expr(&x, lhs)
+       check.expr(nil, &x, lhs)
 
        if v != nil {
                v.used = v_used // restore v.used
@@ -205,7 +205,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type {
        default:
                if sel, ok := x.expr.(*syntax.SelectorExpr); ok {
                        var op operand
-                       check.expr(&op, sel.X)
+                       check.expr(nil, &op, sel.X)
                        if op.mode == mapindex {
                                check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", syntax.String(x.expr))
                                return Typ[Invalid]
@@ -218,15 +218,20 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type {
        return x.typ
 }
 
-// assignVar checks the assignment lhs = x.
-func (check *Checker) assignVar(lhs syntax.Expr, x *operand) {
-       if x.mode == invalid {
-               check.useLHS(lhs)
+// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil).
+// If x != nil, it must be the evaluation of rhs (and rhs will be ignored).
+func (check *Checker) assignVar(lhs, rhs syntax.Expr, x *operand) {
+       T := check.lhsVar(lhs) // nil if lhs is _
+       if T == Typ[Invalid] {
+               check.use(rhs)
                return
        }
 
-       T := check.lhsVar(lhs) // nil if lhs is _
-       if T == Typ[Invalid] {
+       if x == nil {
+               x = new(operand)
+               check.expr(T, x, rhs)
+       }
+       if x.mode == invalid {
                return
        }
 
@@ -351,7 +356,7 @@ func (check *Checker) initVars(lhs []*Var, orig_rhs []syntax.Expr, returnStmt sy
        if l == r && !isCall {
                var x operand
                for i, lhs := range lhs {
-                       check.expr(&x, orig_rhs[i])
+                       check.expr(lhs.typ, &x, orig_rhs[i])
                        check.initVar(lhs, &x, context)
                }
                return
@@ -423,9 +428,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) {
        // each value can be assigned to its corresponding variable.
        if l == r && !isCall {
                for i, lhs := range lhs {
-                       var x operand
-                       check.expr(&x, orig_rhs[i])
-                       check.assignVar(lhs, &x)
+                       check.assignVar(lhs, orig_rhs[i], nil)
                }
                return
        }
@@ -446,7 +449,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) {
        r = len(rhs)
        if l == r {
                for i, lhs := range lhs {
-                       check.assignVar(lhs, rhs[i])
+                       check.assignVar(lhs, nil, rhs[i])
                }
                if commaOk {
                        check.recordCommaOkTypes(orig_rhs[0], rhs)
index e35dab81402daa57909d7a83b8bf2ccc3cd85826..67aa37e401351c6468091faff9d6439769d1c66b 100644 (file)
@@ -678,7 +678,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) (
                        return
                }
 
-               check.expr(x, selx.X)
+               check.expr(nil, x, selx.X)
                if x.mode == invalid {
                        return
                }
@@ -878,7 +878,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) (
                var t operand
                x1 := x
                for _, arg := range call.ArgList {
-                       check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
+                       check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
                        check.dump("%v: %s", posFor(x1), x1)
                        x1 = &t // use incoming x only for first argument
                }
index 72608dea26152ab156e0c0473d8aa2fa07964a29..bb82c2464e4ba3fa95d393238c471e180a492921 100644 (file)
@@ -8,31 +8,54 @@ package types2
 
 import (
        "cmd/compile/internal/syntax"
+       "fmt"
        . "internal/types/errors"
        "strings"
        "unicode"
 )
 
-// funcInst type-checks a function instantiation inst and returns the result in x.
-// The operand x must be the evaluation of inst.X and its type must be a signature.
-func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
+// funcInst type-checks a function instantiation and returns the result in x.
+// The incoming x must be an uninstantiated generic function. If inst != 0,
+// it provides (some or all of) the type arguments (inst.Index) for the
+// instantiation. If the target type T != nil and is a (non-generic) function
+// signature, the signature's parameter types are used to infer additional
+// missing type arguments of x, if any.
+// At least one of inst or T must be provided.
+func (check *Checker) funcInst(T Type, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) {
        if !check.allowVersion(check.pkg, 1, 18) {
                check.versionErrorf(inst.Pos(), "go1.18", "function instantiation")
        }
 
-       xlist := unpackExpr(inst.Index)
-       targs := check.typeList(xlist)
-       if targs == nil {
-               x.mode = invalid
-               x.expr = inst
-               return
+       // tsig is the (assignment) target function signature, or nil.
+       // TODO(gri) refactor and pass in tsig to funcInst instead
+       var tsig *Signature
+       if check.conf.EnableReverseTypeInference && T != nil {
+               tsig, _ = under(T).(*Signature)
        }
-       assert(len(targs) == len(xlist))
 
-       // check number of type arguments (got) vs number of type parameters (want)
+       // targs and xlist are the type arguments and corresponding type expressions, or nil.
+       var targs []Type
+       var xlist []syntax.Expr
+       if inst != nil {
+               xlist = unpackExpr(inst.Index)
+               targs = check.typeList(xlist)
+               if targs == nil {
+                       x.mode = invalid
+                       x.expr = inst
+                       return
+               }
+               assert(len(targs) == len(xlist))
+       }
+
+       assert(tsig != nil || targs != nil)
+
+       // Check the number of type arguments (got) vs number of type parameters (want).
+       // Note that x is a function value, not a type expression, so we don't need to
+       // call under below.
        sig := x.typ.(*Signature)
        got, want := len(targs), sig.TypeParams().Len()
        if got > want {
+               // Providing too many type arguments is always an error.
                check.errorf(xlist[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
                x.mode = invalid
                x.expr = inst
@@ -40,7 +63,37 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
        }
 
        if got < want {
-               targs = check.infer(inst.Pos(), sig.TypeParams().list(), targs, nil, nil)
+               // If the uninstantiated or partially instantiated function x is used in an
+               // assignment (tsig != nil), use the respective function parameter and result
+               // types to infer additional type arguments.
+               var args []*operand
+               var params []*Var
+               if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
+                       // x is a generic function and the signature arity matches the target function.
+                       // To infer x's missing type arguments, treat the function assignment as a call
+                       // of a synthetic function f where f's parameters are the parameters and results
+                       // of x and where the arguments to the call of f are values of the parameter and
+                       // result types of x.
+                       n := tsig.params.Len()
+                       m := tsig.results.Len()
+                       args = make([]*operand, n+m)
+                       params = make([]*Var, n+m)
+                       for i := 0; i < n; i++ {
+                               lvar := tsig.params.At(i)
+                               lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "parameter"))
+                               args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+                               params[i] = sig.params.At(i)
+                       }
+                       for i := 0; i < m; i++ {
+                               lvar := tsig.results.At(i)
+                               lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "result parameter"))
+                               args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+                               params[n+i] = sig.results.At(i)
+                       }
+               }
+
+               // Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
+               targs = check.infer(pos, sig.TypeParams().list(), targs, NewTuple(params...), args)
                if targs == nil {
                        // error was already reported
                        x.mode = invalid
@@ -54,10 +107,33 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
        // instantiate function signature
        sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
        assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
-       check.recordInstance(inst.X, targs, sig)
+
        x.typ = sig
        x.mode = value
-       x.expr = inst
+       // If we don't have an index expression, keep the existing expression of x.
+       if inst != nil {
+               x.expr = inst
+       }
+       check.recordInstance(x.expr, targs, sig)
+}
+
+func paramName(name string, i int, kind string) string {
+       if name != "" {
+               return name
+       }
+       return nth(i+1) + " " + kind
+}
+
+func nth(n int) string {
+       switch n {
+       case 1:
+               return "1st"
+       case 2:
+               return "2nd"
+       case 3:
+               return "3rd"
+       }
+       return fmt.Sprintf("%dth", n)
 }
 
 func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs []Type, xlist []syntax.Expr) (res *Signature) {
@@ -119,7 +195,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
 
        case typexpr:
                // conversion
-               check.nonGeneric(x)
+               check.nonGeneric(nil, x)
                if x.mode == invalid {
                        return conversion
                }
@@ -129,7 +205,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
                case 0:
                        check.errorf(call, WrongArgCount, "missing argument in conversion to %s", T)
                case 1:
-                       check.expr(x, call.ArgList[0])
+                       check.expr(nil, x, call.ArgList[0])
                        if x.mode != invalid {
                                if t, _ := under(T).(*Interface); t != nil && !isTypeParam(T) {
                                        if !t.IsMethodSet() {
@@ -272,7 +348,7 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
                xlist = make([]*operand, len(elist))
                for i, e := range elist {
                        var x operand
-                       check.expr(&x, e)
+                       check.expr(nil, &x, e)
                        xlist[i] = &x
                }
        }
@@ -744,14 +820,14 @@ func (check *Checker) use1(e syntax.Expr, lhs bool) bool {
                                }
                        }
                }
-               check.rawExpr(&x, n, nil, true)
+               check.rawExpr(nil, &x, n, nil, true)
                if v != nil {
                        v.used = v_used // restore v.used
                }
        case *syntax.ListExpr:
                return check.useN(n.ElemList, lhs)
        default:
-               check.rawExpr(&x, e, nil, true)
+               check.rawExpr(nil, &x, e, nil, true)
        }
        return x.mode != invalid
 }
index 26bb1aed9e6e8c5eaed271fbc6af2336a1dce87e..382d1ad19e66af8a3418b845c77d0ac3963dbcc0 100644 (file)
@@ -133,6 +133,7 @@ func testFiles(t *testing.T, filenames []string, colDelta uint, manual bool) {
        flags := flag.NewFlagSet("", flag.PanicOnError)
        flags.StringVar(&conf.GoVersion, "lang", "", "")
        flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "")
+       flags.BoolVar(&conf.EnableReverseTypeInference, "reverseTypeInference", false, "")
        if err := parseFlags(filenames[0], nil, flags); err != nil {
                t.Fatal(err)
        }
index 0ac0f6196af3c6298b8882fed4df73097ed4e754..afa32c1a5f7ea04b7274d8e8d649d9f5859a16a0 100644 (file)
@@ -408,7 +408,7 @@ func (check *Checker) constDecl(obj *Const, typ, init syntax.Expr, inherited boo
                        // (see issues go.dev/issue/42991, go.dev/issue/42992).
                        check.errpos = obj.pos
                }
-               check.expr(&x, init)
+               check.expr(nil, &x, init)
        }
        check.initConst(obj, &x)
 }
@@ -455,7 +455,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init syntax.Expr) {
        if lhs == nil || len(lhs) == 1 {
                assert(lhs == nil || lhs[0] == obj)
                var x operand
-               check.expr(&x, init)
+               check.expr(obj.typ, &x, init)
                check.initVar(obj, &x, "variable declaration")
                return
        }
index fdc7bdbef018361b4f45968e4a04024ae7ead91e..bab52b253b69181adf5b294f802ca48c36d1d4fa 100644 (file)
@@ -173,7 +173,7 @@ func underIs(typ Type, f func(Type) bool) bool {
 }
 
 func (check *Checker) unary(x *operand, e *syntax.Operation) {
-       check.expr(x, e.X)
+       check.expr(nil, x, e.X)
        if x.mode == invalid {
                return
        }
@@ -1097,8 +1097,8 @@ func init() {
 func (check *Checker) binary(x *operand, e syntax.Expr, lhs, rhs syntax.Expr, op syntax.Operator) {
        var y operand
 
-       check.expr(x, lhs)
-       check.expr(&y, rhs)
+       check.expr(nil, x, lhs)
+       check.expr(nil, &y, rhs)
 
        if x.mode == invalid {
                return
@@ -1245,12 +1245,18 @@ const (
        statement
 )
 
+// TODO(gri) In rawExpr below, consider using T instead of hint and
+//           some sort of "operation mode" instead of allowGeneric.
+//           May be clearer and less error-prone.
+
 // rawExpr typechecks expression e and initializes x with the expression
 // value or type. If an error occurred, x.mode is set to invalid.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
 // If hint != nil, it is the type of a composite literal element.
 // If allowGeneric is set, the operand type may be an uninstantiated
 // parameterized type or function value.
-func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind {
+func (check *Checker) rawExpr(T Type, x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind {
        if check.conf.Trace {
                check.trace(e.Pos(), "-- expr %s", e)
                check.indent++
@@ -1260,10 +1266,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric
                }()
        }
 
-       kind := check.exprInternal(x, e, hint)
+       kind := check.exprInternal(T, x, e, hint)
 
        if !allowGeneric {
-               check.nonGeneric(x)
+               check.nonGeneric(T, x)
        }
 
        check.record(x)
@@ -1271,9 +1277,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric
        return kind
 }
 
-// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
+// If x is a generic type, or a generic function whose type arguments cannot be inferred
+// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ.
 // Otherwise it leaves x alone.
-func (check *Checker) nonGeneric(x *operand) {
+func (check *Checker) nonGeneric(T Type, x *operand) {
        if x.mode == invalid || x.mode == novalue {
                return
        }
@@ -1285,6 +1292,12 @@ func (check *Checker) nonGeneric(x *operand) {
                }
        case *Signature:
                if t.tparams != nil {
+                       if check.conf.EnableReverseTypeInference && T != nil {
+                               if _, ok := under(T).(*Signature); ok {
+                                       check.funcInst(T, x.Pos(), x, nil)
+                                       return
+                               }
+                       }
                        what = "function"
                }
        }
@@ -1297,7 +1310,8 @@ func (check *Checker) nonGeneric(x *operand) {
 
 // exprInternal contains the core of type checking of expressions.
 // Must only be called by rawExpr.
-func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKind {
+// (See rawExpr for an explanation of the parameters.)
+func (check *Checker) exprInternal(T Type, x *operand, e syntax.Expr, hint Type) exprKind {
        // make sure x has a valid state in case of bailout
        // (was go.dev/issue/5770)
        x.mode = invalid
@@ -1438,7 +1452,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
                                        key, _ := kv.Key.(*syntax.Name)
                                        // do all possible checks early (before exiting due to errors)
                                        // so we don't drop information on the floor
-                                       check.expr(x, kv.Value)
+                                       check.expr(nil, x, kv.Value)
                                        if key == nil {
                                                check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
                                                continue
@@ -1466,7 +1480,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
                                                check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
                                                continue
                                        }
-                                       check.expr(x, e)
+                                       check.expr(nil, x, e)
                                        if i >= len(fields) {
                                                check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
                                                break // cannot continue
@@ -1593,7 +1607,8 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
                x.typ = typ
 
        case *syntax.ParenExpr:
-               kind := check.rawExpr(x, e.X, nil, false)
+               // type inference doesn't go past parentheses (targe type T = nil)
+               kind := check.rawExpr(nil, x, e.X, nil, false)
                x.expr = e
                return kind
 
@@ -1602,7 +1617,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
 
        case *syntax.IndexExpr:
                if check.indexExpr(x, e) {
-                       check.funcInst(x, e)
+                       check.funcInst(T, e.Pos(), x, e)
                }
                if x.mode == invalid {
                        goto Error
@@ -1615,7 +1630,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
                }
 
        case *syntax.AssertExpr:
-               check.expr(x, e.X)
+               check.expr(nil, x, e.X)
                if x.mode == invalid {
                        goto Error
                }
@@ -1624,7 +1639,6 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
                        check.error(e, InvalidSyntaxTree, "invalid use of AssertExpr")
                        goto Error
                }
-               // TODO(gri) we may want to permit type assertions on type parameter values at some point
                if isTypeParam(x.typ) {
                        check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
                        goto Error
@@ -1814,10 +1828,19 @@ func (check *Checker) typeAssertion(e syntax.Expr, x *operand, T Type, typeSwitc
 }
 
 // expr typechecks expression e and initializes x with the expression value.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
 // The result must be a single value.
 // If an error occurred, x.mode is set to invalid.
-func (check *Checker) expr(x *operand, e syntax.Expr) {
-       check.rawExpr(x, e, nil, false)
+func (check *Checker) expr(T Type, x *operand, e syntax.Expr) {
+       check.rawExpr(T, x, e, nil, false)
+       check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
+       check.singleValue(x)
+}
+
+// genericExpr is like expr but the result may also be generic.
+func (check *Checker) genericExpr(x *operand, e syntax.Expr) {
+       check.rawExpr(nil, x, e, nil, true)
        check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
        check.singleValue(x)
 }
@@ -1829,7 +1852,7 @@ func (check *Checker) expr(x *operand, e syntax.Expr) {
 // If an error occurred, list[0] is not valid.
 func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*operand, commaOk bool) {
        var x operand
-       check.rawExpr(&x, e, nil, false)
+       check.rawExpr(nil, &x, e, nil, false)
        check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
 
        if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
@@ -1860,7 +1883,7 @@ func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*opera
 // If an error occurred, x.mode is set to invalid.
 func (check *Checker) exprWithHint(x *operand, e syntax.Expr, hint Type) {
        assert(hint != nil)
-       check.rawExpr(x, e, hint, false)
+       check.rawExpr(nil, x, e, hint, false)
        check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
        check.singleValue(x)
 }
@@ -1870,7 +1893,7 @@ func (check *Checker) exprWithHint(x *operand, e syntax.Expr, hint Type) {
 // value.
 // If an error occurred, x.mode is set to invalid.
 func (check *Checker) exprOrType(x *operand, e syntax.Expr, allowGeneric bool) {
-       check.rawExpr(x, e, nil, allowGeneric)
+       check.rawExpr(nil, x, e, nil, allowGeneric)
        check.exclude(x, 1<<novalue)
        check.singleValue(x)
 }
index 38134ec2cc3197788f55d236c779cadfe87bc5fe..4fbe064da60d4a1fe96e4d90a10672b50e4ba3e6 100644 (file)
@@ -42,7 +42,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
        }
 
        // x should not be generic at this point, but be safe and check
-       check.nonGeneric(x)
+       check.nonGeneric(nil, x)
        if x.mode == invalid {
                return false
        }
@@ -92,7 +92,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
                        return false
                }
                var key operand
-               check.expr(&key, index)
+               check.expr(nil, &key, index)
                check.assignment(&key, typ.key, "map index")
                // ok to continue even if indexing failed - map element type is known
                x.mode = mapindex
@@ -166,7 +166,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
                                        return false
                                }
                                var k operand
-                               check.expr(&k, index)
+                               check.expr(nil, &k, index)
                                check.assignment(&k, key, "map index")
                                // ok to continue even if indexing failed - map element type is known
                                x.mode = mapindex
@@ -206,7 +206,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
 }
 
 func (check *Checker) sliceExpr(x *operand, e *syntax.SliceExpr) {
-       check.expr(x, e.X)
+       check.expr(nil, x, e.X)
        if x.mode == invalid {
                check.use(e.Index[:]...)
                return
@@ -353,7 +353,7 @@ func (check *Checker) index(index syntax.Expr, max int64) (typ Type, val int64)
        val = -1
 
        var x operand
-       check.expr(&x, index)
+       check.expr(nil, &x, index)
        if !check.isValidIndex(&x, InvalidIndex, "index", false) {
                return
        }
index 3e5c9cb6cd988690b51bbbe8954e5ffbe4d1a6f0..f13ab69830b81ea6cc9fe658ffde94604b9d6e5e 100644 (file)
@@ -180,7 +180,7 @@ func (check *Checker) suspendedCall(keyword string, call syntax.Expr) {
 
        var x operand
        var msg string
-       switch check.rawExpr(&x, call, nil, false) {
+       switch check.rawExpr(nil, &x, call, nil, false) {
        case conversion:
                msg = "requires function call, not conversion"
        case expression:
@@ -240,7 +240,7 @@ func (check *Checker) caseValues(x *operand, values []syntax.Expr, seen valueMap
 L:
        for _, e := range values {
                var v operand
-               check.expr(&v, e)
+               check.expr(nil, &v, e)
                if x.mode == invalid || v.mode == invalid {
                        continue L
                }
@@ -294,7 +294,7 @@ L:
                // The spec allows the value nil instead of a type.
                if check.isNil(e) {
                        T = nil
-                       check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+                       check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
                } else {
                        T = check.varType(e)
                        if T == Typ[Invalid] {
@@ -336,7 +336,7 @@ L:
 //             // The spec allows the value nil instead of a type.
 //             var hash string
 //             if check.isNil(e) {
-//                     check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+//                     check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
 //                     T = nil
 //                     hash = "<nil>" // avoid collision with a type named nil
 //             } else {
@@ -403,7 +403,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
                // function and method calls and receive operations can appear
                // in statement context. Such statements may be parenthesized."
                var x operand
-               kind := check.rawExpr(&x, s.X, nil, false)
+               kind := check.rawExpr(nil, &x, s.X, nil, false)
                var msg string
                var code Code
                switch x.mode {
@@ -424,8 +424,8 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
 
        case *syntax.SendStmt:
                var ch, val operand
-               check.expr(&ch, s.Chan)
-               check.expr(&val, s.Value)
+               check.expr(nil, &ch, s.Chan)
+               check.expr(nil, &val, s.Value)
                if ch.mode == invalid || val.mode == invalid {
                        return
                }
@@ -450,7 +450,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
                        // x++ or x--
                        // (no need to call unpackExpr as s.Lhs must be single-valued)
                        var x operand
-                       check.expr(&x, s.Lhs)
+                       check.expr(nil, &x, s.Lhs)
                        if x.mode == invalid {
                                return
                        }
@@ -458,7 +458,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
                                check.errorf(s.Lhs, NonNumericIncDec, invalidOp+"%s%s%s (non-numeric type %s)", s.Lhs, s.Op, s.Op, x.typ)
                                return
                        }
-                       check.assignVar(s.Lhs, &x)
+                       check.assignVar(s.Lhs, nil, &x)
                        return
                }
 
@@ -481,7 +481,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
 
                var x operand
                check.binary(&x, nil, lhs[0], rhs[0], s.Op)
-               check.assignVar(lhs[0], &x)
+               check.assignVar(lhs[0], nil, &x)
 
        case *syntax.CallStmt:
                kind := "go"
@@ -566,7 +566,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
 
                check.simpleStmt(s.Init)
                var x operand
-               check.expr(&x, s.Cond)
+               check.expr(nil, &x, s.Cond)
                if x.mode != invalid && !allBoolean(x.typ) {
                        check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
                }
@@ -656,7 +656,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
                check.simpleStmt(s.Init)
                if s.Cond != nil {
                        var x operand
-                       check.expr(&x, s.Cond)
+                       check.expr(nil, &x, s.Cond)
                        if x.mode != invalid && !allBoolean(x.typ) {
                                check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
                        }
@@ -680,7 +680,7 @@ func (check *Checker) switchStmt(inner stmtContext, s *syntax.SwitchStmt) {
 
        var x operand
        if s.Tag != nil {
-               check.expr(&x, s.Tag)
+               check.expr(nil, &x, s.Tag)
                // By checking assignment of x to an invisible temporary
                // (as a compiler would), we get all the relevant checks.
                check.assignment(&x, nil, "switch expression")
@@ -747,7 +747,7 @@ func (check *Checker) typeSwitchStmt(inner stmtContext, s *syntax.SwitchStmt, gu
 
        // check rhs
        var x operand
-       check.expr(&x, guard.X)
+       check.expr(nil, &x, guard.X)
        if x.mode == invalid {
                return
        }
@@ -847,7 +847,7 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
 
        // check expression to iterate over
        var x operand
-       check.expr(&x, rclause.X)
+       check.expr(nil, &x, rclause.X)
 
        // determine key/value types
        var key, val Type
@@ -950,7 +950,7 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
                                x.mode = value
                                x.expr = lhs // we don't have a better rhs expression to use here
                                x.typ = typ
-                               check.assignVar(lhs, &x)
+                               check.assignVar(lhs, nil, &x)
                        }
                }
        }
index d85e7beedd07ca2d9aff7aada5da4422ac7413f5..03b2a8488ee6049ede040dd5d51901fc273a6093 100644 (file)
@@ -489,7 +489,7 @@ func (check *Checker) arrayLength(e syntax.Expr) int64 {
        }
 
        var x operand
-       check.expr(&x, e)
+       check.expr(nil, &x, e)
        if x.mode != constant_ {
                if x.mode != invalid {
                        check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)
index b87330804ca3623f2173b9942038906511bde376..06bdb5616d81142bfee77549dca48334d4d05cba 100644 (file)
@@ -170,6 +170,13 @@ type Config struct {
        // If DisableUnusedImportCheck is set, packages are not checked
        // for unused imports.
        DisableUnusedImportCheck bool
+
+       // If _EnableReverseTypeInference is set, uninstantiated and
+       // partially instantiated generic functions may be assigned
+       // (incl. returned) to variables of function type and type
+       // inference will attempt to infer the missing type arguments.
+       // Experimental. Needs a proposal.
+       _EnableReverseTypeInference bool
 }
 
 func srcimporter_setUsesCgo(conf *Config) {
index a73e4515bc48cb9c056716ed74403abb05506920..5eca569b56840b286c555021a8f2f2f8c1bfae87 100644 (file)
@@ -187,7 +187,7 @@ func (check *Checker) lhsVar(lhs ast.Expr) Type {
        }
 
        var x operand
-       check.expr(&x, lhs)
+       check.expr(nil, &x, lhs)
 
        if v != nil {
                v.used = v_used // restore v.used
@@ -203,7 +203,7 @@ func (check *Checker) lhsVar(lhs ast.Expr) Type {
        default:
                if sel, ok := x.expr.(*ast.SelectorExpr); ok {
                        var op operand
-                       check.expr(&op, sel.X)
+                       check.expr(nil, &op, sel.X)
                        if op.mode == mapindex {
                                check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", ExprString(x.expr))
                                return Typ[Invalid]
@@ -216,15 +216,20 @@ func (check *Checker) lhsVar(lhs ast.Expr) Type {
        return x.typ
 }
 
-// assignVar checks the assignment lhs = x.
-func (check *Checker) assignVar(lhs ast.Expr, x *operand) {
-       if x.mode == invalid {
-               check.useLHS(lhs)
+// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil).
+// If x != nil, it must be the evaluation of rhs (and rhs will be ignored).
+func (check *Checker) assignVar(lhs, rhs ast.Expr, x *operand) {
+       T := check.lhsVar(lhs) // nil if lhs is _
+       if T == Typ[Invalid] {
+               check.use(rhs)
                return
        }
 
-       T := check.lhsVar(lhs) // nil if lhs is _
-       if T == Typ[Invalid] {
+       if x == nil {
+               x = new(operand)
+               check.expr(T, x, rhs)
+       }
+       if x.mode == invalid {
                return
        }
 
@@ -349,7 +354,7 @@ func (check *Checker) initVars(lhs []*Var, orig_rhs []ast.Expr, returnStmt ast.S
        if l == r && !isCall {
                var x operand
                for i, lhs := range lhs {
-                       check.expr(&x, orig_rhs[i])
+                       check.expr(lhs.typ, &x, orig_rhs[i])
                        check.initVar(lhs, &x, context)
                }
                return
@@ -421,9 +426,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []ast.Expr) {
        // each value can be assigned to its corresponding variable.
        if l == r && !isCall {
                for i, lhs := range lhs {
-                       var x operand
-                       check.expr(&x, orig_rhs[i])
-                       check.assignVar(lhs, &x)
+                       check.assignVar(lhs, orig_rhs[i], nil)
                }
                return
        }
@@ -444,7 +447,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []ast.Expr) {
        r = len(rhs)
        if l == r {
                for i, lhs := range lhs {
-                       check.assignVar(lhs, rhs[i])
+                       check.assignVar(lhs, nil, rhs[i])
                }
                if commaOk {
                        check.recordCommaOkTypes(orig_rhs[0], rhs)
index 9659a7ccb111662f523ad2bc87e4cc0c0d50c337..a7a3af2725eb09c2f0f221a9b0f915801b175373 100644 (file)
@@ -679,7 +679,7 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
                        return
                }
 
-               check.expr(x, selx.X)
+               check.expr(nil, x, selx.X)
                if x.mode == invalid {
                        return
                }
@@ -879,7 +879,7 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
                var t operand
                x1 := x
                for _, arg := range call.Args {
-                       check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
+                       check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
                        check.dump("%v: %s", x1.Pos(), x1)
                        x1 = &t // use incoming x only for first argument
                }
index e5968c7cfce31f7b7931a4fd39f182a1848f5e6c..f75043d5dcfdf44dab588861ace83fecb042db28 100644 (file)
@@ -7,6 +7,7 @@
 package types
 
 import (
+       "fmt"
        "go/ast"
        "go/internal/typeparams"
        "go/token"
@@ -15,25 +16,48 @@ import (
        "unicode"
 )
 
-// funcInst type-checks a function instantiation inst and returns the result in x.
-// The operand x must be the evaluation of inst.X and its type must be a signature.
-func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
+// funcInst type-checks a function instantiation and returns the result in x.
+// The incoming x must be an uninstantiated generic function. If ix != 0,
+// it provides (some or all of) the type arguments (ix.Indices) for the
+// instantiation. If the target type T != nil and is a (non-generic) function
+// signature, the signature's parameter types are used to infer additional
+// missing type arguments of x, if any.
+// At least one of inst or T must be provided.
+func (check *Checker) funcInst(T Type, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
        if !check.allowVersion(check.pkg, 1, 18) {
                check.softErrorf(inNode(ix.Orig, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
        }
 
-       targs := check.typeList(ix.Indices)
-       if targs == nil {
-               x.mode = invalid
-               x.expr = ix.Orig
-               return
+       // tsig is the (assignment) target function signature, or nil.
+       // TODO(gri) refactor and pass in tsig to funcInst instead
+       var tsig *Signature
+       if check.conf._EnableReverseTypeInference && T != nil {
+               tsig, _ = under(T).(*Signature)
+       }
+
+       // targs and xlist are the type arguments and corresponding type expressions, or nil.
+       var targs []Type
+       var xlist []ast.Expr
+       if ix != nil {
+               xlist = ix.Indices
+               targs = check.typeList(xlist)
+               if targs == nil {
+                       x.mode = invalid
+                       x.expr = ix
+                       return
+               }
+               assert(len(targs) == len(xlist))
        }
-       assert(len(targs) == len(ix.Indices))
 
-       // check number of type arguments (got) vs number of type parameters (want)
+       assert(tsig != nil || targs != nil)
+
+       // Check the number of type arguments (got) vs number of type parameters (want).
+       // Note that x is a function value, not a type expression, so we don't need to
+       // call under below.
        sig := x.typ.(*Signature)
        got, want := len(targs), sig.TypeParams().Len()
        if got > want {
+               // Providing too many type arguments is always an error.
                check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
                x.mode = invalid
                x.expr = ix.Orig
@@ -41,11 +65,43 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
        }
 
        if got < want {
-               targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil)
+               // If the uninstantiated or partially instantiated function x is used in an
+               // assignment (tsig != nil), use the respective function parameter and result
+               // types to infer additional type arguments.
+               var args []*operand
+               var params []*Var
+               if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
+                       // x is a generic function and the signature arity matches the target function.
+                       // To infer x's missing type arguments, treat the function assignment as a call
+                       // of a synthetic function f where f's parameters are the parameters and results
+                       // of x and where the arguments to the call of f are values of the parameter and
+                       // result types of x.
+                       n := tsig.params.Len()
+                       m := tsig.results.Len()
+                       args = make([]*operand, n+m)
+                       params = make([]*Var, n+m)
+                       for i := 0; i < n; i++ {
+                               lvar := tsig.params.At(i)
+                               lname := ast.NewIdent(paramName(lvar.name, i, "parameter"))
+                               lname.NamePos = x.Pos() // correct position
+                               args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+                               params[i] = sig.params.At(i)
+                       }
+                       for i := 0; i < m; i++ {
+                               lvar := tsig.results.At(i)
+                               lname := ast.NewIdent(paramName(lvar.name, i, "result parameter"))
+                               lname.NamePos = x.Pos() // correct position
+                               args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+                               params[n+i] = sig.results.At(i)
+                       }
+               }
+
+               // Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
+               targs = check.infer(atPos(pos), sig.TypeParams().list(), targs, NewTuple(params...), args)
                if targs == nil {
                        // error was already reported
                        x.mode = invalid
-                       x.expr = ix.Orig
+                       x.expr = ix // TODO(gri) is this correct?
                        return
                }
                got = len(targs)
@@ -53,12 +109,35 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
        assert(got == want)
 
        // instantiate function signature
-       sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices)
+       sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
        assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
-       check.recordInstance(ix.Orig, targs, sig)
+
        x.typ = sig
        x.mode = value
-       x.expr = ix.Orig
+       // If we don't have an index expression, keep the existing expression of x.
+       if ix != nil {
+               x.expr = ix.Orig
+       }
+       check.recordInstance(x.expr, targs, sig)
+}
+
+func paramName(name string, i int, kind string) string {
+       if name != "" {
+               return name
+       }
+       return nth(i+1) + " " + kind
+}
+
+func nth(n int) string {
+       switch n {
+       case 1:
+               return "1st"
+       case 2:
+               return "2nd"
+       case 3:
+               return "3rd"
+       }
+       return fmt.Sprintf("%dth", n)
 }
 
 func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) {
@@ -121,7 +200,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
 
        case typexpr:
                // conversion
-               check.nonGeneric(x)
+               check.nonGeneric(nil, x)
                if x.mode == invalid {
                        return conversion
                }
@@ -131,7 +210,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
                case 0:
                        check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T)
                case 1:
-                       check.expr(x, call.Args[0])
+                       check.expr(nil, x, call.Args[0])
                        if x.mode != invalid {
                                if call.Ellipsis.IsValid() {
                                        check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T)
@@ -274,7 +353,7 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
                xlist = make([]*operand, len(elist))
                for i, e := range elist {
                        var x operand
-                       check.expr(&x, e)
+                       check.expr(nil, &x, e)
                        xlist[i] = &x
                }
        }
@@ -791,12 +870,12 @@ func (check *Checker) use1(e ast.Expr, lhs bool) bool {
                                }
                        }
                }
-               check.rawExpr(&x, n, nil, true)
+               check.rawExpr(nil, &x, n, nil, true)
                if v != nil {
                        v.used = v_used // restore v.used
                }
        default:
-               check.rawExpr(&x, e, nil, true)
+               check.rawExpr(nil, &x, e, nil, true)
        }
        return x.mode != invalid
 }
index 36809838c7c988cb75ce080e3086731788792345..0f4c320a47bc4d5369441014bfa61a2cb3f57e8f 100644 (file)
@@ -145,6 +145,7 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man
        flags := flag.NewFlagSet("", flag.PanicOnError)
        flags.StringVar(&conf.GoVersion, "lang", "", "")
        flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "")
+       flags.BoolVar(boolFieldAddr(&conf, "_EnableReverseTypeInference"), "reverseTypeInference", false, "")
        if err := parseFlags(filenames[0], srcs[0], flags); err != nil {
                t.Fatal(err)
        }
index 393d8f34e2301eecb7a3c2fb3e5eb629a97ae10f..3065da2e8ebeace92cdc15ef187b63084b2ddb79 100644 (file)
@@ -477,7 +477,7 @@ func (check *Checker) constDecl(obj *Const, typ, init ast.Expr, inherited bool)
                        // (see issues go.dev/issue/42991, go.dev/issue/42992).
                        check.errpos = atPos(obj.pos)
                }
-               check.expr(&x, init)
+               check.expr(nil, &x, init)
        }
        check.initConst(obj, &x)
 }
@@ -510,7 +510,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init ast.Expr) {
        if lhs == nil || len(lhs) == 1 {
                assert(lhs == nil || lhs[0] == obj)
                var x operand
-               check.expr(&x, init)
+               check.expr(obj.typ, &x, init)
                check.initVar(obj, &x, "variable declaration")
                return
        }
index 1e4d64fe96b2c3019989243d437dd6f8981a169a..1655a8bd2737d51b86c143f047222864e93270c3 100644 (file)
@@ -91,8 +91,8 @@ func CheckExpr(fset *token.FileSet, pkg *Package, pos token.Pos, expr ast.Expr,
 
        // evaluate node
        var x operand
-       check.rawExpr(&x, expr, nil, true) // allow generic expressions
-       check.processDelayed(0)            // incl. all functions
+       check.rawExpr(nil, &x, expr, nil, true) // allow generic expressions
+       check.processDelayed(0)                 // incl. all functions
        check.recordUntyped()
 
        return nil
index 1abf963b7f7c1f9ce5a6d619f9e29268de07d49a..219a392b88af40b2fdf197fe41ccf7cf83cd9db5 100644 (file)
@@ -160,7 +160,7 @@ func underIs(typ Type, f func(Type) bool) bool {
 
 // The unary expression e may be nil. It's passed in for better error messages only.
 func (check *Checker) unary(x *operand, e *ast.UnaryExpr) {
-       check.expr(x, e.X)
+       check.expr(nil, x, e.X)
        if x.mode == invalid {
                return
        }
@@ -1079,8 +1079,8 @@ func init() {
 func (check *Checker) binary(x *operand, e ast.Expr, lhs, rhs ast.Expr, op token.Token, opPos token.Pos) {
        var y operand
 
-       check.expr(x, lhs)
-       check.expr(&y, rhs)
+       check.expr(nil, x, lhs)
+       check.expr(nil, &y, rhs)
 
        if x.mode == invalid {
                return
@@ -1230,12 +1230,18 @@ const (
        statement
 )
 
+// TODO(gri) In rawExpr below, consider using T instead of hint and
+//           some sort of "operation mode" instead of allowGeneric.
+//           May be clearer and less error-prone.
+
 // rawExpr typechecks expression e and initializes x with the expression
 // value or type. If an error occurred, x.mode is set to invalid.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
 // If hint != nil, it is the type of a composite literal element.
 // If allowGeneric is set, the operand type may be an uninstantiated
 // parameterized type or function value.
-func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
+func (check *Checker) rawExpr(T Type, x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
        if check.conf._Trace {
                check.trace(e.Pos(), "-- expr %s", e)
                check.indent++
@@ -1245,10 +1251,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo
                }()
        }
 
-       kind := check.exprInternal(x, e, hint)
+       kind := check.exprInternal(T, x, e, hint)
 
        if !allowGeneric {
-               check.nonGeneric(x)
+               check.nonGeneric(T, x)
        }
 
        check.record(x)
@@ -1256,9 +1262,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo
        return kind
 }
 
-// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
+// If x is a generic type, or a generic function whose type arguments cannot be inferred
+// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ.
 // Otherwise it leaves x alone.
-func (check *Checker) nonGeneric(x *operand) {
+func (check *Checker) nonGeneric(T Type, x *operand) {
        if x.mode == invalid || x.mode == novalue {
                return
        }
@@ -1270,6 +1277,12 @@ func (check *Checker) nonGeneric(x *operand) {
                }
        case *Signature:
                if t.tparams != nil {
+                       if check.conf._EnableReverseTypeInference && T != nil {
+                               if _, ok := under(T).(*Signature); ok {
+                                       check.funcInst(T, x.Pos(), x, nil)
+                                       return
+                               }
+                       }
                        what = "function"
                }
        }
@@ -1282,7 +1295,8 @@ func (check *Checker) nonGeneric(x *operand) {
 
 // exprInternal contains the core of type checking of expressions.
 // Must only be called by rawExpr.
-func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
+// (See rawExpr for an explanation of the parameters.)
+func (check *Checker) exprInternal(T Type, x *operand, e ast.Expr, hint Type) exprKind {
        // make sure x has a valid state in case of bailout
        // (was go.dev/issue/5770)
        x.mode = invalid
@@ -1418,7 +1432,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
                                        key, _ := kv.Key.(*ast.Ident)
                                        // do all possible checks early (before exiting due to errors)
                                        // so we don't drop information on the floor
-                                       check.expr(x, kv.Value)
+                                       check.expr(nil, x, kv.Value)
                                        if key == nil {
                                                check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
                                                continue
@@ -1446,7 +1460,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
                                                check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
                                                continue
                                        }
-                                       check.expr(x, e)
+                                       check.expr(nil, x, e)
                                        if i >= len(fields) {
                                                check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
                                                break // cannot continue
@@ -1575,7 +1589,8 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
                x.typ = typ
 
        case *ast.ParenExpr:
-               kind := check.rawExpr(x, e.X, nil, false)
+               // type inference doesn't go past parentheses (targe type T = nil)
+               kind := check.rawExpr(nil, x, e.X, nil, false)
                x.expr = e
                return kind
 
@@ -1585,7 +1600,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
        case *ast.IndexExpr, *ast.IndexListExpr:
                ix := typeparams.UnpackIndexExpr(e)
                if check.indexExpr(x, ix) {
-                       check.funcInst(x, ix)
+                       check.funcInst(T, e.Pos(), x, ix)
                }
                if x.mode == invalid {
                        goto Error
@@ -1598,7 +1613,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
                }
 
        case *ast.TypeAssertExpr:
-               check.expr(x, e.X)
+               check.expr(nil, x, e.X)
                if x.mode == invalid {
                        goto Error
                }
@@ -1609,7 +1624,6 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
                        check.error(e, BadTypeKeyword, "use of .(type) outside type switch")
                        goto Error
                }
-               // TODO(gri) we may want to permit type assertions on type parameter values at some point
                if isTypeParam(x.typ) {
                        check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
                        goto Error
@@ -1761,10 +1775,19 @@ func (check *Checker) typeAssertion(e ast.Expr, x *operand, T Type, typeSwitch b
 }
 
 // expr typechecks expression e and initializes x with the expression value.
+// If a non-nil target type T is given and e is a generic function
+// or function call, T is used to infer the type arguments for e.
 // The result must be a single value.
 // If an error occurred, x.mode is set to invalid.
-func (check *Checker) expr(x *operand, e ast.Expr) {
-       check.rawExpr(x, e, nil, false)
+func (check *Checker) expr(T Type, x *operand, e ast.Expr) {
+       check.rawExpr(T, x, e, nil, false)
+       check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
+       check.singleValue(x)
+}
+
+// genericExpr is like expr but the result may also be generic.
+func (check *Checker) genericExpr(x *operand, e ast.Expr) {
+       check.rawExpr(nil, x, e, nil, true)
        check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
        check.singleValue(x)
 }
@@ -1776,7 +1799,7 @@ func (check *Checker) expr(x *operand, e ast.Expr) {
 // If an error occurred, list[0] is not valid.
 func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand, commaOk bool) {
        var x operand
-       check.rawExpr(&x, e, nil, false)
+       check.rawExpr(nil, &x, e, nil, false)
        check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
 
        if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
@@ -1807,7 +1830,7 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand,
 // If an error occurred, x.mode is set to invalid.
 func (check *Checker) exprWithHint(x *operand, e ast.Expr, hint Type) {
        assert(hint != nil)
-       check.rawExpr(x, e, hint, false)
+       check.rawExpr(nil, x, e, hint, false)
        check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
        check.singleValue(x)
 }
@@ -1817,7 +1840,7 @@ func (check *Checker) exprWithHint(x *operand, e ast.Expr, hint Type) {
 // value.
 // If an error occurred, x.mode is set to invalid.
 func (check *Checker) exprOrType(x *operand, e ast.Expr, allowGeneric bool) {
-       check.rawExpr(x, e, nil, allowGeneric)
+       check.rawExpr(nil, x, e, nil, allowGeneric)
        check.exclude(x, 1<<novalue)
        check.singleValue(x)
 }
index 2fcc3f3492fc89561716b88a2cd9621c43478253..1bcfb38feb257ac24402d2e3a00e174232fb8013 100644 (file)
@@ -43,7 +43,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
        }
 
        // x should not be generic at this point, but be safe and check
-       check.nonGeneric(x)
+       check.nonGeneric(nil, x)
        if x.mode == invalid {
                return false
        }
@@ -93,7 +93,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
                        return false
                }
                var key operand
-               check.expr(&key, index)
+               check.expr(nil, &key, index)
                check.assignment(&key, typ.key, "map index")
                // ok to continue even if indexing failed - map element type is known
                x.mode = mapindex
@@ -167,7 +167,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
                                        return false
                                }
                                var k operand
-                               check.expr(&k, index)
+                               check.expr(nil, &k, index)
                                check.assignment(&k, key, "map index")
                                // ok to continue even if indexing failed - map element type is known
                                x.mode = mapindex
@@ -208,7 +208,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
 }
 
 func (check *Checker) sliceExpr(x *operand, e *ast.SliceExpr) {
-       check.expr(x, e.X)
+       check.expr(nil, x, e.X)
        if x.mode == invalid {
                check.use(e.Low, e.High, e.Max)
                return
@@ -350,7 +350,7 @@ func (check *Checker) index(index ast.Expr, max int64) (typ Type, val int64) {
        val = -1
 
        var x operand
-       check.expr(&x, index)
+       check.expr(nil, &x, index)
        if !check.isValidIndex(&x, InvalidIndex, "index", false) {
                return
        }
index 3571ca02451d71baddef7f530c440ef20fde9b88..7869f37077f3a13196fcf28810f62fd18d12184c 100644 (file)
@@ -173,7 +173,7 @@ func (check *Checker) suspendedCall(keyword string, call *ast.CallExpr) {
        var x operand
        var msg string
        var code Code
-       switch check.rawExpr(&x, call, nil, false) {
+       switch check.rawExpr(nil, &x, call, nil, false) {
        case conversion:
                msg = "requires function call, not conversion"
                code = InvalidDefer
@@ -237,7 +237,7 @@ func (check *Checker) caseValues(x *operand, values []ast.Expr, seen valueMap) {
 L:
        for _, e := range values {
                var v operand
-               check.expr(&v, e)
+               check.expr(nil, &v, e)
                if x.mode == invalid || v.mode == invalid {
                        continue L
                }
@@ -288,7 +288,7 @@ L:
                // The spec allows the value nil instead of a type.
                if check.isNil(e) {
                        T = nil
-                       check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+                       check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
                } else {
                        T = check.varType(e)
                        if T == Typ[Invalid] {
@@ -327,7 +327,7 @@ L:
 //             // The spec allows the value nil instead of a type.
 //             var hash string
 //             if check.isNil(e) {
-//                     check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
+//                     check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
 //                     T = nil
 //                     hash = "<nil>" // avoid collision with a type named nil
 //             } else {
@@ -394,7 +394,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                // function and method calls and receive operations can appear
                // in statement context. Such statements may be parenthesized."
                var x operand
-               kind := check.rawExpr(&x, s.X, nil, false)
+               kind := check.rawExpr(nil, &x, s.X, nil, false)
                var msg string
                var code Code
                switch x.mode {
@@ -415,8 +415,8 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
 
        case *ast.SendStmt:
                var ch, val operand
-               check.expr(&ch, s.Chan)
-               check.expr(&val, s.Value)
+               check.expr(nil, &ch, s.Chan)
+               check.expr(nil, &val, s.Value)
                if ch.mode == invalid || val.mode == invalid {
                        return
                }
@@ -449,7 +449,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                }
 
                var x operand
-               check.expr(&x, s.X)
+               check.expr(nil, &x, s.X)
                if x.mode == invalid {
                        return
                }
@@ -463,7 +463,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                if x.mode == invalid {
                        return
                }
-               check.assignVar(s.X, &x)
+               check.assignVar(s.X, nil, &x)
 
        case *ast.AssignStmt:
                switch s.Tok {
@@ -495,7 +495,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                        if x.mode == invalid {
                                return
                        }
-                       check.assignVar(s.Lhs[0], &x)
+                       check.assignVar(s.Lhs[0], nil, &x)
                }
 
        case *ast.GoStmt:
@@ -570,7 +570,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
 
                check.simpleStmt(s.Init)
                var x operand
-               check.expr(&x, s.Cond)
+               check.expr(nil, &x, s.Cond)
                if x.mode != invalid && !allBoolean(x.typ) {
                        check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
                }
@@ -594,7 +594,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                check.simpleStmt(s.Init)
                var x operand
                if s.Tag != nil {
-                       check.expr(&x, s.Tag)
+                       check.expr(nil, &x, s.Tag)
                        // By checking assignment of x to an invisible temporary
                        // (as a compiler would), we get all the relevant checks.
                        check.assignment(&x, nil, "switch expression")
@@ -686,7 +686,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                        return
                }
                var x operand
-               check.expr(&x, expr.X)
+               check.expr(nil, &x, expr.X)
                if x.mode == invalid {
                        return
                }
@@ -808,7 +808,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                check.simpleStmt(s.Init)
                if s.Cond != nil {
                        var x operand
-                       check.expr(&x, s.Cond)
+                       check.expr(nil, &x, s.Cond)
                        if x.mode != invalid && !allBoolean(x.typ) {
                                check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
                        }
@@ -830,7 +830,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
 
                // check expression to iterate over
                var x operand
-               check.expr(&x, s.X)
+               check.expr(nil, &x, s.X)
 
                // determine key/value types
                var key, val Type
@@ -928,7 +928,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
                                        x.mode = value
                                        x.expr = lhs // we don't have a better rhs expression to use here
                                        x.typ = typ
-                                       check.assignVar(lhs, &x)
+                                       check.assignVar(lhs, nil, &x)
                                }
                        }
                }
index b619eccf0f405c138ade18ef1ffa9b2f5e9b1827..35f27ddcc5ae86703dc196c9baeaec8222622f06 100644 (file)
@@ -480,7 +480,7 @@ func (check *Checker) arrayLength(e ast.Expr) int64 {
        }
 
        var x operand
-       check.expr(&x, e)
+       check.expr(nil, &x, e)
        if x.mode != constant_ {
                if x.mode != invalid {
                        check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)
index c9e3605c9e44e036e423a0f8005deb38a345e4a6..2dc122c41368a7f5e2a04acf87569f59cff3b5ea 100644 (file)
@@ -148,3 +148,10 @@ func _() {
        wantsMethods /* ERROR "any does not satisfy interface{m1(Q); m2() R} (missing method m1)" */ (any(nil))
        wantsMethods /* ERROR "hasMethods4 does not satisfy interface{m1(Q); m2() R} (wrong type for method m1)" */ (hasMethods4(nil))
 }
+
+// "Reverse" type inference is not yet permitted.
+
+func f[P any](P) {}
+
+// This must not crash.
+var _ func(int) = f // ERROR "cannot use generic function f without instantiation"
diff --git a/src/internal/types/testdata/examples/inference2.go b/src/internal/types/testdata/examples/inference2.go
new file mode 100644 (file)
index 0000000..d309a00
--- /dev/null
@@ -0,0 +1,73 @@
+// -reverseTypeInference
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file shows some examples of "reverse" type inference
+// where the type arguments for generic functions are determined
+// from assigning the functions.
+
+package p
+
+func f1[P any](P)      {}
+func f2[P any]() P     { var x P; return x }
+func f3[P, Q any](P) Q { var x Q; return x }
+func f4[P any](P, P)   {}
+func f5[P any](P) []P  { return nil }
+
+// initialization expressions
+var (
+       v1           = f1 // ERROR "cannot use generic function f1 without instantiation"
+       v2 func(int) = f2 // ERROR "cannot infer P"
+
+       v3 func(int)     = f1
+       v4 func() int    = f2
+       v5 func(int) int = f3
+       _  func(int) int = f3[int]
+
+       v6 func(int, int)     = f4
+       v7 func(int, string)  = f4 // ERROR "type string of 2nd parameter does not match inferred type int for P"
+       v8 func(int) []int    = f5
+       v9 func(string) []int = f5 // ERROR "type []int of 1st result parameter does not match inferred type []string for []P"
+
+       _, _ func(int) = f1, f1
+       _, _ func(int) = f1, f2 // ERROR "cannot infer P"
+)
+
+// Regular assignments
+func _() {
+       v1 = f1 // no error here because v1 is invalid (we don't know its type) due to the error above
+       var v1_ func() int
+       _ = v1_
+       v1_ = f1 // ERROR "cannot infer P"
+       v2 = f2  // ERROR "cannot infer P"
+
+       v3 = f1
+       v4 = f2
+       v5 = f3
+       v5 = f3[int]
+
+       v6 = f4
+       v7 = f4 // ERROR "type string of 2nd parameter does not match inferred type int for P"
+       v8 = f5
+       v9 = f5 // ERROR "type []int of 1st result parameter does not match inferred type []string for []P"
+}
+
+// Return statements
+func _() func(int)     { return f1 }
+func _() func() int    { return f2 }
+func _() func(int) int { return f3 }
+func _() func(int) int { return f3[int] }
+
+func _() func(int, int) { return f4 }
+func _() func(int, string) {
+       return f4 /* ERROR "type string of 2nd parameter does not match inferred type int for P" */
+}
+func _() func(int) []int { return f5 }
+func _() func(string) []int {
+       return f5 /* ERROR "type []int of 1st result parameter does not match inferred type []string for []P" */
+}
+
+func _() (_, _ func(int)) { return f1, f1 }
+func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ }