]> Cypherpunks repositories - gostls13.git/commitdiff
internal/sync: optimize CompareAndSwap and Swap
authorMichael Anthony Knyszek <mknyszek@google.com>
Fri, 16 Aug 2024 15:13:52 +0000 (15:13 +0000)
committerGopher Robot <gobot@golang.org>
Mon, 18 Nov 2024 20:35:39 +0000 (20:35 +0000)
We observe the CompareAndSwap and Swap can both be substantially faster
if the value in each entry node is mutable. This change modifies the
map entry node to store the value indirectly, allowing us to perform
swaps for existing nodes and compare-and-swaps without taking the
parent node's lock.

Change-Id: I371343aa81a843d3a7e6bc5ac87b8a96c12ca3a8
Reviewed-on: https://go-review.googlesource.com/c/go/+/606462
Auto-Submit: Michael Knyszek <mknyszek@google.com>
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/internal/sync/hashtriemap.go

index defcd0b793947f1025694c584defe10f66898604..d31d81df39ca193c40a492c2950d3cc9b1f08722 100644 (file)
@@ -219,12 +219,22 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
 
                        slot = &i.children[(hash>>hashShift)&nChildrenMask]
                        n = slot.Load()
-                       if n == nil || n.isEntry {
+                       if n == nil {
                                // We found a nil slot which is a candidate for insertion,
                                // or an existing entry that we'll replace.
                                haveInsertPoint = true
                                break
                        }
+                       if n.isEntry {
+                               // Swap if the keys compare.
+                               old, swapped := n.entry().swap(key, new)
+                               if swapped {
+                                       return old, true
+                               }
+                               // If we fail, that means we should try to insert.
+                               haveInsertPoint = true
+                               break
+                       }
                        i = n.indirect()
                }
                if !haveInsertPoint {
@@ -251,11 +261,10 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
        var zero V
        var oldEntry *entry[K, V]
        if n != nil {
-               // Swap if the keys compare.
+               // Between before and now, something got inserted. Swap if the keys compare.
                oldEntry = n.entry()
-               newEntry, old, swapped := oldEntry.swap(key, new)
+               old, swapped := oldEntry.swap(key, new)
                if swapped {
-                       slot.Store(&newEntry.node)
                        return old, true
                }
        }
@@ -283,25 +292,30 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
                panic("called CompareAndSwap when value is not of comparable type")
        }
        hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
+       for {
+               // Find the key or return if it's not there.
+               i := ht.root.Load()
+               hashShift := 8 * goarch.PtrSize
+               found := false
+               for hashShift != 0 {
+                       hashShift -= nChildrenLog2
 
-       // Find a node with the key and compare with it. n != nil if we found the node.
-       i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
-       if i != nil {
-               defer i.mu.Unlock()
-       }
-       if n == nil {
-               return false
-       }
-
-       // Try to swap the entry.
-       e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
-       if !swapped {
-               // Nothing was actually swapped, which means the node is no longer there.
-               return false
+                       slot := &i.children[(hash>>hashShift)&nChildrenMask]
+                       n := slot.Load()
+                       if n == nil {
+                               // Nothing to compare with. Give up.
+                               return false
+                       }
+                       if n.isEntry {
+                               // We found an entry. Try to compare and swap directly.
+                               return n.entry().compareAndSwap(key, old, new, ht.valEqual)
+                       }
+                       i = n.indirect()
+               }
+               if !found {
+                       panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
+               }
        }
-       // Store the entry back because it changed.
-       slot.Store(&e.node)
-       return true
 }
 
 // LoadAndDelete deletes the value for a key, returning the previous value if any.
