]> Cypherpunks repositories - gostls13.git/commitdiff
internal/sync: add CompareAndSwap to HashTrieMap
authorMichael Anthony Knyszek <mknyszek@google.com>
Fri, 21 Jun 2024 20:42:45 +0000 (20:42 +0000)
committerGopher Robot <gobot@golang.org>
Mon, 18 Nov 2024 20:35:11 +0000 (20:35 +0000)
This change adds the CompareAndSwap operation (with the same semantics
as sync.Map's CompareAndSwap) to HashTrieMap.

Change-Id: I86153799fc47304784333d17f0c6a7ad7682f04a
Reviewed-on: https://go-review.googlesource.com/c/go/+/594063
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Michael Knyszek <mknyszek@google.com>

src/internal/sync/hashtriemap.go
src/internal/sync/hashtriemap_test.go

index 082aecacba90b10feb005a0442d229e4a11f96ba..29c88b055ed254d1f431f18e051016a220281a7f 100644 (file)
@@ -195,6 +195,36 @@ func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uin
        return &top.node
 }
 
+// CompareAndSwap swaps the old and new values for key
+// if the value stored in the map is equal to old.
+// The value type must be of a comparable type, otherwise CompareAndSwap will panic.
+func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
+       ht.init()
+       if ht.valEqual == nil {
+               panic("called CompareAndSwap when value is not of comparable type")
+       }
+       hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
+
+       // Find a node with the key and compare with it. n != nil if we found the node.
+       i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
+       if i != nil {
+               defer i.mu.Unlock()
+       }
+       if n == nil {
+               return false
+       }
+
+       // Try to swap the entry.
+       e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
+       if !swapped {
+               // Nothing was actually swapped, which means the node is no longer there.
+               return false
+       }
+       // Store the entry back because it changed.
+       slot.Store(&e.node)
+       return true
+}
+
 // CompareAndDelete deletes the entry for key if its value is equal to old.
 // The value type must be comparable, otherwise this CompareAndDelete will panic.
 //
@@ -207,8 +237,8 @@ func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
        }
        hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
 
-       // Find a node with the key and compare with it. n != nil if we found the node.
-       i, hashShift, slot, n := ht.find(key, hash)
+       // Find a node with the key. n != nil if we found the node.
+       i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
        if n == nil {
                if i != nil {
                        i.mu.Unlock()
@@ -252,14 +282,15 @@ func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
        return true
 }
 
-// compare searches the tree for a node that compares with key (hash must be the hash of key).
+// find searches the tree for a node that contains key (hash must be the hash of key).
+// If valEqual != nil, then it will also enforce that the values are equal as well.
 //
 // Returns a non-nil node, which will always be an entry, if found.
 //
 // If i != nil then i.mu is locked, and it is the caller's responsibility to unlock it.
-func (ht *HashTrieMap[K, V]) find(key K, hash uintptr) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
+func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
        for {
-               // Find the key or return when there's nothing to delete.
+               // Find the key or return if it's not there.
                i = ht.root
                hashShift = 8 * goarch.PtrSize
                found := false
@@ -275,7 +306,7 @@ func (ht *HashTrieMap[K, V]) find(key K, hash uintptr) (i *indirect[K, V], hashS
                        }
                        if n.isEntry {
                                // We found an entry. Check if it matches.
-                               if _, ok := n.entry().lookup(key); !ok {
+                               if _, ok := n.entry().lookupWithValue(key, value, valEqual); !ok {
                                        // No match, comparison failed.
                                        i = nil
                                        n = nil
@@ -398,6 +429,44 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
        return *new(V), false
 }
 
+func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
+       for e != nil {
+               if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
+                       return e.value, true
+               }
+               e = e.overflow.Load()
+       }
+       return *new(V), false
+}
+
+// compareAndSwap replaces an entry in the overflow chain if both the key and value compare
+// equal. Returns the new entry chain and whether or not anything was swapped.
+//
+// compareAndSwap must be called under the mutex of the indirect node which e is a child of.
+func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
+       if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
+               // Return the new head of the list.
+               e := newEntryNode(key, new)
+               if chain := head.overflow.Load(); chain != nil {
+                       e.overflow.Store(chain)
+               }
+               return e, true
+       }
+       i := &head.overflow
+       e := i.Load()
+       for e != nil {
+               if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
+                       eNew := newEntryNode(key, new)
+                       eNew.overflow.Store(e.overflow.Load())
+                       i.Store(eNew)
+                       return head, true
+               }
+               i = &e.overflow
+               e = e.overflow.Load()
+       }
+       return head, false
+}
+
 // compareAndDelete deletes an entry in the overflow chain if both the key and value compare
 // equal. Returns the new entry chain and whether or not anything was deleted.
 //
