]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: get untyped constants working in generic functions
authorDan Scales <danscales@google.com>
Thu, 18 Mar 2021 18:56:46 +0000 (11:56 -0700)
committerDan Scales <danscales@google.com>
Thu, 18 Mar 2021 22:18:32 +0000 (22:18 +0000)
types2 will give us a constant with a type T, if an untyped constant is
used with another operand of type T (in a provably correct way). When we
substitute in the type args during stenciling, we now know the real type
of the constant. We may then need to change the BasicLit.val to be the
correct type (e.g. convert an int64Val constant to a floatVal constant).
Otherwise, later parts of the compiler will be confused.

Updated tests list.go and double.go with uses of untyped constants.

Change-Id: I9966bbb0dea3a7de1c5a6420f8ad8af9ca84a33e
Reviewed-on: https://go-review.googlesource.com/c/go/+/303089
Run-TryBot: Dan Scales <danscales@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Dan Scales <danscales@google.com>
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/cmd/compile/internal/noder/stencil.go
test/typeparam/double.go
test/typeparam/list.go

index 55aee9b6ffbd794ef573f3676417fb03d1af6be6..51ef46c7e75491eb67111a36c58b66c9ae64543d 100644 (file)
@@ -367,7 +367,24 @@ func (subst *subster) node(n ir.Node) ir.Node {
                }
                ir.EditChildren(m, edit)
 
