From: Michael Anthony Knyszek Date: Mon, 3 Jun 2024 21:00:51 +0000 (+0000) Subject: iter: don't iterate if stop is called before next on Pull X-Git-Tag: go1.23rc1~59 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=1634fde4f918223614fd8893db8dd7ca4ebcda01;p=gostls13.git iter: don't iterate if stop is called before next on Pull Consider the following code snippet: next, stop := iter.Pull(seq) stop() Today, seq will iterate exactly once before it notices that its iteration is invalid to begin with. This effect is observable in a variety of ways. For example, if the iterator panics, since that panic must propagate to the caller of stop. But if the iterator is stateful in anyway, then it may update some state. This is somewhat unexpected and because it's observable, can be depended upon. This behavior does not align well with other possible implementations of Pull, like CPS performed by the compiler. It's also just odd to let even one iteration happen, precisely because of unexpected state modification. Fix this by not iterating at all of the done flag is set before entering the iterator. For #67712. Change-Id: I18162e29df45a2e8968f68379450d92e1de47c4d Reviewed-on: https://go-review.googlesource.com/c/go/+/590075 Reviewed-by: David Chase Reviewed-by: Austin Clements LUCI-TryBot-Result: Go LUCI --- diff --git a/src/iter/iter.go b/src/iter/iter.go index 2ce129bb49..30f65f7e48 100644 --- a/src/iter/iter.go +++ b/src/iter/iter.go @@ -61,6 +61,10 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) + if done { + race.Release(unsafe.Pointer(&racer)) + return + } yield := func(v1 V) bool { if done { return false @@ -170,6 +174,10 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) + if done { + race.Release(unsafe.Pointer(&racer)) + return + } yield := func(k1 K, v1 V) bool { if done { return false diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go index 0d3f5ab26b..449edee031 100644 --- a/src/iter/pull_test.go +++ b/src/iter/pull_test.go @@ -256,9 +256,13 @@ func TestPullPanic(t *testing.T) { stop() }) t.Run("stop", func(t *testing.T) { - next, stop := Pull(panicSeq()) + next, stop := Pull(panicCleanupSeq()) + x, ok := next() + if !ok || x != 55 { + t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok) + } if !panicsWith("boom", func() { stop() }) { - t.Fatal("failed to propagate panic on first stop") + t.Fatal("failed to propagate panic on stop") } // Make sure we don't panic again if we try to call next or stop. if _, ok := next(); ok { @@ -275,6 +279,16 @@ func panicSeq() Seq[int] { } } +func panicCleanupSeq() Seq[int] { + return func(yield func(int) bool) { + for { + if !yield(55) { + panic("boom") + } + } + } +} + func TestPull2Panic(t *testing.T) { t.Run("next", func(t *testing.T) { next, stop := Pull2(panicSeq2()) @@ -289,9 +303,13 @@ func TestPull2Panic(t *testing.T) { stop() }) t.Run("stop", func(t *testing.T) { - next, stop := Pull2(panicSeq2()) + next, stop := Pull2(panicCleanupSeq2()) + x, y, ok := next() + if !ok || x != 55 || y != 100 { + t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok) + } if !panicsWith("boom", func() { stop() }) { - t.Fatal("failed to propagate panic on first stop") + t.Fatal("failed to propagate panic on stop") } // Make sure we don't panic again if we try to call next or stop. if _, _, ok := next(); ok { @@ -308,6 +326,16 @@ func panicSeq2() Seq2[int, int] { } } +func panicCleanupSeq2() Seq2[int, int] { + return func(yield func(int, int) bool) { + for { + if !yield(55, 100) { + panic("boom") + } + } + } +} + func panicsWith(v any, f func()) (panicked bool) { defer func() { if r := recover(); r != nil { @@ -332,22 +360,26 @@ func TestPullGoexit(t *testing.T) { t.Fatal("failed to Goexit from next") } if x, ok := next(); x != 0 || ok { - t.Fatal("iterator returned valid value after Goexit") + t.Fatal("iterator returned valid value after iterator Goexited") } stop() }) t.Run("stop", func(t *testing.T) { - var next func() (int, bool) - var stop func() + next, stop := Pull(goexitCleanupSeq()) + x, ok := next() + if !ok || x != 55 { + t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok) + } if !goexits(t, func() { - next, stop = Pull(goexitSeq()) stop() }) { t.Fatal("failed to Goexit from stop") } + // Make sure we don't panic again if we try to call next or stop. if x, ok := next(); x != 0 || ok { - t.Fatal("iterator returned valid value after Goexit") + t.Fatal("next returned true or non-zero value after iterator Goexited") } + // Calling stop again should be a no-op. stop() }) } @@ -358,6 +390,16 @@ func goexitSeq() Seq[int] { } } +func goexitCleanupSeq() Seq[int] { + return func(yield func(int) bool) { + for { + if !yield(55) { + runtime.Goexit() + } + } + } +} + func TestPull2Goexit(t *testing.T) { t.Run("next", func(t *testing.T) { var next func() (int, int, bool) @@ -369,22 +411,26 @@ func TestPull2Goexit(t *testing.T) { t.Fatal("failed to Goexit from next") } if x, y, ok := next(); x != 0 || y != 0 || ok { - t.Fatal("iterator returned valid value after Goexit") + t.Fatal("iterator returned valid value after iterator Goexited") } stop() }) t.Run("stop", func(t *testing.T) { - var next func() (int, int, bool) - var stop func() + next, stop := Pull2(goexitCleanupSeq2()) + x, y, ok := next() + if !ok || x != 55 || y != 100 { + t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok) + } if !goexits(t, func() { - next, stop = Pull2(goexitSeq2()) stop() }) { t.Fatal("failed to Goexit from stop") } + // Make sure we don't panic again if we try to call next or stop. if x, y, ok := next(); x != 0 || y != 0 || ok { - t.Fatal("iterator returned valid value after Goexit") + t.Fatal("next returned true or non-zero after iterator Goexited") } + // Calling stop again should be a no-op. stop() }) } @@ -395,6 +441,16 @@ func goexitSeq2() Seq2[int, int] { } } +func goexitCleanupSeq2() Seq2[int, int] { + return func(yield func(int, int) bool) { + for { + if !yield(55, 100) { + runtime.Goexit() + } + } + } +} + func goexits(t *testing.T, f func()) bool { t.Helper() @@ -409,3 +465,21 @@ func goexits(t *testing.T, f func()) bool { }() return <-exit } + +func TestPullImmediateStop(t *testing.T) { + next, stop := Pull(panicSeq()) + stop() + // Make sure we don't panic if we try to call next or stop. + if _, ok := next(); ok { + t.Fatal("next returned true after iterator was stopped") + } +} + +func TestPull2ImmediateStop(t *testing.T) { + next, stop := Pull2(panicSeq2()) + stop() + // Make sure we don't panic if we try to call next or stop. + if _, _, ok := next(); ok { + t.Fatal("next returned true after iterator was stopped") + } +}