]> Cypherpunks repositories - gostls13.git/commitdiff
iter: don't iterate if stop is called before next on Pull
authorMichael Anthony Knyszek <mknyszek@google.com>
Mon, 3 Jun 2024 21:00:51 +0000 (21:00 +0000)
committerMichael Knyszek <mknyszek@google.com>
Fri, 7 Jun 2024 19:09:28 +0000 (19:09 +0000)
Consider the following code snippet:

    next, stop := iter.Pull(seq)
    stop()

Today, seq will iterate exactly once before it notices that its
iteration is invalid to begin with. This effect is observable in a
variety of ways. For example, if the iterator panics, since that panic
must propagate to the caller of stop. But if the iterator is stateful in
anyway, then it may update some state.

This is somewhat unexpected and because it's observable, can be depended
upon. This behavior does not align well with other possible
implementations of Pull, like CPS performed by the compiler. It's also
just odd to let even one iteration happen, precisely because of
unexpected state modification.

Fix this by not iterating at all of the done flag is set before entering
the iterator.

For #67712.

Change-Id: I18162e29df45a2e8968f68379450d92e1de47c4d
Reviewed-on: https://go-review.googlesource.com/c/go/+/590075
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Austin Clements <austin@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/iter/iter.go
src/iter/pull_test.go

