]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: prefer to restart Ps on the same M after STW
authorMichael Pratt <mpratt@google.com>
Fri, 24 Oct 2025 19:14:59 +0000 (15:14 -0400)
committerGopher Robot <gobot@golang.org>
Thu, 13 Nov 2025 15:44:41 +0000 (07:44 -0800)
Today, Ps jump around arbitrarily across STW. Instead, try to keep the P
on the previous M it ran on. In the future, we'll likely want to try to
expand this beyond STW to create a more general affinity for specific
Ms.

For this to be useful, the Ps need to have runnable Gs. Today, STW
preemption goes through goschedImpl, which places the G on the global
run queue. If that was the only G then the P won't have runnable
goroutines anymore.

It makes more sense to keep the G with its P across STW anyway, so add a
special case to goschedImpl for that.

On my machine, this CL reduces the error rate in TestTraceSTW from 99.8%
to 1.9%.

As a nearly 2% error rate shows, there are still cases where this best
effort scheduling doesn't work. The most obvious is that while
procresize assigns Ps back to their original M, startTheWorldWithSema
calls wakep to start a spinning M. The spinning M may steal a goroutine
from another P if that P is too slow to start.

For #65694.

Change-Id: I6a6a636c0969c587d039b68bc68ea16c74ff1fc9
Reviewed-on: https://go-review.googlesource.com/c/go/+/714801
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Auto-Submit: Michael Pratt <mpratt@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/internal/trace/testtrace/helpers.go
src/runtime/proc.go
src/runtime/proc_test.go
src/runtime/runtime2.go
src/runtime/testdata/testprog/stw_mexit.go [new file with mode: 0644]
src/runtime/testdata/testprog/stw_trace.go [new file with mode: 0644]

index c4c404a304b25aa44a31ce818e5707f302cdf3ba..8a64d5c2ee19cfc33158883a230e47a2d151756e 100644 (file)
@@ -16,8 +16,7 @@ import (
        "testing"
 )
 
