]> Cypherpunks repositories - gostls13.git/commitdiff
iter: propagate runtime.Goexit from iterator passed to Pull
authorMichael Anthony Knyszek <mknyszek@google.com>
Fri, 31 May 2024 20:22:32 +0000 (20:22 +0000)
committerMichael Knyszek <mknyszek@google.com>
Fri, 7 Jun 2024 19:09:18 +0000 (19:09 +0000)
This change propagates a runtime.Goexit initiated by the iterator into
the caller of next and/or stop.

Fixes #67712.

Change-Id: I5bb8d22f749fce39ce4f587148c5fc71aee2af65
Reviewed-on: https://go-review.googlesource.com/c/go/+/589137
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Austin Clements <austin@google.com>
Reviewed-by: David Chase <drchase@google.com>
src/iter/iter.go
src/iter/pull_test.go
src/runtime/coro.go

index 3e93f3bdb7bf1b5d1ae00098feaf22fc946ef7c5..2ce129bb49e0a42ce8eaeafcc0fd6e63e2203057 100644 (file)
@@ -8,6 +8,7 @@ package iter
 
 import (
        "internal/race"
+       "runtime"
        "unsafe"
 )
 
@@ -56,6 +57,7 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                yieldNext  bool
                racer      int
                panicValue any
+               seqDone    bool // to detect Goexit
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
@@ -76,15 +78,17 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                // Recover and propagate panics from seq.
                defer func() {
                        if p := recover(); p != nil {
-                               done = true // Invalidate iterator.
                                panicValue = p
+                       } else if !seqDone {
+                               panicValue = goexitPanicValue
                        }
+                       done = true // Invalidate iterator
                        race.Release(unsafe.Pointer(&racer))
                }()
                seq(yield)
                var v0 V
                v, ok = v0, false
-               done = true
+               seqDone = true
        })
        next = func() (v1 V, ok1 bool) {
                race.Write(unsafe.Pointer(&racer)) // detect races
@@ -100,9 +104,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                coroswitch(c)
                race.Acquire(unsafe.Pointer(&racer))
 
-               // Propagate panics from seq.
+               // Propagate panics and goexits from seq.
                if panicValue != nil {
-                       panic(panicValue)
+                       if panicValue == goexitPanicValue {
+                               // Propagate runtime.Goexit from seq.
+                               runtime.Goexit()
+                       } else {
+                               panic(panicValue)
+                       }
                }
                return v, ok
        }
@@ -115,9 +124,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
                        coroswitch(c)
                        race.Acquire(unsafe.Pointer(&racer))
 
-                       // Propagate panics from seq.
+                       // Propagate panics and goexits from seq.
                        if panicValue != nil {
-                               panic(panicValue)
+                               if panicValue == goexitPanicValue {
+                                       // Propagate runtime.Goexit from seq.
+                                       runtime.Goexit()
+                               } else {
+                                       panic(panicValue)
+                               }
                        }
                }
        }
@@ -152,6 +166,7 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                yieldNext  bool
                racer      int
                panicValue any
+               seqDone    bool
        )
        c := newcoro(func(c *coro) {
                race.Acquire(unsafe.Pointer(&racer))
@@ -172,16 +187,18 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                // Recover and propagate panics from seq.
                defer func() {
                        if p := recover(); p != nil {
-                               done = true // Invalidate iterator.
                                panicValue = p
+                       } else if !seqDone {
+                               panicValue = goexitPanicValue
                        }
+                       done = true // Invalidate iterator.
                        race.Release(unsafe.Pointer(&racer))
                }()
                seq(yield)
                var k0 K
                var v0 V
                k, v, ok = k0, v0, false
-               done = true
+               seqDone = true
        })
        next = func() (k1 K, v1 V, ok1 bool) {
                race.Write(unsafe.Pointer(&racer)) // detect races
@@ -197,9 +214,14 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                coroswitch(c)
                race.Acquire(unsafe.Pointer(&racer))
 
-               // Propagate panics from seq.
+               // Propagate panics and goexits from seq.
                if panicValue != nil {
-                       panic(panicValue)
+                       if panicValue == goexitPanicValue {
+                               // Propagate runtime.Goexit from seq.
+                               runtime.Goexit()
+                       } else {
+                               panic(panicValue)
+                       }
                }
                return k, v, ok
        }
