]> Cypherpunks repositories - gostls13.git/commitdiff
os/signal: avoid race between Stop and receiving on channel
authorIan Lance Taylor <iant@golang.org>
Fri, 16 Jun 2017 16:29:44 +0000 (09:29 -0700)
committerIan Lance Taylor <iant@golang.org>
Sat, 24 Jun 2017 00:54:01 +0000 (00:54 +0000)
When Stop is called on a channel, wait until all signals have been
delivered to the channel before returning.

Use atomic operations in sigqueue to communicate more reliably between
the os/signal goroutine and the signal handler.

Fixes #14571

Change-Id: I6c5a9eea1cff85e37a34dffe96f4bb2699e12c6e
Reviewed-on: https://go-review.googlesource.com/46003
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Austin Clements <austin@google.com>
src/os/signal/signal.go
src/os/signal/signal_test.go
src/runtime/sigqueue.go
src/runtime/sigqueue_plan9.go

index c1376daaea6e28ccb85dd7e518aaae878249a1c4..e5a21e8532754d7f2383ca38989e252dc1179172 100644 (file)
@@ -11,8 +11,21 @@ import (
 
 var handlers struct {
        sync.Mutex
-       m   map[chan<- os.Signal]*handler
+       // Map a channel to the signals that should be sent to it.
+       m map[chan<- os.Signal]*handler
+       // Map a signal to the number of channels receiving it.
        ref [numSig]int64
+       // Map channels to signals while the channel is being stopped.
+       // Not a map because entries live here only very briefly.
+       // We need a separate container because we need m to correspond to ref
+       // at all times, and we also need to keep track of the *handler
+       // value for a channel being stopped. See the Stop function.
+       stopping []stopping
+}
+
+type stopping struct {
+       c chan<- os.Signal
+       h *handler
 }
 
 type handler struct {
@@ -142,10 +155,10 @@ func Reset(sig ...os.Signal) {
 // When Stop returns, it is guaranteed that c will receive no more signals.
 func Stop(c chan<- os.Signal) {
        handlers.Lock()
-       defer handlers.Unlock()
 
        h := handlers.m[c]
        if h == nil {
+               handlers.Unlock()
                return
        }
        delete(handlers.m, c)
@@ -158,8 +171,40 @@ func Stop(c chan<- os.Signal) {
                        }
                }
        }
+
+       // Signals will no longer be delivered to the channel.
+       // We want to avoid a race for a signal such as SIGINT:
+       // it should be either delivered to the channel,
+       // or the program should take the default action (that is, exit).
+       // To avoid the possibility that the signal is delivered,
+       // and the signal handler invoked, and then Stop deregisters
+       // the channel before the process function below has a chance
+       // to send it on the channel, put the channel on a list of
+       // channels being stopped and wait for signal delivery to
+       // quiesce before fully removing it.
+
+       handlers.stopping = append(handlers.stopping, stopping{c, h})
+
+       handlers.Unlock()
+
+       signalWaitUntilIdle()
+
+       handlers.Lock()
+
+       for i, s := range handlers.stopping {
+               if s.c == c {
+                       handlers.stopping = append(handlers.stopping[:i], handlers.stopping[i+1:]...)
+                       break
+               }
+       }
+
+       handlers.Unlock()
 }
 
+// Wait until there are no more signals waiting to be delivered.
+// Defined by the runtime package.
+func signalWaitUntilIdle()
+
 func process(sig os.Signal) {
        n := signum(sig)
        if n < 0 {
@@ -178,4 +223,14 @@ func process(sig os.Signal) {
                        }
                }
        }
+
+       // Avoid the race mentioned in Stop.
+       for _, d := range handlers.stopping {
+               if d.h.want(n) {
+                       select {
+                       case d.c <- sig:
+                       default:
+                       }
+               }
+       }
 }
