From 85fa418fd5bec603dd87254178e64afd1782bd35 Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Fri, 21 Jun 2024 20:42:45 +0000 Subject: [PATCH] internal/sync: add CompareAndSwap to HashTrieMap This change adds the CompareAndSwap operation (with the same semantics as sync.Map's CompareAndSwap) to HashTrieMap. Change-Id: I86153799fc47304784333d17f0c6a7ad7682f04a Reviewed-on: https://go-review.googlesource.com/c/go/+/594063 Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI Auto-Submit: Michael Knyszek --- src/internal/sync/hashtriemap.go | 81 ++++++++++- src/internal/sync/hashtriemap_test.go | 191 +++++++++++++++++++++++++- 2 files changed, 263 insertions(+), 9 deletions(-) diff --git a/src/internal/sync/hashtriemap.go b/src/internal/sync/hashtriemap.go index 082aecacba..29c88b055e 100644 --- a/src/internal/sync/hashtriemap.go +++ b/src/internal/sync/hashtriemap.go @@ -195,6 +195,36 @@ func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uin return &top.node } +// CompareAndSwap swaps the old and new values for key +// if the value stored in the map is equal to old. +// The value type must be of a comparable type, otherwise CompareAndSwap will panic. +func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) { + ht.init() + if ht.valEqual == nil { + panic("called CompareAndSwap when value is not of comparable type") + } + hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) + + // Find a node with the key and compare with it. n != nil if we found the node. + i, _, slot, n := ht.find(key, hash, ht.valEqual, old) + if i != nil { + defer i.mu.Unlock() + } + if n == nil { + return false + } + + // Try to swap the entry. + e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual) + if !swapped { + // Nothing was actually swapped, which means the node is no longer there. + return false + } + // Store the entry back because it changed. + slot.Store(&e.node) + return true +} + // CompareAndDelete deletes the entry for key if its value is equal to old. // The value type must be comparable, otherwise this CompareAndDelete will panic. // @@ -207,8 +237,8 @@ func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { } hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) - // Find a node with the key and compare with it. n != nil if we found the node. - i, hashShift, slot, n := ht.find(key, hash) + // Find a node with the key. n != nil if we found the node. + i, hashShift, slot, n := ht.find(key, hash, nil, *new(V)) if n == nil { if i != nil { i.mu.Unlock() @@ -252,14 +282,15 @@ func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { return true } -// compare searches the tree for a node that compares with key (hash must be the hash of key). +// find searches the tree for a node that contains key (hash must be the hash of key). +// If valEqual != nil, then it will also enforce that the values are equal as well. // // Returns a non-nil node, which will always be an entry, if found. // // If i != nil then i.mu is locked, and it is the caller's responsibility to unlock it. -func (ht *HashTrieMap[K, V]) find(key K, hash uintptr) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) { +func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) { for { - // Find the key or return when there's nothing to delete. + // Find the key or return if it's not there. i = ht.root hashShift = 8 * goarch.PtrSize found := false @@ -275,7 +306,7 @@ func (ht *HashTrieMap[K, V]) find(key K, hash uintptr) (i *indirect[K, V], hashS } if n.isEntry { // We found an entry. Check if it matches. - if _, ok := n.entry().lookup(key); !ok { + if _, ok := n.entry().lookupWithValue(key, value, valEqual); !ok { // No match, comparison failed. i = nil n = nil @@ -398,6 +429,44 @@ func (e *entry[K, V]) lookup(key K) (V, bool) { return *new(V), false } +func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) { + for e != nil { + if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) { + return e.value, true + } + e = e.overflow.Load() + } + return *new(V), false +} + +// compareAndSwap replaces an entry in the overflow chain if both the key and value compare +// equal. Returns the new entry chain and whether or not anything was swapped. +// +// compareAndSwap must be called under the mutex of the indirect node which e is a child of. +func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) { + if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) { + // Return the new head of the list. + e := newEntryNode(key, new) + if chain := head.overflow.Load(); chain != nil { + e.overflow.Store(chain) + } + return e, true + } + i := &head.overflow + e := i.Load() + for e != nil { + if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) { + eNew := newEntryNode(key, new) + eNew.overflow.Store(e.overflow.Load()) + i.Store(eNew) + return head, true + } + i = &e.overflow + e = e.overflow.Load() + } + return head, false +} + // compareAndDelete deletes an entry in the overflow chain if both the key and value compare // equal. Returns the new entry chain and whether or not anything was deleted. // diff --git a/src/internal/sync/hashtriemap_test.go b/src/internal/sync/hashtriemap_test.go index 9ab11d4126..b34350c15b 100644 --- a/src/internal/sync/hashtriemap_test.go +++ b/src/internal/sync/hashtriemap_test.go @@ -101,7 +101,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] } } }) - t.Run("DeleteMultiple", func(t *testing.T) { + t.Run("CompareAndDeleteMultiple", func(t *testing.T) { m := newMap() for i, s := range testData { @@ -130,7 +130,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] return true }) }) - t.Run("AllDelete", func(t *testing.T) { + t.Run("AllCompareAndDelete", func(t *testing.T) { m := newMap() testAll(t, m, testDataMap(testData[:]), func(s string, i int) bool { @@ -175,7 +175,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] } wg.Wait() }) - t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) { + t.Run("ConcurrentCompareAndDeleteSharedKeys", func(t *testing.T) { m := newMap() // Load up the map. @@ -202,6 +202,169 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] } wg.Wait() }) + t.Run("CompareAndSwapAll", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + for j := range 3 { + for i, s := range testData { + expectPresent(t, s, i+j)(m.Load(s)) + expectNotSwapped(t, s, math.MaxInt, i+j+1)(m.CompareAndSwap(s, math.MaxInt, i+j+1)) + expectSwapped(t, s, i, i+j+1)(m.CompareAndSwap(s, i+j, i+j+1)) + expectNotSwapped(t, s, i+j, i+j+1)(m.CompareAndSwap(s, i+j, i+j+1)) + expectPresent(t, s, i+j+1)(m.Load(s)) + } + } + for i, s := range testData { + expectPresent(t, s, i+3)(m.Load(s)) + } + }) + t.Run("CompareAndSwapOne", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + expectNotSwapped(t, testData[15], math.MaxInt, 16)(m.CompareAndSwap(testData[15], math.MaxInt, 16)) + expectSwapped(t, testData[15], 15, 16)(m.CompareAndSwap(testData[15], 15, 16)) + expectNotSwapped(t, testData[15], 15, 16)(m.CompareAndSwap(testData[15], 15, 16)) + for i, s := range testData { + if i == 15 { + expectPresent(t, s, 16)(m.Load(s)) + } else { + expectPresent(t, s, i)(m.Load(s)) + } + } + }) + t.Run("CompareAndSwapMultiple", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + for _, i := range []int{1, 105, 6, 85} { + expectNotSwapped(t, testData[i], math.MaxInt, i+1)(m.CompareAndSwap(testData[i], math.MaxInt, i+1)) + expectSwapped(t, testData[i], i, i+1)(m.CompareAndSwap(testData[i], i, i+1)) + expectNotSwapped(t, testData[i], i, i+1)(m.CompareAndSwap(testData[i], i, i+1)) + } + for i, s := range testData { + if i == 1 || i == 105 || i == 6 || i == 85 { + expectPresent(t, s, i+1)(m.Load(s)) + } else { + expectPresent(t, s, i)(m.Load(s)) + } + } + }) + t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) { + m := newMap() + + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + makeKey := func(s string) string { + return s + "-" + strconv.Itoa(id) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + expectStored(t, key, id)(m.LoadOrStore(key, id)) + expectPresent(t, key, id)(m.Load(key)) + expectLoaded(t, key, id)(m.LoadOrStore(key, 0)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id)(m.Load(key)) + expectSwapped(t, key, id, id+1)(m.CompareAndSwap(key, id, id+1)) + expectPresent(t, key, id+1)(m.Load(key)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id+1)(m.Load(key)) + } + }(i) + } + wg.Wait() + }) + t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) { + m := newMap() + + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + makeKey := func(s string) string { + return s + "-" + strconv.Itoa(id) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + expectStored(t, key, id)(m.LoadOrStore(key, id)) + expectPresent(t, key, id)(m.Load(key)) + expectLoaded(t, key, id)(m.LoadOrStore(key, 0)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id)(m.Load(key)) + expectSwapped(t, key, id, id+1)(m.CompareAndSwap(key, id, id+1)) + expectPresent(t, key, id+1)(m.Load(key)) + expectDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1)) + expectNotSwapped(t, key, id+1, id+2)(m.CompareAndSwap(key, id+1, id+2)) + expectNotDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1)) + expectMissing(t, key, 0)(m.Load(key)) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + } + }(i) + } + wg.Wait() + }) + t.Run("ConcurrentCompareAndSwapSharedKeys", func(t *testing.T) { + m := newMap() + + // Load up the map. + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + } + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for i, s := range testData { + expectNotSwapped(t, s, math.MaxInt, i+1)(m.CompareAndSwap(s, math.MaxInt, i+1)) + m.CompareAndSwap(s, i, i+1) + expectPresent(t, s, i+1)(m.Load(s)) + } + for i, s := range testData { + expectPresent(t, s, i+1)(m.Load(s)) + } + }(i) + } + wg.Wait() + }) } func testAll[K, V comparable](t *testing.T, m *isync.HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) { @@ -312,6 +475,28 @@ func expectNotDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted } } +func expectSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swapped bool) { + t.Helper() + return func(swapped bool) { + t.Helper() + + if !swapped { + t.Errorf("expected key %v with value %v to be in map and swapped for %v", key, old, new) + } + } +} + +func expectNotSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swapped bool) { + t.Helper() + return func(swapped bool) { + t.Helper() + + if swapped { + t.Errorf("expected key %v with value %v to not be in map or not swapped for %v", key, old, new) + } + } +} + func testDataMap(data []string) map[string]int { m := make(map[string]int) for i, s := range data { -- 2.48.1