]> Cypherpunks repositories - gostls13.git/commitdiff
Revert "internal/sync: optimize CompareAndSwap and Swap"
authorMichael Knyszek <mknyszek@google.com>
Tue, 24 Dec 2024 01:59:28 +0000 (17:59 -0800)
committerGopher Robot <gobot@golang.org>
Mon, 6 Jan 2025 17:57:18 +0000 (09:57 -0800)
This reverts CL 606462.

Reason for revert: Breaks atomicity between operations. See #70970.

Change-Id: I1a899f2784da5a0f9da3193e3267275c23aea661
Reviewed-on: https://go-review.googlesource.com/c/go/+/638615
Auto-Submit: Michael Knyszek <mknyszek@google.com>
Commit-Queue: Michael Knyszek <mknyszek@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
src/internal/sync/hashtriemap.go

index d31d81df39ca193c40a492c2950d3cc9b1f08722..defcd0b793947f1025694c584defe10f66898604 100644 (file)
@@ -219,22 +219,12 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
 
                        slot = &i.children[(hash>>hashShift)&nChildrenMask]
                        n = slot.Load()
-                       if n == nil {
+                       if n == nil || n.isEntry {
                                // We found a nil slot which is a candidate for insertion,
                                // or an existing entry that we'll replace.
                                haveInsertPoint = true
                                break
                        }
-                       if n.isEntry {
-                               // Swap if the keys compare.
-                               old, swapped := n.entry().swap(key, new)
-                               if swapped {
-                                       return old, true
-                               }
-                               // If we fail, that means we should try to insert.
-                               haveInsertPoint = true
-                               break
-                       }
                        i = n.indirect()
                }
                if !haveInsertPoint {
@@ -261,10 +251,11 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
        var zero V
        var oldEntry *entry[K, V]
        if n != nil {
-               // Between before and now, something got inserted. Swap if the keys compare.
+               // Swap if the keys compare.
                oldEntry = n.entry()
-               old, swapped := oldEntry.swap(key, new)
+               newEntry, old, swapped := oldEntry.swap(key, new)
                if swapped {
+                       slot.Store(&newEntry.node)
                        return old, true
                }
        }
@@ -292,30 +283,25 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
                panic("called CompareAndSwap when value is not of comparable type")
        }
        hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
-       for {
-               // Find the key or return if it's not there.
-               i := ht.root.Load()
-               hashShift := 8 * goarch.PtrSize
-               found := false
-               for hashShift != 0 {
-                       hashShift -= nChildrenLog2
 
-                       slot := &i.children[(hash>>hashShift)&nChildrenMask]
-                       n := slot.Load()
-                       if n == nil {
-                               // Nothing to compare with. Give up.
-                               return false
-                       }
-                       if n.isEntry {
-                               // We found an entry. Try to compare and swap directly.
-                               return n.entry().compareAndSwap(key, old, new, ht.valEqual)
-                       }
-                       i = n.indirect()
-               }
-               if !found {
-                       panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
-               }
+       // Find a node with the key and compare with it. n != nil if we found the node.
+       i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
+       if i != nil {
+               defer i.mu.Unlock()
        }
+       if n == nil {
+               return false
+       }
+
+       // Try to swap the entry.
+       e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
+       if !swapped {
+               // Nothing was actually swapped, which means the node is no longer there.
+               return false
+       }
+       // Store the entry back because it changed.
+       slot.Store(&e.node)
+       return true
 }
 
 // LoadAndDelete deletes the value for a key, returning the previous value if any.
