// the returned context. Future interrupts received will not trigger the default
// (exit) behavior until the returned stop function is called.
//
+// If a signal causes the returned context to be canceled, calling
+// [context.Cause] on it will return an error describing the signal.
+//
// The stop function releases resources associated with it, so code should
// call stop as soon as the operations running in this Context complete and
// signals no longer need to be diverted to the context.
func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
- ctx, cancel := context.WithCancel(parent)
+ ctx, cancel := context.WithCancelCause(parent)
c := &signalCtx{
Context: ctx,
cancel: cancel,
if ctx.Err() == nil {
go func() {
select {
- case <-c.ch:
- c.cancel()
+ case s := <-c.ch:
+ c.cancel(signalError(s.String() + " signal received"))
case <-c.Done():
}
}()
type signalCtx struct {
context.Context
- cancel context.CancelFunc
+ cancel context.CancelCauseFunc
signals []os.Signal
ch chan os.Signal
}
func (c *signalCtx) stop() {
- c.cancel()
+ c.cancel(nil)
Stop(c.ch)
}
buf = append(buf, ')')
return string(buf)
}
+
+type signalError string
+
+func (s signalError) Error() string {
+ return string(s)
+}
import (
"bytes"
"context"
+ "errors"
"flag"
"fmt"
"internal/testenv"
}
wg.Wait()
<-ctx.Done()
+ if got, want := context.Cause(ctx).Error(), "interrupt signal received"; got != want {
+ t.Errorf("context.Cause(ctx) = %q, want %q", got, want)
+ }
fmt.Println("received SIGINT")
// Sleep to give time to simultaneous signals to reach the process.
// These signals must be ignored given stop() is not called on this code.
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
+ if got := context.Cause(c); got != context.Canceled {
+ t.Errorf("context.Cause(c.Err()) = %q, want %q", got, context.Canceled)
+ }
}
func TestNotifyContextCancelParent(t *testing.T) {
- parent, cancelParent := context.WithCancel(context.Background())
- defer cancelParent()
+ parent, cancelParent := context.WithCancelCause(context.Background())
+ parentCause := errors.New("parent canceled")
+ defer cancelParent(parentCause)
c, stop := NotifyContext(parent, syscall.SIGINT)
defer stop()
t.Errorf("c.String() = %q, want %q", got, want)
}
- cancelParent()
+ cancelParent(parentCause)
<-c.Done()
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
+ if got := context.Cause(c); got != parentCause {
+ t.Errorf("context.Cause(c) = %q, want %q", got, parentCause)
+ }
}
func TestNotifyContextPrematureCancelParent(t *testing.T) {
- parent, cancelParent := context.WithCancel(context.Background())
- defer cancelParent()
+ parent, cancelParent := context.WithCancelCause(context.Background())
+ parentCause := errors.New("parent canceled")
+ defer cancelParent(parentCause)
- cancelParent() // Prematurely cancel context before calling NotifyContext.
+ cancelParent(parentCause) // Prematurely cancel context before calling NotifyContext.
c, stop := NotifyContext(parent, syscall.SIGINT)
defer stop()
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
+ if got := context.Cause(c); got != parentCause {
+ t.Errorf("context.Cause(c) = %q, want %q", got, parentCause)
+ }
}
func TestNotifyContextSimultaneousStop(t *testing.T) {