@@ -509,7 +523,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V)
                }
                e := n.entry()
                for e != nil {
-                       if !yield(e.key, e.value) {
+                       if !yield(e.key, *e.value.Load()) {
                                return false
                        }
                        e = e.overflow.Load()
@@ -565,21 +579,22 @@ type entry[K comparable, V any] struct {
        node[K, V]
        overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions.
        key      K
-       value    V
+       value    atomic.Pointer[V]
 }
 
 func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
-       return &entry[K, V]{
-               node:  node[K, V]{isEntry: true},
-               key:   key,
-               value: value,
+       e := &entry[K, V]{
+               node: node[K, V]{isEntry: true},
+               key:  key,
        }
+       e.value.Store(&value)
+       return e
 }
 
 func (e *entry[K, V]) lookup(key K) (V, bool) {
        for e != nil {
                if e.key == key {
-                       return e.value, true
+                       return *e.value.Load(), true
                }
                e = e.overflow.Load()
        }
@@ -588,69 +603,87 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
 
 func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
        for e != nil {
-               if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
-                       return e.value, true
+               oldp := e.value.Load()
+               if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) {
+                       return *oldp, true
                }
                e = e.overflow.Load()
        }
        return *new(V), false
 }
 
-// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
-// the old value, and whether or not anything was swapped.
+// swap replaces a value in the overflow chain if keys compare equal.
+// Returns the old value, and whether or not anything was swapped.
 //
 // swap must be called under the mutex of the indirect node which e is a child of.
-func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
+func (head *entry[K, V]) swap(key K, newv V) (V, bool) {
        if head.key == key {
-               // Return the new head of the list.
-               e := newEntryNode(key, new)
-               if chain := head.overflow.Load(); chain != nil {
-                       e.overflow.Store(chain)
-               }
-               return e, head.value, true
+               vp := new(V)
+               *vp = newv
+               oldp := head.value.Swap(vp)
+               return *oldp, true
        }
        i := &head.overflow
        e := i.Load()
        for e != nil {
                if e.key == key {
-                       eNew := newEntryNode(key, new)
-                       eNew.overflow.Store(e.overflow.Load())
-                       i.Store(eNew)
-                       return head, e.value, true
+                       vp := new(V)
+                       *vp = newv
+                       oldp := e.value.Swap(vp)
+                       return *oldp, true
                }
                i = &e.overflow
                e = e.overflow.Load()
        }
        var zero V
-       return head, zero, false
+       return zero, false
 }
 
-// compareAndSwap replaces an entry in the overflow chain if both the key and value compare
-// equal. Returns the new entry chain and whether or not anything was swapped.
+// compareAndSwap replaces a value for a matching key and existing value in the overflow chain.
+// Returns whether or not anything was swapped.
 //
 // compareAndSwap must be called under the mutex of the indirect node which e is a child of.
-func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
-       if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
-               // Return the new head of the list.
-               e := newEntryNode(key, new)
-               if chain := head.overflow.Load(); chain != nil {
-                       e.overflow.Store(chain)
+func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool {
+       var vbox *V
+outerLoop:
+       for {
+               oldvp := head.value.Load()
+               if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
+                       // Return the new head of the list.
+                       if vbox == nil {
+                               // Delay explicit creation of a new value to hold newv. If we just pass &newv
+                               // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
+                               vbox = new(V)
+                               *vbox = newv
+                       }
+                       if head.value.CompareAndSwap(oldvp, vbox) {
+                               return true
+                       }
+                       // We need to restart from the head of the overflow list in case, due to a removal, a node
+                       // is moved up the list and we miss it.
+                       continue outerLoop
                }
-               return e, true
-       }
-       i := &head.overflow
-       e := i.Load()
-       for e != nil {
-               if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
-                       eNew := newEntryNode(key, new)
-                       eNew.overflow.Store(e.overflow.Load())
-                       i.Store(eNew)
-                       return head, true
+               i := &head.overflow
+               e := i.Load()
+               for e != nil {
+                       oldvp := e.value.Load()
+                       if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
+                               if vbox == nil {
+                                       // Delay explicit creation of a new value to hold newv. If we just pass &newv
+                                       // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
+                                       vbox = new(V)
+                                       *vbox = newv
+                               }
+                               if e.value.CompareAndSwap(oldvp, vbox) {
+                                       return true
+                               }
+                               continue outerLoop
+                       }
+                       i = &e.overflow
+                       e = e.overflow.Load()
                }
-               i = &e.overflow
-               e = e.overflow.Load()
+               return false
        }
-       return head, false
 }
 
 // loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new
@@ -660,14 +693,14 @@ func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (
 func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
        if head.key == key {
                // Drop the head of the list.
-               return head.value, head.overflow.Load(), true
+               return *head.value.Load(), head.overflow.Load(), true
        }
        i := &head.overflow
        e := i.Load()
        for e != nil {
                if e.key == key {
                        i.Store(e.overflow.Load())
-                       return e.value, head, true
+                       return *e.value.Load(), head, true
                }
                i = &e.overflow
                e = e.overflow.Load()
@@ -680,14 +713,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
 //
 // compareAndDelete must be called under the mutex of the indirect node which e is a child of.
 func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
-       if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
+       if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
                // Drop the head of the list.
                return head.overflow.Load(), true
        }
        i := &head.overflow
        e := i.Load()
        for e != nil {
-               if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
+               if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
                        i.Store(e.overflow.Load())
                        return head, true
                }