From 06e6b8efa4c9fbe4b8ceec8c655011117a50279a Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Fri, 16 Aug 2024 15:13:52 +0000 Subject: [PATCH] internal/sync: optimize CompareAndSwap and Swap We observe the CompareAndSwap and Swap can both be substantially faster if the value in each entry node is mutable. This change modifies the map entry node to store the value indirectly, allowing us to perform swaps for existing nodes and compare-and-swaps without taking the parent node's lock. Change-Id: I371343aa81a843d3a7e6bc5ac87b8a96c12ca3a8 Reviewed-on: https://go-review.googlesource.com/c/go/+/606462 Auto-Submit: Michael Knyszek Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI --- src/internal/sync/hashtriemap.go | 171 ++++++++++++++++++------------- 1 file changed, 102 insertions(+), 69 deletions(-) diff --git a/src/internal/sync/hashtriemap.go b/src/internal/sync/hashtriemap.go index defcd0b793..d31d81df39 100644 --- a/src/internal/sync/hashtriemap.go +++ b/src/internal/sync/hashtriemap.go @@ -219,12 +219,22 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) { slot = &i.children[(hash>>hashShift)&nChildrenMask] n = slot.Load() - if n == nil || n.isEntry { + if n == nil { // We found a nil slot which is a candidate for insertion, // or an existing entry that we'll replace. haveInsertPoint = true break } + if n.isEntry { + // Swap if the keys compare. + old, swapped := n.entry().swap(key, new) + if swapped { + return old, true + } + // If we fail, that means we should try to insert. + haveInsertPoint = true + break + } i = n.indirect() } if !haveInsertPoint { @@ -251,11 +261,10 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) { var zero V var oldEntry *entry[K, V] if n != nil { - // Swap if the keys compare. + // Between before and now, something got inserted. Swap if the keys compare. oldEntry = n.entry() - newEntry, old, swapped := oldEntry.swap(key, new) + old, swapped := oldEntry.swap(key, new) if swapped { - slot.Store(&newEntry.node) return old, true } } @@ -283,25 +292,30 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) { panic("called CompareAndSwap when value is not of comparable type") } hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) + for { + // Find the key or return if it's not there. + i := ht.root.Load() + hashShift := 8 * goarch.PtrSize + found := false + for hashShift != 0 { + hashShift -= nChildrenLog2 - // Find a node with the key and compare with it. n != nil if we found the node. - i, _, slot, n := ht.find(key, hash, ht.valEqual, old) - if i != nil { - defer i.mu.Unlock() - } - if n == nil { - return false - } - - // Try to swap the entry. - e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual) - if !swapped { - // Nothing was actually swapped, which means the node is no longer there. - return false + slot := &i.children[(hash>>hashShift)&nChildrenMask] + n := slot.Load() + if n == nil { + // Nothing to compare with. Give up. + return false + } + if n.isEntry { + // We found an entry. Try to compare and swap directly. + return n.entry().compareAndSwap(key, old, new, ht.valEqual) + } + i = n.indirect() + } + if !found { + panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") + } } - // Store the entry back because it changed. - slot.Store(&e.node) - return true } // LoadAndDelete deletes the value for a key, returning the previous value if any. @@ -509,7 +523,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) } e := n.entry() for e != nil { - if !yield(e.key, e.value) { + if !yield(e.key, *e.value.Load()) { return false } e = e.overflow.Load() @@ -565,21 +579,22 @@ type entry[K comparable, V any] struct { node[K, V] overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions. key K - value V + value atomic.Pointer[V] } func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] { - return &entry[K, V]{ - node: node[K, V]{isEntry: true}, - key: key, - value: value, + e := &entry[K, V]{ + node: node[K, V]{isEntry: true}, + key: key, } + e.value.Store(&value) + return e } func (e *entry[K, V]) lookup(key K) (V, bool) { for e != nil { if e.key == key { - return e.value, true + return *e.value.Load(), true } e = e.overflow.Load() } @@ -588,69 +603,87 @@ func (e *entry[K, V]) lookup(key K) (V, bool) { func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) { for e != nil { - if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) { - return e.value, true + oldp := e.value.Load() + if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) { + return *oldp, true } e = e.overflow.Load() } return *new(V), false } -// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain, -// the old value, and whether or not anything was swapped. +// swap replaces a value in the overflow chain if keys compare equal. +// Returns the old value, and whether or not anything was swapped. // // swap must be called under the mutex of the indirect node which e is a child of. -func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) { +func (head *entry[K, V]) swap(key K, newv V) (V, bool) { if head.key == key { - // Return the new head of the list. - e := newEntryNode(key, new) - if chain := head.overflow.Load(); chain != nil { - e.overflow.Store(chain) - } - return e, head.value, true + vp := new(V) + *vp = newv + oldp := head.value.Swap(vp) + return *oldp, true } i := &head.overflow e := i.Load() for e != nil { if e.key == key { - eNew := newEntryNode(key, new) - eNew.overflow.Store(e.overflow.Load()) - i.Store(eNew) - return head, e.value, true + vp := new(V) + *vp = newv + oldp := e.value.Swap(vp) + return *oldp, true } i = &e.overflow e = e.overflow.Load() } var zero V - return head, zero, false + return zero, false } -// compareAndSwap replaces an entry in the overflow chain if both the key and value compare -// equal. Returns the new entry chain and whether or not anything was swapped. +// compareAndSwap replaces a value for a matching key and existing value in the overflow chain. +// Returns whether or not anything was swapped. // // compareAndSwap must be called under the mutex of the indirect node which e is a child of. -func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) { - if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) { - // Return the new head of the list. - e := newEntryNode(key, new) - if chain := head.overflow.Load(); chain != nil { - e.overflow.Store(chain) +func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool { + var vbox *V +outerLoop: + for { + oldvp := head.value.Load() + if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) { + // Return the new head of the list. + if vbox == nil { + // Delay explicit creation of a new value to hold newv. If we just pass &newv + // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails. + vbox = new(V) + *vbox = newv + } + if head.value.CompareAndSwap(oldvp, vbox) { + return true + } + // We need to restart from the head of the overflow list in case, due to a removal, a node + // is moved up the list and we miss it. + continue outerLoop } - return e, true - } - i := &head.overflow - e := i.Load() - for e != nil { - if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) { - eNew := newEntryNode(key, new) - eNew.overflow.Store(e.overflow.Load()) - i.Store(eNew) - return head, true + i := &head.overflow + e := i.Load() + for e != nil { + oldvp := e.value.Load() + if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) { + if vbox == nil { + // Delay explicit creation of a new value to hold newv. If we just pass &newv + // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails. + vbox = new(V) + *vbox = newv + } + if e.value.CompareAndSwap(oldvp, vbox) { + return true + } + continue outerLoop + } + i = &e.overflow + e = e.overflow.Load() } - i = &e.overflow - e = e.overflow.Load() + return false } - return head, false } // loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new @@ -660,14 +693,14 @@ func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) ( func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) { if head.key == key { // Drop the head of the list. - return head.value, head.overflow.Load(), true + return *head.value.Load(), head.overflow.Load(), true } i := &head.overflow e := i.Load() for e != nil { if e.key == key { i.Store(e.overflow.Load()) - return e.value, head, true + return *e.value.Load(), head, true } i = &e.overflow e = e.overflow.Load() @@ -680,14 +713,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) { // // compareAndDelete must be called under the mutex of the indirect node which e is a child of. func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) { - if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) { + if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) { // Drop the head of the list. return head.overflow.Load(), true } i := &head.overflow e := i.Load() for e != nil { - if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) { + if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) { i.Store(e.overflow.Load()) return head, true } -- 2.48.1