]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: make goroutines inherit DIT state, don't lock to OS thread
authorRoland Shoemaker <roland@golang.org>
Wed, 3 Dec 2025 22:19:25 +0000 (14:19 -0800)
committerRoland Shoemaker <roland@golang.org>
Thu, 11 Dec 2025 16:21:53 +0000 (08:21 -0800)
When we first implemented DIT (crypto/subtle.WithDataIndependentTiming),
we made it so that enabling DIT on a goroutine would lock that goroutine
to its current OS thread. This was done to ensure that the DIT state
(which is per-thread) would not leak to other goroutines. We also did
not make goroutines inherit the DIT state.

This change makes goroutines inherit the DIT state from their parent
at creation time. It also removes the OS thread locking when enabling
DIT on a goroutine. Instead, we now set the DIT state on the OS thread
in the scheduler whenever we switch to a goroutine that has DIT enabled,
and we unset it when switching to a goroutine that has DIT disabled.

We add a new field to G and M, ditEnabled, to track whether the G wants
DIT enabled, and whether the M currently has DIT enabled, respectively.
When the scheduler executes a goroutine, it checks these fields and
enables/disables DIT on the thread as needed.

Additionally, cgocallbackg is updated to check if DIT is enabled when
being called from C, and sets the G and M fields accordingly. This
ensures that if DIT was enabled/disabled in C, the correct state will be
reflected in the Go runtime.

The behavior as it currently stands is as follows:
- The function passed to crypto/subtle.WithDataIndependentTiming
  will have DIT enabled.
- Any goroutine created within that function will inherit DIT enabled
  for its lifetime. Any goroutine created from subquent goroutines will
  also inherit DIT enabled for their lifetimes.
- Calling into a C function within from a goroutine with DIT enabled
  will have DIT enabled.
- If the C code disables DIT, the goroutine will have DIT re-enabled
  when returning to Go.
- If the C code enables DIT, the goroutine will have DIT disabled
  when returning to Go if it was not previously enabled.
- Calling back into Go code from C will have DIT enabled if it was
  enabled when calling into C, or if the C code enabled it.

Change-Id: I8e91e6df13bb88e56e1036e0e0e5f04efd8eebd3
Reviewed-on: https://go-review.googlesource.com/c/go/+/726382
Reviewed-by: Michael Pratt <mpratt@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
src/cmd/cgo/internal/test/cgo_test.go
src/cmd/cgo/internal/test/test.go
src/cmd/cgo/internal/test/testx.go
src/cmd/link/internal/loader/loader.go
src/crypto/subtle/dit.go
src/crypto/subtle/dit_test.go
src/runtime/cgocall.go
src/runtime/dit.go [new file with mode: 0644]
src/runtime/proc.go
src/runtime/runtime2.go
src/runtime/sizeof_test.go

index 04e06cf95ec55086ec3a606b4f9fe4e4b716f83f..17f5f3531c02c5bbeb55cfdfe2c895ed4c50a983 100644 (file)
@@ -12,101 +12,105 @@ import "testing"
 // so that they can use cgo (import "C").
 // These wrappers are here for gotest to find.
 
