]> Cypherpunks repositories - gostls13.git/commitdiff
internal/sync: add Swap to HashTrieMap
authorMichael Anthony Knyszek <mknyszek@google.com>
Fri, 21 Jun 2024 21:00:15 +0000 (21:00 +0000)
committerGopher Robot <gobot@golang.org>
Mon, 18 Nov 2024 20:35:14 +0000 (20:35 +0000)
This change adds the Swap operation (with the same semantics as
sync.Map's Swap) to HashTrieMap.

Change-Id: I8697a0c8c2eb761e2452a41b868b590ccbfa5c03
Reviewed-on: https://go-review.googlesource.com/c/go/+/594064
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Michael Knyszek <mknyszek@google.com>

src/internal/sync/hashtriemap.go
src/internal/sync/hashtriemap_test.go

index 29c88b055ed254d1f431f18e051016a220281a7f..f2509d69202cc6a7c5e3b103369c9d22f52b674a 100644 (file)
@@ -195,6 +195,80 @@ func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uin
        return &top.node
 }
 
+// Swap swaps the value for a key and returns the previous value if any.
+// The loaded result reports whether the key was present.
+func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
+       ht.init()
+       hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
+       var i *indirect[K, V]
+       var hashShift uint
+       var slot *atomic.Pointer[node[K, V]]
+       var n *node[K, V]
+       for {
+               // Find the key or a candidate location for insertion.
+               i = ht.root
+               hashShift = 8 * goarch.PtrSize
+               haveInsertPoint := false
+               for hashShift != 0 {
+                       hashShift -= nChildrenLog2
+
+                       slot = &i.children[(hash>>hashShift)&nChildrenMask]
+                       n = slot.Load()
+                       if n == nil || n.isEntry {
+                               // We found a nil slot which is a candidate for insertion,
+                               // or an existing entry that we'll replace.
+                               haveInsertPoint = true
+                               break
+                       }
+                       i = n.indirect()
+               }
+               if !haveInsertPoint {
+                       panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
+               }
+
+               // Grab the lock and double-check what we saw.
+               i.mu.Lock()
+               n = slot.Load()
+               if (n == nil || n.isEntry) && !i.dead.Load() {
+                       // What we saw is still true, so we can continue with the insert.
+                       break
+               }
+               // We have to start over.
+               i.mu.Unlock()
+       }
+       // N.B. This lock is held from when we broke out of the outer loop above.
+       // We specifically break this out so that we can use defer here safely.
+       // One option is to break this out into a new function instead, but
+       // there's so much local iteration state used below that this turns out
+       // to be cleaner.
+       defer i.mu.Unlock()
+
+       var zero V
+       var oldEntry *entry[K, V]
+       if n != nil {
+               // Swap if the keys compare.
+               oldEntry = n.entry()
+               newEntry, old, swapped := oldEntry.swap(key, new)
+               if swapped {
+                       slot.Store(&newEntry.node)
+                       return old, true
+               }
+       }
+       // The keys didn't compare, so we're doing an insertion.
+       newEntry := newEntryNode(key, new)
+       if oldEntry == nil {
+               // Easy case: create a new entry and store it.
+               slot.Store(&newEntry.node)
+       } else {
+               // We possibly need to expand the entry already there into one or more new nodes.
+               //
+               // Publish the node last, which will make both oldEntry and newEntry visible. We
+               // don't want readers to be able to observe that oldEntry isn't in the tree.
+               slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
+       }
+       return zero, false
+}
+
 // CompareAndSwap swaps the old and new values for key
 // if the value stored in the map is equal to old.
 // The value type must be of a comparable type, otherwise CompareAndSwap will panic.
@@ -439,6 +513,35 @@ func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bo
        return *new(V), false
 }
 
+// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
+// the old value, and whether or not anything was swapped.
+//
+// swap must be called under the mutex of the indirect node which e is a child of.
+func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
+       if head.key == key {
+               // Return the new head of the list.
+               e := newEntryNode(key, new)
+               if chain := head.overflow.Load(); chain != nil {
+                       e.overflow.Store(chain)
+               }
+               return e, head.value, true
+       }
+       i := &head.overflow
+       e := i.Load()
+       for e != nil {
+               if e.key == key {
+                       eNew := newEntryNode(key, new)
+                       eNew.overflow.Store(e.overflow.Load())
+                       i.Store(eNew)
+                       return head, e.value, true
+               }
+               i = &e.overflow
+               e = e.overflow.Load()
+       }
+       var zero V
+       return head, zero, false
+}
+
 // compareAndSwap replaces an entry in the overflow chain if both the key and value compare
 // equal. Returns the new entry chain and whether or not anything was swapped.
 //
index b34350c15b14b8e6d1356f91ab442b892e6d293f..f3e1073c4f10a50af8e225ca4602a1dadbf9c9d4 100644 (file)
@@ -266,7 +266,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                        }
                }
        })
-       t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) {
+       t.Run("ConcurrentLifecycleCompareAndSwapUnsharedKeys", func(t *testing.T) {
                m := newMap()
 
                gmp := runtime.GOMAXPROCS(-1)
@@ -300,7 +300,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                }
                wg.Wait()
        })
-       t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) {
+       t.Run("ConcurrentLifecycleCompareAndSwapAndDeleteUnsharedKeys", func(t *testing.T) {
                m := newMap()
 
                gmp := runtime.GOMAXPROCS(-1)
@@ -365,6 +365,161 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                }
                wg.Wait()
        })