-// MustHaveSyscallEvents skips the current test if the current
-// platform does not support true system call events.
+// Dump saves the trace to a file or the test log.
 func Dump(t *testing.T, testName string, traceBytes []byte, forceToFile bool) {
        onBuilder := testenv.Builder() != ""
        onOldBuilder := !strings.Contains(testenv.Builder(), "gotip") && !strings.Contains(testenv.Builder(), "go1")
index 30d2a6862603f88897d1feb89f47c3e0958d8d0b..44b64913a56ae3550f303e85f995a39c74df4c8d 100644 (file)
@@ -1009,6 +1009,8 @@ func mcommoninit(mp *m, id int64) {
                mp.id = mReserveID()
        }
 
+       mp.self = newMWeakPointer(mp)
+
        mrandinit(mp)
 
        mpreinit(mp)
@@ -2018,6 +2020,10 @@ func mexit(osStack bool) {
        // Free vgetrandom state.
        vgetrandomDestroy(mp)
 
+       // Clear the self pointer so Ps don't access this M after it is freed,
+       // or keep it alive.
+       mp.self.clear()
+
        // Remove m from allm.
        lock(&sched.lock)
        for pprev := &allm; *pprev != nil; pprev = &(*pprev).alllink {
@@ -4259,6 +4265,7 @@ func park_m(gp *g) {
 }
 
 func goschedImpl(gp *g, preempted bool) {
+       pp := gp.m.p.ptr()
        trace := traceAcquire()
        status := readgstatus(gp)
        if status&^_Gscan != _Grunning {
@@ -4281,9 +4288,15 @@ func goschedImpl(gp *g, preempted bool) {
        }
 
        dropg()
-       lock(&sched.lock)
-       globrunqput(gp)
-       unlock(&sched.lock)
+       if preempted && sched.gcwaiting.Load() {
+               // If preempted for STW, keep the G on the local P in runnext
+               // so it can keep running immediately after the STW.
+               runqput(pp, gp, true)
+       } else {
+               lock(&sched.lock)
+               globrunqput(gp)
+               unlock(&sched.lock)
+       }
 
        if mainStarted {
                wakep()
@@ -6013,6 +6026,7 @@ func procresize(nprocs int32) *p {
        }
 
        var runnablePs *p
+       var runnablePsNeedM *p
        for i := nprocs - 1; i >= 0; i-- {
                pp := allp[i]
                if gp.m.p.ptr() == pp {
@@ -6021,12 +6035,41 @@ func procresize(nprocs int32) *p {
                pp.status = _Pidle
                if runqempty(pp) {
                        pidleput(pp, now)
-               } else {
-                       pp.m.set(mget())
-                       pp.link.set(runnablePs)
-                       runnablePs = pp
+                       continue
                }
+
+               // Prefer to run on the most recent M if it is
+               // available.
+               //
+               // Ps with no oldm (or for which oldm is already taken
+               // by an earlier P), we delay until all oldm Ps are
+               // handled. Otherwise, mget may return an M that a
+               // later P has in oldm.
+               var mp *m
+               if oldm := pp.oldm.get(); oldm != nil {
+                       // Returns nil if oldm is not idle.
+                       mp = mgetSpecific(oldm)
+               }
+               if mp == nil {
+                       // Call mget later.
+                       pp.link.set(runnablePsNeedM)
+                       runnablePsNeedM = pp
+                       continue
+               }
+               pp.m.set(mp)
+               pp.link.set(runnablePs)
+               runnablePs = pp
        }
+       for runnablePsNeedM != nil {
+               pp := runnablePsNeedM
+               runnablePsNeedM = pp.link.ptr()
+
+               mp := mget()
+               pp.m.set(mp)
+               pp.link.set(runnablePs)
+               runnablePs = pp
+       }
+
        stealOrder.reset(uint32(nprocs))
        var int32p *int32 = &gomaxprocs // make compiler check that gomaxprocs is an int32
        atomic.Store((*uint32)(unsafe.Pointer(int32p)), uint32(nprocs))
@@ -6064,6 +6107,11 @@ func acquirepNoTrace(pp *p) {
 
        // Have p; write barriers now allowed.
 
+       // The M we're associating with will be the old M after the next
+       // releasep. We must set this here because write barriers are not
+       // allowed in releasep.
+       pp.oldm = pp.m.ptr().self
+
        // Perform deferred mcache flush before this P can allocate
        // from a potentially stale mcache.
        pp.mcache.prepareForSweep()
@@ -6998,6 +7046,27 @@ func mget() *m {
        return mp
 }
 
+// Try to get a specific m from midle list. Returns nil if it isn't on the
+// midle list.
+//
+// sched.lock must be held.
+// May run during STW, so write barriers are not allowed.
+//
+//go:nowritebarrierrec
+func mgetSpecific(mp *m) *m {
+       assertLockHeld(&sched.lock)
+
+       if mp.idleNode.prev == 0 && mp.idleNode.next == 0 {
+               // Not on the list.
+               return nil
+       }
+
+       sched.midle.remove(unsafe.Pointer(mp))
+       sched.nmidle--
+
+       return mp
+}
+
 // Put gp on the global runnable queue.
 // sched.lock must be held.
 // May run during STW, so write barriers are not allowed.
index d10d4a1fc931ccefe7a47309a1baf972d8ec143e..b3084f4895fe4fb18c3d614675b3e2f9be9af652 100644 (file)
@@ -5,13 +5,18 @@
 package runtime_test
 
 import (
+       "bytes"
        "fmt"
        "internal/race"
        "internal/testenv"
+       "internal/trace"
+       "internal/trace/testtrace"
+       "io"
        "math"
        "net"
        "runtime"
        "runtime/debug"
+       "slices"
        "strings"
        "sync"
        "sync/atomic"
@@ -1168,3 +1173,364 @@ func TestBigGOMAXPROCS(t *testing.T) {
                t.Errorf("output:\n%s\nwanted:\nunknown function: NonexistentTest", output)
        }
 }
+
+type goroutineState struct {
+       G trace.GoID      // This goroutine.
+       P trace.ProcID    // Most recent P this goroutine ran on.
+       M trace.ThreadID  // Most recent M this goroutine ran on.
+}
+
+func newGoroutineState(g trace.GoID) *goroutineState {
+       return &goroutineState{
+               G: g,
+               P: trace.NoProc,
+               M: trace.NoThread,
+       }
+}
+
+// TestTraceSTW verifies that goroutines continue running on the same M and P
+// after a STW.
+func TestTraceSTW(t *testing.T) {
+       // Across STW, the runtime attempts to keep goroutines running on the
+       // same P and the P running on the same M. It does this by keeping
+       // goroutines in the P's local runq, and remembering which M the P ran
+       // on before STW and preferring that M when restarting.
+       //
+       // This test verifies that affinity by analyzing a trace of testprog
+       // TraceSTW.
+       //
+       // The affinity across STW is best-effort, so have to allow some
+       // failure rate, thus we test many times and ensure the error rate is
+       // low.
+       //
+       // The expected affinity can fail for a variety of reasons. The most
+       // obvious is that while procresize assigns Ps back to their original
+       // M, startTheWorldWithSema calls wakep to start a spinning M. The
+       // spinning M may steal a goroutine from another P if that P is too
+       // slow to start.
+
+       if testing.Short() {
+               t.Skip("skipping in -short mode")
+       }
+
+       if runtime.NumCPU() < 4 {
+               t.Skip("This test sets GOMAXPROCS=4 and wants to avoid thread descheduling as much as possible. Skip on machines with less than 4 CPUs")
+       }
+
+       const runs = 50
+
+       var errors int
+       for i := range runs {
+               err := runTestTracesSTW(t, i)
+               if err != nil {
+                       t.Logf("Run %d failed: %v", i, err)
+                       errors++
+               }
+       }
+
+       pct := float64(errors)/float64(runs)
+       t.Logf("Errors: %d/%d = %f%%", errors, runs, 100*pct)
+       if pct > 0.25 {
+               t.Errorf("Error rate too high")
+       }
+}
+
+func runTestTracesSTW(t *testing.T, run int) (err error) {
+       t.Logf("Run %d", run)
+
+       // By default, TSAN sleeps for 1s at exit to allow background
+       // goroutines to race. This slows down execution for this test far too
+       // much, since we are running 50 iterations, so disable the sleep.
+       //
+       // Outside of race mode, GORACE does nothing.
+       buf := []byte(runTestProg(t, "testprog", "TraceSTW", "GORACE=atexit_sleep_ms=0"))
+
+       // We locally "fail" the run (return an error) if the trace exhibits
+       // unwanted scheduling. i.e., the target goroutines did not remain on
+       // the same P/M.
+       //
+       // We fail the entire test (t.Fatal) for other cases that should never
+       // occur, such as a trace parse error.
+       defer func() {
+               if err != nil || t.Failed() {
+                       testtrace.Dump(t, fmt.Sprintf("TestTraceSTW-run%d", run), []byte(buf), false)
+               }
+       }()
+
+       br, err := trace.NewReader(bytes.NewReader(buf))
+       if err != nil {
+               t.Fatalf("NewReader got err %v want nil", err)
+       }
+
+       var targetGoroutines []*goroutineState
+       findGoroutine := func(goid trace.GoID) *goroutineState {
+               for _, gs := range targetGoroutines {
+                       if gs.G == goid {
+                               return gs
+                       }
+               }
+               return nil
+       }
+       findProc := func(pid trace.ProcID) *goroutineState {
+               for _, gs := range targetGoroutines {
+                       if gs.P == pid {
+                               return gs
+                       }
+               }
+               return nil
+       }
+
+       // 1. Find the goroutine IDs for the target goroutines. This will be in
+       // the StateTransition from NotExist.
+       //
+       // 2. Once found, track which M and P the target goroutines run on until...
+       //
+       // 3. Look for the "TraceSTW" "start" log message, where we commit the
+       // target goroutines' "before" M and P.
+       //
+       // N.B. We must do (1) and (2) together because the first target
+       // goroutine may start running before the second is created.
+findStart:
+       for {
+               ev, err := br.ReadEvent()
+               if err == io.EOF {
+                       // Reached the end of the trace without finding case (3).
+                       t.Fatalf("Trace missing start log message")
+               }
+               if err != nil {
+                       t.Fatalf("ReadEvent got err %v want nil", err)
+               }
+               t.Logf("Event: %s", ev.String())
+
+               switch ev.Kind() {
+               case trace.EventStateTransition:
+                       st := ev.StateTransition()
+                       if st.Resource.Kind != trace.ResourceGoroutine {
+                               continue
+                       }
+
+                       goid := st.Resource.Goroutine()
+                       from, to := st.Goroutine()
+
+                       // Potentially case (1): Goroutine creation.
+                       if from == trace.GoNotExist {
+                               for sf := range st.Stack.Frames() {
+                                       if sf.Func == "main.traceSTWTarget" {
+                                               targetGoroutines = append(targetGoroutines, newGoroutineState(goid))
+                                               t.Logf("Identified target goroutine id %d", goid)
+                                       }
+
+                                       // Always break, the goroutine entrypoint is always the
+                                       // first frame.
+                                       break
+                               }
+                       }
+
+                       // Potentially case (2): Goroutine running.
+                       if to == trace.GoRunning {
+                               gs := findGoroutine(goid)
+                               if gs == nil {
+                                       continue
+                               }
+                               gs.P = ev.Proc()
+                               gs.M = ev.Thread()
+                               t.Logf("G %d running on P %d M %d", gs.G, gs.P, gs.M)
+                       }
+               case trace.EventLog:
+                       // Potentially case (3): Start log event.
+                       log := ev.Log()
+                       if log.Category != "TraceSTW" {
+                               continue
+                       }
+                       if log.Message != "start" {
+                               t.Fatalf("Log message got %s want start", log.Message)
+                       }
+
+                       // Found start point, move on to next stage.
+                       t.Logf("Found start message")
+                       break findStart
+               }
+       }
+
+       t.Log("Target goroutines:")
+       for _, gs := range targetGoroutines {
+               t.Logf("%+v", gs)
+       }
+
+       if len(targetGoroutines) != 2 {
+               t.Fatalf("len(targetGoroutines) got %d want 2", len(targetGoroutines))
+       }
+
+       for _, gs := range targetGoroutines {
+               if gs.P == trace.NoProc {
+                       t.Fatalf("Goroutine %+v not running on a P", gs)
+               }
+               if gs.M == trace.NoThread {
+                       t.Fatalf("Goroutine %+v not running on an M", gs)
+               }
+       }
+
+       // The test continues until we see the "end" log message.
+       //
+       // What we want to observe is that the target goroutines run only on
+       // the original P and M.
+       //
+       // They will be stopped by STW [1], but should resume on the original P
+       // and M.
+       //
+       // However, this is best effort. For example, startTheWorld wakep's a
+       // spinning M. If the original M is slow to restart (e.g., due to poor
+       // kernel scheduling), the spinning M may legally steal the goroutine
+       // and run it instead.
+       //
+       // In practice, we see this occur frequently on builders, likely
+       // because they are overcommitted on CPU. Thus, we instead check
+       // slightly more constrained properties:
+       // - The original P must run on the original M (if it runs at all).
+       // - The original P must run the original G before anything else,
+       //   unless that G has already run elsewhere.
+       //
+       // This allows a spinning M to steal the G from a slow-to-start M, but
+       // does not allow the original P to just flat out run something
+       // completely different from expected.
+       //
+       // Note this is still somewhat racy: the spinning M may steal the
+       // target G, but before it marks the target G as running, the original
+       // P runs an alternative G. This test will fail that case, even though
+       // it is legitimate. We allow that failure because such a race should
+       // be very rare, particularly because the test process usually has no
+       // other runnable goroutines.
+       //
+       // [1] This is slightly fragile because there is a small window between
+       // the "start" log and actual STW during which the target goroutines
+       // could legitimately migrate.
+       var stwSeen bool
+       var pRunning []trace.ProcID
+       var gRunning []trace.GoID
+findEnd:
+       for {
+               ev, err := br.ReadEvent()
+               if err == io.EOF {
+                       break
+               }
+               if err != nil {
+                       t.Fatalf("ReadEvent got err %v want nil", err)
+               }
+               t.Logf("Event: %s", ev.String())
+
+               switch ev.Kind() {
+               case trace.EventStateTransition:
+                       st := ev.StateTransition()
+                       switch st.Resource.Kind {
+                       case trace.ResourceProc:
+                               p := st.Resource.Proc()
+                               _, to := st.Proc()
+
+                               // Proc running. Ensure it didn't migrate.
+                               if to == trace.ProcRunning {
+                                       gs := findProc(p)
+                                       if gs == nil {
+                                               continue
+                                       }
+
+                                       if slices.Contains(pRunning, p) {
+                                               // Only check the first
+                                               // transition to running.
+                                               // Afterwards it is free to
+                                               // migrate anywhere.
+                                               continue
+                                       }
+                                       pRunning = append(pRunning, p)
+
+                                       m := ev.Thread()
+                                       if m != gs.M {
+                                               t.Logf("Proc %d running on M %d want M %d", p, m, gs.M)
+                                               return fmt.Errorf("P did not remain on M")
+                                       }
+                               }
+                       case trace.ResourceGoroutine:
+                               goid := st.Resource.Goroutine()
+                               _, to := st.Goroutine()
+
+                               // Goroutine running. Ensure it didn't migrate.
+                               if to == trace.GoRunning {
+                                       p := ev.Proc()
+                                       m := ev.Thread()
+
+                                       gs := findGoroutine(goid)
+                                       if gs == nil {
+                                               // This isn't a target
+                                               // goroutine. Is it a target P?
+                                               // That shouldn't run anything
+                                               // other than the target G.
+                                               gs = findProc(p)
+                                               if gs == nil {
+                                                       continue
+                                               }
+
+                                               if slices.Contains(gRunning, gs.G) {
+                                                       // This P's target G ran elsewhere. This probably
+                                                       // means that this P was slow to start, so
+                                                       // another P stole it. That isn't ideal, but
+                                                       // we'll allow it.
+                                                       continue
+                                               }
+
+                                               t.Logf("Goroutine %d running on P %d M %d want this P to run G %d", goid, p, m, gs.G)
+                                               return fmt.Errorf("P ran incorrect goroutine")
+                                       }
+
+                                       if !slices.Contains(gRunning, goid) {
+                                               gRunning = append(gRunning, goid)
+                                       }
+
+                                       if p != gs.P || m != gs.M {
+                                               t.Logf("Goroutine %d running on P %d M %d want P %d M %d", goid, p, m, gs.P, gs.M)
+                                               // We don't want this to occur,
+                                               // but allow it for cases of
+                                               // bad kernel scheduling. See
+                                               // "The test continues" comment
+                                               // above.
+                                       }
+                               }
+                       }
+               case trace.EventLog:
+                       // Potentially end log event.
+                       log := ev.Log()
+                       if log.Category != "TraceSTW" {
+                               continue
+                       }
+                       if log.Message != "end" {
+                               t.Fatalf("Log message got %s want end", log.Message)
+                       }
+
+                       // Found end point.
+                       t.Logf("Found end message")
+                       break findEnd
+               case trace.EventRangeBegin:
+                       r := ev.Range()
+                       if r.Name == "stop-the-world (read mem stats)" {
+                               // Note when we see the STW begin. This is not
+                               // load bearing; it's purpose is simply to fail
+                               // the test if we manage to remove the STW from
+                               // ReadMemStat, so we remember to change this
+                               // test to add some new source of STW.
+                               stwSeen = true
+                       }
+               }
+       }
+
+       if !stwSeen {
+               t.Fatal("No STW in the test trace")
+       }
+
+       return nil
+}
+
+func TestMexitSTW(t *testing.T) {
+       got := runTestProg(t, "testprog", "mexitSTW")
+       want := "OK\n"
+       if got != want {
+               t.Fatalf("expected %q, but got:\n%s", want, got)
+       }
+}
index 85a9693ace02b8ffa6fbff5f7395e73bdd816fca..6c955460d4f9993a8f6159561efec10dc957160b 100644 (file)
@@ -716,6 +716,9 @@ type m struct {
        // Up to 10 locks held by this m, maintained by the lock ranking code.
        locksHeldLen int
        locksHeld    [10]heldLockInfo
+
+       // self points this M until mexit clears it to return nil.
+       self mWeakPointer
 }
 
 const mRedZoneSize = (16 << 3) * asanenabledBit // redZoneSize(2048)
@@ -730,6 +733,37 @@ type mPadded struct {
        _ [(1 - goarch.IsWasm) * (2048 - mallocHeaderSize - mRedZoneSize - unsafe.Sizeof(m{}))]byte
 }
 
+// mWeakPointer is a "weak" pointer to an M. A weak pointer for each M is
+// available as m.self. Users may copy mWeakPointer arbitrarily, and get will
+// return the M if it is still live, or nil after mexit.
+//
+// The zero value is treated as a nil pointer.
+//
+// Note that get may race with M exit. A successful get will keep the m object
+// alive, but the M itself may be exited and thus not actually usable.
+type mWeakPointer struct {
+       m *atomic.Pointer[m]
+}
+
+func newMWeakPointer(mp *m) mWeakPointer {
+       w := mWeakPointer{m: new(atomic.Pointer[m])}
+       w.m.Store(mp)
+       return w
+}
+
+func (w mWeakPointer) get() *m {
+       if w.m == nil {
+               return nil
+       }
+       return w.m.Load()
+}
+
+// clear sets the weak pointer to nil. It cannot be used on zero value
+// mWeakPointers.
+func (w mWeakPointer) clear() {
+       w.m.Store(nil)
+}
+
 type p struct {
        id          int32
        status      uint32 // one of pidle/prunning/...
@@ -742,6 +776,17 @@ type p struct {
        pcache      pageCache
        raceprocctx uintptr
 
+       // oldm is the previous m this p ran on.
+       //
+       // We are not assosciated with this m, so we have no control over its
+       // lifecycle. This value is an m.self object which points to the m
+       // until the m exits.
+       //
+       // Note that this m may be idle, running, or exiting. It should only be
+       // used with mgetSpecific, which will take ownership of the m only if
+       // it is idle.
+       oldm mWeakPointer
+
        deferpool    []*_defer // pool of available defer structs (see panic.go)
        deferpoolbuf [32]*_defer
 
diff --git a/src/runtime/testdata/testprog/stw_mexit.go b/src/runtime/testdata/testprog/stw_mexit.go
new file mode 100644 (file)
index 0000000..b022ef4
--- /dev/null
@@ -0,0 +1,69 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "runtime"
+)
+
+func init() {
+       register("mexitSTW", mexitSTW)
+}
+
+// Stress test for pp.oldm pointing to an exited M.
+//
+// If pp.oldm points to an exited M it should be ignored and another M used
+// instead. To stress:
+//
+// 1. Start and exit many threads (thus setting oldm on some P).
+// 2. Meanwhile, frequently stop the world.
+//
+// If procresize incorrect attempts to assign a P to an exited M, likely
+// failure modes are:
+//
+// 1. Crash in startTheWorldWithSema attempting to access the M, if it is nil.
+//
+// 2. Memory corruption elsewhere after startTheWorldWithSema writes to the M,
+// if it is not nil, but is freed and reused for another allocation.
+//
+// 3. Hang on a subsequent stop the world waiting for the P to stop, if the M
+// object is valid, but the M is exited, because startTheWorldWithSema didn't
+// actually wake anything to run the P. The P is _Pidle, but not in the pidle
+// list, thus startTheWorldWithSema will wake for it to actively stop.
+//
+// For this to go wrong, an exited M must fail to clear mp.self and must leave
+// the M on the sched.midle list.
+//
+// Similar to TraceSTW.
+func mexitSTW() {
+       // Ensure we have multiple Ps, but not too many, as we want the
+       // runnable goroutines likely to run on Ps with oldm set.
+       runtime.GOMAXPROCS(4)
+
+       // Background busy work so there is always something runnable.
+       for i := range 2 {
+               go traceSTWTarget(i)
+       }
+
+       // Wait for children to start running.
+       ping.Store(1)
+       for pong[0].Load() != 1 {}
+       for pong[1].Load() != 1 {}
+
+       for range 100 {
+               // Exit a thread. The last P to run this will have it in oldm.
+               go func() {
+                       runtime.LockOSThread()
+               }()
+
+               // STW
+               var ms runtime.MemStats
+               runtime.ReadMemStats(&ms)
+       }
+
+       stop.Store(true)
+
+       println("OK")
+}
diff --git a/src/runtime/testdata/testprog/stw_trace.go b/src/runtime/testdata/testprog/stw_trace.go
new file mode 100644 (file)
index 0000000..0fed55b
--- /dev/null
@@ -0,0 +1,99 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "context"
+       "log"
+       "os"
+       "runtime"
+       "runtime/debug"
+       "runtime/trace"
+       "sync/atomic"
+)
+
+func init() {
+       register("TraceSTW", TraceSTW)
+}
+
+// The parent writes to ping and waits for the children to write back
+// via pong to show that they are running.
+var ping atomic.Uint32
+var pong [2]atomic.Uint32
+
+// Tell runners to stop.
+var stop atomic.Bool
+
+func traceSTWTarget(i int) {
+       for !stop.Load() {
+               // Async preemption often takes 100ms+ to preempt this loop on
+               // windows-386. This makes the test flaky, as the traceReadCPU
+               // timer often fires by the time STW finishes, jumbling the
+               // goroutine scheduling. As a workaround, ensure we have a
+               // morestack call for prompt preemption.
+               ensureMorestack()
+
+               pong[i].Store(ping.Load())
+       }
+}
+
+func TraceSTW() {
+       ctx := context.Background()
+
+       // The idea here is to have 2 target goroutines that are constantly
+       // running. When the world restarts after STW, we expect these
+       // goroutines to continue execution on the same M and P.
+       //
+       // Set GOMAXPROCS=4 to make room for the 2 target goroutines, 1 parent,
+       // and 1 slack for potential misscheduling.
+       //
+       // Disable the GC because GC STW generally moves goroutines (see
+       // https://go.dev/issue/65694). Alternatively, we could just ignore the
+       // trace if the GC runs.
+       runtime.GOMAXPROCS(4)
+       debug.SetGCPercent(0)
+
+       if err := trace.Start(os.Stdout); err != nil {
+               log.Fatalf("failed to start tracing: %v", err)
+       }
+       defer trace.Stop()
+
+       for i := range 2 {
+               go traceSTWTarget(i)
+       }
+
+       // Wait for children to start running.
+       ping.Store(1)
+       for pong[0].Load() != 1 {}
+       for pong[1].Load() != 1 {}
+
+       trace.Log(ctx, "TraceSTW", "start")
+
+       // STW
+       var ms runtime.MemStats
+       runtime.ReadMemStats(&ms)
+
+       // Make sure to run long enough for the children to schedule again
+       // after STW.
+       ping.Store(2)
+       for pong[0].Load() != 2 {}
+       for pong[1].Load() != 2 {}
+
+       trace.Log(ctx, "TraceSTW", "end")
+
+       stop.Store(true)
+}
+
+// Manually insert a morestack call. Leaf functions can omit morestack, but
+// non-leaf functions should include them.
+
+//go:noinline
+func ensureMorestack() {
+       ensureMorestack1()
+}
+
+//go:noinline
+func ensureMorestack1() {
+}