]> Cypherpunks repositories - gostls13.git/commitdiff
sync/atomic: add (*Value).Swap and (*Value).CompareAndSwap
authorColin Arnott <colin@urandom.co.uk>
Thu, 9 Jul 2020 07:06:46 +0000 (07:06 +0000)
committerKeith Randall <khr@golang.org>
Tue, 4 May 2021 00:15:27 +0000 (00:15 +0000)
The functions SwapPointer and CompareAndSwapPointer can be used to
interact with unsafe.Pointer, however generally it is prefered to work
with Value, due to its safer interface. As such, they have been added
along with glue logic to maintain invariants Value guarantees.

To meet these guarantees, the current implementation duplicates much of
the Store function. Some of this is due to inexperience with concurrency
and desire for correctness, but the lack of generic programming
functionality does not help.

Fixes #39351

Change-Id: I1aa394b1e70944736ac1e19de49fe861e1e46fba
Reviewed-on: https://go-review.googlesource.com/c/go/+/241678
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
src/sync/atomic/value.go
src/sync/atomic/value_test.go

index eab7e70c9b0f3ed5155cbf3c6a9e73cb4154c2c8..61f81d8fd377b06e49e49b70bf88918589766998 100644 (file)
@@ -25,7 +25,7 @@ type ifaceWords struct {
 
 // Load returns the value set by the most recent Store.
 // It returns nil if there has been no call to Store for this Value.
-func (v *Value) Load() (x interface{}) {
+func (v *Value) Load() (val interface{}) {
        vp := (*ifaceWords)(unsafe.Pointer(v))
        typ := LoadPointer(&vp.typ)
        if typ == nil || uintptr(typ) == ^uintptr(0) {
@@ -33,21 +33,21 @@ func (v *Value) Load() (x interface{}) {
                return nil
        }
        data := LoadPointer(&vp.data)
-       xp := (*ifaceWords)(unsafe.Pointer(&x))
-       xp.typ = typ
-       xp.data = data
+       vlp := (*ifaceWords)(unsafe.Pointer(&val))
+       vlp.typ = typ
+       vlp.data = data
        return
 }
 
 // Store sets the value of the Value to x.
 // All calls to Store for a given Value must use values of the same concrete type.
 // Store of an inconsistent type panics, as does Store(nil).
-func (v *Value) Store(x interface{}) {
-       if x == nil {
+func (v *Value) Store(val interface{}) {
+       if val == nil {
                panic("sync/atomic: store of nil value into Value")
        }
        vp := (*ifaceWords)(unsafe.Pointer(v))
-       xp := (*ifaceWords)(unsafe.Pointer(&x))
+       vlp := (*ifaceWords)(unsafe.Pointer(&val))
        for {
                typ := LoadPointer(&vp.typ)
                if typ == nil {
@@ -61,8 +61,8 @@ func (v *Value) Store(x interface{}) {
                                continue
                        }
                        // Complete first store.
-                       StorePointer(&vp.data, xp.data)
-                       StorePointer(&vp.typ, xp.typ)
+                       StorePointer(&vp.data, vlp.data)
+                       StorePointer(&vp.typ, vlp.typ)
                        runtime_procUnpin()
                        return
                }
@@ -73,14 +73,121 @@ func (v *Value) Store(x interface{}) {
                        continue
                }
                // First store completed. Check type and overwrite data.
-               if typ != xp.typ {
+               if typ != vlp.typ {
                        panic("sync/atomic: store of inconsistently typed value into Value")
                }
-               StorePointer(&vp.data, xp.data)
+               StorePointer(&vp.data, vlp.data)
                return
        }
 }
 
+// Swap stores new into Value and returns the previous value. It returns nil if
+// the Value is empty.
+//
+// All calls to Swap for a given Value must use values of the same concrete
+// type. Swap of an inconsistent type panics, as does Swap(nil).
+func (v *Value) Swap(new interface{}) (old interface{}) {
+       if new == nil {
+               panic("sync/atomic: swap of nil value into Value")
+       }
+       vp := (*ifaceWords)(unsafe.Pointer(v))
+       np := (*ifaceWords)(unsafe.Pointer(&new))
+       for {
+               typ := LoadPointer(&vp.typ)
+               if typ == nil {
+                       // Attempt to start first store.
+                       // Disable preemption so that other goroutines can use
+                       // active spin wait to wait for completion; and so that
+                       // GC does not see the fake type accidentally.
+                       runtime_procPin()
+                       if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(^uintptr(0))) {
+                               runtime_procUnpin()
+                               continue
+                       }
+                       // Complete first store.
+                       StorePointer(&vp.data, np.data)
+                       StorePointer(&vp.typ, np.typ)
+                       runtime_procUnpin()
+                       return nil
+               }
+               if uintptr(typ) == ^uintptr(0) {
+                       // First store in progress. Wait.
+                       // Since we disable preemption around the first store,
+                       // we can wait with active spinning.
+                       continue
+               }
+               // First store completed. Check type and overwrite data.
+               if typ != np.typ {
+                       panic("sync/atomic: swap of inconsistently typed value into Value")
+               }
+               op := (*ifaceWords)(unsafe.Pointer(&old))
+               op.typ, op.data = np.typ, SwapPointer(&vp.data, np.data)
+               return old
+       }
+}
+
+// CompareAndSwapPointer executes the compare-and-swap operation for the Value.
+//
+// All calls to CompareAndSwap for a given Value must use values of the same
+// concrete type. CompareAndSwap of an inconsistent type panics, as does
+// CompareAndSwap(old, nil).
+func (v *Value) CompareAndSwap(old, new interface{}) (swapped bool) {
+       if new == nil {
+               panic("sync/atomic: compare and swap of nil value into Value")
+       }
+       vp := (*ifaceWords)(unsafe.Pointer(v))
+       np := (*ifaceWords)(unsafe.Pointer(&new))
+       op := (*ifaceWords)(unsafe.Pointer(&old))
+       if op.typ != nil && np.typ != op.typ {
+               panic("sync/atomic: compare and swap of inconsistently typed values")
+       }
+       for {
+               typ := LoadPointer(&vp.typ)
+               if typ == nil {
+                       if old != nil {
+                               return false
+                       }
+                       // Attempt to start first store.
+                       // Disable preemption so that other goroutines can use
+                       // active spin wait to wait for completion; and so that
+                       // GC does not see the fake type accidentally.
+                       runtime_procPin()
+                       if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(^uintptr(0))) {
+                               runtime_procUnpin()
+                               continue
+                       }
+                       // Complete first store.
+                       StorePointer(&vp.data, np.data)
+                       StorePointer(&vp.typ, np.typ)
+                       runtime_procUnpin()
+                       return true
+               }
+               if uintptr(typ) == ^uintptr(0) {
+                       // First store in progress. Wait.
+                       // Since we disable preemption around the first store,
+                       // we can wait with active spinning.
+                       continue
+               }
+               // First store completed. Check type and overwrite data.
+               if typ != np.typ {
+                       panic("sync/atomic: compare and swap of inconsistently typed value into Value")
+               }
+               // Compare old and current via runtime equality check.
+               // This allows value types to be compared, something
+               // not offered by the package functions.
+               // CompareAndSwapPointer below only ensures vp.data
+               // has not changed since LoadPointer.
+               data := LoadPointer(&vp.data)
+               var i interface{}
+               (*ifaceWords)(unsafe.Pointer(&i)).typ = typ
+               (*ifaceWords)(unsafe.Pointer(&i)).data = data
+               if i != old {
+                       return false
+               }
+               return CompareAndSwapPointer(&vp.data, data, np.data)
+       }
+}
+
 // Disable/enable preemption, implemented in runtime.
 func runtime_procPin()
 func runtime_procUnpin()
index f289766340416a388dab8a5d8ee0277258988197..a5e717d6e0ccd1e3e5ff0c2165e54716df6a1c0d 100644 (file)
@@ -7,6 +7,9 @@ package atomic_test
 import (
        "math/rand"
        "runtime"
+       "strconv"
+       "sync"
+       "sync/atomic"
        . "sync/atomic"
        "testing"
 )
@@ -133,3 +136,139 @@ func BenchmarkValueRead(b *testing.B) {
                }
        })
 }
+
+var Value_SwapTests = []struct {
+       init interface{}
+       new  interface{}
+       want interface{}
+       err  interface{}
+}{
+       {init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"},
+       {init: nil, new: true, want: nil, err: nil},
+       {init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"},
+       {init: true, new: false, want: true, err: nil},
+}
+
+func TestValue_Swap(t *testing.T) {
+       for i, tt := range Value_SwapTests {
+               t.Run(strconv.Itoa(i), func(t *testing.T) {
+                       var v Value
+                       if tt.init != nil {
+                               v.Store(tt.init)
+                       }
+                       defer func() {
+                               err := recover()
+                               switch {
+                               case tt.err == nil && err != nil:
+                                       t.Errorf("should not panic, got %v", err)
+                               case tt.err != nil && err == nil:
+                                       t.Errorf("should panic %v, got <nil>", tt.err)
+                               }
+                       }()
+                       if got := v.Swap(tt.new); got != tt.want {
+                               t.Errorf("got %v, want %v", got, tt.want)
+                       }
+                       if got := v.Load(); got != tt.new {
+                               t.Errorf("got %v, want %v", got, tt.new)
+                       }
+               })
+       }
+}
+
+func TestValueSwapConcurrent(t *testing.T) {
+       var v Value
+       var count uint64
+       var g sync.WaitGroup
+       var m, n uint64 = 10000, 10000
+       if testing.Short() {
+               m = 1000
+               n = 1000
+       }
+       for i := uint64(0); i < m*n; i += n {
+               i := i
+               g.Add(1)
+               go func() {
+                       var c uint64
+                       for new := i; new < i+n; new++ {
+                               if old := v.Swap(new); old != nil {
+                                       c += old.(uint64)
+                               }
+                       }
+                       atomic.AddUint64(&count, c)
+                       g.Done()
+               }()
+       }
+       g.Wait()
+       if want, got := (m*n-1)*(m*n)/2, count+v.Load().(uint64); got != want {
+               t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want)
+       }
+}
+
+var heapA, heapB = struct{ uint }{0}, struct{ uint }{0}
+
+var Value_CompareAndSwapTests = []struct {
+       init interface{}
+       new  interface{}
+       old  interface{}
+       want bool
+       err  interface{}
+}{
+       {init: nil, new: nil, old: nil, err: "sync/atomic: compare and swap of nil value into Value"},
+       {init: nil, new: true, old: "", err: "sync/atomic: compare and swap of inconsistently typed values into Value"},
+       {init: nil, new: true, old: true, want: false, err: nil},
+       {init: nil, new: true, old: nil, want: true, err: nil},
+       {init: true, new: "", err: "sync/atomic: compare and swap of inconsistently typed value into Value"},
+       {init: true, new: true, old: false, want: false, err: nil},
+       {init: true, new: true, old: true, want: true, err: nil},
+       {init: heapA, new: struct{ uint }{1}, old: heapB, want: true, err: nil},
+}
+
+func TestValue_CompareAndSwap(t *testing.T) {
+       for i, tt := range Value_CompareAndSwapTests {
+               t.Run(strconv.Itoa(i), func(t *testing.T) {
+                       var v Value
+                       if tt.init != nil {
+                               v.Store(tt.init)
+                       }
+                       defer func() {
+                               err := recover()
+                               switch {
+                               case tt.err == nil && err != nil:
+                                       t.Errorf("got %v, wanted no panic", err)
+                               case tt.err != nil && err == nil:
+                                       t.Errorf("did not panic, want %v", tt.err)
+                               }
+                       }()
+                       if got := v.CompareAndSwap(tt.old, tt.new); got != tt.want {
+                               t.Errorf("got %v, want %v", got, tt.want)
+                       }
+               })
+       }
+}
+
+func TestValueCompareAndSwapConcurrent(t *testing.T) {
+       var v Value
+       var w sync.WaitGroup
+       v.Store(0)
+       m, n := 1000, 100
+       if testing.Short() {
+               m = 100
+               n = 100
+       }
+       for i := 0; i < m; i++ {
+               i := i
+               w.Add(1)
+               go func() {
+                       for j := i; j < m*n; runtime.Gosched() {
+                               if v.CompareAndSwap(j, j+1) {
+                                       j += m
+                               }
+                       }
+                       w.Done()
+               }()
+       }
+       w.Wait()
+       if stop := v.Load().(int); stop != m*n {
+               t.Errorf("did not get to %v, stopped at %v", m*n, stop)
+       }
+}