From 21b4e0146a40a95687c4d3e36939eef991f03c1e Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Thu, 26 Jan 2023 13:20:34 -0800 Subject: [PATCH] go/types, types2: simplify unifier The unifier was written such that it was possible to specify a different set of type parameters (declared by different generic declarations) for each type x, y being unified, to allow for what is called "bidirectional unification" in the documentation (comments). However, in the current implementation, this mechanism is not used: - For function type inference, we only consider the type parameter list of the generic function (type parameters that appear in the arguments are considered stand-alone types). We use type parameter renaming to avoid any problems in case of recursive generic calls that rely on type inference. - For constraint type inference, the type parameters for the types x and y (i.e., the type parameter and its constraint) are the same and had to be explicitly set to be identical. This CL removes the ability to set separate type parameter lists. Instead a single type parameter list is used during unification and is provided when we initialize a unifier. As a consequence, we don't need to maintain the separate tparamsList data structure: since we have a single list of type parameters we can keep it directly in the unifier. Adjust all the unifier code accordingly and update comments. As an aside, remove the `exact` flag from the unifier as it was never set. However, keep the functionality for now and use a constant (exactUnification) instead. This makes it easy to find the respectice code without incurring any cost. Change-Id: I969ba6dbbed2d65d06ba4e20b97bdc362c806772 Reviewed-on: https://go-review.googlesource.com/c/go/+/463223 Reviewed-by: Robert Griesemer Run-TryBot: Robert Griesemer Reviewed-by: Robert Findley Auto-Submit: Robert Griesemer TryBot-Result: Gopher Robot --- src/cmd/compile/internal/types2/infer.go | 32 ++-- src/cmd/compile/internal/types2/unify.go | 207 ++++++++++------------- src/go/types/infer.go | 32 ++-- src/go/types/unify.go | 207 ++++++++++------------- 4 files changed, 206 insertions(+), 272 deletions(-) diff --git a/src/cmd/compile/internal/types2/infer.go b/src/cmd/compile/internal/types2/infer.go index 84279afedf..24a71367c5 100644 --- a/src/cmd/compile/internal/types2/infer.go +++ b/src/cmd/compile/internal/types2/infer.go @@ -135,19 +135,18 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type, // Unify parameter and argument types for generic parameters with typed arguments // and collect the indices of generic parameters with untyped arguments. // Terminology: generic parameter = function parameter with a type-parameterized type - u := newUnifier(false) - u.x.init(tparams) + u := newUnifier(tparams) // Set the type arguments which we know already. for i, targ := range targs { if targ != nil { - u.x.set(i, targ) + u.set(i, targ) } } errorf := func(kind string, tpar, targ Type, arg *operand) { // provide a better error message if we can - targs, index := u.x.types() + targs, index := u.inferred() if index == 0 { // The first type parameter couldn't be inferred. // If none of them could be inferred, don't try @@ -213,7 +212,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type, // If we've got all type arguments, we're done. var index int - targs, index = u.x.types() + targs, index = u.inferred() if index < 0 { return targs } @@ -249,7 +248,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type, } // If we've got all type arguments, we're done. - targs, index = u.x.types() + targs, index = u.inferred() if index < 0 { return targs } @@ -462,16 +461,13 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, }() } - // Setup bidirectional unification between constraints - // and the corresponding type arguments (which may be nil!). - u := newUnifier(false) - u.x.init(tparams) - u.y = u.x // type parameters between LHS and RHS of unification are identical + // Unify type parameters with their constraints. + u := newUnifier(tparams) // Set the type arguments which we know already. for i, targ := range targs { if targ != nil { - u.x.set(i, targ) + u.set(i, targ) } } @@ -490,7 +486,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, // here could handle the respective type parameters only, // but that will come at a cost of extra complexity which // may not be worth it.) - for n := u.x.unknowns(); n > 0; { + for n := u.unknowns(); n > 0; { nn := n for i, tpar := range tparams { @@ -501,7 +497,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, u.tracef("core(%s) = %s (single = %v)", tpar, core, single) } // A type parameter can be unified with its core type in two cases. - tx := u.x.at(i) + tx := u.at(i) switch { case tx != nil: // The corresponding type argument tx is known. @@ -534,7 +530,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, // The corresponding type argument tx is unknown and there's a single // specific type and no tilde. // In this case the type argument must be that single type; set it. - u.x.set(i, core.typ) + u.set(i, core.typ) default: // Unification is not possible and no progress was made. @@ -542,7 +538,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, } // The number of known type arguments may have changed. - nn = u.x.unknowns() + nn = u.unknowns() if nn == 0 { break // all type arguments are known } @@ -560,14 +556,14 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, n = nn } - // u.x.types() now contains the incoming type arguments plus any additional type + // u.inferred() now contains the incoming type arguments plus any additional type // arguments which were inferred from core terms. The newly inferred non-nil // entries may still contain references to other type parameters. // For instance, for [A any, B interface{ []C }, C interface{ *A }], if A == int // was given, unification produced the type list [int, []C, *A]. We eliminate the // remaining type parameters by substituting the type parameters in this type list // until nothing changes anymore. - types, _ = u.x.types() + types, _ = u.inferred() if debug { for i, targ := range targs { assert(targ == nil || types[i] == targ) diff --git a/src/cmd/compile/internal/types2/unify.go b/src/cmd/compile/internal/types2/unify.go index 381093c574..bca7231bbb 100644 --- a/src/cmd/compile/internal/types2/unify.go +++ b/src/cmd/compile/internal/types2/unify.go @@ -12,20 +12,6 @@ import ( "strings" ) -// The unifier maintains two separate sets of type parameters x and y -// which are used to resolve type parameters in the x and y arguments -// provided to the unify call. For unidirectional unification, only -// one of these sets (say x) is provided, and then type parameters are -// only resolved for the x argument passed to unify, not the y argument -// (even if that also contains possibly the same type parameters). -// -// For bidirectional unification, both sets are provided. This enables -// unification to go from argument to parameter type and vice versa. -// For constraint type inference, we use bidirectional unification -// where both the x and y type parameters are identical. This is done -// by setting up one of them (using init) and then assigning its value -// to the other. - const ( // Upper limit for recursion depth. Used to catch infinite recursions // due to implementation issues (e.g., see issues #48619, #48656). @@ -48,32 +34,47 @@ const ( // x ≢ y types x and y cannot be unified // [p, q, ...] ➞ [x, y, ...] mapping from type parameters to types traceInference = false + + // If exactUnification is set, unification requires (named) types + // to match exactly. If it is not set, the underlying types are + // considered when unification is known to fail otherwise. + exactUnification = false ) -// A unifier maintains the current type parameters for x and y -// and the respective types inferred for each type parameter. +// A unifier maintains a list of type parameters and +// corresponding types inferred for each type parameter. // A unifier is created by calling newUnifier. type unifier struct { - exact bool - x, y tparamsList // x and y must initialized via tparamsList.init - types []Type // inferred types, shared by x and y - depth int // recursion depth during unification + tparams []*TypeParam + // For each tparams element, there is a corresponding type slot index in indices. + // index < 0: unifier.types[-index-1] == nil + // index == 0: no type slot allocated yet + // index > 0: unifier.types[index-1] == typ + // Joined tparams elements share the same type slot and thus have the same index. + // By using a negative index for nil types we don't need to check unifier.types + // to see if we have a type or not. + indices []int // len(indices) == len(tparams) + types []Type // inferred types, shared by x and y + depth int // recursion depth during unification } -// newUnifier returns a new unifier. -// If exact is set, unification requires unified types to match -// exactly. If exact is not set, a named type's underlying type -// is considered if unification would fail otherwise, and the -// direction of channels is ignored. -// TODO(gri) exact is not set anymore by a caller. Consider removing it. -func newUnifier(exact bool) *unifier { - u := &unifier{exact: exact} - u.x.unifier = u - u.y.unifier = u - return u +// newUnifier returns a new unifier initialized with the given type parameters. +// The type parameters must be in the order in which they appear in their declaration +// (this ensures that the tparams indices match the respective type parameter index). +func newUnifier(tparams []*TypeParam) *unifier { + if debug { + for i, tpar := range tparams { + assert(i == tpar.index) + } + } + return &unifier{ + tparams: tparams, + indices: make([]int, len(tparams)), + } } // unify attempts to unify x and y and reports whether it succeeded. +// As a side-effect, types may be inferred for type parameters. func (u *unifier) unify(x, y Type) bool { return u.nify(x, y, nil) } @@ -82,75 +83,46 @@ func (u *unifier) tracef(format string, args ...interface{}) { fmt.Println(strings.Repeat(". ", u.depth) + sprintf(nil, true, format, args...)) } -// A tparamsList describes a list of type parameters and the types inferred for them. -type tparamsList struct { - unifier *unifier - tparams []*TypeParam - // For each tparams element, there is a corresponding type slot index in indices. - // index < 0: unifier.types[-index-1] == nil - // index == 0: no type slot allocated yet - // index > 0: unifier.types[index-1] == typ - // Joined tparams elements share the same type slot and thus have the same index. - // By using a negative index for nil types we don't need to check unifier.types - // to see if we have a type or not. - indices []int // len(d.indices) == len(d.tparams) -} - -// String returns a string representation for a tparamsList. For debugging. -func (d *tparamsList) String() string { +// String returns a string representation of the mapping from +// type parameters to types. +func (u *unifier) String() string { var buf bytes.Buffer w := newTypeWriter(&buf, nil) w.byte('[') - for i, tpar := range d.tparams { + for i, tpar := range u.tparams { if i > 0 { w.string(", ") } w.typ(tpar) w.string(": ") - w.typ(d.at(i)) + w.typ(u.at(i)) } w.byte(']') return buf.String() } -// init initializes d with the given type parameters. -// The type parameters must be in the order in which they appear in their declaration -// (this ensures that the tparams indices match the respective type parameter index). -func (d *tparamsList) init(tparams []*TypeParam) { - if len(tparams) == 0 { - return - } - if debug { - for i, tpar := range tparams { - assert(i == tpar.index) - } - } - d.tparams = tparams - d.indices = make([]int, len(tparams)) -} - -// join unifies the i'th type parameter of x with the j'th type parameter of y. -// If both type parameters already have a type associated with them and they are -// not joined, join fails and returns false. +// join unifies the i'th type parameter with the j'th type parameter. +// If both type parameters already have a type associated with them +// and they are not joined, join fails and returns false. func (u *unifier) join(i, j int) bool { if traceInference { - u.tracef("%s ⇄ %s", u.x.tparams[i], u.y.tparams[j]) + u.tracef("%s ⇄ %s", u.tparams[i], u.tparams[j]) } - ti := u.x.indices[i] - tj := u.y.indices[j] + ti := u.indices[i] + tj := u.indices[j] switch { case ti == 0 && tj == 0: // Neither type parameter has a type slot associated with them. // Allocate a new joined nil type slot (negative index). u.types = append(u.types, nil) - u.x.indices[i] = -len(u.types) - u.y.indices[j] = -len(u.types) + u.indices[i] = -len(u.types) + u.indices[j] = -len(u.types) case ti == 0: - // The type parameter for x has no type slot yet. Use slot of y. - u.x.indices[i] = tj + // The type parameter (with index) i has no type slot yet. Use slot of j. + u.indices[i] = tj case tj == 0: - // The type parameter for y has no type slot yet. Use slot of x. - u.y.indices[j] = ti + // The type parameter (with index) j has no type slot yet. Use slot of i. + u.indices[j] = ti // Both type parameters have a slot: ti != 0 && tj != 0. case ti == tj: @@ -161,25 +133,25 @@ func (u *unifier) join(i, j int) bool { // TODO(gri) Should we check if types are identical? Investigate. return false case ti > 0: - // Only the type parameter for x has an inferred type. Use x slot for y. - u.y.setIndex(j, ti) + // Only the type parameter (with index) i has an inferred type. Use i slot for j. + u.setIndex(j, ti) // This case is handled like the default case. // case tj > 0: // // Only the type parameter for y has an inferred type. Use y slot for x. - // u.x.setIndex(i, tj) + // u.setIndex(i, tj) default: - // Neither type parameter has an inferred type. Use y slot for x - // (or x slot for y, it doesn't matter). - u.x.setIndex(i, tj) + // Neither type parameter has an inferred type. Use j slot for i + // (or i slot for j, it doesn't matter). + u.setIndex(i, tj) } return true } -// If typ is a type parameter of d, index returns the type parameter index. +// If typ is a type parameter recorded with u, index returns the type parameter index. // Otherwise, the result is < 0. -func (d *tparamsList) index(typ Type) int { +func (u *unifier) index(typ Type) int { if tpar, ok := typ.(*TypeParam); ok { - return tparamIndex(d.tparams, tpar) + return tparamIndex(u.tparams, tpar) } return -1 } @@ -202,48 +174,47 @@ func tparamIndex(list []*TypeParam, tpar *TypeParam) int { // setIndex sets the type slot index for the i'th type parameter // (and all its joined parameters) to tj. The type parameter // must have a (possibly nil) type slot associated with it. -func (d *tparamsList) setIndex(i, tj int) { - ti := d.indices[i] +func (u *unifier) setIndex(i, tj int) { + ti := u.indices[i] assert(ti != 0 && tj != 0) - for k, tk := range d.indices { + for k, tk := range u.indices { if tk == ti { - d.indices[k] = tj + u.indices[k] = tj } } } // at returns the type set for the i'th type parameter; or nil. -func (d *tparamsList) at(i int) Type { - if ti := d.indices[i]; ti > 0 { - return d.unifier.types[ti-1] +func (u *unifier) at(i int) Type { + if ti := u.indices[i]; ti > 0 { + return u.types[ti-1] } return nil } // set sets the type typ for the i'th type parameter; // typ must not be nil and it must not have been set before. -func (d *tparamsList) set(i int, typ Type) { +func (u *unifier) set(i int, typ Type) { assert(typ != nil) - u := d.unifier if traceInference { - u.tracef("%s ➞ %s", d.tparams[i], typ) + u.tracef("%s ➞ %s", u.tparams[i], typ) } - switch ti := d.indices[i]; { + switch ti := u.indices[i]; { case ti < 0: u.types[-ti-1] = typ - d.setIndex(i, -ti) + u.setIndex(i, -ti) case ti == 0: u.types = append(u.types, typ) - d.indices[i] = len(u.types) + u.indices[i] = len(u.types) default: panic("type already set") } } // unknowns returns the number of type parameters for which no type has been set yet. -func (d *tparamsList) unknowns() int { +func (u *unifier) unknowns() int { n := 0 - for _, ti := range d.indices { + for _, ti := range u.indices { if ti <= 0 { n++ } @@ -251,15 +222,15 @@ func (d *tparamsList) unknowns() int { return n } -// types returns the list of inferred types (via unification) for the type parameters -// described by d, and an index. If all types were inferred, the returned index is < 0. +// inferred returns the list of inferred types (via unification) for the type parameters +// recorded with u, and an index. If all types were inferred, the returned index is < 0. // Otherwise, it is the index of the first type parameter which couldn't be inferred; // i.e., for which list[index] is nil. -func (d *tparamsList) types() (list []Type, index int) { - list = make([]Type, len(d.tparams)) +func (u *unifier) inferred() (list []Type, index int) { + list = make([]Type, len(u.tparams)) index = -1 - for i := range d.tparams { - t := d.at(i) + for i := range u.tparams { + t := u.at(i) list[i] = t if index < 0 && t == nil { index = i @@ -299,7 +270,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } }() - if !u.exact { + if !exactUnification { // If exact unification is known to fail because we attempt to // match a type name against an unnamed type literal, consider // the underlying type of the named type. @@ -319,44 +290,44 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } // Cases where at least one of x or y is a type parameter. - switch i, j := u.x.index(x), u.y.index(y); { + switch i, j := u.index(x), u.index(y); { case i >= 0 && j >= 0: // both x and y are type parameters if u.join(i, j) { return true } // both x and y have an inferred type - they must match - return u.nifyEq(u.x.at(i), u.y.at(j), p) + return u.nifyEq(u.at(i), u.at(j), p) case i >= 0: // x is a type parameter, y is not - if tx := u.x.at(i); tx != nil { + if tx := u.at(i); tx != nil { return u.nifyEq(tx, y, p) } // otherwise, infer type from y - u.x.set(i, y) + u.set(i, y) return true case j >= 0: // y is a type parameter, x is not - if ty := u.y.at(j); ty != nil { + if ty := u.at(j); ty != nil { return u.nifyEq(x, ty, p) } // otherwise, infer type from x - u.y.set(j, x) + u.set(j, x) return true } // If we get here and x or y is a type parameter, they are type parameters // from outside our declaration list. Try to unify their core types, if any // (see go.dev/issue/50755 for a test case). - if enableCoreTypeUnification && !u.exact { + if enableCoreTypeUnification && !exactUnification { if isTypeParam(x) && !hasName(y) { // When considering the type parameter for unification // we look at the adjusted core term (adjusted core type // with tilde information). // If the adjusted core type is a named type N; the - // corresponding core type is under(N). Since !u.exact + // corresponding core type is under(N). Since !exactUnification // and y doesn't have a name, unification will end up // comparing under(N) to y, so we can just use the core // type instead. And we can ignore the tilde because we @@ -532,7 +503,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { case *Chan: // Two channel types are identical if they have identical value types. if y, ok := y.(*Chan); ok { - return (!u.exact || x.dir == y.dir) && u.nify(x.elem, y.elem, p) + return (!exactUnification || x.dir == y.dir) && u.nify(x.elem, y.elem, p) } case *Named: @@ -568,7 +539,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { // avoid a crash in case of nil type default: - panic(sprintf(nil, true, "u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams)) + panic(sprintf(nil, true, "u.nify(%s, %s), u.tparams = %s", x, y, u.tparams)) } return false diff --git a/src/go/types/infer.go b/src/go/types/infer.go index 5f91b526d9..7b921c3b94 100644 --- a/src/go/types/infer.go +++ b/src/go/types/infer.go @@ -137,19 +137,18 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, // Unify parameter and argument types for generic parameters with typed arguments // and collect the indices of generic parameters with untyped arguments. // Terminology: generic parameter = function parameter with a type-parameterized type - u := newUnifier(false) - u.x.init(tparams) + u := newUnifier(tparams) // Set the type arguments which we know already. for i, targ := range targs { if targ != nil { - u.x.set(i, targ) + u.set(i, targ) } } errorf := func(kind string, tpar, targ Type, arg *operand) { // provide a better error message if we can - targs, index := u.x.types() + targs, index := u.inferred() if index == 0 { // The first type parameter couldn't be inferred. // If none of them could be inferred, don't try @@ -215,7 +214,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, // If we've got all type arguments, we're done. var index int - targs, index = u.x.types() + targs, index = u.inferred() if index < 0 { return targs } @@ -251,7 +250,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, } // If we've got all type arguments, we're done. - targs, index = u.x.types() + targs, index = u.inferred() if index < 0 { return targs } @@ -464,16 +463,13 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, }() } - // Setup bidirectional unification between constraints - // and the corresponding type arguments (which may be nil!). - u := newUnifier(false) - u.x.init(tparams) - u.y = u.x // type parameters between LHS and RHS of unification are identical + // Unify type parameters with their constraints. + u := newUnifier(tparams) // Set the type arguments which we know already. for i, targ := range targs { if targ != nil { - u.x.set(i, targ) + u.set(i, targ) } } @@ -492,7 +488,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, // here could handle the respective type parameters only, // but that will come at a cost of extra complexity which // may not be worth it.) - for n := u.x.unknowns(); n > 0; { + for n := u.unknowns(); n > 0; { nn := n for i, tpar := range tparams { @@ -503,7 +499,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, u.tracef("core(%s) = %s (single = %v)", tpar, core, single) } // A type parameter can be unified with its core type in two cases. - tx := u.x.at(i) + tx := u.at(i) switch { case tx != nil: // The corresponding type argument tx is known. @@ -536,7 +532,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, // The corresponding type argument tx is unknown and there's a single // specific type and no tilde. // In this case the type argument must be that single type; set it. - u.x.set(i, core.typ) + u.set(i, core.typ) default: // Unification is not possible and no progress was made. @@ -544,7 +540,7 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, } // The number of known type arguments may have changed. - nn = u.x.unknowns() + nn = u.unknowns() if nn == 0 { break // all type arguments are known } @@ -562,14 +558,14 @@ func (check *Checker) inferB(tparams []*TypeParam, targs []Type) (types []Type, n = nn } - // u.x.types() now contains the incoming type arguments plus any additional type + // u.inferred() now contains the incoming type arguments plus any additional type // arguments which were inferred from core terms. The newly inferred non-nil // entries may still contain references to other type parameters. // For instance, for [A any, B interface{ []C }, C interface{ *A }], if A == int // was given, unification produced the type list [int, []C, *A]. We eliminate the // remaining type parameters by substituting the type parameters in this type list // until nothing changes anymore. - types, _ = u.x.types() + types, _ = u.inferred() if debug { for i, targ := range targs { assert(targ == nil || types[i] == targ) diff --git a/src/go/types/unify.go b/src/go/types/unify.go index 206ec69d59..a83757f2a4 100644 --- a/src/go/types/unify.go +++ b/src/go/types/unify.go @@ -14,20 +14,6 @@ import ( "strings" ) -// The unifier maintains two separate sets of type parameters x and y -// which are used to resolve type parameters in the x and y arguments -// provided to the unify call. For unidirectional unification, only -// one of these sets (say x) is provided, and then type parameters are -// only resolved for the x argument passed to unify, not the y argument -// (even if that also contains possibly the same type parameters). -// -// For bidirectional unification, both sets are provided. This enables -// unification to go from argument to parameter type and vice versa. -// For constraint type inference, we use bidirectional unification -// where both the x and y type parameters are identical. This is done -// by setting up one of them (using init) and then assigning its value -// to the other. - const ( // Upper limit for recursion depth. Used to catch infinite recursions // due to implementation issues (e.g., see issues #48619, #48656). @@ -50,32 +36,47 @@ const ( // x ≢ y types x and y cannot be unified // [p, q, ...] ➞ [x, y, ...] mapping from type parameters to types traceInference = false + + // If exactUnification is set, unification requires (named) types + // to match exactly. If it is not set, the underlying types are + // considered when unification is known to fail otherwise. + exactUnification = false ) -// A unifier maintains the current type parameters for x and y -// and the respective types inferred for each type parameter. +// A unifier maintains a list of type parameters and +// corresponding types inferred for each type parameter. // A unifier is created by calling newUnifier. type unifier struct { - exact bool - x, y tparamsList // x and y must initialized via tparamsList.init - types []Type // inferred types, shared by x and y - depth int // recursion depth during unification + tparams []*TypeParam + // For each tparams element, there is a corresponding type slot index in indices. + // index < 0: unifier.types[-index-1] == nil + // index == 0: no type slot allocated yet + // index > 0: unifier.types[index-1] == typ + // Joined tparams elements share the same type slot and thus have the same index. + // By using a negative index for nil types we don't need to check unifier.types + // to see if we have a type or not. + indices []int // len(indices) == len(tparams) + types []Type // inferred types, shared by x and y + depth int // recursion depth during unification } -// newUnifier returns a new unifier. -// If exact is set, unification requires unified types to match -// exactly. If exact is not set, a named type's underlying type -// is considered if unification would fail otherwise, and the -// direction of channels is ignored. -// TODO(gri) exact is not set anymore by a caller. Consider removing it. -func newUnifier(exact bool) *unifier { - u := &unifier{exact: exact} - u.x.unifier = u - u.y.unifier = u - return u +// newUnifier returns a new unifier initialized with the given type parameters. +// The type parameters must be in the order in which they appear in their declaration +// (this ensures that the tparams indices match the respective type parameter index). +func newUnifier(tparams []*TypeParam) *unifier { + if debug { + for i, tpar := range tparams { + assert(i == tpar.index) + } + } + return &unifier{ + tparams: tparams, + indices: make([]int, len(tparams)), + } } // unify attempts to unify x and y and reports whether it succeeded. +// As a side-effect, types may be inferred for type parameters. func (u *unifier) unify(x, y Type) bool { return u.nify(x, y, nil) } @@ -84,75 +85,46 @@ func (u *unifier) tracef(format string, args ...interface{}) { fmt.Println(strings.Repeat(". ", u.depth) + sprintf(nil, nil, true, format, args...)) } -// A tparamsList describes a list of type parameters and the types inferred for them. -type tparamsList struct { - unifier *unifier - tparams []*TypeParam - // For each tparams element, there is a corresponding type slot index in indices. - // index < 0: unifier.types[-index-1] == nil - // index == 0: no type slot allocated yet - // index > 0: unifier.types[index-1] == typ - // Joined tparams elements share the same type slot and thus have the same index. - // By using a negative index for nil types we don't need to check unifier.types - // to see if we have a type or not. - indices []int // len(d.indices) == len(d.tparams) -} - -// String returns a string representation for a tparamsList. For debugging. -func (d *tparamsList) String() string { +// String returns a string representation of the mapping from +// type parameters to types. +func (u *unifier) String() string { var buf bytes.Buffer w := newTypeWriter(&buf, nil) w.byte('[') - for i, tpar := range d.tparams { + for i, tpar := range u.tparams { if i > 0 { w.string(", ") } w.typ(tpar) w.string(": ") - w.typ(d.at(i)) + w.typ(u.at(i)) } w.byte(']') return buf.String() } -// init initializes d with the given type parameters. -// The type parameters must be in the order in which they appear in their declaration -// (this ensures that the tparams indices match the respective type parameter index). -func (d *tparamsList) init(tparams []*TypeParam) { - if len(tparams) == 0 { - return - } - if debug { - for i, tpar := range tparams { - assert(i == tpar.index) - } - } - d.tparams = tparams - d.indices = make([]int, len(tparams)) -} - -// join unifies the i'th type parameter of x with the j'th type parameter of y. -// If both type parameters already have a type associated with them and they are -// not joined, join fails and returns false. +// join unifies the i'th type parameter with the j'th type parameter. +// If both type parameters already have a type associated with them +// and they are not joined, join fails and returns false. func (u *unifier) join(i, j int) bool { if traceInference { - u.tracef("%s ⇄ %s", u.x.tparams[i], u.y.tparams[j]) + u.tracef("%s ⇄ %s", u.tparams[i], u.tparams[j]) } - ti := u.x.indices[i] - tj := u.y.indices[j] + ti := u.indices[i] + tj := u.indices[j] switch { case ti == 0 && tj == 0: // Neither type parameter has a type slot associated with them. // Allocate a new joined nil type slot (negative index). u.types = append(u.types, nil) - u.x.indices[i] = -len(u.types) - u.y.indices[j] = -len(u.types) + u.indices[i] = -len(u.types) + u.indices[j] = -len(u.types) case ti == 0: - // The type parameter for x has no type slot yet. Use slot of y. - u.x.indices[i] = tj + // The type parameter (with index) i has no type slot yet. Use slot of j. + u.indices[i] = tj case tj == 0: - // The type parameter for y has no type slot yet. Use slot of x. - u.y.indices[j] = ti + // The type parameter (with index) j has no type slot yet. Use slot of i. + u.indices[j] = ti // Both type parameters have a slot: ti != 0 && tj != 0. case ti == tj: @@ -163,25 +135,25 @@ func (u *unifier) join(i, j int) bool { // TODO(gri) Should we check if types are identical? Investigate. return false case ti > 0: - // Only the type parameter for x has an inferred type. Use x slot for y. - u.y.setIndex(j, ti) + // Only the type parameter (with index) i has an inferred type. Use i slot for j. + u.setIndex(j, ti) // This case is handled like the default case. // case tj > 0: // // Only the type parameter for y has an inferred type. Use y slot for x. - // u.x.setIndex(i, tj) + // u.setIndex(i, tj) default: - // Neither type parameter has an inferred type. Use y slot for x - // (or x slot for y, it doesn't matter). - u.x.setIndex(i, tj) + // Neither type parameter has an inferred type. Use j slot for i + // (or i slot for j, it doesn't matter). + u.setIndex(i, tj) } return true } -// If typ is a type parameter of d, index returns the type parameter index. +// If typ is a type parameter recorded with u, index returns the type parameter index. // Otherwise, the result is < 0. -func (d *tparamsList) index(typ Type) int { +func (u *unifier) index(typ Type) int { if tpar, ok := typ.(*TypeParam); ok { - return tparamIndex(d.tparams, tpar) + return tparamIndex(u.tparams, tpar) } return -1 } @@ -204,48 +176,47 @@ func tparamIndex(list []*TypeParam, tpar *TypeParam) int { // setIndex sets the type slot index for the i'th type parameter // (and all its joined parameters) to tj. The type parameter // must have a (possibly nil) type slot associated with it. -func (d *tparamsList) setIndex(i, tj int) { - ti := d.indices[i] +func (u *unifier) setIndex(i, tj int) { + ti := u.indices[i] assert(ti != 0 && tj != 0) - for k, tk := range d.indices { + for k, tk := range u.indices { if tk == ti { - d.indices[k] = tj + u.indices[k] = tj } } } // at returns the type set for the i'th type parameter; or nil. -func (d *tparamsList) at(i int) Type { - if ti := d.indices[i]; ti > 0 { - return d.unifier.types[ti-1] +func (u *unifier) at(i int) Type { + if ti := u.indices[i]; ti > 0 { + return u.types[ti-1] } return nil } // set sets the type typ for the i'th type parameter; // typ must not be nil and it must not have been set before. -func (d *tparamsList) set(i int, typ Type) { +func (u *unifier) set(i int, typ Type) { assert(typ != nil) - u := d.unifier if traceInference { - u.tracef("%s ➞ %s", d.tparams[i], typ) + u.tracef("%s ➞ %s", u.tparams[i], typ) } - switch ti := d.indices[i]; { + switch ti := u.indices[i]; { case ti < 0: u.types[-ti-1] = typ - d.setIndex(i, -ti) + u.setIndex(i, -ti) case ti == 0: u.types = append(u.types, typ) - d.indices[i] = len(u.types) + u.indices[i] = len(u.types) default: panic("type already set") } } // unknowns returns the number of type parameters for which no type has been set yet. -func (d *tparamsList) unknowns() int { +func (u *unifier) unknowns() int { n := 0 - for _, ti := range d.indices { + for _, ti := range u.indices { if ti <= 0 { n++ } @@ -253,15 +224,15 @@ func (d *tparamsList) unknowns() int { return n } -// types returns the list of inferred types (via unification) for the type parameters -// described by d, and an index. If all types were inferred, the returned index is < 0. +// inferred returns the list of inferred types (via unification) for the type parameters +// recorded with u, and an index. If all types were inferred, the returned index is < 0. // Otherwise, it is the index of the first type parameter which couldn't be inferred; // i.e., for which list[index] is nil. -func (d *tparamsList) types() (list []Type, index int) { - list = make([]Type, len(d.tparams)) +func (u *unifier) inferred() (list []Type, index int) { + list = make([]Type, len(u.tparams)) index = -1 - for i := range d.tparams { - t := d.at(i) + for i := range u.tparams { + t := u.at(i) list[i] = t if index < 0 && t == nil { index = i @@ -301,7 +272,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } }() - if !u.exact { + if !exactUnification { // If exact unification is known to fail because we attempt to // match a type name against an unnamed type literal, consider // the underlying type of the named type. @@ -321,44 +292,44 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } // Cases where at least one of x or y is a type parameter. - switch i, j := u.x.index(x), u.y.index(y); { + switch i, j := u.index(x), u.index(y); { case i >= 0 && j >= 0: // both x and y are type parameters if u.join(i, j) { return true } // both x and y have an inferred type - they must match - return u.nifyEq(u.x.at(i), u.y.at(j), p) + return u.nifyEq(u.at(i), u.at(j), p) case i >= 0: // x is a type parameter, y is not - if tx := u.x.at(i); tx != nil { + if tx := u.at(i); tx != nil { return u.nifyEq(tx, y, p) } // otherwise, infer type from y - u.x.set(i, y) + u.set(i, y) return true case j >= 0: // y is a type parameter, x is not - if ty := u.y.at(j); ty != nil { + if ty := u.at(j); ty != nil { return u.nifyEq(x, ty, p) } // otherwise, infer type from x - u.y.set(j, x) + u.set(j, x) return true } // If we get here and x or y is a type parameter, they are type parameters // from outside our declaration list. Try to unify their core types, if any // (see go.dev/issue/50755 for a test case). - if enableCoreTypeUnification && !u.exact { + if enableCoreTypeUnification && !exactUnification { if isTypeParam(x) && !hasName(y) { // When considering the type parameter for unification // we look at the adjusted core term (adjusted core type // with tilde information). // If the adjusted core type is a named type N; the - // corresponding core type is under(N). Since !u.exact + // corresponding core type is under(N). Since !exactUnification // and y doesn't have a name, unification will end up // comparing under(N) to y, so we can just use the core // type instead. And we can ignore the tilde because we @@ -534,7 +505,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { case *Chan: // Two channel types are identical if they have identical value types. if y, ok := y.(*Chan); ok { - return (!u.exact || x.dir == y.dir) && u.nify(x.elem, y.elem, p) + return (!exactUnification || x.dir == y.dir) && u.nify(x.elem, y.elem, p) } case *Named: @@ -570,7 +541,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { // avoid a crash in case of nil type default: - panic(sprintf(nil, nil, true, "u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams)) + panic(sprintf(nil, nil, true, "u.nify(%s, %s), u.tparams = %s", x, y, u.tparams)) } return false -- 2.48.1