-               if x.Op() == ir.OXDOT {
+               switch x.Op() {
+               case ir.OLITERAL:
+                       t := m.Type()
+                       if t != x.Type() {
+                               // types2 will give us a constant with a type T,
+                               // if an untyped constant is used with another
+                               // operand of type T (in a provably correct way).
+                               // When we substitute in the type args during
+                               // stenciling, we now know the real type of the
+                               // constant. We may then need to change the
+                               // BasicLit.val to be the correct type (e.g.
+                               // convert an int64Val constant to a floatVal
+                               // constant).
+                               m.SetType(types.UntypedInt) // use any untyped type for DefaultLit to work
+                               m = typecheck.DefaultLit(m, t)
+                       }
+
+               case ir.OXDOT:
                        // A method value/call via a type param will have been left as an
                        // OXDOT. When we see this during stenciling, finish the
                        // typechecking, now that we have the instantiated receiver type.
@@ -377,8 +394,8 @@ func (subst *subster) node(n ir.Node) ir.Node {
                        m.SetTypecheck(0)
                        // m will transform to an OCALLPART
                        typecheck.Expr(m)
-               }
-               if x.Op() == ir.OCALL {
+
+               case ir.OCALL:
                        call := m.(*ir.CallExpr)
                        if call.X.Op() == ir.OTYPE {
                                // Do typechecking on a conversion, now that we
@@ -419,9 +436,8 @@ func (subst *subster) node(n ir.Node) ir.Node {
                                // instantiation to be called.
                                base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST or builtin with CALL")
                        }
-               }
 
-               if x.Op() == ir.OCLOSURE {
+               case ir.OCLOSURE:
                        x := x.(*ir.ClosureExpr)
                        // Need to save/duplicate x.Func.Nname,
                        // x.Func.Nname.Ntype, x.Func.Dcl, x.Func.ClosureVars, and
index 1f7a26c7f4ded73eae646e503446f33cfd3d545e..ce78ec9748319ca580451630955a3428b30e09cd 100644 (file)
@@ -16,6 +16,7 @@ type Number interface {
 }
 
 type MySlice []int
+type MyFloatSlice []float64
 
 type _SliceOf[E any] interface {
        type []E
@@ -29,6 +30,15 @@ func _DoubleElems[S _SliceOf[E], E Number](s S) S {
        return r
 }
 
+// Test use of untyped constant in an expression with a generically-typed parameter
+func _DoubleElems2[S _SliceOf[E], E Number](s S) S {
+       r := make(S, len(s))
+       for i, v := range s {
+               r[i] = v * 2
+       }
+       return r
+}
+
 func main() {
        arg := MySlice{1, 2, 3}
        want := MySlice{2, 4, 6}
@@ -47,4 +57,16 @@ func main() {
        if !reflect.DeepEqual(got, want) {
                 panic(fmt.Sprintf("got %s, want %s", got, want))
        }
+
+       farg := MyFloatSlice{1.2, 2.0, 3.5}
+       fwant := MyFloatSlice{2.4, 4.0, 7.0}
+       fgot := _DoubleElems(farg)
+       if !reflect.DeepEqual(fgot, fwant) {
+                panic(fmt.Sprintf("got %s, want %s", fgot, fwant))
+       }
+
+       fgot = _DoubleElems2(farg)
+       if !reflect.DeepEqual(fgot, fwant) {
+                panic(fmt.Sprintf("got %s, want %s", fgot, fwant))
+       }
 }
index 64230060de7f3478bcba960570b2929752fc17e7..579078f02f7b1397cd5de521def0c82f086d6dae 100644 (file)
@@ -17,13 +17,13 @@ type Ordered interface {
                 string
 }
 
-// List is a linked list of ordered values of type T.
-type list[T Ordered] struct {
-       next *list[T]
+// _List is a linked list of ordered values of type T.
+type _List[T Ordered] struct {
+       next *_List[T]
        val  T
 }
 
-func (l *list[T]) largest() T {
+func (l *_List[T]) Largest() T {
        var max T
        for p := l; p != nil; p = p.next {
                if p.val > max {
@@ -33,33 +33,71 @@ func (l *list[T]) largest() T {
        return max
 }
 
+type OrderedNum interface {
+        type int, int8, int16, int32, int64,
+                uint, uint8, uint16, uint32, uint64, uintptr,
+                float32, float64
+}
+
+// _ListNum is a linked _List of ordered numeric values of type T.
+type _ListNum[T OrderedNum] struct {
+       next *_ListNum[T]
+       val  T
+}
+
+const Clip = 5
+
+// clippedLargest returns the largest in the list of OrderNums, but a max of 5.
+// Test use of untyped constant in an expression with a generically-typed parameter
+func (l *_ListNum[T]) ClippedLargest() T {
+       var max T
+       for p := l; p != nil; p = p.next {
+               if p.val > max && p.val < Clip {
+                       max = p.val
+               }
+       }
+       return max
+}
 
 func main() {
-       i3 := &list[int]{nil, 1}
-       i2 := &list[int]{i3, 3}
-       i1 := &list[int]{i2, 2}
-       if got, want := i1.largest(), 3; got != want {
+       i3 := &_List[int]{nil, 1}
+       i2 := &_List[int]{i3, 3}
+       i1 := &_List[int]{i2, 2}
+       if got, want := i1.Largest(), 3; got != want {
                 panic(fmt.Sprintf("got %d, want %d", got, want))
        }
 
-       b3 := &list[byte]{nil, byte(1)}
-       b2 := &list[byte]{b3, byte(3)}
-       b1 := &list[byte]{b2, byte(2)}
-       if got, want := b1.largest(), byte(3); got != want {
+       b3 := &_List[byte]{nil, byte(1)}
+       b2 := &_List[byte]{b3, byte(3)}
+       b1 := &_List[byte]{b2, byte(2)}
+       if got, want := b1.Largest(), byte(3); got != want {
                 panic(fmt.Sprintf("got %d, want %d", got, want))
        }
 
-       f3 := &list[float64]{nil, 13.5}
-       f2 := &list[float64]{f3, 1.2}
-       f1 := &list[float64]{f2, 4.5}
-       if got, want := f1.largest(), 13.5; got != want {
+       f3 := &_List[float64]{nil, 13.5}
+       f2 := &_List[float64]{f3, 1.2}
+       f1 := &_List[float64]{f2, 4.5}
+       if got, want := f1.Largest(), 13.5; got != want {
                 panic(fmt.Sprintf("got %f, want %f", got, want))
        }
 
-       s3 := &list[string]{nil, "dd"}
-       s2 := &list[string]{s3, "aa"}
-       s1 := &list[string]{s2, "bb"}
-       if got, want := s1.largest(), "dd"; got != want {
+       s3 := &_List[string]{nil, "dd"}
+       s2 := &_List[string]{s3, "aa"}
+       s1 := &_List[string]{s2, "bb"}
+       if got, want := s1.Largest(), "dd"; got != want {
                 panic(fmt.Sprintf("got %s, want %s", got, want))
        }
+
+       j3 := &_ListNum[int]{nil, 1}
+       j2 := &_ListNum[int]{j3, 32}
+       j1 := &_ListNum[int]{j2, 2}
+       if got, want := j1.ClippedLargest(), 2; got != want {
+                panic(fmt.Sprintf("got %d, want %d", got, want))
+       }
+       g3 := &_ListNum[float64]{nil, 13.5}
+       g2 := &_ListNum[float64]{g3, 1.2}
+       g1 := &_ListNum[float64]{g2, 4.5}
+       if got, want := g1.ClippedLargest(), 4.5; got != want {
+                panic(fmt.Sprintf("got %f, want %f", got, want))
+       }
 }