From c112c0af1328ef0aae989ae20d27359a18f72543 Mon Sep 17 00:00:00 2001 From: Michael Knyszek Date: Mon, 23 Dec 2024 17:59:28 -0800 Subject: [PATCH] Revert "internal/sync: optimize CompareAndSwap and Swap" This reverts CL 606462. Reason for revert: Breaks atomicity between operations. See #70970. Change-Id: I1a899f2784da5a0f9da3193e3267275c23aea661 Reviewed-on: https://go-review.googlesource.com/c/go/+/638615 Auto-Submit: Michael Knyszek Commit-Queue: Michael Knyszek LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- src/internal/sync/hashtriemap.go | 171 +++++++++++++------------------ 1 file changed, 69 insertions(+), 102 deletions(-) diff --git a/src/internal/sync/hashtriemap.go b/src/internal/sync/hashtriemap.go index d31d81df39..defcd0b793 100644 --- a/src/internal/sync/hashtriemap.go +++ b/src/internal/sync/hashtriemap.go @@ -219,22 +219,12 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) { slot = &i.children[(hash>>hashShift)&nChildrenMask] n = slot.Load() - if n == nil { + if n == nil || n.isEntry { // We found a nil slot which is a candidate for insertion, // or an existing entry that we'll replace. haveInsertPoint = true break } - if n.isEntry { - // Swap if the keys compare. - old, swapped := n.entry().swap(key, new) - if swapped { - return old, true - } - // If we fail, that means we should try to insert. - haveInsertPoint = true - break - } i = n.indirect() } if !haveInsertPoint { @@ -261,10 +251,11 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) { var zero V var oldEntry *entry[K, V] if n != nil { - // Between before and now, something got inserted. Swap if the keys compare. + // Swap if the keys compare. oldEntry = n.entry() - old, swapped := oldEntry.swap(key, new) + newEntry, old, swapped := oldEntry.swap(key, new) if swapped { + slot.Store(&newEntry.node) return old, true } } @@ -292,30 +283,25 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) { panic("called CompareAndSwap when value is not of comparable type") } hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) - for { - // Find the key or return if it's not there. - i := ht.root.Load() - hashShift := 8 * goarch.PtrSize - found := false - for hashShift != 0 { - hashShift -= nChildrenLog2 - slot := &i.children[(hash>>hashShift)&nChildrenMask] - n := slot.Load() - if n == nil { - // Nothing to compare with. Give up. - return false - } - if n.isEntry { - // We found an entry. Try to compare and swap directly. - return n.entry().compareAndSwap(key, old, new, ht.valEqual) - } - i = n.indirect() - } - if !found { - panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") - } + // Find a node with the key and compare with it. n != nil if we found the node. + i, _, slot, n := ht.find(key, hash, ht.valEqual, old) + if i != nil { + defer i.mu.Unlock() } + if n == nil { + return false + } + + // Try to swap the entry. + e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual) + if !swapped { + // Nothing was actually swapped, which means the node is no longer there. + return false + } + // Store the entry back because it changed. + slot.Store(&e.node) + return true } // LoadAndDelete deletes the value for a key, returning the previous value if any. @@ -523,7 +509,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) } e := n.entry() for e != nil { - if !yield(e.key, *e.value.Load()) { + if !yield(e.key, e.value) { return false } e = e.overflow.Load() @@ -579,22 +565,21 @@ type entry[K comparable, V any] struct { node[K, V] overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions. key K - value atomic.Pointer[V] + value V } func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] { - e := &entry[K, V]{ - node: node[K, V]{isEntry: true}, - key: key, + return &entry[K, V]{ + node: node[K, V]{isEntry: true}, + key: key, + value: value, } - e.value.Store(&value) - return e } func (e *entry[K, V]) lookup(key K) (V, bool) { for e != nil { if e.key == key { - return *e.value.Load(), true + return e.value, true } e = e.overflow.Load() } @@ -603,87 +588,69 @@ func (e *entry[K, V]) lookup(key K) (V, bool) { func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) { for e != nil { - oldp := e.value.Load() - if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) { - return *oldp, true + if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) { + return e.value, true } e = e.overflow.Load() } return *new(V), false } -// swap replaces a value in the overflow chain if keys compare equal. -// Returns the old value, and whether or not anything was swapped. +// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain, +// the old value, and whether or not anything was swapped. // // swap must be called under the mutex of the indirect node which e is a child of. -func (head *entry[K, V]) swap(key K, newv V) (V, bool) { +func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) { if head.key == key { - vp := new(V) - *vp = newv - oldp := head.value.Swap(vp) - return *oldp, true + // Return the new head of the list. + e := newEntryNode(key, new) + if chain := head.overflow.Load(); chain != nil { + e.overflow.Store(chain) + } + return e, head.value, true } i := &head.overflow e := i.Load() for e != nil { if e.key == key { - vp := new(V) - *vp = newv - oldp := e.value.Swap(vp) - return *oldp, true + eNew := newEntryNode(key, new) + eNew.overflow.Store(e.overflow.Load()) + i.Store(eNew) + return head, e.value, true } i = &e.overflow e = e.overflow.Load() } var zero V - return zero, false + return head, zero, false } -// compareAndSwap replaces a value for a matching key and existing value in the overflow chain. -// Returns whether or not anything was swapped. +// compareAndSwap replaces an entry in the overflow chain if both the key and value compare +// equal. Returns the new entry chain and whether or not anything was swapped. // // compareAndSwap must be called under the mutex of the indirect node which e is a child of. -func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool { - var vbox *V -outerLoop: - for { - oldvp := head.value.Load() - if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) { - // Return the new head of the list. - if vbox == nil { - // Delay explicit creation of a new value to hold newv. If we just pass &newv - // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails. - vbox = new(V) - *vbox = newv - } - if head.value.CompareAndSwap(oldvp, vbox) { - return true - } - // We need to restart from the head of the overflow list in case, due to a removal, a node - // is moved up the list and we miss it. - continue outerLoop +func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) { + if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) { + // Return the new head of the list. + e := newEntryNode(key, new) + if chain := head.overflow.Load(); chain != nil { + e.overflow.Store(chain) } - i := &head.overflow - e := i.Load() - for e != nil { - oldvp := e.value.Load() - if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) { - if vbox == nil { - // Delay explicit creation of a new value to hold newv. If we just pass &newv - // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails. - vbox = new(V) - *vbox = newv - } - if e.value.CompareAndSwap(oldvp, vbox) { - return true - } - continue outerLoop - } - i = &e.overflow - e = e.overflow.Load() + return e, true + } + i := &head.overflow + e := i.Load() + for e != nil { + if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) { + eNew := newEntryNode(key, new) + eNew.overflow.Store(e.overflow.Load()) + i.Store(eNew) + return head, true } - return false + i = &e.overflow + e = e.overflow.Load() } + return head, false } // loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new @@ -693,14 +660,14 @@ outerLoop: func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) { if head.key == key { // Drop the head of the list. - return *head.value.Load(), head.overflow.Load(), true + return head.value, head.overflow.Load(), true } i := &head.overflow e := i.Load() for e != nil { if e.key == key { i.Store(e.overflow.Load()) - return *e.value.Load(), head, true + return e.value, head, true } i = &e.overflow e = e.overflow.Load() @@ -713,14 +680,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) { // // compareAndDelete must be called under the mutex of the indirect node which e is a child of. func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) { - if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) { + if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) { // Drop the head of the list. return head.overflow.Load(), true } i := &head.overflow e := i.Load() for e != nil { - if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) { + if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) { i.Store(e.overflow.Load()) return head, true } -- 2.48.1