From 9d2aeae72d34880510c7221f35ab61171cec1ffd Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Fri, 31 May 2024 20:22:32 +0000 Subject: [PATCH] iter: propagate runtime.Goexit from iterator passed to Pull This change propagates a runtime.Goexit initiated by the iterator into the caller of next and/or stop. Fixes #67712. Change-Id: I5bb8d22f749fce39ce4f587148c5fc71aee2af65 Reviewed-on: https://go-review.googlesource.com/c/go/+/589137 LUCI-TryBot-Result: Go LUCI Reviewed-by: Austin Clements Reviewed-by: David Chase --- src/iter/iter.go | 55 ++++++++++++++++++++------ src/iter/pull_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++ src/runtime/coro.go | 2 +- 3 files changed, 133 insertions(+), 13 deletions(-) diff --git a/src/iter/iter.go b/src/iter/iter.go index 3e93f3bdb7..2ce129bb49 100644 --- a/src/iter/iter.go +++ b/src/iter/iter.go @@ -8,6 +8,7 @@ package iter import ( "internal/race" + "runtime" "unsafe" ) @@ -56,6 +57,7 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { yieldNext bool racer int panicValue any + seqDone bool // to detect Goexit ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -76,15 +78,17 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { // Recover and propagate panics from seq. defer func() { if p := recover(); p != nil { - done = true // Invalidate iterator. panicValue = p + } else if !seqDone { + panicValue = goexitPanicValue } + done = true // Invalidate iterator race.Release(unsafe.Pointer(&racer)) }() seq(yield) var v0 V v, ok = v0, false - done = true + seqDone = true }) next = func() (v1 V, ok1 bool) { race.Write(unsafe.Pointer(&racer)) // detect races @@ -100,9 +104,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } return v, ok } @@ -115,9 +124,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } } } @@ -152,6 +166,7 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { yieldNext bool racer int panicValue any + seqDone bool ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -172,16 +187,18 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { // Recover and propagate panics from seq. defer func() { if p := recover(); p != nil { - done = true // Invalidate iterator. panicValue = p + } else if !seqDone { + panicValue = goexitPanicValue } + done = true // Invalidate iterator. race.Release(unsafe.Pointer(&racer)) }() seq(yield) var k0 K var v0 V k, v, ok = k0, v0, false - done = true + seqDone = true }) next = func() (k1 K, v1 V, ok1 bool) { race.Write(unsafe.Pointer(&racer)) // detect races @@ -197,9 +214,14 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } return k, v, ok } @@ -212,11 +234,20 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } } } return next, stop } + +// goexitPanicValue is a sentinel value indicating that an iterator +// exited via runtime.Goexit. +var goexitPanicValue any = new(int) diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go index 09f2270fa1..0d3f5ab26b 100644 --- a/src/iter/pull_test.go +++ b/src/iter/pull_test.go @@ -320,3 +320,92 @@ func panicsWith(v any, f func()) (panicked bool) { f() return } + +func TestPullGoexit(t *testing.T) { + t.Run("next", func(t *testing.T) { + var next func() (int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull(goexitSeq()) + next() + }) { + t.Fatal("failed to Goexit from next") + } + if x, ok := next(); x != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) + t.Run("stop", func(t *testing.T) { + var next func() (int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull(goexitSeq()) + stop() + }) { + t.Fatal("failed to Goexit from stop") + } + if x, ok := next(); x != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) +} + +func goexitSeq() Seq[int] { + return func(yield func(int) bool) { + runtime.Goexit() + } +} + +func TestPull2Goexit(t *testing.T) { + t.Run("next", func(t *testing.T) { + var next func() (int, int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull2(goexitSeq2()) + next() + }) { + t.Fatal("failed to Goexit from next") + } + if x, y, ok := next(); x != 0 || y != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) + t.Run("stop", func(t *testing.T) { + var next func() (int, int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull2(goexitSeq2()) + stop() + }) { + t.Fatal("failed to Goexit from stop") + } + if x, y, ok := next(); x != 0 || y != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) +} + +func goexitSeq2() Seq2[int, int] { + return func(yield func(int, int) bool) { + runtime.Goexit() + } +} + +func goexits(t *testing.T, f func()) bool { + t.Helper() + + exit := make(chan bool) + go func() { + cleanExit := false + defer func() { + exit <- recover() == nil && !cleanExit + }() + f() + cleanExit = true + }() + return <-exit +} diff --git a/src/runtime/coro.go b/src/runtime/coro.go index 3d39d13493..30ada455e4 100644 --- a/src/runtime/coro.go +++ b/src/runtime/coro.go @@ -68,8 +68,8 @@ func corostart() { c := gp.coroarg gp.coroarg = nil + defer coroexit(c) c.f(c) - coroexit(c) } // coroexit is like coroswitch but closes the coro -- 2.48.1