@@ -212,11 +234,20 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
                        coroswitch(c)
                        race.Acquire(unsafe.Pointer(&racer))
 
-                       // Propagate panics from seq.
+                       // Propagate panics and goexits from seq.
                        if panicValue != nil {
-                               panic(panicValue)
+                               if panicValue == goexitPanicValue {
+                                       // Propagate runtime.Goexit from seq.
+                                       runtime.Goexit()
+                               } else {
+                                       panic(panicValue)
+                               }
                        }
                }
        }
        return next, stop
 }
+
+// goexitPanicValue is a sentinel value indicating that an iterator
+// exited via runtime.Goexit.
+var goexitPanicValue any = new(int)
index 09f2270fa1d5f73958c692792bde2de26be0bb4b..0d3f5ab26b9eb5def9989fc2e4dfdaae8bf657a2 100644 (file)
@@ -320,3 +320,92 @@ func panicsWith(v any, f func()) (panicked bool) {
        f()
        return
 }
+
+func TestPullGoexit(t *testing.T) {
+       t.Run("next", func(t *testing.T) {
+               var next func() (int, bool)
+               var stop func()
+               if !goexits(t, func() {
+                       next, stop = Pull(goexitSeq())
+                       next()
+               }) {
+                       t.Fatal("failed to Goexit from next")
+               }
+               if x, ok := next(); x != 0 || ok {
+                       t.Fatal("iterator returned valid value after Goexit")
+               }
+               stop()
+       })
+       t.Run("stop", func(t *testing.T) {
+               var next func() (int, bool)
+               var stop func()
+               if !goexits(t, func() {
+                       next, stop = Pull(goexitSeq())
+                       stop()
+               }) {
+                       t.Fatal("failed to Goexit from stop")
+               }
+               if x, ok := next(); x != 0 || ok {
+                       t.Fatal("iterator returned valid value after Goexit")
+               }
+               stop()
+       })
+}
+
+func goexitSeq() Seq[int] {
+       return func(yield func(int) bool) {
+               runtime.Goexit()
+       }
+}
+
+func TestPull2Goexit(t *testing.T) {
+       t.Run("next", func(t *testing.T) {
+               var next func() (int, int, bool)
+               var stop func()
+               if !goexits(t, func() {
+                       next, stop = Pull2(goexitSeq2())
+                       next()
+               }) {
+                       t.Fatal("failed to Goexit from next")
+               }
+               if x, y, ok := next(); x != 0 || y != 0 || ok {
+                       t.Fatal("iterator returned valid value after Goexit")
+               }
+               stop()
+       })
+       t.Run("stop", func(t *testing.T) {
+               var next func() (int, int, bool)
+               var stop func()
+               if !goexits(t, func() {
+                       next, stop = Pull2(goexitSeq2())
+                       stop()
+               }) {
+                       t.Fatal("failed to Goexit from stop")
+               }
+               if x, y, ok := next(); x != 0 || y != 0 || ok {
+                       t.Fatal("iterator returned valid value after Goexit")
+               }
+               stop()
+       })
+}
+
+func goexitSeq2() Seq2[int, int] {
+       return func(yield func(int, int) bool) {
+               runtime.Goexit()
+       }
+}
+
+func goexits(t *testing.T, f func()) bool {
+       t.Helper()
+
+       exit := make(chan bool)
+       go func() {
+               cleanExit := false
+               defer func() {
+                       exit <- recover() == nil && !cleanExit
+               }()
+               f()
+               cleanExit = true
+       }()
+       return <-exit
+}
index 3d39d134939c2d180a789d6c5ad5e7f600ba2c3d..30ada455e4985e1c92071aaa34bb3e890bbc2afc 100644 (file)
@@ -68,8 +68,8 @@ func corostart() {
        c := gp.coroarg
        gp.coroarg = nil
 
+       defer coroexit(c)
        c.f(c)
-       coroexit(c)
 }
 
 // coroexit is like coroswitch but closes the coro