-func Test1328(t *testing.T)                  { test1328(t) }
-func Test1560(t *testing.T)                  { test1560(t) }
-func Test1635(t *testing.T)                  { test1635(t) }
-func Test3250(t *testing.T)                  { test3250(t) }
-func Test3729(t *testing.T)                  { test3729(t) }
-func Test3775(t *testing.T)                  { test3775(t) }
-func Test4029(t *testing.T)                  { test4029(t) }
-func Test4339(t *testing.T)                  { test4339(t) }
-func Test5227(t *testing.T)                  { test5227(t) }
-func Test5242(t *testing.T)                  { test5242(t) }
-func Test5337(t *testing.T)                  { test5337(t) }
-func Test5548(t *testing.T)                  { test5548(t) }
-func Test5603(t *testing.T)                  { test5603(t) }
-func Test5986(t *testing.T)                  { test5986(t) }
-func Test6390(t *testing.T)                  { test6390(t) }
-func Test6833(t *testing.T)                  { test6833(t) }
-func Test6907(t *testing.T)                  { test6907(t) }
-func Test6907Go(t *testing.T)                { test6907Go(t) }
-func Test7560(t *testing.T)                  { test7560(t) }
-func Test7665(t *testing.T)                  { test7665(t) }
-func Test7978(t *testing.T)                  { test7978(t) }
-func Test8092(t *testing.T)                  { test8092(t) }
-func Test8517(t *testing.T)                  { test8517(t) }
-func Test8694(t *testing.T)                  { test8694(t) }
-func Test8756(t *testing.T)                  { test8756(t) }
-func Test8811(t *testing.T)                  { test8811(t) }
-func Test9026(t *testing.T)                  { test9026(t) }
-func Test9510(t *testing.T)                  { test9510(t) }
-func Test9557(t *testing.T)                  { test9557(t) }
-func Test10303(t *testing.T)                 { test10303(t, 10) }
-func Test11925(t *testing.T)                 { test11925(t) }
-func Test12030(t *testing.T)                 { test12030(t) }
-func Test14838(t *testing.T)                 { test14838(t) }
-func Test17065(t *testing.T)                 { test17065(t) }
-func Test17537(t *testing.T)                 { test17537(t) }
-func Test18126(t *testing.T)                 { test18126(t) }
-func Test18720(t *testing.T)                 { test18720(t) }
-func Test20129(t *testing.T)                 { test20129(t) }
-func Test20266(t *testing.T)                 { test20266(t) }
-func Test20369(t *testing.T)                 { test20369(t) }
-func Test20910(t *testing.T)                 { test20910(t) }
-func Test21708(t *testing.T)                 { test21708(t) }
-func Test21809(t *testing.T)                 { test21809(t) }
-func Test21897(t *testing.T)                 { test21897(t) }
-func Test22906(t *testing.T)                 { test22906(t) }
-func Test23356(t *testing.T)                 { test23356(t) }
-func Test24206(t *testing.T)                 { test24206(t) }
-func Test25143(t *testing.T)                 { test25143(t) }
-func Test26066(t *testing.T)                 { test26066(t) }
-func Test26213(t *testing.T)                 { test26213(t) }
-func Test27660(t *testing.T)                 { test27660(t) }
-func Test28896(t *testing.T)                 { test28896(t) }
-func Test30065(t *testing.T)                 { test30065(t) }
-func Test32579(t *testing.T)                 { test32579(t) }
-func Test31891(t *testing.T)                 { test31891(t) }
-func Test42018(t *testing.T)                 { test42018(t) }
-func Test45451(t *testing.T)                 { test45451(t) }
-func Test49633(t *testing.T)                 { test49633(t) }
-func Test69086(t *testing.T)                 { test69086(t) }
-func TestAlign(t *testing.T)                 { testAlign(t) }
-func TestAtol(t *testing.T)                  { testAtol(t) }
-func TestBlocking(t *testing.T)              { testBlocking(t) }
-func TestBoolAlign(t *testing.T)             { testBoolAlign(t) }
-func TestCallGoWithString(t *testing.T)      { testCallGoWithString(t) }
-func TestCallback(t *testing.T)              { testCallback(t) }
-func TestCallbackCallers(t *testing.T)       { testCallbackCallers(t) }
-func TestCallbackGC(t *testing.T)            { testCallbackGC(t) }
-func TestCallbackPanic(t *testing.T)         { testCallbackPanic(t) }
-func TestCallbackPanicLocked(t *testing.T)   { testCallbackPanicLocked(t) }
-func TestCallbackPanicLoop(t *testing.T)     { testCallbackPanicLoop(t) }
-func TestCallbackStack(t *testing.T)         { testCallbackStack(t) }
-func TestCflags(t *testing.T)                { testCflags(t) }
-func TestCheckConst(t *testing.T)            { testCheckConst(t) }
-func TestConst(t *testing.T)                 { testConst(t) }
-func TestCthread(t *testing.T)               { testCthread(t) }
-func TestEnum(t *testing.T)                  { testEnum(t) }
-func TestNamedEnum(t *testing.T)             { testNamedEnum(t) }
-func TestCastToEnum(t *testing.T)            { testCastToEnum(t) }
-func TestErrno(t *testing.T)                 { testErrno(t) }
-func TestFpVar(t *testing.T)                 { testFpVar(t) }
-func TestGCC68255(t *testing.T)              { testGCC68255(t) }
-func TestHandle(t *testing.T)                { testHandle(t) }
-func TestHelpers(t *testing.T)               { testHelpers(t) }
-func TestLibgcc(t *testing.T)                { testLibgcc(t) }
-func TestMultipleAssign(t *testing.T)        { testMultipleAssign(t) }
-func TestNaming(t *testing.T)                { testNaming(t) }
-func TestPanicFromC(t *testing.T)            { testPanicFromC(t) }
-func TestPrintf(t *testing.T)                { testPrintf(t) }
-func TestReturnAfterGrow(t *testing.T)       { testReturnAfterGrow(t) }
-func TestReturnAfterGrowFromGo(t *testing.T) { testReturnAfterGrowFromGo(t) }
-func TestSetEnv(t *testing.T)                { testSetEnv(t) }
-func TestThreadLock(t *testing.T)            { testThreadLockFunc(t) }
-func TestUnsignedInt(t *testing.T)           { testUnsignedInt(t) }
-func TestZeroArgCallback(t *testing.T)       { testZeroArgCallback(t) }
-func Test76340(t *testing.T)                 { test76340(t) }
+func Test1328(t *testing.T)                     { test1328(t) }
+func Test1560(t *testing.T)                     { test1560(t) }
+func Test1635(t *testing.T)                     { test1635(t) }
+func Test3250(t *testing.T)                     { test3250(t) }
+func Test3729(t *testing.T)                     { test3729(t) }
+func Test3775(t *testing.T)                     { test3775(t) }
+func Test4029(t *testing.T)                     { test4029(t) }
+func Test4339(t *testing.T)                     { test4339(t) }
+func Test5227(t *testing.T)                     { test5227(t) }
+func Test5242(t *testing.T)                     { test5242(t) }
+func Test5337(t *testing.T)                     { test5337(t) }
+func Test5548(t *testing.T)                     { test5548(t) }
+func Test5603(t *testing.T)                     { test5603(t) }
+func Test5986(t *testing.T)                     { test5986(t) }
+func Test6390(t *testing.T)                     { test6390(t) }
+func Test6833(t *testing.T)                     { test6833(t) }
+func Test6907(t *testing.T)                     { test6907(t) }
+func Test6907Go(t *testing.T)                   { test6907Go(t) }
+func Test7560(t *testing.T)                     { test7560(t) }
+func Test7665(t *testing.T)                     { test7665(t) }
+func Test7978(t *testing.T)                     { test7978(t) }
+func Test8092(t *testing.T)                     { test8092(t) }
+func Test8517(t *testing.T)                     { test8517(t) }
+func Test8694(t *testing.T)                     { test8694(t) }
+func Test8756(t *testing.T)                     { test8756(t) }
+func Test8811(t *testing.T)                     { test8811(t) }
+func Test9026(t *testing.T)                     { test9026(t) }
+func Test9510(t *testing.T)                     { test9510(t) }
+func Test9557(t *testing.T)                     { test9557(t) }
+func Test10303(t *testing.T)                    { test10303(t, 10) }
+func Test11925(t *testing.T)                    { test11925(t) }
+func Test12030(t *testing.T)                    { test12030(t) }
+func Test14838(t *testing.T)                    { test14838(t) }
+func Test17065(t *testing.T)                    { test17065(t) }
+func Test17537(t *testing.T)                    { test17537(t) }
+func Test18126(t *testing.T)                    { test18126(t) }
+func Test18720(t *testing.T)                    { test18720(t) }
+func Test20129(t *testing.T)                    { test20129(t) }
+func Test20266(t *testing.T)                    { test20266(t) }
+func Test20369(t *testing.T)                    { test20369(t) }
+func Test20910(t *testing.T)                    { test20910(t) }
+func Test21708(t *testing.T)                    { test21708(t) }
+func Test21809(t *testing.T)                    { test21809(t) }
+func Test21897(t *testing.T)                    { test21897(t) }
+func Test22906(t *testing.T)                    { test22906(t) }
+func Test23356(t *testing.T)                    { test23356(t) }
+func Test24206(t *testing.T)                    { test24206(t) }
+func Test25143(t *testing.T)                    { test25143(t) }
+func Test26066(t *testing.T)                    { test26066(t) }
+func Test26213(t *testing.T)                    { test26213(t) }
+func Test27660(t *testing.T)                    { test27660(t) }
+func Test28896(t *testing.T)                    { test28896(t) }
+func Test30065(t *testing.T)                    { test30065(t) }
+func Test32579(t *testing.T)                    { test32579(t) }
+func Test31891(t *testing.T)                    { test31891(t) }
+func Test42018(t *testing.T)                    { test42018(t) }
+func Test45451(t *testing.T)                    { test45451(t) }
+func Test49633(t *testing.T)                    { test49633(t) }
+func Test69086(t *testing.T)                    { test69086(t) }
+func TestAlign(t *testing.T)                    { testAlign(t) }
+func TestAtol(t *testing.T)                     { testAtol(t) }
+func TestBlocking(t *testing.T)                 { testBlocking(t) }
+func TestBoolAlign(t *testing.T)                { testBoolAlign(t) }
+func TestCallGoWithString(t *testing.T)         { testCallGoWithString(t) }
+func TestCallback(t *testing.T)                 { testCallback(t) }
+func TestCallbackCallers(t *testing.T)          { testCallbackCallers(t) }
+func TestCallbackGC(t *testing.T)               { testCallbackGC(t) }
+func TestCallbackPanic(t *testing.T)            { testCallbackPanic(t) }
+func TestCallbackPanicLocked(t *testing.T)      { testCallbackPanicLocked(t) }
+func TestCallbackPanicLoop(t *testing.T)        { testCallbackPanicLoop(t) }
+func TestCallbackStack(t *testing.T)            { testCallbackStack(t) }
+func TestCflags(t *testing.T)                   { testCflags(t) }
+func TestCheckConst(t *testing.T)               { testCheckConst(t) }
+func TestConst(t *testing.T)                    { testConst(t) }
+func TestCthread(t *testing.T)                  { testCthread(t) }
+func TestEnum(t *testing.T)                     { testEnum(t) }
+func TestNamedEnum(t *testing.T)                { testNamedEnum(t) }
+func TestCastToEnum(t *testing.T)               { testCastToEnum(t) }
+func TestErrno(t *testing.T)                    { testErrno(t) }
+func TestFpVar(t *testing.T)                    { testFpVar(t) }
+func TestGCC68255(t *testing.T)                 { testGCC68255(t) }
+func TestHandle(t *testing.T)                   { testHandle(t) }
+func TestHelpers(t *testing.T)                  { testHelpers(t) }
+func TestLibgcc(t *testing.T)                   { testLibgcc(t) }
+func TestMultipleAssign(t *testing.T)           { testMultipleAssign(t) }
+func TestNaming(t *testing.T)                   { testNaming(t) }
+func TestPanicFromC(t *testing.T)               { testPanicFromC(t) }
+func TestPrintf(t *testing.T)                   { testPrintf(t) }
+func TestReturnAfterGrow(t *testing.T)          { testReturnAfterGrow(t) }
+func TestReturnAfterGrowFromGo(t *testing.T)    { testReturnAfterGrowFromGo(t) }
+func TestSetEnv(t *testing.T)                   { testSetEnv(t) }
+func TestThreadLock(t *testing.T)               { testThreadLockFunc(t) }
+func TestUnsignedInt(t *testing.T)              { testUnsignedInt(t) }
+func TestZeroArgCallback(t *testing.T)          { testZeroArgCallback(t) }
+func Test76340(t *testing.T)                    { test76340(t) }
+func TestDITCgo(t *testing.T)                   { testDITCgo(t) }
+func TestDITCgoCallback(t *testing.T)           { testDITCgoCallback(t) }
+func TestDITCgoCallbackEnableDIT(t *testing.T)  { testDITCgoCallbackEnableDIT(t) }
+func TestDITCgoCallbackDisableDIT(t *testing.T) { testDITCgoCallbackDisableDIT(t) }
 
 func BenchmarkCgoCall(b *testing.B)      { benchCgoCall(b) }
 func BenchmarkGoString(b *testing.B)     { benchGoString(b) }
