]> Cypherpunks repositories - gostls13.git/commitdiff
context: avoid key collisions in test afterfunc map
authorDamien Neil <dneil@google.com>
Thu, 7 Sep 2023 16:27:50 +0000 (09:27 -0700)
committerDamien Neil <dneil@google.com>
Mon, 18 Sep 2023 16:58:52 +0000 (16:58 +0000)
The afterFuncContext type, used only in tests, contains a
set of registered afterfuncs indexed by an arbitrary unique key.

That key is currently a *struct{}. Unfortunately, all
*struct{} pointers are equal to each other, so all registered
funcs share the same key. Fortunately, the tests using this
type never register more than one afterfunc.

Change the key to a *byte.

Change-Id: Icadf7d6f258e328f6e3375846d29ce0f98b60924
Reviewed-on: https://go-review.googlesource.com/c/go/+/526655
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
src/context/afterfunc_test.go

index 71f639a345bbc71f1dcf02aaaf51aa2e4caec9d2..7b75295eb4dab03363efee1f56fbfa693edeca78 100644 (file)
@@ -15,7 +15,7 @@ import (
 // defined in context.go, that supports registering AfterFuncs.
 type afterFuncContext struct {
        mu         sync.Mutex
-       afterFuncs map[*struct{}]func()
+       afterFuncs map[*byte]func()
        done       chan struct{}
        err        error
 }
@@ -50,9 +50,9 @@ func (c *afterFuncContext) Value(key any) any {
 func (c *afterFuncContext) AfterFunc(f func()) func() bool {
        c.mu.Lock()
        defer c.mu.Unlock()
-       k := &struct{}{}
+       k := new(byte)
        if c.afterFuncs == nil {
-               c.afterFuncs = make(map[*struct{}]func())
+               c.afterFuncs = make(map[*byte]func())
        }
        c.afterFuncs[k] = f
        return func() bool {
@@ -106,11 +106,13 @@ func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
 
 func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
        ctx0 := &afterFuncContext{}
-       _, cancel := context.WithCancel(ctx0)
-       if got, want := len(ctx0.afterFuncs), 1; got != want {
+       _, cancel1 := context.WithCancel(ctx0)
+       _, cancel2 := context.WithCancel(ctx0)
+       if got, want := len(ctx0.afterFuncs), 2; got != want {
                t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
        }
-       cancel()
+       cancel1()
+       cancel2()
        if got, want := len(ctx0.afterFuncs), 0; got != want {
                t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
        }