]> Cypherpunks repositories - gostls13.git/commitdiff
reflect: add iterative related methods
authorqiulaidongfeng <2645477756@qq.com>
Tue, 7 May 2024 22:53:35 +0000 (22:53 +0000)
committerGopher Robot <gobot@golang.org>
Thu, 9 May 2024 11:54:18 +0000 (11:54 +0000)
Fixes #66056

Change-Id: I1e24636e43e68cd57576c39b014e0826fb6c322c
GitHub-Last-Rev: 319ad8ea7cd5326d23f9fddb9607924326aaf927
GitHub-Pull-Request: golang/go#66824
Reviewed-on: https://go-review.googlesource.com/c/go/+/578815
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
api/next/66056.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/reflect/66056.md [new file with mode: 0644]
src/reflect/iter.go [new file with mode: 0644]
src/reflect/iter_test.go [new file with mode: 0644]
src/reflect/type.go
src/reflect/type_test.go

diff --git a/api/next/66056.txt b/api/next/66056.txt
new file mode 100644 (file)
index 0000000..db7065a
--- /dev/null
@@ -0,0 +1,4 @@
+pkg reflect, method (Value) Seq() iter.Seq[Value] #66056
+pkg reflect, method (Value) Seq2() iter.Seq2[Value, Value] #66056
+pkg reflect, type Type interface, CanSeq() bool #66056
+pkg reflect, type Type interface, CanSeq2() bool #66056
diff --git a/doc/next/6-stdlib/99-minor/reflect/66056.md b/doc/next/6-stdlib/99-minor/reflect/66056.md
new file mode 100644 (file)
index 0000000..b5f3934
--- /dev/null
@@ -0,0 +1,4 @@
+The new methods [Value.Seq] and [Value.Seq2] return sequences that iterate over the value
+as though it were used in a for/range loop.
+The new methods [Type.CanSeq] and [Type.CanSeq2] report whether calling
+[Value.Seq] and [Value.Seq2], respectively, will succeed without panicking.
diff --git a/src/reflect/iter.go b/src/reflect/iter.go
new file mode 100644 (file)
index 0000000..539872d
--- /dev/null
@@ -0,0 +1,140 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package reflect
+
+import "iter"
+
+// Seq returns an iter.Seq[reflect.Value] that loops over the elements of v.
+// If v's kind is Func, it must be a function that has no results and
+// that takes a single argument of type func(T) bool for some type T.
+// If v's kind is Pointer, the pointer element type must have kind Array.
+// Otherwise v's kind must be Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr,
+// Array, Chan, Map, Slice, or String.
+func (v Value) Seq() iter.Seq[Value] {
+       if canRangeFunc(v.typ()) {
+               return func(yield func(Value) bool) {
+                       rf := MakeFunc(v.Type().In(0), func(in []Value) []Value {
+                               return []Value{ValueOf(yield(in[0]))}
+                       })
+                       v.Call([]Value{rf})
+               }
+       }
+       switch v.Kind() {
+       case Int, Int8, Int16, Int32, Int64:
+               return func(yield func(Value) bool) {
+                       for i := range v.Int() {
+                               if !yield(ValueOf(i)) {
+                                       return
+                               }
+                       }
+               }
+       case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
+               return func(yield func(Value) bool) {
+                       for i := range v.Uint() {
+                               if !yield(ValueOf(i)) {
+                                       return
+                               }
+                       }
+               }
+       case Pointer:
+               if v.Elem().kind() != Array {
+                       break
+               }
+               return func(yield func(Value) bool) {
+                       v = v.Elem()
+                       for i := range v.Len() {
+                               if !yield(ValueOf(i)) {
+                                       return
+                               }
+                       }
+               }
+       case Array, Slice:
+               return func(yield func(Value) bool) {
+                       for i := range v.Len() {
+                               if !yield(ValueOf(i)) {
+                                       return
+                               }
+                       }
+               }
+       case String:
+               return func(yield func(Value) bool) {
+                       for i := range v.String() {
+                               if !yield(ValueOf(i)) {
+                                       return
+                               }
+                       }
+               }
+       case Map:
+               return func(yield func(Value) bool) {
+                       i := v.MapRange()
+                       for i.Next() {
+                               if !yield(i.Key()) {
+                                       return
+                               }
+                       }
+               }
+       case Chan:
+               return func(yield func(Value) bool) {
+                       for value, ok := v.Recv(); ok; value, ok = v.Recv() {
+                               if !yield(value) {
+                                       return
+                               }
+                       }
+               }
+       }
+       panic("reflect: " + v.Type().String() + " cannot produce iter.Seq[Value]")
+}
+
+// Seq2 is like Seq but for two values.
+func (v Value) Seq2() iter.Seq2[Value, Value] {
+       if canRangeFunc2(v.typ()) {
+               return func(yield func(Value, Value) bool) {
+                       rf := MakeFunc(v.Type().In(0), func(in []Value) []Value {
+                               return []Value{ValueOf(yield(in[0], in[1]))}
+                       })
+                       v.Call([]Value{rf})
+               }
+       }
+       switch v.Kind() {
+       case Pointer:
+               if v.Elem().kind() != Array {
+                       break
+               }
+               return func(yield func(Value, Value) bool) {
+                       v = v.Elem()
+                       for i := range v.Len() {
+                               if !yield(ValueOf(i), v.Index(i)) {
+                                       return
+                               }
+                       }
+               }
+       case Array, Slice:
+               return func(yield func(Value, Value) bool) {
+                       for i := range v.Len() {
+                               if !yield(ValueOf(i), v.Index(i)) {
+                                       return
+                               }
+                       }
+               }
+       case String:
+               return func(yield func(Value, Value) bool) {
+                       for i, v := range v.String() {
+                               if !yield(ValueOf(i), ValueOf(v)) {
+                                       return
+                               }
+                       }
+               }
+       case Map:
+               return func(yield func(Value, Value) bool) {
+                       i := v.MapRange()
+                       for i.Next() {
+                               if !yield(i.Key(), i.Value()) {
+                                       return
+                               }
+                       }
+               }
+       }
+       panic("reflect: " + v.Type().String() + " cannot produce iter.Seq2[Value, Value]")
+}
diff --git a/src/reflect/iter_test.go b/src/reflect/iter_test.go
new file mode 100644 (file)
index 0000000..c4a14e7
--- /dev/null
@@ -0,0 +1,280 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package reflect_test
+
+import (
+       "iter"
+       "maps"
+       . "reflect"
+       "testing"
+)
+
+func TestValueSeq(t *testing.T) {
+       m := map[string]int{
+               "1": 1,
+               "2": 2,
+               "3": 3,
+               "4": 4,
+       }
+       c := make(chan int, 3)
+       for i := range 3 {
+               c <- i
+       }
+       close(c)
+       tests := []struct {
+               name  string
+               val   Value
+               check func(*testing.T, iter.Seq[Value])
+       }{
+               {"int", ValueOf(4), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       for v := range s {
+                               if v.Int() != i {
+                                       t.Fatalf("got %d, want %d", v.Int(), i)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"uint", ValueOf(uint64(4)), func(t *testing.T, s iter.Seq[Value]) {
+                       i := uint64(0)
+                       for v := range s {
+                               if v.Uint() != i {
+                                       t.Fatalf("got %d, want %d", v.Uint(), i)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"*[4]int", ValueOf(&[4]int{1, 2, 3, 4}), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       for v := range s {
+                               if v.Int() != i {
+                                       t.Fatalf("got %d, want %d", v.Int(), i)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"[4]int", ValueOf([4]int{1, 2, 3, 4}), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       for v := range s {
+                               if v.Int() != i {
+                                       t.Fatalf("got %d, want %d", v.Int(), i)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"[]int", ValueOf([]int{1, 2, 3, 4}), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       for v := range s {
+                               if v.Int() != i {
+                                       t.Fatalf("got %d, want %d", v.Int(), i)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"string", ValueOf("12语言"), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       indexs := []int64{0, 1, 2, 5}
+                       for v := range s {
+                               if v.Int() != indexs[i] {
+                                       t.Fatalf("got %d, want %d", v.Int(), indexs[i])
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"map[string]int", ValueOf(m), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       copy := maps.Clone(m)
+                       for v := range s {
+                               if _, ok := copy[v.String()]; !ok {
+                                       t.Fatalf("unexpected %v", v.Interface())
+                               }
+                               delete(copy, v.String())
+                               i++
+                       }
+                       if len(copy) != 0 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"chan int", ValueOf(c), func(t *testing.T, s iter.Seq[Value]) {
+                       i := 0
+                       m := map[int64]bool{
+                               0: false,
+                               1: false,
+                               2: false,
+                       }
+                       for v := range s {
+                               if b, ok := m[v.Int()]; !ok || b {
+                                       t.Fatalf("unexpected %v", v.Interface())
+                               }
+                               m[v.Int()] = true
+                               i++
+                       }
+                       if i != 3 {
+                               t.Fatalf("should loop three times")
+                       }
+               }},
+               {"func", ValueOf(func(yield func(int) bool) {
+                       for i := range 4 {
+                               if !yield(i) {
+                                       return
+                               }
+                       }
+               }), func(t *testing.T, s iter.Seq[Value]) {
+                       i := int64(0)
+                       for v := range s {
+                               if v.Int() != i {
+                                       t.Fatalf("got %d, want %d", v.Int(), i)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+       }
+       for _, tc := range tests {
+               seq := tc.val.Seq()
+               tc.check(t, seq)
+       }
+}
+
+func TestValueSeq2(t *testing.T) {
+       m := map[string]int{
+               "1": 1,
+               "2": 2,
+               "3": 3,
+               "4": 4,
+       }
+       tests := []struct {
+               name  string
+               val   Value
+               check func(*testing.T, iter.Seq2[Value, Value])
+       }{
+               {"*[4]int", ValueOf(&[4]int{1, 2, 3, 4}), func(t *testing.T, s iter.Seq2[Value, Value]) {
+                       i := int64(0)
+                       for v1, v2 := range s {
+                               if v1.Int() != i {
+                                       t.Fatalf("got %d, want %d", v1.Int(), i)
+                               }
+                               i++
+                               if v2.Int() != i {
+                                       t.Fatalf("got %d, want %d", v2.Int(), i)
+                               }
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"[4]int", ValueOf([4]int{1, 2, 3, 4}), func(t *testing.T, s iter.Seq2[Value, Value]) {
+                       i := int64(0)
+                       for v1, v2 := range s {
+                               if v1.Int() != i {
+                                       t.Fatalf("got %d, want %d", v1.Int(), i)
+                               }
+                               i++
+                               if v2.Int() != i {
+                                       t.Fatalf("got %d, want %d", v2.Int(), i)
+                               }
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"[]int", ValueOf([]int{1, 2, 3, 4}), func(t *testing.T, s iter.Seq2[Value, Value]) {
+                       i := int64(0)
+                       for v1, v2 := range s {
+                               if v1.Int() != i {
+                                       t.Fatalf("got %d, want %d", v1.Int(), i)
+                               }
+                               i++
+                               if v2.Int() != i {
+                                       t.Fatalf("got %d, want %d", v2.Int(), i)
+                               }
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"string", ValueOf("12语言"), func(t *testing.T, s iter.Seq2[Value, Value]) {
+                       i := int64(0)
+                       str := "12语言"
+                       next, stop := iter.Pull2(s)
+                       defer stop()
+                       for j, s := range str {
+                               v1, v2, ok := next()
+                               if !ok {
+                                       t.Fatalf("should loop four times")
+                               }
+                               if v1.Int() != int64(j) {
+                                       t.Fatalf("got %d, want %d", v1.Int(), j)
+                               }
+                               if v2.Interface() != s {
+                                       t.Fatalf("got %v, want %v", v2.Interface(), s)
+                               }
+                               i++
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"map[string]int", ValueOf(m), func(t *testing.T, s iter.Seq2[Value, Value]) {
+                       copy := maps.Clone(m)
+                       for v1, v2 := range s {
+                               v, ok := copy[v1.String()]
+                               if !ok {
+                                       t.Fatalf("unexpected %v", v1.String())
+                               }
+                               if v != v2.Interface() {
+                                       t.Fatalf("got %v, want %d", v2.Interface(), v)
+                               }
+                               delete(copy, v1.String())
+                       }
+                       if len(copy) != 0 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+               {"func", ValueOf(func(f func(int, int) bool) {
+                       for i := range 4 {
+                               f(i, i+1)
+                       }
+               }), func(t *testing.T, s iter.Seq2[Value, Value]) {
+                       i := int64(0)
+                       for v1, v2 := range s {
+                               if v1.Int() != i {
+                                       t.Fatalf("got %d, want %d", v1.Int(), i)
+                               }
+                               i++
+                               if v2.Int() != i {
+                                       t.Fatalf("got %d, want %d", v2.Int(), i)
+                               }
+                       }
+                       if i != 4 {
+                               t.Fatalf("should loop four times")
+                       }
+               }},
+       }
+       for _, tc := range tests {
+               seq := tc.val.Seq2()
+               tc.check(t, seq)
+       }
+}
index de447c0d1577b0be677f50192ad9d1164a667fdb..5ad74aabfcaf25e5e06702743bb8e034ddbe095e 100644 (file)
@@ -241,6 +241,12 @@ type Type interface {
        // It panics if t's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64.
        OverflowUint(x uint64) bool
 
+       // CanSeq reports whether a [Value] with this type can be iterated over using [Value.Seq].
+       CanSeq() bool
+
+       // CanSeq2 reports whether a [Value] with this type can be iterated over using [Value.Seq2].
+       CanSeq2() bool
+
        common() *abi.Type
        uncommon() *uncommonType
 }
@@ -866,6 +872,62 @@ func (t *rtype) OverflowUint(x uint64) bool {
        panic("reflect: OverflowUint of non-uint type " + t.String())
 }
 
+func (t *rtype) CanSeq() bool {
+       switch t.Kind() {
+       case Int8, Int16, Int32, Int64, Int, Uint8, Uint16, Uint32, Uint64, Uint, Uintptr, Array, Slice, Chan, String, Map:
+               return true
+       case Func:
+               return canRangeFunc(&t.t)
+       case Pointer:
+               return t.Elem().Kind() == Array
+       }
+       return false
+}
+
+func canRangeFunc(t *abi.Type) bool {
+       if t.Kind() != abi.Func {
+               return false
+       }
+       f := t.FuncType()
+       if f.InCount != 1 || f.OutCount != 0 {
+               return false
+       }
+       y := f.In(0)
+       if y.Kind() != abi.Func {
+               return false
+       }
+       yield := y.FuncType()
+       return yield.InCount == 1 && yield.OutCount == 1 && yield.Out(0).Kind() == abi.Bool
+}
+
+func (t *rtype) CanSeq2() bool {
+       switch t.Kind() {
+       case Array, Slice, String, Map:
+               return true
+       case Func:
+               return canRangeFunc2(&t.t)
+       case Pointer:
+               return t.Elem().Kind() == Array
+       }
+       return false
+}
+
+func canRangeFunc2(t *abi.Type) bool {
+       if t.Kind() != abi.Func {
+               return false
+       }
+       f := t.FuncType()
+       if f.InCount != 1 || f.OutCount != 0 {
+               return false
+       }
+       y := f.In(0)
+       if y.Kind() != abi.Func {
+               return false
+       }
+       yield := y.FuncType()
+       return yield.InCount == 2 && yield.OutCount == 1 && yield.Out(0).Kind() == abi.Bool
+}
+
 // add returns p+x.
 //
 // The whySafe string is ignored, so that the function still inlines
index 200ecf6eca66b627ace3f38759bcbf30fca49411..40ae7131c3ae45d93a4e1be0787fdc529281d0df 100644 (file)
@@ -117,3 +117,53 @@ func BenchmarkTypeForError(b *testing.B) {
                sinkType = reflect.TypeFor[error]()
        }
 }
+
+func Test_Type_CanSeq(t *testing.T) {
+       tests := []struct {
+               name string
+               tr   reflect.Type
+               want bool
+       }{
+               {"func(func(int) bool)", reflect.TypeOf(func(func(int) bool) {}), true},
+               {"func(func(int))", reflect.TypeOf(func(func(int)) {}), false},
+               {"int64", reflect.TypeOf(int64(1)), true},
+               {"uint64", reflect.TypeOf(uint64(1)), true},
+               {"*[4]int", reflect.TypeOf(&[4]int{}), true},
+               {"chan int64", reflect.TypeOf(make(chan int64)), true},
+               {"map[int]int", reflect.TypeOf(make(map[int]int)), true},
+               {"string", reflect.TypeOf(""), true},
+               {"[]int", reflect.TypeOf([]int{}), true},
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := tt.tr.CanSeq(); got != tt.want {
+                               t.Errorf("Type.CanSeq() = %v, want %v", got, tt.want)
+                       }
+               })
+       }
+}
+
+func Test_Type_CanSeq2(t *testing.T) {
+       tests := []struct {
+               name string
+               tr   reflect.Type
+               want bool
+       }{
+               {"func(func(int, int) bool)", reflect.TypeOf(func(func(int, int) bool) {}), true},
+               {"func(func(int, int))", reflect.TypeOf(func(func(int, int)) {}), false},
+               {"int64", reflect.TypeOf(int64(1)), false},
+               {"uint64", reflect.TypeOf(uint64(1)), false},
+               {"*[4]int", reflect.TypeOf(&[4]int{}), true},
+               {"chan int64", reflect.TypeOf(make(chan int64)), false},
+               {"map[int]int", reflect.TypeOf(make(map[int]int)), true},
+               {"string", reflect.TypeOf(""), true},
+               {"[]int", reflect.TypeOf([]int{}), true},
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := tt.tr.CanSeq2(); got != tt.want {
+                               t.Errorf("Type.CanSeq2() = %v, want %v", got, tt.want)
+                       }
+               })
+       }
+}