]> Cypherpunks repositories - gostls13.git/commitdiff
os/signal: add NotifyContext to cancel context using system signals
authorHenrique Vicente <henriquevicente@gmail.com>
Mon, 17 Feb 2020 01:22:47 +0000 (02:22 +0100)
committerIan Lance Taylor <iant@golang.org>
Tue, 15 Sep 2020 23:14:33 +0000 (23:14 +0000)
Fixes #37255

Change-Id: Ic0fde3498afefed6e4447f8476e4da7c1faa7145
Reviewed-on: https://go-review.googlesource.com/c/go/+/219640
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Giovanni Bajo <rasky@develer.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/os/signal/example_unix_test.go [new file with mode: 0644]
src/os/signal/signal.go
src/os/signal/signal_test.go

diff --git a/src/os/signal/example_unix_test.go b/src/os/signal/example_unix_test.go
new file mode 100644 (file)
index 0000000..a0af37a
--- /dev/null
@@ -0,0 +1,47 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package signal_test
+
+import (
+       "context"
+       "fmt"
+       "log"
+       "os"
+       "os/signal"
+       "time"
+)
+
+// This example passes a context with a signal to tell a blocking function that
+// it should abandon its work after a signal is received.
+func ExampleNotifyContext() {
+       ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
+       defer stop()
+
+       p, err := os.FindProcess(os.Getpid())
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // On a Unix-like system, pressing Ctrl+C on a keyboard sends a
+       // SIGINT signal to the process of the program in execution.
+       //
+       // This example simulates that by sending a SIGINT signal to itself.
+       if err := p.Signal(os.Interrupt); err != nil {
+               log.Fatal(err)
+       }
+
+       select {
+       case <-time.After(time.Second):
+               fmt.Println("missed signal")
+       case <-ctx.Done():
+               fmt.Println(ctx.Err()) // prints "context canceled"
+               stop()                 // stop receiving signal notifications as soon as possible.
+       }
+
+       // Output:
+       // context canceled
+}
index 8e31aa26278b12be75bc32752c2b818ec902af69..4250a7e0de603959ee1b0037420a29d6e4c8b251 100644 (file)
@@ -5,6 +5,7 @@
 package signal
 
 import (
+       "context"
        "os"
        "sync"
 )
@@ -257,3 +258,77 @@ func process(sig os.Signal) {
                }
        }
 }
