return &top.node
}
+// CompareAndSwap swaps the old and new values for key
+// if the value stored in the map is equal to old.
+// The value type must be of a comparable type, otherwise CompareAndSwap will panic.
+func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
+ ht.init()
+ if ht.valEqual == nil {
+ panic("called CompareAndSwap when value is not of comparable type")
+ }
+ hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
+
+ // Find a node with the key and compare with it. n != nil if we found the node.
+ i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
+ if i != nil {
+ defer i.mu.Unlock()
+ }
+ if n == nil {
+ return false
+ }
+
+ // Try to swap the entry.
+ e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
+ if !swapped {
+ // Nothing was actually swapped, which means the node is no longer there.
+ return false
+ }
+ // Store the entry back because it changed.
+ slot.Store(&e.node)
+ return true
+}
+
// CompareAndDelete deletes the entry for key if its value is equal to old.
// The value type must be comparable, otherwise this CompareAndDelete will panic.
//
}
hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
- // Find a node with the key and compare with it. n != nil if we found the node.
- i, hashShift, slot, n := ht.find(key, hash)
+ // Find a node with the key. n != nil if we found the node.
+ i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
if n == nil {
if i != nil {
i.mu.Unlock()
return true
}
-// compare searches the tree for a node that compares with key (hash must be the hash of key).
+// find searches the tree for a node that contains key (hash must be the hash of key).
+// If valEqual != nil, then it will also enforce that the values are equal as well.
//
// Returns a non-nil node, which will always be an entry, if found.
//
// If i != nil then i.mu is locked, and it is the caller's responsibility to unlock it.
-func (ht *HashTrieMap[K, V]) find(key K, hash uintptr) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
+func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
for {
- // Find the key or return when there's nothing to delete.
+ // Find the key or return if it's not there.
i = ht.root
hashShift = 8 * goarch.PtrSize
found := false
}
if n.isEntry {
// We found an entry. Check if it matches.
- if _, ok := n.entry().lookup(key); !ok {
+ if _, ok := n.entry().lookupWithValue(key, value, valEqual); !ok {
// No match, comparison failed.
i = nil
n = nil
return *new(V), false
}
+func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
+ for e != nil {
+ if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
+ return e.value, true
+ }
+ e = e.overflow.Load()
+ }
+ return *new(V), false
+}
+
+// compareAndSwap replaces an entry in the overflow chain if both the key and value compare
+// equal. Returns the new entry chain and whether or not anything was swapped.
+//
+// compareAndSwap must be called under the mutex of the indirect node which e is a child of.
+func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
+ if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
+ // Return the new head of the list.
+ e := newEntryNode(key, new)
+ if chain := head.overflow.Load(); chain != nil {
+ e.overflow.Store(chain)
+ }
+ return e, true
+ }
+ i := &head.overflow
+ e := i.Load()
+ for e != nil {
+ if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
+ eNew := newEntryNode(key, new)
+ eNew.overflow.Store(e.overflow.Load())
+ i.Store(eNew)
+ return head, true
+ }
+ i = &e.overflow
+ e = e.overflow.Load()
+ }
+ return head, false
+}
+
// compareAndDelete deletes an entry in the overflow chain if both the key and value compare
// equal. Returns the new entry chain and whether or not anything was deleted.
//
}
}
})
- t.Run("DeleteMultiple", func(t *testing.T) {
+ t.Run("CompareAndDeleteMultiple", func(t *testing.T) {
m := newMap()
for i, s := range testData {
return true
})
})
- t.Run("AllDelete", func(t *testing.T) {
+ t.Run("AllCompareAndDelete", func(t *testing.T) {
m := newMap()
testAll(t, m, testDataMap(testData[:]), func(s string, i int) bool {
}
wg.Wait()
})
- t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) {
+ t.Run("ConcurrentCompareAndDeleteSharedKeys", func(t *testing.T) {
m := newMap()
// Load up the map.
}
wg.Wait()
})
+ t.Run("CompareAndSwapAll", func(t *testing.T) {
+ m := newMap()
+
+ for i, s := range testData {
+ expectMissing(t, s, 0)(m.Load(s))
+ expectStored(t, s, i)(m.LoadOrStore(s, i))
+ expectPresent(t, s, i)(m.Load(s))
+ expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
+ }
+ for j := range 3 {
+ for i, s := range testData {
+ expectPresent(t, s, i+j)(m.Load(s))
+ expectNotSwapped(t, s, math.MaxInt, i+j+1)(m.CompareAndSwap(s, math.MaxInt, i+j+1))
+ expectSwapped(t, s, i, i+j+1)(m.CompareAndSwap(s, i+j, i+j+1))
+ expectNotSwapped(t, s, i+j, i+j+1)(m.CompareAndSwap(s, i+j, i+j+1))
+ expectPresent(t, s, i+j+1)(m.Load(s))
+ }
+ }
+ for i, s := range testData {
+ expectPresent(t, s, i+3)(m.Load(s))
+ }
+ })
+ t.Run("CompareAndSwapOne", func(t *testing.T) {
+ m := newMap()
+
+ for i, s := range testData {
+ expectMissing(t, s, 0)(m.Load(s))
+ expectStored(t, s, i)(m.LoadOrStore(s, i))
+ expectPresent(t, s, i)(m.Load(s))
+ expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
+ }
+ expectNotSwapped(t, testData[15], math.MaxInt, 16)(m.CompareAndSwap(testData[15], math.MaxInt, 16))
+ expectSwapped(t, testData[15], 15, 16)(m.CompareAndSwap(testData[15], 15, 16))
+ expectNotSwapped(t, testData[15], 15, 16)(m.CompareAndSwap(testData[15], 15, 16))
+ for i, s := range testData {
+ if i == 15 {
+ expectPresent(t, s, 16)(m.Load(s))
+ } else {
+ expectPresent(t, s, i)(m.Load(s))
+ }
+ }
+ })
+ t.Run("CompareAndSwapMultiple", func(t *testing.T) {
+ m := newMap()
+
+ for i, s := range testData {
+ expectMissing(t, s, 0)(m.Load(s))
+ expectStored(t, s, i)(m.LoadOrStore(s, i))
+ expectPresent(t, s, i)(m.Load(s))
+ expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
+ }
+ for _, i := range []int{1, 105, 6, 85} {
+ expectNotSwapped(t, testData[i], math.MaxInt, i+1)(m.CompareAndSwap(testData[i], math.MaxInt, i+1))
+ expectSwapped(t, testData[i], i, i+1)(m.CompareAndSwap(testData[i], i, i+1))
+ expectNotSwapped(t, testData[i], i, i+1)(m.CompareAndSwap(testData[i], i, i+1))
+ }
+ for i, s := range testData {
+ if i == 1 || i == 105 || i == 6 || i == 85 {
+ expectPresent(t, s, i+1)(m.Load(s))
+ } else {
+ expectPresent(t, s, i)(m.Load(s))
+ }
+ }
+ })
+ t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) {
+ m := newMap()
+
+ gmp := runtime.GOMAXPROCS(-1)
+ var wg sync.WaitGroup
+ for i := range gmp {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+
+ makeKey := func(s string) string {
+ return s + "-" + strconv.Itoa(id)
+ }
+ for _, s := range testData {
+ key := makeKey(s)
+ expectMissing(t, key, 0)(m.Load(key))
+ expectStored(t, key, id)(m.LoadOrStore(key, id))
+ expectPresent(t, key, id)(m.Load(key))
+ expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
+ }
+ for _, s := range testData {
+ key := makeKey(s)
+ expectPresent(t, key, id)(m.Load(key))
+ expectSwapped(t, key, id, id+1)(m.CompareAndSwap(key, id, id+1))
+ expectPresent(t, key, id+1)(m.Load(key))
+ }
+ for _, s := range testData {
+ key := makeKey(s)
+ expectPresent(t, key, id+1)(m.Load(key))
+ }
+ }(i)
+ }
+ wg.Wait()
+ })
+ t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) {
+ m := newMap()
+
+ gmp := runtime.GOMAXPROCS(-1)
+ var wg sync.WaitGroup
+ for i := range gmp {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+
+ makeKey := func(s string) string {
+ return s + "-" + strconv.Itoa(id)
+ }
+ for _, s := range testData {
+ key := makeKey(s)
+ expectMissing(t, key, 0)(m.Load(key))
+ expectStored(t, key, id)(m.LoadOrStore(key, id))
+ expectPresent(t, key, id)(m.Load(key))
+ expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
+ }
+ for _, s := range testData {
+ key := makeKey(s)
+ expectPresent(t, key, id)(m.Load(key))
+ expectSwapped(t, key, id, id+1)(m.CompareAndSwap(key, id, id+1))
+ expectPresent(t, key, id+1)(m.Load(key))
+ expectDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1))
+ expectNotSwapped(t, key, id+1, id+2)(m.CompareAndSwap(key, id+1, id+2))
+ expectNotDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1))
+ expectMissing(t, key, 0)(m.Load(key))
+ }
+ for _, s := range testData {
+ key := makeKey(s)
+ expectMissing(t, key, 0)(m.Load(key))
+ }
+ }(i)
+ }
+ wg.Wait()
+ })
+ t.Run("ConcurrentCompareAndSwapSharedKeys", func(t *testing.T) {
+ m := newMap()
+
+ // Load up the map.
+ for i, s := range testData {
+ expectMissing(t, s, 0)(m.Load(s))
+ expectStored(t, s, i)(m.LoadOrStore(s, i))
+ }
+ gmp := runtime.GOMAXPROCS(-1)
+ var wg sync.WaitGroup
+ for i := range gmp {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+
+ for i, s := range testData {
+ expectNotSwapped(t, s, math.MaxInt, i+1)(m.CompareAndSwap(s, math.MaxInt, i+1))
+ m.CompareAndSwap(s, i, i+1)
+ expectPresent(t, s, i+1)(m.Load(s))
+ }
+ for i, s := range testData {
+ expectPresent(t, s, i+1)(m.Load(s))
+ }
+ }(i)
+ }
+ wg.Wait()
+ })
}
func testAll[K, V comparable](t *testing.T, m *isync.HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
}
}
+func expectSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swapped bool) {
+ t.Helper()
+ return func(swapped bool) {
+ t.Helper()
+
+ if !swapped {
+ t.Errorf("expected key %v with value %v to be in map and swapped for %v", key, old, new)
+ }
+ }
+}
+
+func expectNotSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swapped bool) {
+ t.Helper()
+ return func(swapped bool) {
+ t.Helper()
+
+ if swapped {
+ t.Errorf("expected key %v with value %v to not be in map or not swapped for %v", key, old, new)
+ }
+ }
+}
+
func testDataMap(data []string) map[string]int {
m := make(map[string]int)
for i, s := range data {