index 146dc813a439ae48dacbc8683b660fd27a197341..7866aae3c47c2debeb1c645b439ec91fa4428d42 100644 (file)
@@ -7,13 +7,16 @@
 package signal_test
 
 import (
+       "bytes"
        "flag"
+       "fmt"
        "io/ioutil"
        "os"
        "os/exec"
        . "os/signal"
        "runtime"
        "strconv"
+       "sync"
        "syscall"
        "testing"
        "time"
@@ -302,3 +305,88 @@ func TestSIGCONT(t *testing.T) {
        syscall.Kill(syscall.Getpid(), syscall.SIGCONT)
        waitSig(t, c, syscall.SIGCONT)
 }
+
+// Test race between stopping and receiving a signal (issue 14571).
+func TestAtomicStop(t *testing.T) {
+       if os.Getenv("GO_TEST_ATOMIC_STOP") != "" {
+               atomicStopTestProgram()
+               t.Fatal("atomicStopTestProgram returned")
+       }
+
+       const execs = 10
+       for i := 0; i < execs; i++ {
+               cmd := exec.Command(os.Args[0], "-test.run=TestAtomicStop")
+               cmd.Env = append(os.Environ(), "GO_TEST_ATOMIC_STOP=1")
+               out, err := cmd.CombinedOutput()
+               if err == nil {
+                       t.Logf("iteration %d: output %s", i, out)
+               } else {
+                       t.Logf("iteration %d: exit status %q: output: %s", i, err, out)
+               }
+
+               lost := bytes.Contains(out, []byte("lost signal"))
+               if lost {
+                       t.Errorf("iteration %d: lost signal", i)
+               }
+
+               // The program should either die due to SIGINT,
+               // or exit with success without printing "lost signal".
+               if err == nil {
+                       if len(out) > 0 && !lost {
+                               t.Errorf("iteration %d: unexpected output", i)
+                       }
+               } else {
+                       if ee, ok := err.(*exec.ExitError); !ok {
+                               t.Errorf("iteration %d: error (%v) has type %T; expected exec.ExitError", i, err, err)
+                       } else if ws, ok := ee.Sys().(syscall.WaitStatus); !ok {
+                               t.Errorf("iteration %d: error.Sys (%v) has type %T; expected syscall.WaitStatus", i, ee.Sys(), ee.Sys())
+                       } else if !ws.Signaled() || ws.Signal() != syscall.SIGINT {
+                               t.Errorf("iteration %d: got exit status %v; expected SIGINT", i, ee)
+                       }
+               }
+       }
+}
+
+// atomicStopTestProgram is run in a subprocess by TestAtomicStop.
+// It tries to trigger a signal delivery race. This function should
+// either catch a signal or die from it.
+func atomicStopTestProgram() {
+       const tries = 10
+       pid := syscall.Getpid()
+       printed := false
+       for i := 0; i < tries; i++ {
+               cs := make(chan os.Signal, 1)
+               Notify(cs, syscall.SIGINT)
+
+               var wg sync.WaitGroup
+               wg.Add(1)
+               go func() {
+                       defer wg.Done()
+                       Stop(cs)
+               }()
+
+               syscall.Kill(pid, syscall.SIGINT)
+
+               // At this point we should either die from SIGINT or
+               // get a notification on cs. If neither happens, we
+               // dropped the signal. Give it a second to deliver,
+               // which is far far longer than it should require.
+
+               select {
+               case <-cs:
+               case <-time.After(1 * time.Second):
+                       if !printed {
+                               fmt.Print("lost signal on iterations:")
+                               printed = true
+                       }
+                       fmt.Printf(" %d", i)
+               }
+
+               wg.Wait()
+       }
+       if printed {
+               fmt.Print("\n")
+       }
+
+       os.Exit(0)
+}
index 0162ffa04fbb6ecdc9497778e91461a6ef30e8a1..236bb29929672fabf79b2b23b2dbd849a8cd0b22 100644 (file)
@@ -33,6 +33,17 @@ import (
        _ "unsafe" // for go:linkname
 )
 
+// sig handles communication between the signal handler and os/signal.
+// Other than the inuse and recv fields, the fields are accessed atomically.
+//
+// The wanted and ignored fields are only written by one goroutine at
+// a time; access is controlled by the handlers Mutex in os/signal.
+// The fields are only read by that one goroutine and by the signal handler.
+// We access them atomically to minimize the race between setting them
+// in the goroutine calling os/signal and the signal handler,
+// which may be running in a different thread. That race is unavoidable,
+// as there is no connection between handling a signal and receiving one,
+// but atomic instructions should minimize it.
 var sig struct {
        note    note
        mask    [(_NSIG + 31) / 32]uint32
@@ -53,7 +64,11 @@ const (
 // Reports whether the signal was sent. If not, the caller typically crashes the program.
 func sigsend(s uint32) bool {
        bit := uint32(1) << uint(s&31)
-       if !sig.inuse || s >= uint32(32*len(sig.wanted)) || sig.wanted[s/32]&bit == 0 {
+       if !sig.inuse || s >= uint32(32*len(sig.wanted)) {
+               return false
+       }
+
+       if w := atomic.Load(&sig.wanted[s/32]); w&bit == 0 {
                return false
        }
 
@@ -131,6 +146,23 @@ func signal_recv() uint32 {
        }
 }
 
+// signalWaitUntilIdle waits until the signal delivery mechanism is idle.
+// This is used to ensure that we do not drop a signal notification due
+// to a race between disabling a signal and receiving a signal.
+// This assumes that signal delivery has already been disabled for
+// the signal(s) in question, and here we are just waiting to make sure
+// that all the signals have been delivered to the user channels
+// by the os/signal package.
+//go:linkname signalWaitUntilIdle os/signal.signalWaitUntilIdle
+func signalWaitUntilIdle() {
+       // Although WaitUntilIdle seems like the right name for this
+       // function, the state we are looking for is sigReceiving, not
+       // sigIdle.  The sigIdle state is really more like sigProcessing.
+       for atomic.Load(&sig.state) != sigReceiving {
+               Gosched()
+       }
+}
+
 // Must only be called from a single goroutine at a time.
 //go:linkname signal_enable os/signal.signal_enable
 func signal_enable(s uint32) {
@@ -146,8 +178,15 @@ func signal_enable(s uint32) {
        if s >= uint32(len(sig.wanted)*32) {
                return
        }
-       sig.wanted[s/32] |= 1 << (s & 31)
-       sig.ignored[s/32] &^= 1 << (s & 31)
+
+       w := sig.wanted[s/32]
+       w |= 1 << (s & 31)
+       atomic.Store(&sig.wanted[s/32], w)
+
+       i := sig.ignored[s/32]
+       i &^= 1 << (s & 31)
+       atomic.Store(&sig.ignored[s/32], i)
+
        sigenable(s)
 }
 
@@ -157,8 +196,11 @@ func signal_disable(s uint32) {
        if s >= uint32(len(sig.wanted)*32) {
                return
        }
-       sig.wanted[s/32] &^= 1 << (s & 31)
        sigdisable(s)
+
+       w := sig.wanted[s/32]
+       w &^= 1 << (s & 31)
+       atomic.Store(&sig.wanted[s/32], w)
 }
 
 // Must only be called from a single goroutine at a time.
@@ -167,12 +209,19 @@ func signal_ignore(s uint32) {
        if s >= uint32(len(sig.wanted)*32) {
                return
        }
-       sig.wanted[s/32] &^= 1 << (s & 31)
-       sig.ignored[s/32] |= 1 << (s & 31)
        sigignore(s)
+
+       w := sig.wanted[s/32]
+       w &^= 1 << (s & 31)
+       atomic.Store(&sig.wanted[s/32], w)
+
+       i := sig.ignored[s/32]
+       i |= 1 << (s & 31)
+       atomic.Store(&sig.ignored[s/32], i)
 }
 
 // Checked by signal handlers.
 func signal_ignored(s uint32) bool {
-       return sig.ignored[s/32]&(1<<(s&31)) != 0
+       i := atomic.Load(&sig.ignored[s/32])
+       return i&(1<<(s&31)) != 0
 }
index 575d26afb4e78664587fd97fc9c24f31d7938620..76668045a86c59b6e304af999205315ee1ab6fc3 100644 (file)
@@ -110,6 +110,26 @@ func signal_recv() string {
        }
 }
 
+// signalWaitUntilIdle waits until the signal delivery mechanism is idle.
+// This is used to ensure that we do not drop a signal notification due
+// to a race between disabling a signal and receiving a signal.
+// This assumes that signal delivery has already been disabled for
+// the signal(s) in question, and here we are just waiting to make sure
+// that all the signals have been delivered to the user channels
+// by the os/signal package.
+//go:linkname signalWaitUntilIdle os/signal.signalWaitUntilIdle
+func signalWaitUntilIdle() {
+       for {
+               lock(&sig.lock)
+               sleeping := sig.sleeping
+               unlock(&sig.lock)
+               if sleeping {
+                       return
+               }
+               Gosched()
+       }
+}
+
 // Must only be called from a single goroutine at a time.
 //go:linkname signal_enable os/signal.signal_enable
 func signal_enable(s uint32) {