index 2ce129bb49e0a42ce8eaeafcc0fd6e63e2203057..30f65f7e4895424f5d33502bfdca24d8fc72ee4a 100644 (file)
@@ -61,6 +61,10 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
+               if done {
+                       race.Release(unsafe.Pointer(&racer))
+                       return
+               }
                yield := func(v1 V) bool {
                        if done {
                                return false
@@ -170,6 +174,10 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
+               if done {
+                       race.Release(unsafe.Pointer(&racer))
+                       return
+               }
                yield := func(k1 K, v1 V) bool {
                        if done {
                                return false
index 0d3f5ab26b9eb5def9989fc2e4dfdaae8bf657a2..449edee031dcd7c43354faf92892cd7ea53bdb6e 100644 (file)
@@ -256,9 +256,13 @@ func TestPullPanic(t *testing.T) {
                stop()
        })
        t.Run("stop", func(t *testing.T) {
-               next, stop := Pull(panicSeq())
+               next, stop := Pull(panicCleanupSeq())
+               x, ok := next()
+               if !ok || x != 55 {
+                       t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok)
+               }
                if !panicsWith("boom", func() { stop() }) {
-                       t.Fatal("failed to propagate panic on first stop")
+                       t.Fatal("failed to propagate panic on stop")
                }
                // Make sure we don't panic again if we try to call next or stop.
                if _, ok := next(); ok {
@@ -275,6 +279,16 @@ func panicSeq() Seq[int] {
        }
 }
 
+func panicCleanupSeq() Seq[int] {
+       return func(yield func(int) bool) {
+               for {
+                       if !yield(55) {
+                               panic("boom")
+                       }
+               }
+       }
+}
+
 func TestPull2Panic(t *testing.T) {
        t.Run("next", func(t *testing.T) {
                next, stop := Pull2(panicSeq2())
@@ -289,9 +303,13 @@ func TestPull2Panic(t *testing.T) {
                stop()
        })
        t.Run("stop", func(t *testing.T) {
-               next, stop := Pull2(panicSeq2())
+               next, stop := Pull2(panicCleanupSeq2())
+               x, y, ok := next()
+               if !ok || x != 55 || y != 100 {
+                       t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok)
+               }
                if !panicsWith("boom", func() { stop() }) {
-                       t.Fatal("failed to propagate panic on first stop")
+                       t.Fatal("failed to propagate panic on stop")
                }
                // Make sure we don't panic again if we try to call next or stop.
                if _, _, ok := next(); ok {
@@ -308,6 +326,16 @@ func panicSeq2() Seq2[int, int] {
        }
 }
 
+func panicCleanupSeq2() Seq2[int, int] {
+       return func(yield func(int, int) bool) {
+               for {
+                       if !yield(55, 100) {
+                               panic("boom")
+                       }
+               }
+       }
+}
+
 func panicsWith(v any, f func()) (panicked bool) {
        defer func() {
                if r := recover(); r != nil {
@@ -332,22 +360,26 @@ func TestPullGoexit(t *testing.T) {
                        t.Fatal("failed to Goexit from next")
                }
                if x, ok := next(); x != 0 || ok {
-                       t.Fatal("iterator returned valid value after Goexit")
+                       t.Fatal("iterator returned valid value after iterator Goexited")
                }
                stop()
        })
        t.Run("stop", func(t *testing.T) {
-               var next func() (int, bool)
-               var stop func()
+               next, stop := Pull(goexitCleanupSeq())
+               x, ok := next()
+               if !ok || x != 55 {
+                       t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok)
+               }
                if !goexits(t, func() {
-                       next, stop = Pull(goexitSeq())
                        stop()
                }) {
                        t.Fatal("failed to Goexit from stop")
                }
+               // Make sure we don't panic again if we try to call next or stop.
                if x, ok := next(); x != 0 || ok {
-                       t.Fatal("iterator returned valid value after Goexit")
+                       t.Fatal("next returned true or non-zero value after iterator Goexited")
                }
+               // Calling stop again should be a no-op.
                stop()
        })
 }
@@ -358,6 +390,16 @@ func goexitSeq() Seq[int] {
        }
 }
 
+func goexitCleanupSeq() Seq[int] {
+       return func(yield func(int) bool) {
+               for {
+                       if !yield(55) {
+                               runtime.Goexit()
+                       }
+               }
+       }
+}
+
 func TestPull2Goexit(t *testing.T) {
        t.Run("next", func(t *testing.T) {
                var next func() (int, int, bool)
@@ -369,22 +411,26 @@ func TestPull2Goexit(t *testing.T) {
                        t.Fatal("failed to Goexit from next")
                }
                if x, y, ok := next(); x != 0 || y != 0 || ok {
-                       t.Fatal("iterator returned valid value after Goexit")
+                       t.Fatal("iterator returned valid value after iterator Goexited")
                }
                stop()
        })
        t.Run("stop", func(t *testing.T) {
-               var next func() (int, int, bool)
-               var stop func()
+               next, stop := Pull2(goexitCleanupSeq2())
+               x, y, ok := next()
+               if !ok || x != 55 || y != 100 {
+                       t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok)
+               }
                if !goexits(t, func() {
-                       next, stop = Pull2(goexitSeq2())
                        stop()
                }) {
                        t.Fatal("failed to Goexit from stop")
                }
+               // Make sure we don't panic again if we try to call next or stop.
                if x, y, ok := next(); x != 0 || y != 0 || ok {
-                       t.Fatal("iterator returned valid value after Goexit")
+                       t.Fatal("next returned true or non-zero after iterator Goexited")
                }
+               // Calling stop again should be a no-op.
                stop()
        })
 }
@@ -395,6 +441,16 @@ func goexitSeq2() Seq2[int, int] {
        }
 }
 
+func goexitCleanupSeq2() Seq2[int, int] {
+       return func(yield func(int, int) bool) {
+               for {
+                       if !yield(55, 100) {
+                               runtime.Goexit()
+                       }
+               }
+       }
+}
+
 func goexits(t *testing.T, f func()) bool {
        t.Helper()
 
@@ -409,3 +465,21 @@ func goexits(t *testing.T, f func()) bool {
        }()
        return <-exit
 }
+
+func TestPullImmediateStop(t *testing.T) {
+       next, stop := Pull(panicSeq())
+       stop()
+       // Make sure we don't panic if we try to call next or stop.
+       if _, ok := next(); ok {
+               t.Fatal("next returned true after iterator was stopped")
+       }
+}
+
+func TestPull2ImmediateStop(t *testing.T) {
+       next, stop := Pull2(panicSeq2())
+       stop()
+       // Make sure we don't panic if we try to call next or stop.
+       if _, _, ok := next(); ok {
+               t.Fatal("next returned true after iterator was stopped")
+       }
+}