From 0f1227c507c62b635a3b4b85f5b7a967df72b59f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 7 Sep 2023 09:27:50 -0700 Subject: [PATCH] context: avoid key collisions in test afterfunc map The afterFuncContext type, used only in tests, contains a set of registered afterfuncs indexed by an arbitrary unique key. That key is currently a *struct{}. Unfortunately, all *struct{} pointers are equal to each other, so all registered funcs share the same key. Fortunately, the tests using this type never register more than one afterfunc. Change the key to a *byte. Change-Id: Icadf7d6f258e328f6e3375846d29ce0f98b60924 Reviewed-on: https://go-review.googlesource.com/c/go/+/526655 LUCI-TryBot-Result: Go LUCI Reviewed-by: Bryan Mills --- src/context/afterfunc_test.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/context/afterfunc_test.go b/src/context/afterfunc_test.go index 71f639a345..7b75295eb4 100644 --- a/src/context/afterfunc_test.go +++ b/src/context/afterfunc_test.go @@ -15,7 +15,7 @@ import ( // defined in context.go, that supports registering AfterFuncs. type afterFuncContext struct { mu sync.Mutex - afterFuncs map[*struct{}]func() + afterFuncs map[*byte]func() done chan struct{} err error } @@ -50,9 +50,9 @@ func (c *afterFuncContext) Value(key any) any { func (c *afterFuncContext) AfterFunc(f func()) func() bool { c.mu.Lock() defer c.mu.Unlock() - k := &struct{}{} + k := new(byte) if c.afterFuncs == nil { - c.afterFuncs = make(map[*struct{}]func()) + c.afterFuncs = make(map[*byte]func()) } c.afterFuncs[k] = f return func() bool { @@ -106,11 +106,13 @@ func TestCustomContextAfterFuncAfterFunc(t *testing.T) { func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) { ctx0 := &afterFuncContext{} - _, cancel := context.WithCancel(ctx0) - if got, want := len(ctx0.afterFuncs), 1; got != want { + _, cancel1 := context.WithCancel(ctx0) + _, cancel2 := context.WithCancel(ctx0) + if got, want := len(ctx0.afterFuncs), 2; got != want { t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) } - cancel() + cancel1() + cancel2() if got, want := len(ctx0.afterFuncs), 0; got != want { t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) } -- 2.50.0