package signal
import (
+ "context"
"os"
"sync"
)
}
}
}
+
+// NotifyContext returns a copy of the parent context that is marked done
+// (its Done channel is closed) when one of the listed signals arrives,
+// when the returned stop function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// The stop function unregisters the signal behavior, which, like signal.Reset,
+// may restore the default behavior for a given signal. For example, the default
+// behavior of a Go program receiving os.Interrupt is to exit. Calling
+// NotifyContext(parent, os.Interrupt) will change the behavior to cancel
+// the returned context. Future interrupts received will not trigger the default
+// (exit) behavior until the returned stop function is called.
+//
+// The stop function releases resources associated with it, so code should
+// call stop as soon as the operations running in this Context complete and
+// signals no longer need to be diverted to the context.
+func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
+ ctx, cancel := context.WithCancel(parent)
+ c := &signalCtx{
+ Context: ctx,
+ cancel: cancel,
+ signals: signals,
+ }
+ c.ch = make(chan os.Signal, 1)
+ Notify(c.ch, c.signals...)
+ if ctx.Err() == nil {
+ go func() {
+ select {
+ case <-c.ch:
+ c.cancel()
+ case <-c.Done():
+ }
+ }()
+ }
+ return c, c.stop
+}
+
+type signalCtx struct {
+ context.Context
+
+ cancel context.CancelFunc
+ signals []os.Signal
+ ch chan os.Signal
+}
+
+func (c *signalCtx) stop() {
+ c.cancel()
+ Stop(c.ch)
+}
+
+type stringer interface {
+ String() string
+}
+
+func (c *signalCtx) String() string {
+ var buf []byte
+ // We know that the type of c.Context is context.cancelCtx, and we know that the
+ // String method of cancelCtx returns a string that ends with ".WithCancel".
+ name := c.Context.(stringer).String()
+ name = name[:len(name)-len(".WithCancel")]
+ buf = append(buf, "signal.NotifyContext("+name...)
+ if len(c.signals) != 0 {
+ buf = append(buf, ", ["...)
+ for i, s := range c.signals {
+ buf = append(buf, s.String()...)
+ if i != len(c.signals)-1 {
+ buf = append(buf, ' ')
+ }
+ }
+ buf = append(buf, ']')
+ }
+ buf = append(buf, ')')
+ return string(buf)
+}
import (
"bytes"
+ "context"
"flag"
"fmt"
"internal/testenv"
close(stop)
<-done
}
+
+func TestNotifyContext(t *testing.T) {
+ c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+ defer stop()
+
+ if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+ t.Errorf("c.String() = %q, want %q", got, want)
+ }
+
+ syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+ select {
+ case <-c.Done():
+ if got := c.Err(); got != context.Canceled {
+ t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+ }
+ case <-time.After(time.Second):
+ t.Errorf("timed out waiting for context to be done after SIGINT")
+ }
+}
+
+func TestNotifyContextStop(t *testing.T) {
+ Ignore(syscall.SIGHUP)
+ if !Ignored(syscall.SIGHUP) {
+ t.Errorf("expected SIGHUP to be ignored when explicitly ignoring it.")
+ }
+
+ parent, cancelParent := context.WithCancel(context.Background())
+ defer cancelParent()
+ c, stop := NotifyContext(parent, syscall.SIGHUP)
+ defer stop()
+
+ // If we're being notified, then the signal should not be ignored.
+ if Ignored(syscall.SIGHUP) {
+ t.Errorf("expected SIGHUP to not be ignored.")
+ }
+
+ if want, got := "signal.NotifyContext(context.Background.WithCancel, [hangup])", fmt.Sprint(c); want != got {
+ t.Errorf("c.String() = %q, wanted %q", got, want)
+ }
+
+ stop()
+ select {
+ case <-c.Done():
+ if got := c.Err(); got != context.Canceled {
+ t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+ }
+ case <-time.After(time.Second):
+ t.Errorf("timed out waiting for context to be done after calling stop")
+ }
+}
+
+func TestNotifyContextCancelParent(t *testing.T) {
+ parent, cancelParent := context.WithCancel(context.Background())
+ defer cancelParent()
+ c, stop := NotifyContext(parent, syscall.SIGINT)
+ defer stop()
+
+ if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
+ t.Errorf("c.String() = %q, want %q", got, want)
+ }
+
+ cancelParent()
+ select {
+ case <-c.Done():
+ if got := c.Err(); got != context.Canceled {
+ t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+ }
+ case <-time.After(time.Second):
+ t.Errorf("timed out waiting for parent context to be canceled")
+ }
+}
+
+func TestNotifyContextPrematureCancelParent(t *testing.T) {
+ parent, cancelParent := context.WithCancel(context.Background())
+ defer cancelParent()
+
+ cancelParent() // Prematurely cancel context before calling NotifyContext.
+ c, stop := NotifyContext(parent, syscall.SIGINT)
+ defer stop()
+
+ if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
+ t.Errorf("c.String() = %q, want %q", got, want)
+ }
+
+ select {
+ case <-c.Done():
+ if got := c.Err(); got != context.Canceled {
+ t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+ }
+ case <-time.After(time.Second):
+ t.Errorf("timed out waiting for parent context to be canceled")
+ }
+}
+
+func TestNotifyContextSimultaneousNotifications(t *testing.T) {
+ c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+ defer stop()
+
+ if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+ t.Errorf("c.String() = %q, want %q", got, want)
+ }
+
+ var wg sync.WaitGroup
+ n := 10
+ wg.Add(n)
+ for i := 0; i < n; i++ {
+ go func() {
+ syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ select {
+ case <-c.Done():
+ if got := c.Err(); got != context.Canceled {
+ t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+ }
+ case <-time.After(time.Second):
+ t.Errorf("expected context to be canceled")
+ }
+}
+
+func TestNotifyContextSimultaneousStop(t *testing.T) {
+ c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+ defer stop()
+
+ if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+ t.Errorf("c.String() = %q, want %q", got, want)
+ }
+
+ var wg sync.WaitGroup
+ n := 10
+ wg.Add(n)
+ for i := 0; i < n; i++ {
+ go func() {
+ stop()
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ select {
+ case <-c.Done():
+ if got := c.Err(); got != context.Canceled {
+ t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+ }
+ case <-time.After(time.Second):
+ t.Errorf("expected context to be canceled")
+ }
+}
+
+func TestNotifyContextStringer(t *testing.T) {
+ parent, cancelParent := context.WithCancel(context.Background())
+ defer cancelParent()
+ c, stop := NotifyContext(parent, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
+ defer stop()
+
+ want := `signal.NotifyContext(context.Background.WithCancel, [hangup interrupt terminated])`
+ if got := fmt.Sprint(c); got != want {
+ t.Errorf("c.String() = %q, want %q", got, want)
+ }
+}