return subst.targs[i].Type()
}
}
- return t
+ // If t is a simple typeparam T, then t has the name/symbol 'T'
+ // and t.Underlying() == t.
+ //
+ // However, consider the type definition: 'type P[T any] T'. We
+ // might use this definition so we can have a variant of type T
+ // that we can add new methods to. Suppose t is a reference to
+ // P[T]. t has the name 'P[T]', but its kind is TTYPEPARAM,
+ // because P[T] is defined as T. If we look at t.Underlying(), it
+ // is different, because the name of t.Underlying() is 'T' rather
+ // than 'P[T]'. But the kind of t.Underlying() is also TTYPEPARAM.
+ // In this case, we do the needed recursive substitution in the
+ // case statement below.
+ if t.Underlying() == t {
+ // t is a simple typeparam that didn't match anything in tparam
+ return t
+ }
+ // t is a more complex typeparam (e.g. P[T], as above, whose
+ // definition is just T).
+ assert(t.Sym() != nil)
}
var newsym *types.Sym
var newt *types.Type
switch t.Kind() {
+ case types.TTYPEPARAM:
+ if t.Sym() == newsym {
+ // The substitution did not change the type.
+ return t
+ }
+ // Substitute the underlying typeparam (e.g. T in P[T], see
+ // the example describing type P[T] above).
+ newt = subst.typ(t.Underlying())
+ assert(newt != t)
case types.TARRAY:
elem := t.Elem()
--- /dev/null
+// run -gcflags=-G=3
+
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ //"math"
+)
+
+type Numeric interface {
+ type int, int8, int16, int32, int64,
+ uint, uint8, uint16, uint32, uint64, uintptr,
+ float32, float64,
+ complex64, complex128
+}
+
+// numericAbs matches numeric types with an Abs method.
+type numericAbs[T any] interface {
+ Numeric
+ Abs() T
+}
+
+// AbsDifference computes the absolute value of the difference of
+// a and b, where the absolute value is determined by the Abs method.
+func absDifference[T numericAbs[T]](a, b T) T {
+ d := a - b
+ return d.Abs()
+}
+
+// orderedNumeric matches numeric types that support the < operator.
+type orderedNumeric interface {
+ type int, int8, int16, int32, int64,
+ uint, uint8, uint16, uint32, uint64, uintptr,
+ float32, float64
+}
+
+// Complex matches the two complex types, which do not have a < operator.
+type Complex interface {
+ type complex64, complex128
+}
+
+// orderedAbs is a helper type that defines an Abs method for
+// ordered numeric types.
+type orderedAbs[T orderedNumeric] T
+
+func (a orderedAbs[T]) Abs() orderedAbs[T] {
+ // TODO(danscales): orderedAbs[T] conversion shouldn't be needed
+ if a < orderedAbs[T](0) {
+ return -a
+ }
+ return a
+}
+
+// complexAbs is a helper type that defines an Abs method for
+// complex types.
+// type complexAbs[T Complex] T
+
+// func (a complexAbs[T]) Abs() complexAbs[T] {
+// r := float64(real(a))
+// i := float64(imag(a))
+// d := math.Sqrt(r * r + i * i)
+// return complexAbs[T](complex(d, 0))
+// }
+
+// OrderedAbsDifference returns the absolute value of the difference
+// between a and b, where a and b are of an ordered type.
+func orderedAbsDifference[T orderedNumeric](a, b T) T {
+ return T(absDifference(orderedAbs[T](a), orderedAbs[T](b)))
+}
+
+// ComplexAbsDifference returns the absolute value of the difference
+// between a and b, where a and b are of a complex type.
+// func complexAbsDifference[T Complex](a, b T) T {
+// return T(absDifference(complexAbs[T](a), complexAbs[T](b)))
+// }
+
+func main() {
+ if got, want := orderedAbsDifference(1.0, -2.0), 3.0; got != want {
+ panic(fmt.Sprintf("got = %v, want = %v", got, want))
+ }
+ if got, want := orderedAbsDifference(-1.0, 2.0), 3.0; got != want {
+ panic(fmt.Sprintf("got = %v, want = %v", got, want))
+ }
+ if got, want := orderedAbsDifference(-20, 15), 35; got != want {
+ panic(fmt.Sprintf("got = %v, want = %v", got, want))
+ }
+
+ // Still have to handle built-ins real/abs to make this work
+ // if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5; got != want {
+ // panic(fmt.Sprintf("got = %v, want = %v", got, want)
+ // }
+ // if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5; got != want {
+ // panic(fmt.Sprintf("got = %v, want = %v", got, want)
+ // }
+}