From: Damien Neil Date: Wed, 5 Apr 2023 23:06:36 +0000 (-0700) Subject: context: add AfterFunc X-Git-Tag: go1.21rc1~843 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=54d429999c00f277c80b5698fb78c7b71f91aad9;p=gostls13.git context: add AfterFunc Add an AfterFunc function, which registers a function to run after a context has been canceled. Add support for contexts that implement an AfterFunc method, which can be used to avoid the need to start a new goroutine watching the Done channel when propagating cancellation signals. Fixes #57928 Change-Id: If0b2cdcc4332961276a1ff57311338e74916259c Reviewed-on: https://go-review.googlesource.com/c/go/+/482695 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Sameer Ajmani --- diff --git a/api/next/57928.txt b/api/next/57928.txt new file mode 100644 index 0000000000..5b85e74cdb --- /dev/null +++ b/api/next/57928.txt @@ -0,0 +1 @@ +pkg context, func AfterFunc(Context, func()) func() bool #57928 diff --git a/src/context/afterfunc_test.go b/src/context/afterfunc_test.go new file mode 100644 index 0000000000..71f639a345 --- /dev/null +++ b/src/context/afterfunc_test.go @@ -0,0 +1,141 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "context" + "sync" + "testing" + "time" +) + +// afterFuncContext is a context that's not one of the types +// defined in context.go, that supports registering AfterFuncs. +type afterFuncContext struct { + mu sync.Mutex + afterFuncs map[*struct{}]func() + done chan struct{} + err error +} + +func newAfterFuncContext() context.Context { + return &afterFuncContext{} +} + +func (c *afterFuncContext) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +func (c *afterFuncContext) Done() <-chan struct{} { + c.mu.Lock() + defer c.mu.Unlock() + if c.done == nil { + c.done = make(chan struct{}) + } + return c.done +} + +func (c *afterFuncContext) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *afterFuncContext) Value(key any) any { + return nil +} + +func (c *afterFuncContext) AfterFunc(f func()) func() bool { + c.mu.Lock() + defer c.mu.Unlock() + k := &struct{}{} + if c.afterFuncs == nil { + c.afterFuncs = make(map[*struct{}]func()) + } + c.afterFuncs[k] = f + return func() bool { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.afterFuncs[k] + delete(c.afterFuncs, k) + return ok + } +} + +func (c *afterFuncContext) cancel(err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.err != nil { + return + } + c.err = err + for _, f := range c.afterFuncs { + go f() + } + c.afterFuncs = nil +} + +func TestCustomContextAfterFuncCancel(t *testing.T) { + ctx0 := &afterFuncContext{} + ctx1, cancel := context.WithCancel(ctx0) + defer cancel() + ctx0.cancel(context.Canceled) + <-ctx1.Done() +} + +func TestCustomContextAfterFuncTimeout(t *testing.T) { + ctx0 := &afterFuncContext{} + ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration) + defer cancel() + ctx0.cancel(context.Canceled) + <-ctx1.Done() +} + +func TestCustomContextAfterFuncAfterFunc(t *testing.T) { + ctx0 := &afterFuncContext{} + donec := make(chan struct{}) + stop := context.AfterFunc(ctx0, func() { + close(donec) + }) + defer stop() + ctx0.cancel(context.Canceled) + <-donec +} + +func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) { + ctx0 := &afterFuncContext{} + _, cancel := context.WithCancel(ctx0) + if got, want := len(ctx0.afterFuncs), 1; got != want { + t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) + } + cancel() + if got, want := len(ctx0.afterFuncs), 0; got != want { + t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) + } +} + +func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) { + ctx0 := &afterFuncContext{} + _, cancel := context.WithTimeout(ctx0, veryLongDuration) + if got, want := len(ctx0.afterFuncs), 1; got != want { + t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want) + } + cancel() + if got, want := len(ctx0.afterFuncs), 0; got != want { + t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want) + } +} + +func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) { + ctx0 := &afterFuncContext{} + stop := context.AfterFunc(ctx0, func() {}) + if got, want := len(ctx0.afterFuncs), 1; got != want { + t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want) + } + stop() + if got, want := len(ctx0.afterFuncs), 0; got != want { + t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want) + } +} diff --git a/src/context/context.go b/src/context/context.go index 7227099889..4c0ba7c1d7 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -269,8 +269,8 @@ func withCancel(parent Context) *cancelCtx { if parent == nil { panic("cannot create context from nil parent") } - c := &cancelCtx{Context: parent} - propagateCancel(parent, c) + c := &cancelCtx{} + c.propagateCancel(parent, c) return c } @@ -289,48 +289,72 @@ func Cause(c Context) error { return nil } -// goroutines counts the number of goroutines ever created; for testing. -var goroutines atomic.Int32 - -// propagateCancel arranges for child to be canceled when parent is. -func propagateCancel(parent Context, child canceler) { - done := parent.Done() - if done == nil { - return // parent is never canceled +// AfterFunc arranges to call f in its own goroutine after ctx is done +// (cancelled or timed out). +// If ctx is already done, AfterFunc calls f immediately in its own goroutine. +// +// Multiple calls to AfterFunc on a context operate independently; +// one does not replace another. +// +// Calling the returned stop function stops the association of ctx with f. +// It returns true if the call stopped f from being run. +// If stop returns false, +// either the context is done and f has been started in its own goroutine; +// or f was already stopped. +// The stop function does not wait for f to complete before returning. +// If the caller needs to know whether f is completed, +// it must coordinate with f explicitly. +// +// If ctx has a "AfterFunc(func()) func() bool" method, +// AfterFunc will use it to schedule the call. +func AfterFunc(ctx Context, f func()) (stop func() bool) { + a := &afterFuncCtx{ + f: f, + } + a.cancelCtx.propagateCancel(ctx, a) + return func() bool { + stopped := false + a.once.Do(func() { + stopped = true + }) + if stopped { + a.cancel(true, Canceled, nil) + } + return stopped } +} - select { - case <-done: - // parent is already canceled - child.cancel(false, parent.Err(), Cause(parent)) - return - default: - } +type afterFuncer interface { + AfterFunc(func()) func() bool +} - if p, ok := parentCancelCtx(parent); ok { - p.mu.Lock() - if p.err != nil { - // parent has already been canceled - child.cancel(false, p.err, p.cause) - } else { - if p.children == nil { - p.children = make(map[canceler]struct{}) - } - p.children[child] = struct{}{} - } - p.mu.Unlock() - } else { - goroutines.Add(1) - go func() { - select { - case <-parent.Done(): - child.cancel(false, parent.Err(), Cause(parent)) - case <-child.Done(): - } - }() +type afterFuncCtx struct { + cancelCtx + once sync.Once // either starts running f or stops f from running + f func() +} + +func (a *afterFuncCtx) cancel(removeFromParent bool, err, cause error) { + a.cancelCtx.cancel(false, err, cause) + if removeFromParent { + removeChild(a.Context, a) } + a.once.Do(func() { + go a.f() + }) } +// A stopCtx is used as the parent context of a cancelCtx when +// an AfterFunc has been registered with the parent. +// It holds the stop function used to unregister the AfterFunc. +type stopCtx struct { + Context + stop func() bool +} + +// goroutines counts the number of goroutines ever created; for testing. +var goroutines atomic.Int32 + // &cancelCtxKey is the key that a cancelCtx returns itself for. var cancelCtxKey int @@ -358,6 +382,10 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) { // removeChild removes a context from its parent. func removeChild(parent Context, child canceler) { + if s, ok := parent.(stopCtx); ok { + s.stop() + return + } p, ok := parentCancelCtx(parent) if !ok { return @@ -424,6 +452,64 @@ func (c *cancelCtx) Err() error { return err } +// propagateCancel arranges for child to be canceled when parent is. +// It sets the parent context of cancelCtx. +func (c *cancelCtx) propagateCancel(parent Context, child canceler) { + c.Context = parent + + done := parent.Done() + if done == nil { + return // parent is never canceled + } + + select { + case <-done: + // parent is already canceled + child.cancel(false, parent.Err(), Cause(parent)) + return + default: + } + + if p, ok := parentCancelCtx(parent); ok { + // parent is a *cancelCtx, or derives from one. + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err, p.cause) + } else { + if p.children == nil { + p.children = make(map[canceler]struct{}) + } + p.children[child] = struct{}{} + } + p.mu.Unlock() + return + } + + if a, ok := parent.(afterFuncer); ok { + // parent implements an AfterFunc method. + c.mu.Lock() + stop := a.AfterFunc(func() { + child.cancel(false, parent.Err(), Cause(parent)) + }) + c.Context = stopCtx{ + Context: parent, + stop: stop, + } + c.mu.Unlock() + return + } + + goroutines.Add(1) + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err(), Cause(parent)) + case <-child.Done(): + } + }() +} + type stringer interface { String() string } @@ -533,10 +619,9 @@ func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, Cance return WithCancel(parent) } c := &timerCtx{ - cancelCtx: cancelCtx{Context: parent}, - deadline: d, + deadline: d, } - propagateCancel(parent, c) + c.cancelCtx.propagateCancel(parent, c) dur := time.Until(d) if dur <= 0 { c.cancel(true, DeadlineExceeded, cause) // deadline has already passed diff --git a/src/context/context_test.go b/src/context/context_test.go index 74738fd316..57066c9685 100644 --- a/src/context/context_test.go +++ b/src/context/context_test.go @@ -44,12 +44,15 @@ func XTestParentFinishesChild(t testingT) { // Context tree: // parent -> cancelChild // parent -> valueChild -> timerChild + // parent -> afterChild parent, cancel := WithCancel(Background()) cancelChild, stop := WithCancel(parent) defer stop() valueChild := WithValue(parent, "key", "value") timerChild, stop := WithTimeout(valueChild, veryLongDuration) defer stop() + afterStop := AfterFunc(parent, func() {}) + defer afterStop() select { case x := <-parent.Done(): @@ -63,13 +66,20 @@ func XTestParentFinishesChild(t testingT) { default: } - // The parent's children should contain the two cancelable children. + // The parent's children should contain the three cancelable children. pc := parent.(*cancelCtx) cc := cancelChild.(*cancelCtx) tc := timerChild.(*timerCtx) pc.mu.Lock() - if len(pc.children) != 2 || !contains(pc.children, cc) || !contains(pc.children, tc) { - t.Errorf("bad linkage: pc.children = %v, want %v and %v", + var ac *afterFuncCtx + for c := range pc.children { + if a, ok := c.(*afterFuncCtx); ok { + ac = a + break + } + } + if len(pc.children) != 3 || !contains(pc.children, cc) || !contains(pc.children, tc) || ac == nil { + t.Errorf("bad linkage: pc.children = %v, want %v, %v, and an afterFunc", pc.children, cc, tc) } pc.mu.Unlock() @@ -80,6 +90,9 @@ func XTestParentFinishesChild(t testingT) { if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) } + if p, ok := parentCancelCtx(ac.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(afterChild.Context) = %v, %v want %v, true", p, ok, pc) + } cancel() @@ -197,6 +210,13 @@ func XTestCancelRemoves(t testingT) { checkChildren("with WithTimeout child ", ctx, 1) cancel() checkChildren("after canceling WithTimeout child", ctx, 0) + + ctx, _ = WithCancel(Background()) + checkChildren("after creation", ctx, 0) + stop := AfterFunc(ctx, func() {}) + checkChildren("with AfterFunc child ", ctx, 1) + stop() + checkChildren("after stopping AfterFunc child ", ctx, 0) } type myCtx struct { diff --git a/src/context/example_test.go b/src/context/example_test.go index 8dad0a4220..38549a12de 100644 --- a/src/context/example_test.go +++ b/src/context/example_test.go @@ -6,7 +6,10 @@ package context_test import ( "context" + "errors" "fmt" + "net" + "sync" "time" ) @@ -118,3 +121,104 @@ func ExampleWithValue() { // found value: Go // key not found: color } + +// This example uses AfterFunc to define a function which waits on a sync.Cond, +// stopping the wait when a context is canceled. +func ExampleAfterFunc_cond() { + waitOnCond := func(ctx context.Context, cond *sync.Cond) error { + stopf := context.AfterFunc(ctx, cond.Broadcast) + defer stopf() + cond.Wait() + return ctx.Err() + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + var mu sync.Mutex + cond := sync.NewCond(&mu) + + mu.Lock() + err := waitOnCond(ctx, cond) + fmt.Println(err) + + // Output: + // context deadline exceeded +} + +// This example uses AfterFunc to define a function which reads from a net.Conn, +// stopping the read when a context is canceled. +func ExampleAfterFunc_connection() { + readFromConn := func(ctx context.Context, conn net.Conn, b []byte) (n int, err error) { + stopc := make(chan struct{}) + stop := context.AfterFunc(ctx, func() { + conn.SetReadDeadline(time.Now()) + close(stopc) + }) + n, err = conn.Read(b) + if !stop() { + // The AfterFunc was started. + // Wait for it to complete, and reset the Conn's deadline. + <-stopc + conn.SetReadDeadline(time.Time{}) + return n, ctx.Err() + } + return n, err + } + + listener, err := net.Listen("tcp", ":0") + if err != nil { + fmt.Println(err) + return + } + defer listener.Close() + + conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String()) + if err != nil { + fmt.Println(err) + return + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + b := make([]byte, 1024) + _, err = readFromConn(ctx, conn, b) + fmt.Println(err) + + // Output: + // context deadline exceeded +} + +// This example uses AfterFunc to define a function which combines +// the cancellation signals of two Contexts. +func ExampleAfterFunc_merge() { + // mergeCancel returns a context that contains the values of ctx, + // and which is canceled when either ctx or cancelCtx is canceled. + mergeCancel := func(ctx, cancelCtx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancelCause(ctx) + stop := context.AfterFunc(cancelCtx, func() { + cancel(context.Cause(cancelCtx)) + }) + return ctx, func() { + stop() + cancel(context.Canceled) + } + } + + ctx1, cancel1 := context.WithCancelCause(context.Background()) + defer cancel1(errors.New("ctx1 canceled")) + + ctx2, cancel2 := context.WithCancelCause(context.Background()) + + mergedCtx, mergedCancel := mergeCancel(ctx1, ctx2) + defer mergedCancel() + + cancel2(errors.New("ctx2 canceled")) + <-mergedCtx.Done() + fmt.Println(context.Cause(mergedCtx)) + + // Output: + // ctx2 canceled +} diff --git a/src/context/x_test.go b/src/context/x_test.go index 957590d01b..bf0af674c1 100644 --- a/src/context/x_test.go +++ b/src/context/x_test.go @@ -502,6 +502,24 @@ func TestWithCancelCanceledParent(t *testing.T) { } } +func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) { + // Cancel the parent goroutine concurrently with creating a child. + for i := 0; i < 100; i++ { + parent, pcancel := WithCancelCause(Background()) + cause := fmt.Errorf("Because!") + go pcancel(cause) + + c, _ := WithCancel(parent) + <-c.Done() + if got, want := c.Err(), Canceled; got != want { + t.Errorf("child not canceled; got = %v, want = %v", got, want) + } + if got, want := Cause(c), cause; got != want { + t.Errorf("child has wrong cause; got = %v, want = %v", got, want) + } + } +} + func TestWithValueChecksKey(t *testing.T) { panicVal := recoveredValue(func() { WithValue(Background(), []byte("foo"), "bar") }) if panicVal == nil { @@ -816,3 +834,123 @@ func TestWithoutCancel(t *testing.T) { t.Errorf("ctx.Value(%q) = %q want %q", key, v, value) } } + +type customDoneContext struct { + Context + donec chan struct{} +} + +func (c *customDoneContext) Done() <-chan struct{} { + return c.donec +} + +func TestCustomContextPropagation(t *testing.T) { + cause := errors.New("TestCustomContextPropagation") + donec := make(chan struct{}) + ctx1, cancel1 := WithCancelCause(Background()) + ctx2 := &customDoneContext{ + Context: ctx1, + donec: donec, + } + ctx3, cancel3 := WithCancel(ctx2) + defer cancel3() + + cancel1(cause) + close(donec) + + <-ctx3.Done() + if got, want := ctx3.Err(), Canceled; got != want { + t.Errorf("child not canceled; got = %v, want = %v", got, want) + } + if got, want := Cause(ctx3), cause; got != want { + t.Errorf("child has wrong cause; got = %v, want = %v", got, want) + } +} + +func TestAfterFuncCalledAfterCancel(t *testing.T) { + ctx, cancel := WithCancel(Background()) + donec := make(chan struct{}) + stop := AfterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + t.Fatalf("AfterFunc called before context is done") + case <-time.After(shortDuration): + } + cancel() + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } + if stop() { + t.Fatalf("stop() = true, want false") + } +} + +func TestAfterFuncCalledAfterTimeout(t *testing.T) { + ctx, cancel := WithTimeout(Background(), shortDuration) + defer cancel() + donec := make(chan struct{}) + AfterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } +} + +func TestAfterFuncCalledImmediately(t *testing.T) { + ctx, cancel := WithCancel(Background()) + cancel() + donec := make(chan struct{}) + AfterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called for already-canceled context") + } +} + +func TestAfterFuncNotCalledAfterStop(t *testing.T) { + ctx, cancel := WithCancel(Background()) + donec := make(chan struct{}) + stop := AfterFunc(ctx, func() { + close(donec) + }) + if !stop() { + t.Fatalf("stop() = false, want true") + } + cancel() + select { + case <-donec: + t.Fatalf("AfterFunc called for already-canceled context") + case <-time.After(shortDuration): + } + if stop() { + t.Fatalf("stop() = true, want false") + } +} + +// This test verifies that cancelling a context does not block waiting for AfterFuncs to finish. +func TestAfterFuncCalledAsynchronously(t *testing.T) { + ctx, cancel := WithCancel(Background()) + donec := make(chan struct{}) + stop := AfterFunc(ctx, func() { + // The channel send blocks until donec is read from. + donec <- struct{}{} + }) + defer stop() + cancel() + // After cancel returns, read from donec and unblock the AfterFunc. + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } +}