]> Cypherpunks repositories - gostls13.git/commitdiff
context: use fewer goroutines in WithCancel/WithTimeout
authorRuss Cox <rsc@golang.org>
Thu, 19 Sep 2019 19:33:02 +0000 (15:33 -0400)
committerRuss Cox <rsc@golang.org>
Thu, 26 Sep 2019 16:25:30 +0000 (16:25 +0000)
If the parent context passed to WithCancel or WithTimeout
is a known context implementation (one created by this package),
we attach the child to the parent by editing data structures directly;
otherwise, for unknown parent implementations, we make a
goroutine that watches for the parent to finish and propagates
the cancellation.

A common problem with this scheme, before this CL, is that
users who write custom context implementations to manage
their value sets cause WithCancel/WithTimeout to start
goroutines that would have not been started before.

This CL changes the way we map a parent context back to the
underlying data structure. Instead of walking up through
known context implementations to reach the *cancelCtx,
we look up parent.Value(&cancelCtxKey) to return the
innermost *cancelCtx, which we use if it matches parent.Done().

This way, a custom context implementation wrapping a
*cancelCtx but not changing Done-ness (and not refusing
to return wrapped keys) will not require a goroutine anymore
in WithCancel/WithTimeout.

For #28728.

Change-Id: Idba2f435c81b19fe38d0dbf308458ca87c7381e9
Reviewed-on: https://go-review.googlesource.com/c/go/+/196521
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/context/context.go
src/context/context_test.go
src/context/x_test.go
src/go/build/deps_test.go

index 390f93c078f2dae96008f2745fb39153be7c50a8..b561968f31c7648fc465f694c745271bc8ae7569 100644 (file)
@@ -51,6 +51,7 @@ import (
        "errors"
        "internal/reflectlite"
        "sync"
+       "sync/atomic"
        "time"
 )
 
@@ -239,11 +240,24 @@ func newCancelCtx(parent Context) cancelCtx {
        return cancelCtx{Context: parent}
 }
 