+
+// NotifyContext returns a copy of the parent context that is marked done
+// (its Done channel is closed) when one of the listed signals arrives,
+// when the returned stop function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// The stop function unregisters the signal behavior, which, like signal.Reset,
+// may restore the default behavior for a given signal. For example, the default
+// behavior of a Go program receiving os.Interrupt is to exit. Calling
+// NotifyContext(parent, os.Interrupt) will change the behavior to cancel
+// the returned context. Future interrupts received will not trigger the default
+// (exit) behavior until the returned stop function is called.
+//
+// The stop function releases resources associated with it, so code should
+// call stop as soon as the operations running in this Context complete and
+// signals no longer need to be diverted to the context.
+func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
+       ctx, cancel := context.WithCancel(parent)
+       c := &signalCtx{
+               Context: ctx,
+               cancel:  cancel,
+               signals: signals,
+       }
+       c.ch = make(chan os.Signal, 1)
+       Notify(c.ch, c.signals...)
+       if ctx.Err() == nil {
+               go func() {
+                       select {
+                       case <-c.ch:
+                               c.cancel()
+                       case <-c.Done():
+                       }
+               }()
+       }
+       return c, c.stop
+}
+
+type signalCtx struct {
+       context.Context
+
+       cancel  context.CancelFunc
+       signals []os.Signal
+       ch      chan os.Signal
+}
+
+func (c *signalCtx) stop() {
+       c.cancel()
+       Stop(c.ch)
+}
+
+type stringer interface {
+       String() string
+}
+
+func (c *signalCtx) String() string {
+       var buf []byte
+       // We know that the type of c.Context is context.cancelCtx, and we know that the
+       // String method of cancelCtx returns a string that ends with ".WithCancel".
+       name := c.Context.(stringer).String()
+       name = name[:len(name)-len(".WithCancel")]
+       buf = append(buf, "signal.NotifyContext("+name...)
+       if len(c.signals) != 0 {
+               buf = append(buf, ", ["...)
+               for i, s := range c.signals {
+                       buf = append(buf, s.String()...)
+                       if i != len(c.signals)-1 {
+                               buf = append(buf, ' ')
+                       }
+               }
+               buf = append(buf, ']')
+       }
+       buf = append(buf, ')')
+       return string(buf)
+}
index f0e06b879504ce375282d6e3d57ab276d0605b31..23e33fe82bbf3140446119bd050af7fbe14536f1 100644 (file)
@@ -8,6 +8,7 @@ package signal
 
 import (
        "bytes"
+       "context"
        "flag"
        "fmt"
        "internal/testenv"
@@ -674,3 +675,164 @@ func TestTime(t *testing.T) {
        close(stop)
        <-done
 }
+
+func TestNotifyContext(t *testing.T) {
+       c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+       defer stop()
+
+       if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+               t.Errorf("c.String() = %q, want %q", got, want)
+       }
+
+       syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+       select {
+       case <-c.Done():
+               if got := c.Err(); got != context.Canceled {
+                       t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+               }
+       case <-time.After(time.Second):
+               t.Errorf("timed out waiting for context to be done after SIGINT")
+       }
+}
+
+func TestNotifyContextStop(t *testing.T) {
+       Ignore(syscall.SIGHUP)
+       if !Ignored(syscall.SIGHUP) {
+               t.Errorf("expected SIGHUP to be ignored when explicitly ignoring it.")
+       }
+
+       parent, cancelParent := context.WithCancel(context.Background())
+       defer cancelParent()
+       c, stop := NotifyContext(parent, syscall.SIGHUP)
+       defer stop()
+
+       // If we're being notified, then the signal should not be ignored.
+       if Ignored(syscall.SIGHUP) {
+               t.Errorf("expected SIGHUP to not be ignored.")
+       }
+
+       if want, got := "signal.NotifyContext(context.Background.WithCancel, [hangup])", fmt.Sprint(c); want != got {
+               t.Errorf("c.String() = %q, wanted %q", got, want)
+       }
+
+       stop()
+       select {
+       case <-c.Done():
+               if got := c.Err(); got != context.Canceled {
+                       t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+               }
+       case <-time.After(time.Second):
+               t.Errorf("timed out waiting for context to be done after calling stop")
+       }
+}
+
+func TestNotifyContextCancelParent(t *testing.T) {
+       parent, cancelParent := context.WithCancel(context.Background())
+       defer cancelParent()
+       c, stop := NotifyContext(parent, syscall.SIGINT)
+       defer stop()
+
+       if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
+               t.Errorf("c.String() = %q, want %q", got, want)
+       }
+
+       cancelParent()
+       select {
+       case <-c.Done():
+               if got := c.Err(); got != context.Canceled {
+                       t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+               }
+       case <-time.After(time.Second):
+               t.Errorf("timed out waiting for parent context to be canceled")
+       }
+}
+
+func TestNotifyContextPrematureCancelParent(t *testing.T) {
+       parent, cancelParent := context.WithCancel(context.Background())
+       defer cancelParent()
+
+       cancelParent() // Prematurely cancel context before calling NotifyContext.
+       c, stop := NotifyContext(parent, syscall.SIGINT)
+       defer stop()
+
+       if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
+               t.Errorf("c.String() = %q, want %q", got, want)
+       }
+
+       select {
+       case <-c.Done():
+               if got := c.Err(); got != context.Canceled {
+                       t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+               }
+       case <-time.After(time.Second):
+               t.Errorf("timed out waiting for parent context to be canceled")
+       }
+}
+
+func TestNotifyContextSimultaneousNotifications(t *testing.T) {
+       c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+       defer stop()
+
+       if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+               t.Errorf("c.String() = %q, want %q", got, want)
+       }
+
+       var wg sync.WaitGroup
+       n := 10
+       wg.Add(n)
+       for i := 0; i < n; i++ {
+               go func() {
+                       syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+                       wg.Done()
+               }()
+       }
+       wg.Wait()
+       select {
+       case <-c.Done():
+               if got := c.Err(); got != context.Canceled {
+                       t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+               }
+       case <-time.After(time.Second):
+               t.Errorf("expected context to be canceled")
+       }
+}
+
+func TestNotifyContextSimultaneousStop(t *testing.T) {
+       c, stop := NotifyContext(context.Background(), syscall.SIGINT)
+       defer stop()
+
+       if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
+               t.Errorf("c.String() = %q, want %q", got, want)
+       }
+
+       var wg sync.WaitGroup
+       n := 10
+       wg.Add(n)
+       for i := 0; i < n; i++ {
+               go func() {
+                       stop()
+                       wg.Done()
+               }()
+       }
+       wg.Wait()
+       select {
+       case <-c.Done():
+               if got := c.Err(); got != context.Canceled {
+                       t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
+               }
+       case <-time.After(time.Second):
+               t.Errorf("expected context to be canceled")
+       }
+}
+
+func TestNotifyContextStringer(t *testing.T) {
+       parent, cancelParent := context.WithCancel(context.Background())
+       defer cancelParent()
+       c, stop := NotifyContext(parent, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
+       defer stop()
+
+       want := `signal.NotifyContext(context.Background.WithCancel, [hangup interrupt terminated])`
+       if got := fmt.Sprint(c); got != want {
+               t.Errorf("c.String() = %q, want %q", got, want)
+       }
+}