From 28bac5640cfd889a795bd7d7d24e6d788a985ead Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Fri, 21 Jun 2024 21:00:15 +0000 Subject: [PATCH] internal/sync: add Swap to HashTrieMap This change adds the Swap operation (with the same semantics as sync.Map's Swap) to HashTrieMap. Change-Id: I8697a0c8c2eb761e2452a41b868b590ccbfa5c03 Reviewed-on: https://go-review.googlesource.com/c/go/+/594064 Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI Auto-Submit: Michael Knyszek --- src/internal/sync/hashtriemap.go | 103 +++++++++++++++ src/internal/sync/hashtriemap_test.go | 183 +++++++++++++++++++++++++- 2 files changed, 284 insertions(+), 2 deletions(-) diff --git a/src/internal/sync/hashtriemap.go b/src/internal/sync/hashtriemap.go index 29c88b055e..f2509d6920 100644 --- a/src/internal/sync/hashtriemap.go +++ b/src/internal/sync/hashtriemap.go @@ -195,6 +195,80 @@ func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uin return &top.node } +// Swap swaps the value for a key and returns the previous value if any. +// The loaded result reports whether the key was present. +func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) { + ht.init() + hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) + var i *indirect[K, V] + var hashShift uint + var slot *atomic.Pointer[node[K, V]] + var n *node[K, V] + for { + // Find the key or a candidate location for insertion. + i = ht.root + hashShift = 8 * goarch.PtrSize + haveInsertPoint := false + for hashShift != 0 { + hashShift -= nChildrenLog2 + + slot = &i.children[(hash>>hashShift)&nChildrenMask] + n = slot.Load() + if n == nil || n.isEntry { + // We found a nil slot which is a candidate for insertion, + // or an existing entry that we'll replace. + haveInsertPoint = true + break + } + i = n.indirect() + } + if !haveInsertPoint { + panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") + } + + // Grab the lock and double-check what we saw. + i.mu.Lock() + n = slot.Load() + if (n == nil || n.isEntry) && !i.dead.Load() { + // What we saw is still true, so we can continue with the insert. + break + } + // We have to start over. + i.mu.Unlock() + } + // N.B. This lock is held from when we broke out of the outer loop above. + // We specifically break this out so that we can use defer here safely. + // One option is to break this out into a new function instead, but + // there's so much local iteration state used below that this turns out + // to be cleaner. + defer i.mu.Unlock() + + var zero V + var oldEntry *entry[K, V] + if n != nil { + // Swap if the keys compare. + oldEntry = n.entry() + newEntry, old, swapped := oldEntry.swap(key, new) + if swapped { + slot.Store(&newEntry.node) + return old, true + } + } + // The keys didn't compare, so we're doing an insertion. + newEntry := newEntryNode(key, new) + if oldEntry == nil { + // Easy case: create a new entry and store it. + slot.Store(&newEntry.node) + } else { + // We possibly need to expand the entry already there into one or more new nodes. + // + // Publish the node last, which will make both oldEntry and newEntry visible. We + // don't want readers to be able to observe that oldEntry isn't in the tree. + slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i)) + } + return zero, false +} + // CompareAndSwap swaps the old and new values for key // if the value stored in the map is equal to old. // The value type must be of a comparable type, otherwise CompareAndSwap will panic. @@ -439,6 +513,35 @@ func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bo return *new(V), false } +// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain, +// the old value, and whether or not anything was swapped. +// +// swap must be called under the mutex of the indirect node which e is a child of. +func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) { + if head.key == key { + // Return the new head of the list. + e := newEntryNode(key, new) + if chain := head.overflow.Load(); chain != nil { + e.overflow.Store(chain) + } + return e, head.value, true + } + i := &head.overflow + e := i.Load() + for e != nil { + if e.key == key { + eNew := newEntryNode(key, new) + eNew.overflow.Store(e.overflow.Load()) + i.Store(eNew) + return head, e.value, true + } + i = &e.overflow + e = e.overflow.Load() + } + var zero V + return head, zero, false +} + // compareAndSwap replaces an entry in the overflow chain if both the key and value compare // equal. Returns the new entry chain and whether or not anything was swapped. // diff --git a/src/internal/sync/hashtriemap_test.go b/src/internal/sync/hashtriemap_test.go index b34350c15b..f3e1073c4f 100644 --- a/src/internal/sync/hashtriemap_test.go +++ b/src/internal/sync/hashtriemap_test.go @@ -266,7 +266,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] } } }) - t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) { + t.Run("ConcurrentLifecycleCompareAndSwapUnsharedKeys", func(t *testing.T) { m := newMap() gmp := runtime.GOMAXPROCS(-1) @@ -300,7 +300,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] } wg.Wait() }) - t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) { + t.Run("ConcurrentLifecycleCompareAndSwapAndDeleteUnsharedKeys", func(t *testing.T) { m := newMap() gmp := runtime.GOMAXPROCS(-1) @@ -365,6 +365,161 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] } wg.Wait() }) + t.Run("SwapAll", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i)) + } + for j := range 3 { + for i, s := range testData { + expectPresent(t, s, i+j)(m.Load(s)) + expectLoadedFromSwap(t, s, i+j, i+j+1)(m.Swap(s, i+j+1)) + expectPresent(t, s, i+j+1)(m.Load(s)) + } + } + for i, s := range testData { + expectLoadedFromSwap(t, s, i+3, i+3)(m.Swap(s, i+3)) + } + }) + t.Run("SwapOne", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i)) + } + expectLoadedFromSwap(t, testData[15], 15, 16)(m.Swap(testData[15], 16)) + for i, s := range testData { + if i == 15 { + expectPresent(t, s, 16)(m.Load(s)) + } else { + expectPresent(t, s, i)(m.Load(s)) + } + } + }) + t.Run("SwapMultiple", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i)) + } + for _, i := range []int{1, 105, 6, 85} { + expectLoadedFromSwap(t, testData[i], i, i+1)(m.Swap(testData[i], i+1)) + } + for i, s := range testData { + if i == 1 || i == 105 || i == 6 || i == 85 { + expectPresent(t, s, i+1)(m.Load(s)) + } else { + expectPresent(t, s, i)(m.Load(s)) + } + } + }) + t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) { + m := newMap() + + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + makeKey := func(s string) string { + return s + "-" + strconv.Itoa(id) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + expectNotLoadedFromSwap(t, key, id)(m.Swap(key, id)) + expectPresent(t, key, id)(m.Load(key)) + expectLoadedFromSwap(t, key, id, id)(m.Swap(key, id)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id)(m.Load(key)) + expectLoadedFromSwap(t, key, id, id+1)(m.Swap(key, id+1)) + expectPresent(t, key, id+1)(m.Load(key)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id+1)(m.Load(key)) + } + }(i) + } + wg.Wait() + }) + t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) { + m := newMap() + + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + makeKey := func(s string) string { + return s + "-" + strconv.Itoa(id) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + expectNotLoadedFromSwap(t, key, id)(m.Swap(key, id)) + expectPresent(t, key, id)(m.Load(key)) + expectLoadedFromSwap(t, key, id, id)(m.Swap(key, id)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id)(m.Load(key)) + expectLoadedFromSwap(t, key, id, id+1)(m.Swap(key, id+1)) + expectPresent(t, key, id+1)(m.Load(key)) + expectDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1)) + expectNotLoadedFromSwap(t, key, id+2)(m.Swap(key, id+2)) + expectPresent(t, key, id+2)(m.Load(key)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id+2)(m.Load(key)) + } + }(i) + } + wg.Wait() + }) + t.Run("ConcurrentSwapSharedKeys", func(t *testing.T) { + m := newMap() + + // Load up the map. + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + } + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for i, s := range testData { + m.Swap(s, i+1) + expectPresent(t, s, i+1)(m.Load(s)) + } + for i, s := range testData { + expectPresent(t, s, i+1)(m.Load(s)) + } + }(i) + } + wg.Wait() + }) } func testAll[K, V comparable](t *testing.T, m *isync.HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) { @@ -497,6 +652,30 @@ func expectNotSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swa } } +func expectLoadedFromSwap[K, V comparable](t *testing.T, key K, want, new V) func(got V, loaded bool) { + t.Helper() + return func(got V, loaded bool) { + t.Helper() + + if !loaded { + t.Errorf("expected key %v to be in map and for %v to have been swapped for %v", key, want, new) + } else if want != got { + t.Errorf("key %v had its value %v swapped for %v, but expected it to have value %v", key, got, new, want) + } + } +} + +func expectNotLoadedFromSwap[K, V comparable](t *testing.T, key K, new V) func(old V, loaded bool) { + t.Helper() + return func(old V, loaded bool) { + t.Helper() + + if loaded { + t.Errorf("expected key %v to not be in map, but found value %v for it", key, old) + } + } +} + func testDataMap(data []string) map[string]int { m := make(map[string]int) for i, s := range data { -- 2.48.1