]> Cypherpunks repositories - gostls13.git/commitdiff
iter: propagate panics from the iterator passed to Pull
authorMichael Anthony Knyszek <mknyszek@google.com>
Fri, 31 May 2024 20:10:09 +0000 (20:10 +0000)
committerMichael Knyszek <mknyszek@google.com>
Fri, 7 Jun 2024 19:09:09 +0000 (19:09 +0000)
This change propagates panics from the iterator passed to Pull through
next and stop. Once the panic occurs, next and stop become no-ops (the
iterator is invalidated).

For #67712.

Change-Id: I05e45601d4d10acdf51b53e3164bd891c1b324ac
Reviewed-on: https://go-review.googlesource.com/c/go/+/589136
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Austin Clements <austin@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/iter/iter.go
src/iter/pull_test.go

index af360aeb079de4d5d53319bc3753895bb86492c5..3e93f3bdb7bf1b5d1ae00098feaf22fc946ef7c5 100644 (file)
@@ -50,11 +50,12 @@ func coroswitch(*coro)
 // simultaneously.
 func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
        var (
-               v         V
-               ok        bool
-               done      bool
-               yieldNext bool
-               racer     int
+               v          V
+               ok         bool
+               done       bool
+               yieldNext  bool
+               racer      int
+               panicValue any
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
@@ -72,14 +73,22 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                        race.Acquire(unsafe.Pointer(&racer))
                        return !done
                }
+               // Recover and propagate panics from seq.
+               defer func() {
+                       if p := recover(); p != nil {
+                               done = true // Invalidate iterator.
+                               panicValue = p
+                       }
+                       race.Release(unsafe.Pointer(&racer))
+               }()
                seq(yield)
                var v0 V
                v, ok = v0, false
                done = true
-               race.Release(unsafe.Pointer(&racer))
        })
        next = func() (v1 V, ok1 bool) {
                race.Write(unsafe.Pointer(&racer)) // detect races
+
                if done {
                        return
                }
@@ -90,15 +99,26 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                race.Release(unsafe.Pointer(&racer))
                coroswitch(c)
                race.Acquire(unsafe.Pointer(&racer))
+
+               // Propagate panics from seq.
+               if panicValue != nil {
+                       panic(panicValue)
+               }
                return v, ok
        }
        stop = func() {
                race.Write(unsafe.Pointer(&racer)) // detect races
+
                if !done {
                        done = true
                        race.Release(unsafe.Pointer(&racer))
                        coroswitch(c)
                        race.Acquire(unsafe.Pointer(&racer))
+
+                       // Propagate panics from seq.
+                       if panicValue != nil {
+                               panic(panicValue)
+                       }
                }
        }
        return next, stop
@@ -125,12 +145,13 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
 // simultaneously.
 func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
        var (
-               k         K
-               v         V
-               ok        bool
-               done      bool
-               yieldNext bool
-               racer     int
+               k          K
+               v          V
+               ok         bool
+               done       bool
+               yieldNext  bool
+               racer      int
+               panicValue any
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
@@ -148,15 +169,23 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                        race.Acquire(unsafe.Pointer(&racer))
                        return !done
                }
+               // Recover and propagate panics from seq.
+               defer func() {
+                       if p := recover(); p != nil {
+                               done = true // Invalidate iterator.
+                               panicValue = p
+                       }
+                       race.Release(unsafe.Pointer(&racer))
+               }()
                seq(yield)
                var k0 K
                var v0 V
                k, v, ok = k0, v0, false
                done = true
-               race.Release(unsafe.Pointer(&racer))
        })
        next = func() (k1 K, v1 V, ok1 bool) {
                race.Write(unsafe.Pointer(&racer)) // detect races
+
                if done {
                        return
                }
@@ -167,15 +196,26 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                race.Release(unsafe.Pointer(&racer))
                coroswitch(c)
                race.Acquire(unsafe.Pointer(&racer))
+
+               // Propagate panics from seq.
+               if panicValue != nil {
+                       panic(panicValue)
+               }
                return k, v, ok
        }
        stop = func() {
                race.Write(unsafe.Pointer(&racer)) // detect races
+
                if !done {
                        done = true
                        race.Release(unsafe.Pointer(&racer))
                        coroswitch(c)
                        race.Acquire(unsafe.Pointer(&racer))
+
+                       // Propagate panics from seq.
+                       if panicValue != nil {
+                               panic(panicValue)
+                       }
                }
        }
        return next, stop
index c39574b959608b1bb082a1bc8519748c1013a5f4..09f2270fa1d5f73958c692792bde2de26be0bb4b 100644 (file)
@@ -241,3 +241,82 @@ func storeYield2() Seq2[int, int] {
 }
 
 var yieldSlot2 func(int, int) bool
+
+func TestPullPanic(t *testing.T) {
+       t.Run("next", func(t *testing.T) {
+               next, stop := Pull(panicSeq())
+               if !panicsWith("boom", func() { next() }) {
+                       t.Fatal("failed to propagate panic on first next")
+               }
+               // Make sure we don't panic again if we try to call next or stop.
+               if _, ok := next(); ok {
+                       t.Fatal("next returned true after iterator panicked")
+               }
+               // Calling stop again should be a no-op.
+               stop()
+       })
+       t.Run("stop", func(t *testing.T) {
+               next, stop := Pull(panicSeq())
+               if !panicsWith("boom", func() { stop() }) {
+                       t.Fatal("failed to propagate panic on first stop")
+               }
+               // Make sure we don't panic again if we try to call next or stop.
+               if _, ok := next(); ok {
+                       t.Fatal("next returned true after iterator panicked")
+               }
+               // Calling stop again should be a no-op.
+               stop()
+       })
+}
+
+func panicSeq() Seq[int] {
+       return func(yield func(int) bool) {
+               panic("boom")
+       }
+}
+
+func TestPull2Panic(t *testing.T) {
+       t.Run("next", func(t *testing.T) {
+               next, stop := Pull2(panicSeq2())
+               if !panicsWith("boom", func() { next() }) {
+                       t.Fatal("failed to propagate panic on first next")
+               }
+               // Make sure we don't panic again if we try to call next or stop.
+               if _, _, ok := next(); ok {
+                       t.Fatal("next returned true after iterator panicked")
+               }
+               // Calling stop again should be a no-op.
+               stop()
+       })
+       t.Run("stop", func(t *testing.T) {
+               next, stop := Pull2(panicSeq2())
+               if !panicsWith("boom", func() { stop() }) {
+                       t.Fatal("failed to propagate panic on first stop")
+               }
+               // Make sure we don't panic again if we try to call next or stop.
+               if _, _, ok := next(); ok {
+                       t.Fatal("next returned true after iterator panicked")
+               }
+               // Calling stop again should be a no-op.
+               stop()
+       })
+}
+
+func panicSeq2() Seq2[int, int] {
+       return func(yield func(int, int) bool) {
+               panic("boom")
+       }
+}
+
+func panicsWith(v any, f func()) (panicked bool) {
+       defer func() {
+               if r := recover(); r != nil {
+                       if r != v {
+                               panic(r)
+                       }
+                       panicked = true
+               }
+       }()
+       f()
+       return
+}