@@ -523,7 +509,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V)
                }
                e := n.entry()
                for e != nil {
-                       if !yield(e.key, *e.value.Load()) {
+                       if !yield(e.key, e.value) {
                                return false
                        }
                        e = e.overflow.Load()
@@ -579,22 +565,21 @@ type entry[K comparable, V any] struct {
        node[K, V]
        overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions.
        key      K
-       value    atomic.Pointer[V]
+       value    V
 }
 
 func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
-       e := &entry[K, V]{
-               node: node[K, V]{isEntry: true},
-               key:  key,
+       return &entry[K, V]{
+               node:  node[K, V]{isEntry: true},
+               key:   key,
+               value: value,
        }
-       e.value.Store(&value)
-       return e
 }
 
 func (e *entry[K, V]) lookup(key K) (V, bool) {
        for e != nil {
                if e.key == key {
-                       return *e.value.Load(), true
+                       return e.value, true
                }
                e = e.overflow.Load()
        }
@@ -603,87 +588,69 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
 
 func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
        for e != nil {
-               oldp := e.value.Load()
-               if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) {
-                       return *oldp, true
+               if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
+                       return e.value, true
                }
                e = e.overflow.Load()
        }
        return *new(V), false
 }
 
-// swap replaces a value in the overflow chain if keys compare equal.
-// Returns the old value, and whether or not anything was swapped.
+// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
+// the old value, and whether or not anything was swapped.
 //
 // swap must be called under the mutex of the indirect node which e is a child of.
-func (head *entry[K, V]) swap(key K, newv V) (V, bool) {
+func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
        if head.key == key {
-               vp := new(V)
-               *vp = newv
-               oldp := head.value.Swap(vp)
-               return *oldp, true
+               // Return the new head of the list.
+               e := newEntryNode(key, new)
+               if chain := head.overflow.Load(); chain != nil {
+                       e.overflow.Store(chain)
+               }
+               return e, head.value, true
        }
        i := &head.overflow
        e := i.Load()
        for e != nil {
                if e.key == key {
-                       vp := new(V)
-                       *vp = newv
-                       oldp := e.value.Swap(vp)
-                       return *oldp, true
+                       eNew := newEntryNode(key, new)
+                       eNew.overflow.Store(e.overflow.Load())
+                       i.Store(eNew)
+                       return head, e.value, true
                }
                i = &e.overflow
                e = e.overflow.Load()
        }
        var zero V
-       return zero, false
+       return head, zero, false
 }
 
-// compareAndSwap replaces a value for a matching key and existing value in the overflow chain.
-// Returns whether or not anything was swapped.
+// compareAndSwap replaces an entry in the overflow chain if both the key and value compare
+// equal. Returns the new entry chain and whether or not anything was swapped.
 //
 // compareAndSwap must be called under the mutex of the indirect node which e is a child of.
-func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool {
-       var vbox *V
-outerLoop:
-       for {
-               oldvp := head.value.Load()
-               if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
-                       // Return the new head of the list.
-                       if vbox == nil {
-                               // Delay explicit creation of a new value to hold newv. If we just pass &newv
-                               // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
-                               vbox = new(V)
-                               *vbox = newv
-                       }
-                       if head.value.CompareAndSwap(oldvp, vbox) {
-                               return true
-                       }
-                       // We need to restart from the head of the overflow list in case, due to a removal, a node
-                       // is moved up the list and we miss it.
-                       continue outerLoop
+func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
+       if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
+               // Return the new head of the list.
+               e := newEntryNode(key, new)
+               if chain := head.overflow.Load(); chain != nil {
+                       e.overflow.Store(chain)
                }
-               i := &head.overflow
-               e := i.Load()
-               for e != nil {
-                       oldvp := e.value.Load()
-                       if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
-                               if vbox == nil {
-                                       // Delay explicit creation of a new value to hold newv. If we just pass &newv
-                                       // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
-                                       vbox = new(V)
-                                       *vbox = newv
-                               }
-                               if e.value.CompareAndSwap(oldvp, vbox) {
-                                       return true
-                               }
-                               continue outerLoop
-                       }
-                       i = &e.overflow
-                       e = e.overflow.Load()
+               return e, true
+       }
+       i := &head.overflow
+       e := i.Load()
+       for e != nil {
+               if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
+                       eNew := newEntryNode(key, new)
+                       eNew.overflow.Store(e.overflow.Load())
+                       i.Store(eNew)
+                       return head, true
                }
-               return false
+               i = &e.overflow
+               e = e.overflow.Load()
        }
+       return head, false
 }
 
 // loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new
@@ -693,14 +660,14 @@ outerLoop:
 func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
        if head.key == key {
                // Drop the head of the list.
-               return *head.value.Load(), head.overflow.Load(), true
+               return head.value, head.overflow.Load(), true
        }
        i := &head.overflow
        e := i.Load()
        for e != nil {
                if e.key == key {
                        i.Store(e.overflow.Load())
-                       return *e.value.Load(), head, true
+                       return e.value, head, true
                }
                i = &e.overflow
                e = e.overflow.Load()
@@ -713,14 +680,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
 //
 // compareAndDelete must be called under the mutex of the indirect node which e is a child of.
 func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
-       if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
+       if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
                // Drop the head of the list.
                return head.overflow.Load(), true
        }
        i := &head.overflow
        e := i.Load()
        for e != nil {
-               if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
+               if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
                        i.Store(e.overflow.Load())
                        return head, true
                }