+// goroutines counts the number of goroutines ever created; for testing.
+var goroutines int32
+
 // propagateCancel arranges for child to be canceled when parent is.
 func propagateCancel(parent Context, child canceler) {
-       if parent.Done() == nil {
+       done := parent.Done()
+       if done == nil {
                return // parent is never canceled
        }
+
+       select {
+       case <-done:
+               // parent is already canceled
+               child.cancel(false, parent.Err())
+               return
+       default:
+       }
+
        if p, ok := parentCancelCtx(parent); ok {
                p.mu.Lock()
                if p.err != nil {
@@ -257,6 +271,7 @@ func propagateCancel(parent Context, child canceler) {
                }
                p.mu.Unlock()
        } else {
+               atomic.AddInt32(&goroutines, +1)
                go func() {
                        select {
                        case <-parent.Done():
@@ -267,22 +282,31 @@ func propagateCancel(parent Context, child canceler) {
        }
 }
 
-// parentCancelCtx follows a chain of parent references until it finds a
-// *cancelCtx. This function understands how each of the concrete types in this
-// package represents its parent.
+// &cancelCtxKey is the key that a cancelCtx returns itself for.
+var cancelCtxKey int
+
+// parentCancelCtx returns the underlying *cancelCtx for parent.
+// It does this by looking up parent.Value(&cancelCtxKey) to find
+// the innermost enclosing *cancelCtx and then checking whether
+// parent.Done() matches that *cancelCtx. (If not, the *cancelCtx
+// has been wrapped in a custom implementation providing a
+// different done channel, in which case we should not bypass it.)
 func parentCancelCtx(parent Context) (*cancelCtx, bool) {
-       for {
-               switch c := parent.(type) {
-               case *cancelCtx:
-                       return c, true
-               case *timerCtx:
-                       return &c.cancelCtx, true
-               case *valueCtx:
-                       parent = c.Context
-               default:
-                       return nil, false
-               }
+       done := parent.Done()
+       if done == closedchan || done == nil {
+               return nil, false
+       }
+       p, ok := parent.Value(&cancelCtxKey).(*cancelCtx)
+       if !ok {
+               return nil, false
        }
+       p.mu.Lock()
+       ok = p.done == done
+       p.mu.Unlock()
+       if !ok {
+               return nil, false
+       }
+       return p, true
 }
 
 // removeChild removes a context from its parent.
@@ -323,6 +347,13 @@ type cancelCtx struct {
        err      error                 // set to non-nil by the first cancel call
 }
 
+func (c *cancelCtx) Value(key interface{}) interface{} {
+       if key == &cancelCtxKey {
+               return c
+       }
+       return c.Context.Value(key)
+}
+
 func (c *cancelCtx) Done() <-chan struct{} {
        c.mu.Lock()
        if c.done == nil {
index 0e69e2f6fdefa85a0ea8bc24b78daf0845bb4b1b..869b02c92ee99741d2bc780a3e967993bab8eb07 100644 (file)
@@ -10,6 +10,7 @@ import (
        "runtime"
        "strings"
        "sync"
+       "sync/atomic"
        "time"
 )
 
@@ -21,6 +22,7 @@ type testingT interface {
        Failed() bool
        Fatal(args ...interface{})
        Fatalf(format string, args ...interface{})
+       Helper()
        Log(args ...interface{})
        Logf(format string, args ...interface{})
        Name() string
@@ -401,7 +403,7 @@ func XTestAllocs(t testingT, testingShort func() bool, testingAllocsPerRun func(
                                c, _ := WithTimeout(bg, 15*time.Millisecond)
                                <-c.Done()
                        },
-                       limit:      8,
+                       limit:      12,
                        gccgoLimit: 15,
                },
                {
@@ -648,3 +650,81 @@ func XTestDeadlineExceededSupportsTimeout(t testingT) {
                t.Fatal("wrong value for timeout")
        }
 }
+
+type myCtx struct {
+       Context
+}
+
+type myDoneCtx struct {
+       Context
+}
+
+func (d *myDoneCtx) Done() <-chan struct{} {
+       c := make(chan struct{})
+       return c
+}
+
+func XTestCustomContextGoroutines(t testingT) {
+       g := atomic.LoadInt32(&goroutines)
+       checkNoGoroutine := func() {
+               t.Helper()
+               now := atomic.LoadInt32(&goroutines)
+               if now != g {
+                       t.Fatalf("%d goroutines created", now-g)
+               }
+       }
+       checkCreatedGoroutine := func() {
+               t.Helper()
+               now := atomic.LoadInt32(&goroutines)
+               if now != g+1 {
+                       t.Fatalf("%d goroutines created, want 1", now-g)
+               }
+               g = now
+       }
+
+       _, cancel0 := WithCancel(&myDoneCtx{Background()})
+       cancel0()
+       checkCreatedGoroutine()
+
+       _, cancel0 = WithTimeout(&myDoneCtx{Background()}, 1*time.Hour)
+       cancel0()
+       checkCreatedGoroutine()
+
+       checkNoGoroutine()
+       defer checkNoGoroutine()
+
+       ctx1, cancel1 := WithCancel(Background())
+       defer cancel1()
+       checkNoGoroutine()
+
+       ctx2 := &myCtx{ctx1}
+       ctx3, cancel3 := WithCancel(ctx2)
+       defer cancel3()
+       checkNoGoroutine()
+
+       _, cancel3b := WithCancel(&myDoneCtx{ctx2})
+       defer cancel3b()
+       checkCreatedGoroutine() // ctx1 is not providing Done, must not be used
+
+       ctx4, cancel4 := WithTimeout(ctx3, 1*time.Hour)
+       defer cancel4()
+       checkNoGoroutine()
+
+       ctx5, cancel5 := WithCancel(ctx4)
+       defer cancel5()
+       checkNoGoroutine()
+
+       cancel5()
+       checkNoGoroutine()
+
+       _, cancel6 := WithTimeout(ctx5, 1*time.Hour)
+       defer cancel6()
+       checkNoGoroutine()
+
+       // Check applied to cancelled context.
+       cancel6()
+       cancel1()
+       _, cancel7 := WithCancel(ctx5)
+       defer cancel7()
+       checkNoGoroutine()
+}
index d14b6f1a32b8d52a62adc9a8b3b99720ac8830a1..e85ef2d50e5fd8cffad74f6c17f06669f3eb3935 100644 (file)
@@ -27,3 +27,4 @@ func TestCancelRemoves(t *testing.T)                   { XTestCancelRemoves(t) }
 func TestWithCancelCanceledParent(t *testing.T)        { XTestWithCancelCanceledParent(t) }
 func TestWithValueChecksKey(t *testing.T)              { XTestWithValueChecksKey(t) }
 func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) }
+func TestCustomContextGoroutines(t *testing.T)         { XTestCustomContextGoroutines(t) }
index c914d66b4dd232ce6b4fa9c29431e82e1849ae89..cbb0c59127d06470b640c8b014d47b4e53f91f33 100644 (file)
@@ -252,7 +252,7 @@ var pkgDeps = map[string][]string{
        "compress/gzip":                  {"L4", "compress/flate"},
        "compress/lzw":                   {"L4"},
        "compress/zlib":                  {"L4", "compress/flate"},
-       "context":                        {"errors", "internal/reflectlite", "sync", "time"},
+       "context":                        {"errors", "internal/reflectlite", "sync", "sync/atomic", "time"},
        "database/sql":                   {"L4", "container/list", "context", "database/sql/driver", "database/sql/internal"},
        "database/sql/driver":            {"L4", "context", "time", "database/sql/internal"},
        "debug/dwarf":                    {"L4"},