index 4dd14facb50e6b88b1ab124b81209d03d2fb1b62..3f0732b428b2fcb8612b75e67efbe0a725c9cdfd 100644 (file)
@@ -971,13 +971,43 @@ int issue76340testFromC(GoInterface obj) {
 GoInterface issue76340returnFromC(int val) {
        return exportAny76340Return(val);
 }
+
+static void enableDIT() {
+       #ifdef __arm64__
+       __asm__ __volatile__("msr dit, #1");
+       #endif
+}
+
+static void disableDIT() {
+       #ifdef __arm64__
+       __asm__ __volatile__("msr dit, #0");
+       #endif
+}
+
+extern uint8_t ditCallback();
+
+static uint8_t ditCallbackTest() {
+       return ditCallback();
+}
+
+static void ditCallbackEnableDIT() {
+       enableDIT();
+       ditCallback();
+}
+
+static void ditCallbackDisableDIT() {
+       disableDIT();
+       ditCallback();
+}
 */
 import "C"
 
 import (
        "context"
+       "crypto/subtle"
        "fmt"
        "internal/asan"
+       "internal/runtime/sys"
        "math"
        "math/rand"
        "os"
@@ -2438,3 +2468,76 @@ func test76340(t *testing.T) {
                t.Errorf("issue76340returnFromC(0) returned non-nil interface: got %v, want nil", r3)
        }
 }
+
+func testDITCgo(t *testing.T) {
+       if !sys.DITSupported {
+               t.Skip("CPU does not support DIT")
+       }
+
+       ditAlreadyEnabled := sys.DITEnabled()
+       C.enableDIT()
+
+       if ditAlreadyEnabled != sys.DITEnabled() {
+               t.Fatalf("DIT state not preserved across cgo call: before %t, after %t", ditAlreadyEnabled, sys.DITEnabled())
+       }
+
+       subtle.WithDataIndependentTiming(func() {
+               C.disableDIT()
+
+               if !sys.DITEnabled() {
+                       t.Fatal("DIT disabled after disabling in cgo call")
+               }
+       })
+}
+
+func testDITCgoCallback(t *testing.T) {
+       if !sys.DITSupported {
+               t.Skip("CPU does not support DIT")
+       }
+
+       ditAlreadyEnabled := sys.DITEnabled()
+
+       subtle.WithDataIndependentTiming(func() {
+               if C.ditCallbackTest() != 1 {
+                       t.Fatal("DIT not enabled in cgo callback within WithDataIndependentTiming")
+               }
+       })
+
+       if ditAlreadyEnabled != sys.DITEnabled() {
+               t.Fatalf("DIT state not preserved across cgo callback: before %t, after %t", ditAlreadyEnabled, sys.DITEnabled())
+       }
+}
+
+func testDITCgoCallbackEnableDIT(t *testing.T) {
+       if !sys.DITSupported {
+               t.Skip("CPU does not support DIT")
+       }
+
+       ditAlreadyEnabled := sys.DITEnabled()
+
+       C.ditCallbackEnableDIT()
+
+       if ditAlreadyEnabled != sys.DITEnabled() {
+               t.Fatalf("DIT state not preserved across cgo callback: before %t, after %t", ditAlreadyEnabled, sys.DITEnabled())
+       }
+}
+
+func testDITCgoCallbackDisableDIT(t *testing.T) {
+       if !sys.DITSupported {
+               t.Skip("CPU does not support DIT")
+       }
+
+       ditAlreadyEnabled := sys.DITEnabled()
+
+       subtle.WithDataIndependentTiming(func() {
+               C.ditCallbackDisableDIT()
+
+               if !sys.DITEnabled() {
+                       t.Fatal("DIT disabled after disabling in cgo call")
+               }
+       })
+
+       if ditAlreadyEnabled != sys.DITEnabled() {
+               t.Fatalf("DIT state not preserved across cgo callback: before %t, after %t", ditAlreadyEnabled, sys.DITEnabled())
+       }
+}
index 21ba52260ef893e2e5f5a620de6de63e61b3eb0a..5a6f42e44e6021b1da9e010762bd28fc76ba3202 100644 (file)
@@ -11,6 +11,7 @@
 package cgotest
 
 import (
+       "internal/runtime/sys"
        "runtime"
        "runtime/cgo"
        "runtime/debug"
@@ -613,3 +614,11 @@ func exportAny76340Return(val C.int) any {
 
        return int(val)
 }
+
+//export ditCallback
+func ditCallback() uint8 {
+       if sys.DITEnabled() {
+               return 1
+       }
+       return 0
+}
index 9ab55643f6313d457f2a11332413496f2cc9fc81..57029a1300ce90d38f82b1c4858d5c453ce82205 100644 (file)
@@ -2449,6 +2449,8 @@ var blockedLinknames = map[string][]string{
        "runtime.mapdelete_faststr":  {"runtime"},
        // New internal linknames in Go 1.25
        // Pushed from runtime
+       "crypto/subtle.setDITEnabled":                    {"crypto/subtle"},
+       "crypto/subtle.setDITDisabled":                   {"crypto/subtle"},
        "internal/cpu.riscvHWProbe":                      {"internal/cpu"},
        "internal/runtime/cgroup.throw":                  {"internal/runtime/cgroup"},
        "internal/runtime/maps.typeString":               {"internal/runtime/maps"},
index c23df971f0bd71fca0c82030e6b654f114111380..733261c3b09d2f99c7d85af19ec812d75bd27f5c 100644 (file)
@@ -6,19 +6,29 @@ package subtle
 
 import (
        "internal/runtime/sys"
-       "runtime"
+       _ "unsafe"
 )
 
 // WithDataIndependentTiming enables architecture specific features which ensure
 // that the timing of specific instructions is independent of their inputs
 // before executing f. On f returning it disables these features.
 //
+// Any goroutine spawned by f will also have data independent timing enabled for
+// its lifetime, as well as any of their descendant goroutines.
+//
+// Any C code called via cgo from within f, or from a goroutine spawned by f, will
+// also have data independent timing enabled for the duration of the call. If the
+// C code disables data independent timing, it will be re-enabled on return to Go.
+//
+// If C code called via cgo, from f or elsewhere, enables or disables data
+// independent timing then calling into Go will preserve that state for the
+// duration of the call.
+//
 // WithDataIndependentTiming should only be used when f is written to make use
 // of constant-time operations. WithDataIndependentTiming does not make
 // variable-time code constant-time.
 //
-// WithDataIndependentTiming may lock the current goroutine to the OS thread for
-// the duration of f. Calls to WithDataIndependentTiming may be nested.
+// Calls to WithDataIndependentTiming may be nested.
 //
 // On Arm64 processors with FEAT_DIT, WithDataIndependentTiming enables
 // PSTATE.DIT. See https://developer.arm.com/documentation/ka005181/1-0/?lang=en.
@@ -33,18 +43,21 @@ func WithDataIndependentTiming(f func()) {
                return
        }
 
-       runtime.LockOSThread()
-       defer runtime.UnlockOSThread()
-
-       alreadyEnabled := sys.EnableDIT()
+       alreadyEnabled := setDITEnabled()
 
        // disableDIT is called in a deferred function so that if f panics we will
        // still disable DIT, in case the panic is recovered further up the stack.
        defer func() {
                if !alreadyEnabled {
-                       sys.DisableDIT()
+                       setDITDisabled()
                }
        }()
 
        f()
 }
+
+//go:linkname setDITEnabled
+func setDITEnabled() bool
+
+//go:linkname setDITDisabled
+func setDITDisabled()
index 29779683b57c06b2c5cd86721b84a2ac74afbf03..952a18db1f66cd3b0a4bd53cf49e070dbf1fde57 100644 (file)
@@ -63,3 +63,25 @@ func TestDITPanic(t *testing.T) {
                panic("bad")
        })
 }