index 9ab11d41267f953840a3dc79ccda7200aa991760..b34350c15b14b8e6d1356f91ab442b892e6d293f 100644 (file)
@@ -101,7 +101,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                        }
                }
        })
-       t.Run("DeleteMultiple", func(t *testing.T) {
+       t.Run("CompareAndDeleteMultiple", func(t *testing.T) {
                m := newMap()
 
                for i, s := range testData {
@@ -130,7 +130,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                        return true
                })
        })
-       t.Run("AllDelete", func(t *testing.T) {
+       t.Run("AllCompareAndDelete", func(t *testing.T) {
                m := newMap()
 
                testAll(t, m, testDataMap(testData[:]), func(s string, i int) bool {
@@ -175,7 +175,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                }
                wg.Wait()
        })
-       t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) {
+       t.Run("ConcurrentCompareAndDeleteSharedKeys", func(t *testing.T) {
                m := newMap()
 
                // Load up the map.
@@ -202,6 +202,169 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
                }
                wg.Wait()
        })
+       t.Run("CompareAndSwapAll", func(t *testing.T) {
+               m := newMap()
+
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectStored(t, s, i)(m.LoadOrStore(s, i))
+                       expectPresent(t, s, i)(m.Load(s))
+                       expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
+               }
+               for j := range 3 {
+                       for i, s := range testData {
+                               expectPresent(t, s, i+j)(m.Load(s))
+                               expectNotSwapped(t, s, math.MaxInt, i+j+1)(m.CompareAndSwap(s, math.MaxInt, i+j+1))
+                               expectSwapped(t, s, i, i+j+1)(m.CompareAndSwap(s, i+j, i+j+1))
+                               expectNotSwapped(t, s, i+j, i+j+1)(m.CompareAndSwap(s, i+j, i+j+1))
+                               expectPresent(t, s, i+j+1)(m.Load(s))
+                       }
+               }
+               for i, s := range testData {
+                       expectPresent(t, s, i+3)(m.Load(s))
+               }
+       })
+       t.Run("CompareAndSwapOne", func(t *testing.T) {
+               m := newMap()
+
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectStored(t, s, i)(m.LoadOrStore(s, i))
+                       expectPresent(t, s, i)(m.Load(s))
+                       expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
+               }
+               expectNotSwapped(t, testData[15], math.MaxInt, 16)(m.CompareAndSwap(testData[15], math.MaxInt, 16))
+               expectSwapped(t, testData[15], 15, 16)(m.CompareAndSwap(testData[15], 15, 16))
+               expectNotSwapped(t, testData[15], 15, 16)(m.CompareAndSwap(testData[15], 15, 16))
+               for i, s := range testData {
+                       if i == 15 {
+                               expectPresent(t, s, 16)(m.Load(s))
+                       } else {
+                               expectPresent(t, s, i)(m.Load(s))
+                       }
+               }
+       })
+       t.Run("CompareAndSwapMultiple", func(t *testing.T) {
+               m := newMap()
+
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectStored(t, s, i)(m.LoadOrStore(s, i))
+                       expectPresent(t, s, i)(m.Load(s))
+                       expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
+               }
+               for _, i := range []int{1, 105, 6, 85} {
+                       expectNotSwapped(t, testData[i], math.MaxInt, i+1)(m.CompareAndSwap(testData[i], math.MaxInt, i+1))
+                       expectSwapped(t, testData[i], i, i+1)(m.CompareAndSwap(testData[i], i, i+1))
+                       expectNotSwapped(t, testData[i], i, i+1)(m.CompareAndSwap(testData[i], i, i+1))
+               }
+               for i, s := range testData {
+                       if i == 1 || i == 105 || i == 6 || i == 85 {
+                               expectPresent(t, s, i+1)(m.Load(s))
+                       } else {
+                               expectPresent(t, s, i)(m.Load(s))
+                       }
+               }
+       })
+       t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) {
+               m := newMap()
+
+               gmp := runtime.GOMAXPROCS(-1)
+               var wg sync.WaitGroup
+               for i := range gmp {
+                       wg.Add(1)
+                       go func(id int) {
+                               defer wg.Done()
+
+                               makeKey := func(s string) string {
+                                       return s + "-" + strconv.Itoa(id)
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectMissing(t, key, 0)(m.Load(key))
+                                       expectStored(t, key, id)(m.LoadOrStore(key, id))
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectSwapped(t, key, id, id+1)(m.CompareAndSwap(key, id, id+1))
+                                       expectPresent(t, key, id+1)(m.Load(key))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id+1)(m.Load(key))
+                               }
+                       }(i)
+               }
+               wg.Wait()
+       })
+       t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) {
+               m := newMap()
+
+               gmp := runtime.GOMAXPROCS(-1)
+               var wg sync.WaitGroup
+               for i := range gmp {
+                       wg.Add(1)
+                       go func(id int) {
+                               defer wg.Done()
+
+                               makeKey := func(s string) string {
+                                       return s + "-" + strconv.Itoa(id)
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectMissing(t, key, 0)(m.Load(key))
+                                       expectStored(t, key, id)(m.LoadOrStore(key, id))
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectPresent(t, key, id)(m.Load(key))
+                                       expectSwapped(t, key, id, id+1)(m.CompareAndSwap(key, id, id+1))
+                                       expectPresent(t, key, id+1)(m.Load(key))
+                                       expectDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1))
+                                       expectNotSwapped(t, key, id+1, id+2)(m.CompareAndSwap(key, id+1, id+2))
+                                       expectNotDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1))
+                                       expectMissing(t, key, 0)(m.Load(key))
+                               }
+                               for _, s := range testData {
+                                       key := makeKey(s)
+                                       expectMissing(t, key, 0)(m.Load(key))
+                               }
+                       }(i)
+               }
+               wg.Wait()
+       })
+       t.Run("ConcurrentCompareAndSwapSharedKeys", func(t *testing.T) {
+               m := newMap()
+
+               // Load up the map.
+               for i, s := range testData {
+                       expectMissing(t, s, 0)(m.Load(s))
+                       expectStored(t, s, i)(m.LoadOrStore(s, i))
+               }
+               gmp := runtime.GOMAXPROCS(-1)
+               var wg sync.WaitGroup
+               for i := range gmp {
+                       wg.Add(1)
+                       go func(id int) {
+                               defer wg.Done()
+
+                               for i, s := range testData {
+                                       expectNotSwapped(t, s, math.MaxInt, i+1)(m.CompareAndSwap(s, math.MaxInt, i+1))
+                                       m.CompareAndSwap(s, i, i+1)
+                                       expectPresent(t, s, i+1)(m.Load(s))
+                               }
+                               for i, s := range testData {
+                                       expectPresent(t, s, i+1)(m.Load(s))
+                               }
+                       }(i)
+               }
+               wg.Wait()
+       })
 }
 
 func testAll[K, V comparable](t *testing.T, m *isync.HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
@@ -312,6 +475,28 @@ func expectNotDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted
        }
 }
 
+func expectSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swapped bool) {
+       t.Helper()
+       return func(swapped bool) {
+               t.Helper()
+
+               if !swapped {
+                       t.Errorf("expected key %v with value %v to be in map and swapped for %v", key, old, new)
+               }
+       }
+}
+
+func expectNotSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swapped bool) {
+       t.Helper()
+       return func(swapped bool) {
+               t.Helper()
+
+               if swapped {
+                       t.Errorf("expected key %v with value %v to not be in map or not swapped for %v", key, old, new)
+               }
+       }
+}
+
 func testDataMap(data []string) map[string]int {
        m := make(map[string]int)
        for i, s := range data {