]> Cypherpunks repositories - gostls13.git/commitdiff
sync: simplify WaitGroup
authorDmitry Vyukov <dvyukov@google.com>
Sun, 8 Feb 2015 15:11:35 +0000 (18:11 +0300)
committerRuss Cox <rsc@golang.org>
Fri, 26 Jun 2015 18:48:29 +0000 (18:48 +0000)
A comment in waitgroup.go describes the following scenario
as the reason to have dynamically created semaphores:

// G1: Add(1)
// G1: go G2()
// G1: Wait() // Context switch after Unlock() and before Semacquire().
// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet.
// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block.
// G3: Add(1) // Makes counter == 1, waiters == 0.
// G3: go G4()
// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug.

However, the scenario is incorrect:
G3: Add(1) happens concurrently with G1: Wait(),
and so there is no reasonable behavior of the program
(G1: Wait() may or may not wait for G3: Add(1) which
can't be the intended behavior).

With this conclusion we can:
1. Remove dynamic allocation of semaphores.
2. Remove the mutex entirely and instead pack counter and waiters
   into single uint64.

This makes the logic significantly simpler, both Add and Wait
do only a single atomic RMW to update the state.

benchmark                            old ns/op     new ns/op     delta
BenchmarkWaitGroupUncontended        30.6          32.7          +6.86%
BenchmarkWaitGroupActuallyWait       722           595           -17.59%
BenchmarkWaitGroupActuallyWait-2     396           319           -19.44%
BenchmarkWaitGroupActuallyWait-4     224           183           -18.30%
BenchmarkWaitGroupActuallyWait-8     134           106           -20.90%

benchmark                          old allocs     new allocs     delta
BenchmarkWaitGroupActuallyWait     2              1              -50.00%

benchmark                          old bytes     new bytes     delta
BenchmarkWaitGroupActuallyWait     48            16            -66.67%

Change-Id: I28911f3243aa16544e99ac8f1f5af31944c7ea3a
Reviewed-on: https://go-review.googlesource.com/4117
Run-TryBot: Dmitry Vyukov <dvyukov@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
src/sync/waitgroup.go
src/sync/waitgroup_test.go

index 92cc57d2cc87eaf9490203c3e3a45d5ca39081a6..de399e64ebbfc83974d00b7fbc09f2333a51f29a 100644 (file)
@@ -15,23 +15,21 @@ import (
 // runs and calls Done when finished.  At the same time,
 // Wait can be used to block until all goroutines have finished.
 type WaitGroup struct {
-       m       Mutex
-       counter int32
-       waiters int32
-       sema    *uint32
+       // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
+       // 64-bit atomic operations require 64-bit alignment, but 32-bit
+       // compilers do not ensure it. So we allocate 12 bytes and then use
+       // the aligned 8 bytes in them as state.
+       state1 [12]byte
+       sema   uint32
 }
 
-// WaitGroup creates a new semaphore each time the old semaphore
-// is released. This is to avoid the following race:
-//
-// G1: Add(1)
-// G1: go G2()
-// G1: Wait() // Context switch after Unlock() and before Semacquire().
-// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet.
-// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block.
-// G3: Add(1) // Makes counter == 1, waiters == 0.
-// G3: go G4()
-// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug.
+func (wg *WaitGroup) state() *uint64 {
+       if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
+               return (*uint64)(unsafe.Pointer(&wg.state1))
+       } else {
+               return (*uint64)(unsafe.Pointer(&wg.state1[4]))
+       }
+}
 
 // Add adds delta, which may be negative, to the WaitGroup counter.
 // If the counter becomes zero, all goroutines blocked on Wait are released.
@@ -43,10 +41,13 @@ type WaitGroup struct {
 // at any time.
 // Typically this means the calls to Add should execute before the statement
 // creating the goroutine or other event to be waited for.
+// If a WaitGroup is reused to wait for several independent sets of events,
+// new Add calls must happen after all previous Wait calls have returned.
 // See the WaitGroup example.
 func (wg *WaitGroup) Add(delta int) {
+       statep := wg.state()
        if raceenabled {
-               _ = wg.m.state // trigger nil deref early
+               _ = *statep // trigger nil deref early
                if delta < 0 {
                        // Synchronize decrements with Wait.
                        raceReleaseMerge(unsafe.Pointer(wg))
@@ -54,7 +55,9 @@ func (wg *WaitGroup) Add(delta int) {
                raceDisable()
                defer raceEnable()
        }
-       v := atomic.AddInt32(&wg.counter, int32(delta))
+       state := atomic.AddUint64(statep, uint64(delta)<<32)
+       v := int32(state >> 32)
+       w := uint32(state)
        if raceenabled {
                if delta > 0 && v == int32(delta) {
                        // The first increment must be synchronized with Wait.
@@ -66,18 +69,25 @@ func (wg *WaitGroup) Add(delta int) {
        if v < 0 {
                panic("sync: negative WaitGroup counter")
        }
-       if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 {
+       if w != 0 && delta > 0 && v == int32(delta) {
+               panic("sync: WaitGroup misuse: Add called concurrently with Wait")
+       }
+       if v > 0 || w == 0 {
                return
        }
-       wg.m.Lock()
-       if atomic.LoadInt32(&wg.counter) == 0 {
-               for i := int32(0); i < wg.waiters; i++ {
-                       runtime_Semrelease(wg.sema)
-               }
-               wg.waiters = 0
-               wg.sema = nil
+       // This goroutine has set counter to 0 when waiters > 0.
+       // Now there can't be concurrent mutations of state:
+       // - Adds must not happen concurrently with Wait,
+       // - Wait does not increment waiters if it sees counter == 0.
+       // Still do a cheap sanity check to detect WaitGroup misuse.
+       if *statep != state {
+               panic("sync: WaitGroup misuse: Add called concurrently with Wait")
+       }
+       // Reset waiters count to 0.
+       *statep = 0
+       for ; w != 0; w-- {
+               runtime_Semrelease(&wg.sema)
        }
-       wg.m.Unlock()
 }
 
 // Done decrements the WaitGroup counter.
@@ -87,51 +97,41 @@ func (wg *WaitGroup) Done() {
 
 // Wait blocks until the WaitGroup counter is zero.
 func (wg *WaitGroup) Wait() {
+       statep := wg.state()
        if raceenabled {
-               _ = wg.m.state // trigger nil deref early
+               _ = *statep // trigger nil deref early
                raceDisable()
        }
-       if atomic.LoadInt32(&wg.counter) == 0 {
-               if raceenabled {
-                       raceEnable()
-                       raceAcquire(unsafe.Pointer(wg))
-               }
-               return
-       }
-       wg.m.Lock()
-       w := atomic.AddInt32(&wg.waiters, 1)
-       // This code is racing with the unlocked path in Add above.
-       // The code above modifies counter and then reads waiters.
-       // We must modify waiters and then read counter (the opposite order)
-       // to avoid missing an Add.
-       if atomic.LoadInt32(&wg.counter) == 0 {
-               atomic.AddInt32(&wg.waiters, -1)
-               if raceenabled {
-                       raceEnable()
-                       raceAcquire(unsafe.Pointer(wg))
-                       raceDisable()
+       for {
+               state := atomic.LoadUint64(statep)
+               v := int32(state >> 32)
+               w := uint32(state)
+               if v == 0 {
+                       // Counter is 0, no need to wait.
+                       if raceenabled {
+                               raceEnable()
+                               raceAcquire(unsafe.Pointer(wg))
+                       }
+                       return
                }
-               wg.m.Unlock()
-               if raceenabled {
-                       raceEnable()
+               // Increment waiters count.
+               if atomic.CompareAndSwapUint64(statep, state, state+1) {
+                       if raceenabled && w == 0 {
+                               // Wait must be synchronized with the first Add.
+                               // Need to model this is as a write to race with the read in Add.
+                               // As a consequence, can do the write only for the first waiter,
+                               // otherwise concurrent Waits will race with each other.
+                               raceWrite(unsafe.Pointer(&wg.sema))
+                       }
+                       runtime_Semacquire(&wg.sema)
+                       if *statep != 0 {
+                               panic("sync: WaitGroup is reused before previous Wait has returned")
+                       }
+                       if raceenabled {
+                               raceEnable()
+                               raceAcquire(unsafe.Pointer(wg))
+                       }
+                       return
                }
-               return
-       }
-       if raceenabled && w == 1 {
-               // Wait must be synchronized with the first Add.
-               // Need to model this is as a write to race with the read in Add.
-               // As a consequence, can do the write only for the first waiter,
-               // otherwise concurrent Waits will race with each other.
-               raceWrite(unsafe.Pointer(&wg.sema))
-       }
-       if wg.sema == nil {
-               wg.sema = new(uint32)
-       }
-       s := wg.sema
-       wg.m.Unlock()
-       runtime_Semacquire(s)
-       if raceenabled {
-               raceEnable()
-               raceAcquire(unsafe.Pointer(wg))
        }
 }
index 4c0a043c01ee7d6ce058c59032e4ade62075e6b0..06a77798d088d3d11e9a45a23c85dc20b150bdbd 100644 (file)
@@ -5,6 +5,7 @@
 package sync_test
 
 import (
+       "runtime"
        . "sync"
        "sync/atomic"
        "testing"
@@ -60,6 +61,90 @@ func TestWaitGroupMisuse(t *testing.T) {
        t.Fatal("Should panic")
 }
 
+func TestWaitGroupMisuse2(t *testing.T) {
+       if runtime.NumCPU() <= 2 {
+               t.Skip("NumCPU<=2, skipping: this test requires parallelism")
+       }
+       defer func() {
+               err := recover()
+               if err != "sync: negative WaitGroup counter" &&
+                       err != "sync: WaitGroup misuse: Add called concurrently with Wait" &&
+                       err != "sync: WaitGroup is reused before previous Wait has returned" {
+                       t.Fatalf("Unexpected panic: %#v", err)
+               }
+       }()
+       defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+       done := make(chan interface{}, 2)
+       // The detection is opportunistically, so we want it to panic
+       // at least in one run out of a million.
+       for i := 0; i < 1e6; i++ {
+               var wg WaitGroup
+               wg.Add(1)
+               go func() {
+                       defer func() {
+                               done <- recover()
+                       }()
+                       wg.Wait()
+               }()
+               go func() {
+                       defer func() {
+                               done <- recover()
+                       }()
+                       wg.Add(1) // This is the bad guy.
+                       wg.Done()
+               }()
+               wg.Done()
+               for j := 0; j < 2; j++ {
+                       if err := <-done; err != nil {
+                               panic(err)
+                       }
+               }
+       }
+       t.Fatal("Should panic")
+}
+
+func TestWaitGroupMisuse3(t *testing.T) {
+       if runtime.NumCPU() <= 1 {
+               t.Skip("NumCPU==1, skipping: this test requires parallelism")
+       }
+       defer func() {
+               err := recover()
+               if err != "sync: negative WaitGroup counter" &&
+                       err != "sync: WaitGroup misuse: Add called concurrently with Wait" &&
+                       err != "sync: WaitGroup is reused before previous Wait has returned" {
+                       t.Fatalf("Unexpected panic: %#v", err)
+               }
+       }()
+       defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+       done := make(chan interface{}, 1)
+       // The detection is opportunistically, so we want it to panic
+       // at least in one run out of a million.
+       for i := 0; i < 1e6; i++ {
+               var wg WaitGroup
+               wg.Add(1)
+               go func() {
+                       wg.Done()
+               }()
+               go func() {
+                       defer func() {
+                               done <- recover()
+                       }()
+                       wg.Wait()
+                       // Start reusing the wg before waiting for the Wait below to return.
+                       wg.Add(1)
+                       go func() {
+                               wg.Done()
+                       }()
+                       wg.Wait()
+               }()
+               wg.Wait()
+               if err := <-done; err != nil {
+                       panic(err)
+               }
+       }
+       t.Fatal("Should panic")
+}
+
 func TestWaitGroupRace(t *testing.T) {
        // Run this test for about 1ms.
        for i := 0; i < 1000; i++ {
@@ -85,6 +170,19 @@ func TestWaitGroupRace(t *testing.T) {
        }
 }
 
+func TestWaitGroupAlign(t *testing.T) {
+       type X struct {
+               x  byte
+               wg WaitGroup
+       }
+       var x X
+       x.wg.Add(1)
+       go func(x *X) {
+               x.wg.Done()
+       }(&x)
+       x.wg.Wait()
+}
+
 func BenchmarkWaitGroupUncontended(b *testing.B) {
        type PaddedWaitGroup struct {
                WaitGroup
@@ -146,3 +244,17 @@ func BenchmarkWaitGroupWait(b *testing.B) {
 func BenchmarkWaitGroupWaitWork(b *testing.B) {
        benchmarkWaitGroupWait(b, 100)
 }
+
+func BenchmarkWaitGroupActuallyWait(b *testing.B) {
+       b.ReportAllocs()
+       b.RunParallel(func(pb *testing.PB) {
+               for pb.Next() {
+                       var wg WaitGroup
+                       wg.Add(1)
+                       go func() {
+                               wg.Done()
+                       }()
+                       wg.Wait()
+               }
+       })
+}