+
+func TestDITGoroutineInheritance(t *testing.T) {
+       if !cpu.ARM64.HasDIT {
+               t.Skip("CPU does not support DIT")
+       }
+
+       ditAlreadyEnabled := sys.DITEnabled()
+
+       WithDataIndependentTiming(func() {
+               done := make(chan struct{})
+               go func() {
+                       if !sys.DITEnabled() {
+                               t.Error("DIT not enabled in new goroutine")
+                       }
+                       close(done)
+               }()
+               <-done
+               if !ditAlreadyEnabled && !sys.DITEnabled() {
+                       t.Fatal("dit unset after returning from goroutine started in WithDataIndependentTiming closure")
+               }
+       })
+}
index 55e7bdbdb55f7956e8ae5bd4cc24d8c3b9edda02..b36da2f12b7264baf48f5efca70d79236e8eda75 100644 (file)
@@ -205,6 +205,18 @@ func cgocall(fn, arg unsafe.Pointer) int32 {
                raceacquire(unsafe.Pointer(&racecgosync))
        }
 
+       if sys.DITSupported {
+               // C code may have enabled or disabled DIT on this thread, restore
+               // our state to the expected one.
+               ditEnabled := sys.DITEnabled()
+               gp := getg()
+               if !gp.ditWanted && ditEnabled {
+                       sys.DisableDIT()
+               } else if gp.ditWanted && !ditEnabled {
+                       sys.EnableDIT()
+               }
+       }
+
        // From the garbage collector's perspective, time can move
        // backwards in the sequence above. If there's a callback into
        // Go code, GC will see this function at the call to
