]> Cypherpunks repositories - gostls13.git/commitdiff
iter: detect and reject double next and double yield in Pull, Pull2
authorMichael Anthony Knyszek <mknyszek@google.com>
Tue, 7 May 2024 00:49:03 +0000 (00:49 +0000)
committerMichael Knyszek <mknyszek@google.com>
Wed, 8 May 2024 18:03:22 +0000 (18:03 +0000)
Currently it's possible for next and yield to be called out of sequence,
which will result in surprising behavior due to the implementation.
Because we blindly coroswitch between goroutines, calling next from the
iterator, or yield from the calling goroutine, will actually switch back
to the other goroutine. In the case of next, we'll switch back with a
stale (or zero) value: the results are basically garbage. In the case of
yield, we're switching back to the *same* goroutine, which will crash in
the runtime.

This change adds a single bool to ensure that next and yield are always
called in sequence. That is, every next must always be paired with a
yield before continuing. This restricts what can be done with Pull, but
prevents observing some truly strange behaviors that the user of Pull
likely did not intend, or can't easily predict.

Change-Id: I6f72461f49c5635d6914bc5b968ad6970cd3c734
Reviewed-on: https://go-review.googlesource.com/c/go/+/583676
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
Auto-Submit: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Russ Cox <rsc@golang.org>
src/iter/iter.go
src/iter/pull_test.go

index 4d9cfad73bb7ce499cc212e860652bb9c7ff0db5..af360aeb079de4d5d53319bc3753895bb86492c5 100644 (file)
@@ -50,10 +50,11 @@ func coroswitch(*coro)
 // simultaneously.
 func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
        var (
-               v     V
-               ok    bool
-               done  bool
-               racer int
+               v         V
+               ok        bool
+               done      bool
+               yieldNext bool
+               racer     int
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
@@ -61,6 +62,10 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                        if done {
                                return false
                        }
+                       if !yieldNext {
+                               panic("iter.Pull: yield called again before next")
+                       }
+                       yieldNext = false
                        v, ok = v1, true
                        race.Release(unsafe.Pointer(&racer))
                        coroswitch(c)
@@ -78,6 +83,10 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                if done {
                        return
                }
+               if yieldNext {
+                       panic("iter.Pull: next called again before yield")
+               }
+               yieldNext = true
                race.Release(unsafe.Pointer(&racer))
                coroswitch(c)
                race.Acquire(unsafe.Pointer(&racer))
@@ -116,11 +125,12 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
 // simultaneously.
 func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
        var (
-               k     K
-               v     V
-               ok    bool
-               done  bool
-               racer int
+               k         K
+               v         V
+               ok        bool
+               done      bool
+               yieldNext bool
+               racer     int
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
@@ -128,6 +138,10 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                        if done {
                                return false
                        }
+                       if !yieldNext {
+                               panic("iter.Pull2: yield called again before next")
+                       }
+                       yieldNext = false
                        k, v, ok = k1, v1, true
                        race.Release(unsafe.Pointer(&racer))
                        coroswitch(c)
@@ -146,6 +160,10 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                if done {
                        return
                }
+               if yieldNext {
+                       panic("iter.Pull2: next called again before yield")
+               }
+               yieldNext = true
                race.Release(unsafe.Pointer(&racer))
                coroswitch(c)
                race.Acquire(unsafe.Pointer(&racer))
index 4a9510a8045f5ab791c21236d0a0c2bf90ca345b..21db1029af1cbb693d305529cf7be0f75cf13286 100644 (file)
@@ -114,3 +114,97 @@ func TestPull2(t *testing.T) {
                })
        }
 }
+
+func TestPullDoubleNext(t *testing.T) {
+       next, _ := Pull(doDoubleNext())
+       nextSlot = next
+       next()
+       if nextSlot != nil {
+               t.Fatal("double next did not fail")
+       }
+}
+
+var nextSlot func() (int, bool)
+
+func doDoubleNext() Seq[int] {
+       return func(_ func(int) bool) {
+               defer func() {
+                       if recover() != nil {
+                               nextSlot = nil
+                       }
+               }()
+               nextSlot()
+       }
+}
+
+func TestPullDoubleNext2(t *testing.T) {
+       next, _ := Pull2(doDoubleNext2())
+       nextSlot2 = next
+       next()
+       if nextSlot2 != nil {
+               t.Fatal("double next did not fail")
+       }
+}
+
+var nextSlot2 func() (int, int, bool)
+
+func doDoubleNext2() Seq2[int, int] {
+       return func(_ func(int, int) bool) {
+               defer func() {
+                       if recover() != nil {
+                               nextSlot2 = nil
+                       }
+               }()
+               nextSlot2()
+       }
+}
+
+func TestPullDoubleYield(t *testing.T) {
+       _, stop := Pull(storeYield())
+       defer func() {
+               if recover() != nil {
+                       yieldSlot = nil
+               }
+               stop()
+       }()
+       yieldSlot(5)
+       if yieldSlot != nil {
+               t.Fatal("double yield did not fail")
+       }
+}
+
+func storeYield() Seq[int] {
+       return func(yield func(int) bool) {
+               yieldSlot = yield
+               if !yield(5) {
+                       return
+               }
+       }
+}
+
+var yieldSlot func(int) bool
+
+func TestPullDoubleYield2(t *testing.T) {
+       _, stop := Pull2(storeYield2())
+       defer func() {
+               if recover() != nil {
+                       yieldSlot2 = nil
+               }
+               stop()
+       }()
+       yieldSlot2(23, 77)
+       if yieldSlot2 != nil {
+               t.Fatal("double yield did not fail")
+       }
+}
+
+func storeYield2() Seq2[int, int] {
+       return func(yield func(int, int) bool) {
+               yieldSlot2 = yield
+               if !yield(23, 77) {
+                       return
+               }
+       }
+}
+
+var yieldSlot2 func(int, int) bool