From 90b1ed1602ad8db08a5b1bde4aae5f8569b21915 Mon Sep 17 00:00:00 2001 From: Dan Scales Date: Thu, 18 Mar 2021 11:56:46 -0700 Subject: [PATCH] cmd/compile: get untyped constants working in generic functions types2 will give us a constant with a type T, if an untyped constant is used with another operand of type T (in a provably correct way). When we substitute in the type args during stenciling, we now know the real type of the constant. We may then need to change the BasicLit.val to be the correct type (e.g. convert an int64Val constant to a floatVal constant). Otherwise, later parts of the compiler will be confused. Updated tests list.go and double.go with uses of untyped constants. Change-Id: I9966bbb0dea3a7de1c5a6420f8ad8af9ca84a33e Reviewed-on: https://go-review.googlesource.com/c/go/+/303089 Run-TryBot: Dan Scales TryBot-Result: Go Bot Trust: Dan Scales Trust: Robert Griesemer Reviewed-by: Robert Griesemer --- src/cmd/compile/internal/noder/stencil.go | 26 ++++++-- test/typeparam/double.go | 22 +++++++ test/typeparam/list.go | 78 +++++++++++++++++------ 3 files changed, 101 insertions(+), 25 deletions(-) diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 55aee9b6ff..51ef46c7e7 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -367,7 +367,24 @@ func (subst *subster) node(n ir.Node) ir.Node { } ir.EditChildren(m, edit) - if x.Op() == ir.OXDOT { + switch x.Op() { + case ir.OLITERAL: + t := m.Type() + if t != x.Type() { + // types2 will give us a constant with a type T, + // if an untyped constant is used with another + // operand of type T (in a provably correct way). + // When we substitute in the type args during + // stenciling, we now know the real type of the + // constant. We may then need to change the + // BasicLit.val to be the correct type (e.g. + // convert an int64Val constant to a floatVal + // constant). + m.SetType(types.UntypedInt) // use any untyped type for DefaultLit to work + m = typecheck.DefaultLit(m, t) + } + + case ir.OXDOT: // A method value/call via a type param will have been left as an // OXDOT. When we see this during stenciling, finish the // typechecking, now that we have the instantiated receiver type. @@ -377,8 +394,8 @@ func (subst *subster) node(n ir.Node) ir.Node { m.SetTypecheck(0) // m will transform to an OCALLPART typecheck.Expr(m) - } - if x.Op() == ir.OCALL { + + case ir.OCALL: call := m.(*ir.CallExpr) if call.X.Op() == ir.OTYPE { // Do typechecking on a conversion, now that we @@ -419,9 +436,8 @@ func (subst *subster) node(n ir.Node) ir.Node { // instantiation to be called. base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST or builtin with CALL") } - } - if x.Op() == ir.OCLOSURE { + case ir.OCLOSURE: x := x.(*ir.ClosureExpr) // Need to save/duplicate x.Func.Nname, // x.Func.Nname.Ntype, x.Func.Dcl, x.Func.ClosureVars, and diff --git a/test/typeparam/double.go b/test/typeparam/double.go index 1f7a26c7f4..ce78ec9748 100644 --- a/test/typeparam/double.go +++ b/test/typeparam/double.go @@ -16,6 +16,7 @@ type Number interface { } type MySlice []int +type MyFloatSlice []float64 type _SliceOf[E any] interface { type []E @@ -29,6 +30,15 @@ func _DoubleElems[S _SliceOf[E], E Number](s S) S { return r } +// Test use of untyped constant in an expression with a generically-typed parameter +func _DoubleElems2[S _SliceOf[E], E Number](s S) S { + r := make(S, len(s)) + for i, v := range s { + r[i] = v * 2 + } + return r +} + func main() { arg := MySlice{1, 2, 3} want := MySlice{2, 4, 6} @@ -47,4 +57,16 @@ func main() { if !reflect.DeepEqual(got, want) { panic(fmt.Sprintf("got %s, want %s", got, want)) } + + farg := MyFloatSlice{1.2, 2.0, 3.5} + fwant := MyFloatSlice{2.4, 4.0, 7.0} + fgot := _DoubleElems(farg) + if !reflect.DeepEqual(fgot, fwant) { + panic(fmt.Sprintf("got %s, want %s", fgot, fwant)) + } + + fgot = _DoubleElems2(farg) + if !reflect.DeepEqual(fgot, fwant) { + panic(fmt.Sprintf("got %s, want %s", fgot, fwant)) + } } diff --git a/test/typeparam/list.go b/test/typeparam/list.go index 64230060de..579078f02f 100644 --- a/test/typeparam/list.go +++ b/test/typeparam/list.go @@ -17,13 +17,13 @@ type Ordered interface { string } -// List is a linked list of ordered values of type T. -type list[T Ordered] struct { - next *list[T] +// _List is a linked list of ordered values of type T. +type _List[T Ordered] struct { + next *_List[T] val T } -func (l *list[T]) largest() T { +func (l *_List[T]) Largest() T { var max T for p := l; p != nil; p = p.next { if p.val > max { @@ -33,33 +33,71 @@ func (l *list[T]) largest() T { return max } +type OrderedNum interface { + type int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, uintptr, + float32, float64 +} + +// _ListNum is a linked _List of ordered numeric values of type T. +type _ListNum[T OrderedNum] struct { + next *_ListNum[T] + val T +} + +const Clip = 5 + +// clippedLargest returns the largest in the list of OrderNums, but a max of 5. +// Test use of untyped constant in an expression with a generically-typed parameter +func (l *_ListNum[T]) ClippedLargest() T { + var max T + for p := l; p != nil; p = p.next { + if p.val > max && p.val < Clip { + max = p.val + } + } + return max +} func main() { - i3 := &list[int]{nil, 1} - i2 := &list[int]{i3, 3} - i1 := &list[int]{i2, 2} - if got, want := i1.largest(), 3; got != want { + i3 := &_List[int]{nil, 1} + i2 := &_List[int]{i3, 3} + i1 := &_List[int]{i2, 2} + if got, want := i1.Largest(), 3; got != want { panic(fmt.Sprintf("got %d, want %d", got, want)) } - b3 := &list[byte]{nil, byte(1)} - b2 := &list[byte]{b3, byte(3)} - b1 := &list[byte]{b2, byte(2)} - if got, want := b1.largest(), byte(3); got != want { + b3 := &_List[byte]{nil, byte(1)} + b2 := &_List[byte]{b3, byte(3)} + b1 := &_List[byte]{b2, byte(2)} + if got, want := b1.Largest(), byte(3); got != want { panic(fmt.Sprintf("got %d, want %d", got, want)) } - f3 := &list[float64]{nil, 13.5} - f2 := &list[float64]{f3, 1.2} - f1 := &list[float64]{f2, 4.5} - if got, want := f1.largest(), 13.5; got != want { + f3 := &_List[float64]{nil, 13.5} + f2 := &_List[float64]{f3, 1.2} + f1 := &_List[float64]{f2, 4.5} + if got, want := f1.Largest(), 13.5; got != want { panic(fmt.Sprintf("got %f, want %f", got, want)) } - s3 := &list[string]{nil, "dd"} - s2 := &list[string]{s3, "aa"} - s1 := &list[string]{s2, "bb"} - if got, want := s1.largest(), "dd"; got != want { + s3 := &_List[string]{nil, "dd"} + s2 := &_List[string]{s3, "aa"} + s1 := &_List[string]{s2, "bb"} + if got, want := s1.Largest(), "dd"; got != want { panic(fmt.Sprintf("got %s, want %s", got, want)) } + + j3 := &_ListNum[int]{nil, 1} + j2 := &_ListNum[int]{j3, 32} + j1 := &_ListNum[int]{j2, 2} + if got, want := j1.ClippedLargest(), 2; got != want { + panic(fmt.Sprintf("got %d, want %d", got, want)) + } + g3 := &_ListNum[float64]{nil, 13.5} + g2 := &_ListNum[float64]{g3, 1.2} + g1 := &_ListNum[float64]{g2, 4.5} + if got, want := g1.ClippedLargest(), 4.5; got != want { + panic(fmt.Sprintf("got %f, want %f", got, want)) + } } -- 2.50.0