From ff743ce862440f332f76a8a24333a90b7afc9fa6 Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Tue, 7 May 2024 00:49:03 +0000 Subject: [PATCH] iter: detect and reject double next and double yield in Pull, Pull2 Currently it's possible for next and yield to be called out of sequence, which will result in surprising behavior due to the implementation. Because we blindly coroswitch between goroutines, calling next from the iterator, or yield from the calling goroutine, will actually switch back to the other goroutine. In the case of next, we'll switch back with a stale (or zero) value: the results are basically garbage. In the case of yield, we're switching back to the *same* goroutine, which will crash in the runtime. This change adds a single bool to ensure that next and yield are always called in sequence. That is, every next must always be paired with a yield before continuing. This restricts what can be done with Pull, but prevents observing some truly strange behaviors that the user of Pull likely did not intend, or can't easily predict. Change-Id: I6f72461f49c5635d6914bc5b968ad6970cd3c734 Reviewed-on: https://go-review.googlesource.com/c/go/+/583676 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Auto-Submit: Michael Knyszek Reviewed-by: Russ Cox --- src/iter/iter.go | 36 ++++++++++++----- src/iter/pull_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 9 deletions(-) diff --git a/src/iter/iter.go b/src/iter/iter.go index 4d9cfad73b..af360aeb07 100644 --- a/src/iter/iter.go +++ b/src/iter/iter.go @@ -50,10 +50,11 @@ func coroswitch(*coro) // simultaneously. func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { var ( - v V - ok bool - done bool - racer int + v V + ok bool + done bool + yieldNext bool + racer int ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -61,6 +62,10 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { if done { return false } + if !yieldNext { + panic("iter.Pull: yield called again before next") + } + yieldNext = false v, ok = v1, true race.Release(unsafe.Pointer(&racer)) coroswitch(c) @@ -78,6 +83,10 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { if done { return } + if yieldNext { + panic("iter.Pull: next called again before yield") + } + yieldNext = true race.Release(unsafe.Pointer(&racer)) coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) @@ -116,11 +125,12 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { // simultaneously. func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { var ( - k K - v V - ok bool - done bool - racer int + k K + v V + ok bool + done bool + yieldNext bool + racer int ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -128,6 +138,10 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { if done { return false } + if !yieldNext { + panic("iter.Pull2: yield called again before next") + } + yieldNext = false k, v, ok = k1, v1, true race.Release(unsafe.Pointer(&racer)) coroswitch(c) @@ -146,6 +160,10 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { if done { return } + if yieldNext { + panic("iter.Pull2: next called again before yield") + } + yieldNext = true race.Release(unsafe.Pointer(&racer)) coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go index 4a9510a804..21db1029af 100644 --- a/src/iter/pull_test.go +++ b/src/iter/pull_test.go @@ -114,3 +114,97 @@ func TestPull2(t *testing.T) { }) } } + +func TestPullDoubleNext(t *testing.T) { + next, _ := Pull(doDoubleNext()) + nextSlot = next + next() + if nextSlot != nil { + t.Fatal("double next did not fail") + } +} + +var nextSlot func() (int, bool) + +func doDoubleNext() Seq[int] { + return func(_ func(int) bool) { + defer func() { + if recover() != nil { + nextSlot = nil + } + }() + nextSlot() + } +} + +func TestPullDoubleNext2(t *testing.T) { + next, _ := Pull2(doDoubleNext2()) + nextSlot2 = next + next() + if nextSlot2 != nil { + t.Fatal("double next did not fail") + } +} + +var nextSlot2 func() (int, int, bool) + +func doDoubleNext2() Seq2[int, int] { + return func(_ func(int, int) bool) { + defer func() { + if recover() != nil { + nextSlot2 = nil + } + }() + nextSlot2() + } +} + +func TestPullDoubleYield(t *testing.T) { + _, stop := Pull(storeYield()) + defer func() { + if recover() != nil { + yieldSlot = nil + } + stop() + }() + yieldSlot(5) + if yieldSlot != nil { + t.Fatal("double yield did not fail") + } +} + +func storeYield() Seq[int] { + return func(yield func(int) bool) { + yieldSlot = yield + if !yield(5) { + return + } + } +} + +var yieldSlot func(int) bool + +func TestPullDoubleYield2(t *testing.T) { + _, stop := Pull2(storeYield2()) + defer func() { + if recover() != nil { + yieldSlot2 = nil + } + stop() + }() + yieldSlot2(23, 77) + if yieldSlot2 != nil { + t.Fatal("double yield did not fail") + } +} + +func storeYield2() Seq2[int, int] { + return func(yield func(int, int) bool) { + yieldSlot2 = yield + if !yield(23, 77) { + return + } + } +} + +var yieldSlot2 func(int, int) bool -- 2.48.1