+       t.Run("SwapAll", func(t *testing.T) {
+               m := newMap()
+
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i))
+                       expectPresent(t, s, i)(m.Load(s))
+                       expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i))
+               }
+               for j := range 3 {
+                       for i, s := range testData {
+                               expectPresent(t, s, i+j)(m.Load(s))
+                               expectLoadedFromSwap(t, s, i+j, i+j+1)(m.Swap(s, i+j+1))
+                               expectPresent(t, s, i+j+1)(m.Load(s))
+                       }
+               }
+               for i, s := range testData {
+                       expectLoadedFromSwap(t, s, i+3, i+3)(m.Swap(s, i+3))
+               }
+       })
+       t.Run("SwapOne", func(t *testing.T) {
+               m := newMap()
+
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i))
+                       expectPresent(t, s, i)(m.Load(s))
+                       expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i))
+               }
+               expectLoadedFromSwap(t, testData[15], 15, 16)(m.Swap(testData[15], 16))
+               for i, s := range testData {
+                       if i == 15 {
+                               expectPresent(t, s, 16)(m.Load(s))
+                       } else {
+                               expectPresent(t, s, i)(m.Load(s))
+                       }
+               }
+       })
+       t.Run("SwapMultiple", func(t *testing.T) {
+               m := newMap()
+
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i))
+                       expectPresent(t, s, i)(m.Load(s))
+                       expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i))
+               }
+               for _, i := range []int{1, 105, 6, 85} {
+                       expectLoadedFromSwap(t, testData[i], i, i+1)(m.Swap(testData[i], i+1))
+               }
+               for i, s := range testData {
+                       if i == 1 || i == 105 || i == 6 || i == 85 {
+                               expectPresent(t, s, i+1)(m.Load(s))
+                       } else {
+                               expectPresent(t, s, i)(m.Load(s))
+                       }
+               }
+       })
+       t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) {
+               m := newMap()
+
+               gmp := runtime.GOMAXPROCS(-1)
+               var wg sync.WaitGroup
+               for i := range gmp {
+                       wg.Add(1)
+                       go func(id int) {
+                               defer wg.Done()
+
+                               makeKey := func(s string) string {
+                                       return s + "-" + strconv.Itoa(id)
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectMissing(t, key, 0)(m.Load(key))
+                                       expectNotLoadedFromSwap(t, key, id)(m.Swap(key, id))
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectLoadedFromSwap(t, key, id, id)(m.Swap(key, id))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectLoadedFromSwap(t, key, id, id+1)(m.Swap(key, id+1))
+                                       expectPresent(t, key, id+1)(m.Load(key))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id+1)(m.Load(key))
+                               }
+                       }(i)
+               }
+               wg.Wait()
+       })
+       t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) {
+               m := newMap()
+
+               gmp := runtime.GOMAXPROCS(-1)
+               var wg sync.WaitGroup
+               for i := range gmp {
+                       wg.Add(1)
+                       go func(id int) {
+                               defer wg.Done()
+
+                               makeKey := func(s string) string {
+                                       return s + "-" + strconv.Itoa(id)
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectMissing(t, key, 0)(m.Load(key))
+                                       expectNotLoadedFromSwap(t, key, id)(m.Swap(key, id))
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectLoadedFromSwap(t, key, id, id)(m.Swap(key, id))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectLoadedFromSwap(t, key, id, id+1)(m.Swap(key, id+1))
+                                       expectPresent(t, key, id+1)(m.Load(key))
+                                       expectDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1))
+                                       expectNotLoadedFromSwap(t, key, id+2)(m.Swap(key, id+2))
+                                       expectPresent(t, key, id+2)(m.Load(key))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id+2)(m.Load(key))
+                               }
+                       }(i)
+               }
+               wg.Wait()
+       })
+       t.Run("ConcurrentSwapSharedKeys", func(t *testing.T) {
+               m := newMap()
+
+               // Load up the map.
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectStored(t, s, i)(m.LoadOrStore(s, i))
+               }
+               gmp := runtime.GOMAXPROCS(-1)
+               var wg sync.WaitGroup
+               for i := range gmp {
+                       wg.Add(1)
+                       go func(id int) {
+                               defer wg.Done()
+
+                               for i, s := range testData {
+                                       m.Swap(s, i+1)
+                                       expectPresent(t, s, i+1)(m.Load(s))
+                               }
+                               for i, s := range testData {
+                                       expectPresent(t, s, i+1)(m.Load(s))
+                               }
+                       }(i)
+               }
+               wg.Wait()
+       })
 }
 
 func testAll[K, V comparable](t *testing.T, m *isync.HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
@@ -497,6 +652,30 @@ func expectNotSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swa
        }
 }
 
+func expectLoadedFromSwap[K, V comparable](t *testing.T, key K, want, new V) func(got V, loaded bool) {
+       t.Helper()
+       return func(got V, loaded bool) {
+               t.Helper()
+
+               if !loaded {
+                       t.Errorf("expected key %v to be in map and for %v to have been swapped for %v", key, want, new)
+               } else if want != got {
+                       t.Errorf("key %v had its value %v swapped for %v, but expected it to have value %v", key, got, new, want)
+               }
+       }
+}
+
+func expectNotLoadedFromSwap[K, V comparable](t *testing.T, key K, new V) func(old V, loaded bool) {
+       t.Helper()
+       return func(old V, loaded bool) {
+               t.Helper()
+
+               if loaded {
+                       t.Errorf("expected key %v to not be in map, but found value %v for it", key, old)
+               }
+       }
+}
+
 func testDataMap(data []string) map[string]int {
        m := make(map[string]int)
        for i, s := range data {