@@ -427,11 +439,19 @@ func cgocallbackg1(fn, frame unsafe.Pointer, ctxt uintptr) {
        restore := true
        defer unwindm(&restore)
 
-       var ditAlreadySet bool
+       var ditStateM, ditStateG bool
        if debug.dataindependenttiming == 1 && gp.m.isextra {
                // We only need to enable DIT for threads that were created by C, as it
                // should already by enabled on threads that were created by Go.
-               ditAlreadySet = sys.EnableDIT()
+               ditStateM = sys.EnableDIT()
+       } else if sys.DITSupported && debug.dataindependenttiming != 1 {
+               // C code may have enabled or disabled DIT on this thread. Set the flag
+               // on the M and G accordingly, saving their previous state to restore
+               // on return from the callback.
+               ditStateM, ditStateG = gp.m.ditEnabled, gp.ditWanted
+               ditEnabled := sys.DITEnabled()
+               gp.ditWanted = ditEnabled
+               gp.m.ditEnabled = ditEnabled
        }
 
        if raceenabled {
@@ -449,9 +469,16 @@ func cgocallbackg1(fn, frame unsafe.Pointer, ctxt uintptr) {
                racereleasemerge(unsafe.Pointer(&racecgosync))
        }
 
-       if debug.dataindependenttiming == 1 && !ditAlreadySet {
+       if debug.dataindependenttiming == 1 && !ditStateM {
                // Only unset DIT if it wasn't already enabled when cgocallback was called.
                sys.DisableDIT()
+       } else if sys.DITSupported && debug.dataindependenttiming != 1 {
+               // Restore DIT state on M and G.
+               gp.ditWanted = ditStateG
+               gp.m.ditEnabled = ditStateM
+               if !ditStateM {
+                       sys.DisableDIT()
+               }
        }
 
        // Do not unwind m->g0->sched.sp.
diff --git a/src/runtime/dit.go b/src/runtime/dit.go
new file mode 100644 (file)
index 0000000..c234b02
--- /dev/null
@@ -0,0 +1,26 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package runtime
+
+import (
+       "internal/runtime/sys"
+       _ "unsafe"
+)
+
+//go:linkname dit_setEnabled crypto/subtle.setDITEnabled
+func dit_setEnabled() bool {
+       g := getg()
+       g.ditWanted = true
+       g.m.ditEnabled = true
+       return sys.EnableDIT()
+}
+
+//go:linkname dit_setDisabled crypto/subtle.setDITDisabled
+func dit_setDisabled() {
+       g := getg()
+       g.ditWanted = false
+       g.m.ditEnabled = false
+       sys.DisableDIT()
+}
index 52def488ffca42f8a9cf3547020a709c551050f3..4bca9f1347d45bf6f9c38c2ad9eb6bf63b0c6583 100644 (file)
@@ -3342,6 +3342,23 @@ func execute(gp *g, inheritTime bool) {
                mp.p.ptr().schedtick++
        }
 
+       if sys.DITSupported && debug.dataindependenttiming != 1 {
+               if gp.ditWanted && !mp.ditEnabled {
+                       // The current M doesn't have DIT enabled, but the goroutine we're
+                       // executing does need it, so turn it on.
+                       sys.EnableDIT()
+                       mp.ditEnabled = true
+               } else if !gp.ditWanted && mp.ditEnabled {
+                       // The current M has DIT enabled, but the goroutine we're executing does
+                       // not need it, so turn it off.
+                       // NOTE: turning off DIT here means that the scheduler will have DIT enabled
+                       // when it runs after this goroutine yields or is preempted. This may have
+                       // a minor performance impact on the scheduler.
+                       sys.DisableDIT()
+                       mp.ditEnabled = false
+               }
+       }
+
        // Check whether the profiler needs to be turned on or off.
        hz := sched.profilehz
        if mp.profilehz != hz {
@@ -5378,6 +5395,9 @@ func newproc1(fn *funcval, callergp *g, callerpc uintptr, parked bool, waitreaso
        // fips140 bubble
        newg.fipsOnlyBypass = callergp.fipsOnlyBypass
 
+       // dit bubble
+       newg.ditWanted = callergp.ditWanted
+
        // Set up race context.
        if raceenabled {
                newg.racectx = racegostart(callerpc)
index fde378ff25ce540dfecd06510a225c93e3adb919..be33932b24d16010920a8f9abb97e87c121ec49f 100644 (file)
@@ -546,6 +546,7 @@ type g struct {
        lockedm         muintptr
        fipsIndicator   uint8
        fipsOnlyBypass  bool
+       ditWanted       bool // set if g wants to be executed with DIT enabled
        syncSafePoint   bool // set if g is stopped at a synchronous safe point.
        runningCleanups atomic.Bool
        sig             uint32
@@ -674,6 +675,7 @@ type m struct {
        lockedExt       uint32      // tracking for external LockOSThread
        lockedInt       uint32      // tracking for internal lockOSThread
        mWaitList       mWaitList   // list of runtime lock waiters
+       ditEnabled      bool        // set if DIT is currently enabled on this M
 
        mLockProfile mLockProfile // fields relating to runtime.lock contention
        profStack    []uintptr    // used for memory/block/mutex stack traces
index 9dde0da9636a48eab297bf3b5b9c20818e69eba4..219bcb92598c3a0a0e0ee80ca7b547552ee56fa9 100644 (file)
@@ -21,7 +21,7 @@ func TestSizeof(t *testing.T) {
                _32bit uintptr // size on 32bit platforms
                _64bit uintptr // size on 64bit platforms
        }{
-               {runtime.G{}, 284 + xreg, 448 + xreg}, // g, but exported for testing
+               {runtime.G{}, 288 + xreg, 448 + xreg}, // g, but exported for testing
                {runtime.Sudog{}, 64, 104},            // sudog, but exported for testing
        }