type coro struct {
gp guintptr
f func(*coro)
+
+ // State for validating thread-lock interactions.
+ mp *m
+ lockedExt uint32 // mp's external LockOSThread counter at coro creation time.
+ lockedInt uint32 // mp's internal lockOSThread counter at coro creation time.
}
//go:linkname newcoro
pc := getcallerpc()
gp := getg()
systemstack(func() {
+ mp := gp.m
start := corostart
startfv := *(**funcval)(unsafe.Pointer(&start))
gp = newproc1(startfv, gp, pc, true, waitReasonCoroutine)
+
+ // Scribble down locked thread state if needed and/or donate
+ // thread-lock state to the new goroutine.
+ if mp.lockedExt+mp.lockedInt != 0 {
+ c.mp = mp
+ c.lockedExt = mp.lockedExt
+ c.lockedInt = mp.lockedInt
+ }
})
gp.coroarg = c
c.gp.set(gp)
// It is important not to add more atomic operations or other
// expensive operations to the fast path.
func coroswitch_m(gp *g) {
- // TODO(go.dev/issue/65889): Something really nasty will happen if either
- // goroutine in this handoff tries to lock itself to an OS thread.
- // There's an explicit multiplexing going on here that needs to be
- // disabled if either the consumer or the iterator ends up in such
- // a state.
c := gp.coroarg
gp.coroarg = nil
exit := gp.coroexit
gp.coroexit = false
mp := gp.m
+ // Track and validate thread-lock interactions.
+ //
+ // The rules with thread-lock interactions are simple. When a coro goroutine is switched to,
+ // the same thread must be used, and the locked state must match with the thread-lock state of
+ // the goroutine which called newcoro. Thread-lock state consists of the thread and the number
+ // of internal (cgo callback, etc.) and external (LockOSThread) thread locks.
+ locked := gp.lockedm != 0
+ if c.mp != nil || locked {
+ if mp != c.mp || mp.lockedInt != c.lockedInt || mp.lockedExt != c.lockedExt {
+ print("coro: got thread ", unsafe.Pointer(mp), ", want ", unsafe.Pointer(c.mp), "\n")
+ print("coro: got lock internal ", mp.lockedInt, ", want ", c.lockedInt, "\n")
+ print("coro: got lock external ", mp.lockedExt, ", want ", c.lockedExt, "\n")
+ throw("coro: OS thread locking must match locking at coroutine creation")
+ }
+ }
+
// Acquire tracer for writing for the duration of this call.
//
// There's a lot of state manipulation performed with shortcuts
// emitting an event for every single transition.
trace := traceAcquire()
+ if locked {
+ // Detach the goroutine from the thread; we'll attach to the goroutine we're
+ // switching to before returning.
+ gp.lockedm.set(nil)
+ }
+
if exit {
- // TODO(65889): If we're locked to the current OS thread and
- // we exit here while tracing is enabled, we're going to end up
- // in a really bad place (traceAcquire also calls acquirem; there's
- // no releasem before the thread exits).
+ // The M might have a non-zero OS thread lock count when we get here, gdestroy
+ // will avoid destroying the M if the G isn't explicitly locked to it via lockedm,
+ // which we cleared above. It's fine to gdestroy here also, even when locked to
+ // the thread, because we'll be switching back to another goroutine anyway, which
+ // will take back its thread-lock state before returning.
gdestroy(gp)
gp = nil
} else {
}
}
+ // Check if we're switching to ourselves. This case is able to break our
+ // thread-lock invariants and an unbuffered channel implementation of
+ // coroswitch would deadlock. It's clear that this case should just not
+ // work.
+ if gnext == gp {
+ throw("coroswitch of a goroutine to itself")
+ }
+
// Emit the trace event after getting gnext but before changing curg.
// GoSwitch expects that the current G is running and that we haven't
// switched yet for correct status emission.
casgstatus(gnext, _Grunnable, _Grunning)
}
+ // Donate locked state.
+ if locked {
+ mp.lockedg.set(gnext)
+ gnext.lockedm.set(mp)
+ }
+
// Release the trace locker. We've completed all the necessary transitions..
if trace.ok() {
traceRelease(trace)
--- /dev/null
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package runtime_test
+
+import (
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func TestCoroLockOSThread(t *testing.T) {
+ for _, test := range []string{
+ "CoroLockOSThreadIterLock",
+ "CoroLockOSThreadIterLockYield",
+ "CoroLockOSThreadLock",
+ "CoroLockOSThreadLockIterNested",
+ "CoroLockOSThreadLockIterLock",
+ "CoroLockOSThreadLockIterLockYield",
+ "CoroLockOSThreadLockIterYieldNewG",
+ "CoroLockOSThreadLockAfterPull",
+ "CoroLockOSThreadStopLocked",
+ "CoroLockOSThreadStopLockedIterNested",
+ } {
+ t.Run(test, func(t *testing.T) {
+ checkCoroTestProgOutput(t, runTestProg(t, "testprog", test))
+ })
+ }
+}
+
+func TestCoroCgoCallback(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("coro cgo callback tests not supported on Windows")
+ }
+ for _, test := range []string{
+ "CoroCgoIterCallback",
+ "CoroCgoIterCallbackYield",
+ "CoroCgoCallback",
+ "CoroCgoCallbackIterNested",
+ "CoroCgoCallbackIterCallback",
+ "CoroCgoCallbackIterCallbackYield",
+ "CoroCgoCallbackAfterPull",
+ "CoroCgoStopCallback",
+ "CoroCgoStopCallbackIterNested",
+ } {
+ t.Run(test, func(t *testing.T) {
+ checkCoroTestProgOutput(t, runTestProg(t, "testprogcgo", test))
+ })
+ }
+}
+
+func checkCoroTestProgOutput(t *testing.T, output string) {
+ t.Helper()
+
+ c := strings.SplitN(output, "\n", 2)
+ if len(c) == 1 {
+ t.Fatalf("expected at least one complete line in the output, got:\n%s", output)
+ }
+ expect, ok := strings.CutPrefix(c[0], "expect: ")
+ if !ok {
+ t.Fatalf("expected first line of output to start with \"expect: \", got: %q", c[0])
+ }
+ rest := c[1]
+ if expect == "OK" && rest != "OK\n" {
+ t.Fatalf("expected just 'OK' in the output, got:\n%s", rest)
+ }
+ if !strings.Contains(rest, expect) {
+ t.Fatalf("expected %q in the output, got:\n%s", expect, rest)
+ }
+}
cmd := exec.Command(testenv.GoToolPath(t), append([]string{"build", "-o", exe}, flags...)...)
t.Logf("running %v", cmd)
cmd.Dir = "testdata/" + binary
- out, err := testenv.CleanCmdEnv(cmd).CombinedOutput()
+ cmd = testenv.CleanCmdEnv(cmd)
+
+ // Add the rangefunc GOEXPERIMENT unconditionally since some tests depend on it.
+ // TODO(61405): Remove this once it's enabled by default.
+ edited := false
+ for i := range cmd.Env {
+ e := cmd.Env[i]
+ if _, vars, ok := strings.Cut(e, "GOEXPERIMENT="); ok {
+ cmd.Env[i] = "GOEXPERIMENT=" + vars + ",rangefunc"
+ edited = true
+ }
+ }
+ if !edited {
+ cmd.Env = append(cmd.Env, "GOEXPERIMENT=rangefunc")
+ }
+
+ out, err := cmd.CombinedOutput()
if err != nil {
target.err = fmt.Errorf("building %s %v: %v\n%s", binary, flags, err, out)
} else {
return
}
- if mp.lockedInt != 0 {
- print("invalid m->lockedInt = ", mp.lockedInt, "\n")
- throw("internal lockOSThread error")
+ if locked && mp.lockedInt != 0 {
+ print("runtime: mp.lockedInt = ", mp.lockedInt, "\n")
+ throw("exited a goroutine internally locked to the OS thread")
}
gfput(pp, gp)
if locked {
--- /dev/null
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build goexperiment.rangefunc
+
+package main
+
+import (
+ "fmt"
+ "iter"
+ "runtime"
+)
+
+func init() {
+ register("CoroLockOSThreadIterLock", func() {
+ println("expect: OK")
+ CoroLockOSThread(callerExhaust, iterLock)
+ })
+ register("CoroLockOSThreadIterLockYield", func() {
+ println("expect: OS thread locking must match")
+ CoroLockOSThread(callerExhaust, iterLockYield)
+ })
+ register("CoroLockOSThreadLock", func() {
+ println("expect: OK")
+ CoroLockOSThread(callerExhaustLocked, iterSimple)
+ })
+ register("CoroLockOSThreadLockIterNested", func() {
+ println("expect: OK")
+ CoroLockOSThread(callerExhaustLocked, iterNested)
+ })
+ register("CoroLockOSThreadLockIterLock", func() {
+ println("expect: OK")
+ CoroLockOSThread(callerExhaustLocked, iterLock)
+ })
+ register("CoroLockOSThreadLockIterLockYield", func() {
+ println("expect: OS thread locking must match")
+ CoroLockOSThread(callerExhaustLocked, iterLockYield)
+ })
+ register("CoroLockOSThreadLockIterYieldNewG", func() {
+ println("expect: OS thread locking must match")
+ CoroLockOSThread(callerExhaustLocked, iterYieldNewG)
+ })
+ register("CoroLockOSThreadLockAfterPull", func() {
+ println("expect: OS thread locking must match")
+ CoroLockOSThread(callerLockAfterPull, iterSimple)
+ })
+ register("CoroLockOSThreadStopLocked", func() {
+ println("expect: OK")
+ CoroLockOSThread(callerStopLocked, iterSimple)
+ })
+ register("CoroLockOSThreadStopLockedIterNested", func() {
+ println("expect: OK")
+ CoroLockOSThread(callerStopLocked, iterNested)
+ })
+}
+
+func CoroLockOSThread(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
+ if err := driver(seq); err != nil {
+ println("error:", err.Error())
+ return
+ }
+ println("OK")
+}
+
+func callerExhaust(i iter.Seq[int]) error {
+ next, _ := iter.Pull(i)
+ for {
+ v, ok := next()
+ if !ok {
+ break
+ }
+ if v != 5 {
+ return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ }
+ return nil
+}
+
+func callerExhaustLocked(i iter.Seq[int]) error {
+ runtime.LockOSThread()
+ next, _ := iter.Pull(i)
+ for {
+ v, ok := next()
+ if !ok {
+ break
+ }
+ if v != 5 {
+ return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ }
+ runtime.UnlockOSThread()
+ return nil
+}
+
+func callerLockAfterPull(i iter.Seq[int]) error {
+ n := 0
+ next, _ := iter.Pull(i)
+ for {
+ runtime.LockOSThread()
+ n++
+ v, ok := next()
+ if !ok {
+ break
+ }
+ if v != 5 {
+ return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ }
+ for range n {
+ runtime.UnlockOSThread()
+ }
+ return nil
+}
+
+func callerStopLocked(i iter.Seq[int]) error {
+ runtime.LockOSThread()
+ next, stop := iter.Pull(i)
+ v, _ := next()
+ stop()
+ if v != 5 {
+ return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ runtime.UnlockOSThread()
+ return nil
+}
+
+func iterSimple(yield func(int) bool) {
+ for range 3 {
+ if !yield(5) {
+ return
+ }
+ }
+}
+
+func iterNested(yield func(int) bool) {
+ next, stop := iter.Pull(iterSimple)
+ for {
+ v, ok := next()
+ if ok {
+ if !yield(v) {
+ stop()
+ }
+ } else {
+ return
+ }
+ }
+}
+
+func iterLock(yield func(int) bool) {
+ for range 3 {
+ runtime.LockOSThread()
+ runtime.UnlockOSThread()
+
+ if !yield(5) {
+ return
+ }
+ }
+}
+
+func iterLockYield(yield func(int) bool) {
+ for range 3 {
+ runtime.LockOSThread()
+ ok := yield(5)
+ runtime.UnlockOSThread()
+ if !ok {
+ return
+ }
+ }
+}
+
+func iterYieldNewG(yield func(int) bool) {
+ for range 3 {
+ done := make(chan struct{})
+ var ok bool
+ go func() {
+ ok = yield(5)
+ done <- struct{}{}
+ }()
+ <-done
+ if !ok {
+ return
+ }
+ }
+}
--- /dev/null
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build goexperiment.rangefunc && !windows
+
+package main
+
+/*
+#include <stdint.h> // for uintptr_t
+
+void go_callback_coro(uintptr_t handle);
+
+static void call_go(uintptr_t handle) {
+ go_callback_coro(handle);
+}
+*/
+import "C"
+
+import (
+ "fmt"
+ "iter"
+ "runtime/cgo"
+)
+
+func init() {
+ register("CoroCgoIterCallback", func() {
+ println("expect: OK")
+ CoroCgo(callerExhaust, iterCallback)
+ })
+ register("CoroCgoIterCallbackYield", func() {
+ println("expect: OS thread locking must match")
+ CoroCgo(callerExhaust, iterCallbackYield)
+ })
+ register("CoroCgoCallback", func() {
+ println("expect: OK")
+ CoroCgo(callerExhaustCallback, iterSimple)
+ })
+ register("CoroCgoCallbackIterNested", func() {
+ println("expect: OK")
+ CoroCgo(callerExhaustCallback, iterNested)
+ })
+ register("CoroCgoCallbackIterCallback", func() {
+ println("expect: OK")
+ CoroCgo(callerExhaustCallback, iterCallback)
+ })
+ register("CoroCgoCallbackIterCallbackYield", func() {
+ println("expect: OS thread locking must match")
+ CoroCgo(callerExhaustCallback, iterCallbackYield)
+ })
+ register("CoroCgoCallbackAfterPull", func() {
+ println("expect: OS thread locking must match")
+ CoroCgo(callerCallbackAfterPull, iterSimple)
+ })
+ register("CoroCgoStopCallback", func() {
+ println("expect: OK")
+ CoroCgo(callerStopCallback, iterSimple)
+ })
+ register("CoroCgoStopCallbackIterNested", func() {
+ println("expect: OK")
+ CoroCgo(callerStopCallback, iterNested)
+ })
+}
+
+var toCall func()
+
+//export go_callback_coro
+func go_callback_coro(handle C.uintptr_t) {
+ h := cgo.Handle(handle)
+ h.Value().(func())()
+ h.Delete()
+}
+
+func callFromC(f func()) {
+ C.call_go(C.uintptr_t(cgo.NewHandle(f)))
+}
+
+func CoroCgo(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
+ if err := driver(seq); err != nil {
+ println("error:", err.Error())
+ return
+ }
+ println("OK")
+}
+
+func callerExhaust(i iter.Seq[int]) error {
+ next, _ := iter.Pull(i)
+ for {
+ v, ok := next()
+ if !ok {
+ break
+ }
+ if v != 5 {
+ return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ }
+ return nil
+}
+
+func callerExhaustCallback(i iter.Seq[int]) (err error) {
+ callFromC(func() {
+ next, _ := iter.Pull(i)
+ for {
+ v, ok := next()
+ if !ok {
+ break
+ }
+ if v != 5 {
+ err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ }
+ })
+ return err
+}
+
+func callerStopCallback(i iter.Seq[int]) (err error) {
+ callFromC(func() {
+ next, stop := iter.Pull(i)
+ v, _ := next()
+ stop()
+ if v != 5 {
+ err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ })
+ return err
+}
+
+func callerCallbackAfterPull(i iter.Seq[int]) (err error) {
+ next, _ := iter.Pull(i)
+ callFromC(func() {
+ for {
+ v, ok := next()
+ if !ok {
+ break
+ }
+ if v != 5 {
+ err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
+ }
+ }
+ })
+ return err
+}
+
+func iterSimple(yield func(int) bool) {
+ for range 3 {
+ if !yield(5) {
+ return
+ }
+ }
+}
+
+func iterNested(yield func(int) bool) {
+ next, stop := iter.Pull(iterSimple)
+ for {
+ v, ok := next()
+ if ok {
+ if !yield(v) {
+ stop()
+ }
+ } else {
+ return
+ }
+ }
+}
+
+func iterCallback(yield func(int) bool) {
+ for range 3 {
+ callFromC(func() {})
+ if !yield(5) {
+ return
+ }
+ }
+}
+
+func iterCallbackYield(yield func(int) bool) {
+ for range 3 {
+ var ok bool
+ callFromC(func() {
+ ok = yield(5)
+ })
+ if !ok {
+ return
+ }
+ }
+}