// runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished.
type WaitGroup struct {
- m Mutex
- counter int32
- waiters int32
- sema *uint32
+ // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
+ // 64-bit atomic operations require 64-bit alignment, but 32-bit
+ // compilers do not ensure it. So we allocate 12 bytes and then use
+ // the aligned 8 bytes in them as state.
+ state1 [12]byte
+ sema uint32
}
-// WaitGroup creates a new semaphore each time the old semaphore
-// is released. This is to avoid the following race:
-//
-// G1: Add(1)
-// G1: go G2()
-// G1: Wait() // Context switch after Unlock() and before Semacquire().
-// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet.
-// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block.
-// G3: Add(1) // Makes counter == 1, waiters == 0.
-// G3: go G4()
-// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug.
+func (wg *WaitGroup) state() *uint64 {
+ if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
+ return (*uint64)(unsafe.Pointer(&wg.state1))
+ } else {
+ return (*uint64)(unsafe.Pointer(&wg.state1[4]))
+ }
+}
// Add adds delta, which may be negative, to the WaitGroup counter.
// If the counter becomes zero, all goroutines blocked on Wait are released.
// at any time.
// Typically this means the calls to Add should execute before the statement
// creating the goroutine or other event to be waited for.
+// If a WaitGroup is reused to wait for several independent sets of events,
+// new Add calls must happen after all previous Wait calls have returned.
// See the WaitGroup example.
func (wg *WaitGroup) Add(delta int) {
+ statep := wg.state()
if raceenabled {
- _ = wg.m.state // trigger nil deref early
+ _ = *statep // trigger nil deref early
if delta < 0 {
// Synchronize decrements with Wait.
raceReleaseMerge(unsafe.Pointer(wg))
raceDisable()
defer raceEnable()
}
- v := atomic.AddInt32(&wg.counter, int32(delta))
+ state := atomic.AddUint64(statep, uint64(delta)<<32)
+ v := int32(state >> 32)
+ w := uint32(state)
if raceenabled {
if delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
if v < 0 {
panic("sync: negative WaitGroup counter")
}
- if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 {
+ if w != 0 && delta > 0 && v == int32(delta) {
+ panic("sync: WaitGroup misuse: Add called concurrently with Wait")
+ }
+ if v > 0 || w == 0 {
return
}
- wg.m.Lock()
- if atomic.LoadInt32(&wg.counter) == 0 {
- for i := int32(0); i < wg.waiters; i++ {
- runtime_Semrelease(wg.sema)
- }
- wg.waiters = 0
- wg.sema = nil
+ // This goroutine has set counter to 0 when waiters > 0.
+ // Now there can't be concurrent mutations of state:
+ // - Adds must not happen concurrently with Wait,
+ // - Wait does not increment waiters if it sees counter == 0.
+ // Still do a cheap sanity check to detect WaitGroup misuse.
+ if *statep != state {
+ panic("sync: WaitGroup misuse: Add called concurrently with Wait")
+ }
+ // Reset waiters count to 0.
+ *statep = 0
+ for ; w != 0; w-- {
+ runtime_Semrelease(&wg.sema)
}
- wg.m.Unlock()
}
// Done decrements the WaitGroup counter.
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
+ statep := wg.state()
if raceenabled {
- _ = wg.m.state // trigger nil deref early
+ _ = *statep // trigger nil deref early
raceDisable()
}
- if atomic.LoadInt32(&wg.counter) == 0 {
- if raceenabled {
- raceEnable()
- raceAcquire(unsafe.Pointer(wg))
- }
- return
- }
- wg.m.Lock()
- w := atomic.AddInt32(&wg.waiters, 1)
- // This code is racing with the unlocked path in Add above.
- // The code above modifies counter and then reads waiters.
- // We must modify waiters and then read counter (the opposite order)
- // to avoid missing an Add.
- if atomic.LoadInt32(&wg.counter) == 0 {
- atomic.AddInt32(&wg.waiters, -1)
- if raceenabled {
- raceEnable()
- raceAcquire(unsafe.Pointer(wg))
- raceDisable()
+ for {
+ state := atomic.LoadUint64(statep)
+ v := int32(state >> 32)
+ w := uint32(state)
+ if v == 0 {
+ // Counter is 0, no need to wait.
+ if raceenabled {
+ raceEnable()
+ raceAcquire(unsafe.Pointer(wg))
+ }
+ return
}
- wg.m.Unlock()
- if raceenabled {
- raceEnable()
+ // Increment waiters count.
+ if atomic.CompareAndSwapUint64(statep, state, state+1) {
+ if raceenabled && w == 0 {
+ // Wait must be synchronized with the first Add.
+ // Need to model this is as a write to race with the read in Add.
+ // As a consequence, can do the write only for the first waiter,
+ // otherwise concurrent Waits will race with each other.
+ raceWrite(unsafe.Pointer(&wg.sema))
+ }
+ runtime_Semacquire(&wg.sema)
+ if *statep != 0 {
+ panic("sync: WaitGroup is reused before previous Wait has returned")
+ }
+ if raceenabled {
+ raceEnable()
+ raceAcquire(unsafe.Pointer(wg))
+ }
+ return
}
- return
- }
- if raceenabled && w == 1 {
- // Wait must be synchronized with the first Add.
- // Need to model this is as a write to race with the read in Add.
- // As a consequence, can do the write only for the first waiter,
- // otherwise concurrent Waits will race with each other.
- raceWrite(unsafe.Pointer(&wg.sema))
- }
- if wg.sema == nil {
- wg.sema = new(uint32)
- }
- s := wg.sema
- wg.m.Unlock()
- runtime_Semacquire(s)
- if raceenabled {
- raceEnable()
- raceAcquire(unsafe.Pointer(wg))
}
}
package sync_test
import (
+ "runtime"
. "sync"
"sync/atomic"
"testing"
t.Fatal("Should panic")
}
+func TestWaitGroupMisuse2(t *testing.T) {
+ if runtime.NumCPU() <= 2 {
+ t.Skip("NumCPU<=2, skipping: this test requires parallelism")
+ }
+ defer func() {
+ err := recover()
+ if err != "sync: negative WaitGroup counter" &&
+ err != "sync: WaitGroup misuse: Add called concurrently with Wait" &&
+ err != "sync: WaitGroup is reused before previous Wait has returned" {
+ t.Fatalf("Unexpected panic: %#v", err)
+ }
+ }()
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+ done := make(chan interface{}, 2)
+ // The detection is opportunistically, so we want it to panic
+ // at least in one run out of a million.
+ for i := 0; i < 1e6; i++ {
+ var wg WaitGroup
+ wg.Add(1)
+ go func() {
+ defer func() {
+ done <- recover()
+ }()
+ wg.Wait()
+ }()
+ go func() {
+ defer func() {
+ done <- recover()
+ }()
+ wg.Add(1) // This is the bad guy.
+ wg.Done()
+ }()
+ wg.Done()
+ for j := 0; j < 2; j++ {
+ if err := <-done; err != nil {
+ panic(err)
+ }
+ }
+ }
+ t.Fatal("Should panic")
+}
+
+func TestWaitGroupMisuse3(t *testing.T) {
+ if runtime.NumCPU() <= 1 {
+ t.Skip("NumCPU==1, skipping: this test requires parallelism")
+ }
+ defer func() {
+ err := recover()
+ if err != "sync: negative WaitGroup counter" &&
+ err != "sync: WaitGroup misuse: Add called concurrently with Wait" &&
+ err != "sync: WaitGroup is reused before previous Wait has returned" {
+ t.Fatalf("Unexpected panic: %#v", err)
+ }
+ }()
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+ done := make(chan interface{}, 1)
+ // The detection is opportunistically, so we want it to panic
+ // at least in one run out of a million.
+ for i := 0; i < 1e6; i++ {
+ var wg WaitGroup
+ wg.Add(1)
+ go func() {
+ wg.Done()
+ }()
+ go func() {
+ defer func() {
+ done <- recover()
+ }()
+ wg.Wait()
+ // Start reusing the wg before waiting for the Wait below to return.
+ wg.Add(1)
+ go func() {
+ wg.Done()
+ }()
+ wg.Wait()
+ }()
+ wg.Wait()
+ if err := <-done; err != nil {
+ panic(err)
+ }
+ }
+ t.Fatal("Should panic")
+}
+
func TestWaitGroupRace(t *testing.T) {
// Run this test for about 1ms.
for i := 0; i < 1000; i++ {
}
}
+func TestWaitGroupAlign(t *testing.T) {
+ type X struct {
+ x byte
+ wg WaitGroup
+ }
+ var x X
+ x.wg.Add(1)
+ go func(x *X) {
+ x.wg.Done()
+ }(&x)
+ x.wg.Wait()
+}
+
func BenchmarkWaitGroupUncontended(b *testing.B) {
type PaddedWaitGroup struct {
WaitGroup
func BenchmarkWaitGroupWaitWork(b *testing.B) {
benchmarkWaitGroupWait(b, 100)
}
+
+func BenchmarkWaitGroupActuallyWait(b *testing.B) {
+ b.ReportAllocs()
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ var wg WaitGroup
+ wg.Add(1)
+ go func() {
+ wg.Done()
+ }()
+ wg.Wait()
+ }
+ })
+}