]> Cypherpunks repositories - gostls13.git/commitdiff
testing: add T.Context method
authorBrad Fitzpatrick <bradfitz@golang.org>
Sat, 22 Oct 2016 14:25:21 +0000 (07:25 -0700)
committerRuss Cox <rsc@golang.org>
Thu, 3 Nov 2016 21:14:30 +0000 (21:14 +0000)
From the doc comment:

Context returns the context for the current test or benchmark.
The context is cancelled when the test or benchmark finishes.
A goroutine started during a test or benchmark can wait for the
context's Done channel to become readable as a signal that the
test or benchmark is over, so that the goroutine can exit.

Fixes #16221.
Fixes #17552.

Change-Id: I657df946be2c90048cc74615436c77c7d9d1226c
Reviewed-on: https://go-review.googlesource.com/31724
Reviewed-by: Rob Pike <r@golang.org>
src/go/build/deps_test.go
src/testing/benchmark.go
src/testing/sub_test.go
src/testing/testing.go
src/testing/testing_test.go

index 17ddf13c9012174eb46055c6db6d36ea5b4f8ba2..d4877f7aeb5f6643164166eee88e5d0c43cd83e4 100644 (file)
@@ -178,7 +178,7 @@ var pkgDeps = map[string][]string{
        "runtime/trace":                {"L0"},
        "text/tabwriter":               {"L2"},
 
-       "testing":          {"L2", "flag", "fmt", "internal/race", "os", "runtime/debug", "runtime/pprof", "runtime/trace", "time"},
+       "testing":          {"L2", "context", "flag", "fmt", "internal/race", "os", "runtime/debug", "runtime/pprof", "runtime/trace", "time"},
        "testing/iotest":   {"L2", "log"},
        "testing/quick":    {"L2", "flag", "fmt", "reflect"},
        "internal/testenv": {"L2", "OS", "flag", "testing", "syscall"},
index c033ce5fecb1286965089adde2a4919ff2a4da84..b1c6d2eff045b8c347ff2bc832fd6f25587148b1 100644 (file)
@@ -5,6 +5,7 @@
 package testing
 
 import (
+       "context"
        "flag"
        "fmt"
        "internal/race"
@@ -127,6 +128,9 @@ func (b *B) nsPerOp() int64 {
 
 // runN runs a single benchmark for the specified number of iterations.
 func (b *B) runN(n int) {
+       b.ctx, b.cancel = context.WithCancel(b.parentContext())
+       defer b.cancel()
+
        benchmarkLock.Lock()
        defer benchmarkLock.Unlock()
        // Try to get a comparable environment for each run
index 2a24aaacfd72a031deb98832f89fcf99d1c0d099..563e8656c60cd087a9a13904c7885dae706c64b7 100644 (file)
@@ -6,6 +6,7 @@ package testing
 
 import (
        "bytes"
+       "context"
        "regexp"
        "strings"
        "sync/atomic"
@@ -277,28 +278,33 @@ func TestTRun(t *T) {
                ok:     true,
                maxPar: 4,
                f: func(t *T) {
-                       t.Parallel()
-                       for i := 0; i < 12; i++ {
-                               t.Run("a", func(t *T) {
-                                       t.Parallel()
-                                       time.Sleep(time.Nanosecond)
-                                       for i := 0; i < 12; i++ {
-                                               t.Run("b", func(t *T) {
-                                                       time.Sleep(time.Nanosecond)
-                                                       for i := 0; i < 12; i++ {
-                                                               t.Run("c", func(t *T) {
-                                                                       t.Parallel()
-                                                                       time.Sleep(time.Nanosecond)
-                                                                       t.Run("d1", func(t *T) {})
-                                                                       t.Run("d2", func(t *T) {})
-                                                                       t.Run("d3", func(t *T) {})
-                                                                       t.Run("d4", func(t *T) {})
-                                                               })
-                                                       }
-                                               })
-                                       }
-                               })
-                       }
+                       // t.Parallel doesn't work in the pseudo-T we start with:
+                       // it leaks a goroutine.
+                       // Call t.Run to get a real one.
+                       t.Run("X", func(t *T) {
+                               t.Parallel()
+                               for i := 0; i < 12; i++ {
+                                       t.Run("a", func(t *T) {
+                                               t.Parallel()
+                                               time.Sleep(time.Nanosecond)
+                                               for i := 0; i < 12; i++ {
+                                                       t.Run("b", func(t *T) {
+                                                               time.Sleep(time.Nanosecond)
+                                                               for i := 0; i < 12; i++ {
+                                                                       t.Run("c", func(t *T) {
+                                                                               t.Parallel()
+                                                                               time.Sleep(time.Nanosecond)
+                                                                               t.Run("d1", func(t *T) {})
+                                                                               t.Run("d2", func(t *T) {})
+                                                                               t.Run("d3", func(t *T) {})
+                                                                               t.Run("d4", func(t *T) {})
+                                                                       })
+                                                               }
+                                                       })
+                                               }
+                                       })
+                               }
+                       })
                },
        }, {
                desc:   "skip output",
@@ -341,6 +347,7 @@ func TestTRun(t *T) {
                        },
                        context: ctx,
                }
+               root.ctx, root.cancel = context.WithCancel(context.Background())
                ok := root.Run(tc.desc, tc.f)
                ctx.release()
 
index 31290aaec04fab82c139ec185a5e58e35872b837..01f5da31d7179fd7a38c07be641749fd17321b01 100644 (file)
@@ -204,6 +204,7 @@ package testing
 
 import (
        "bytes"
+       "context"
        "errors"
        "flag"
        "fmt"
@@ -261,12 +262,14 @@ type common struct {
        mu         sync.RWMutex // guards output, failed, and done.
        output     []byte       // Output generated by test or benchmark.
        w          io.Writer    // For flushToParent.
-       chatty     bool         // A copy of the chatty flag.
-       ran        bool         // Test or benchmark (or one of its subtests) was executed.
-       failed     bool         // Test or benchmark has failed.
-       skipped    bool         // Test of benchmark has been skipped.
-       finished   bool         // Test function has completed.
-       done       bool         // Test is finished and all subtests have completed.
+       ctx        context.Context
+       cancel     context.CancelFunc
+       chatty     bool // A copy of the chatty flag.
+       ran        bool // Test or benchmark (or one of its subtests) was executed.
+       failed     bool // Test or benchmark has failed.
+       skipped    bool // Test of benchmark has been skipped.
+       finished   bool // Test function has completed.
+       done       bool // Test is finished and all subtests have completed.
        hasSub     bool
        raceErrors int // number of races detected during test
 
@@ -280,6 +283,13 @@ type common struct {
        sub      []*T      // Queue of subtests to be run in parallel.
 }
 
+func (c *common) parentContext() context.Context {
+       if c == nil || c.parent == nil || c.parent.ctx == nil {
+               return context.Background()
+       }
+       return c.parent.ctx
+}
+
 // Short reports whether the -test.short flag is set.
 func Short() bool {
        return *short
@@ -376,6 +386,7 @@ func fmtDuration(d time.Duration) string {
 
 // TB is the interface common to T and B.
 type TB interface {
+       Context() context.Context
        Error(args ...interface{})
        Errorf(format string, args ...interface{})
        Fail()
@@ -423,6 +434,15 @@ func (c *common) Name() string {
        return c.name
 }
 
+// Context returns the context for the current test or benchmark.
+// The context is cancelled when the test or benchmark finishes.
+// A goroutine started during a test or benchmark can wait for the
+// context's Done channel to become readable as a signal that the
+// test or benchmark is over, so that the goroutine can exit.
+func (c *common) Context() context.Context {
+       return c.ctx
+}
+
 func (c *common) setRan() {
        if c.parent != nil {
                c.parent.setRan()
@@ -599,6 +619,9 @@ type InternalTest struct {
 }
 
 func tRunner(t *T, fn func(t *T)) {
+       t.ctx, t.cancel = context.WithCancel(t.parentContext())
+       defer t.cancel()
+
        // When this goroutine is done, either because fn(t)
        // returned normally or because a test failure triggered
        // a call to runtime.Goexit, record the duration and send
index 45e44683b43c017cfb4512a5d8b81e19308b9d80..9954f9af8cda59b5b338edff38e88e12948cb16c 100644 (file)
@@ -5,14 +5,42 @@
 package testing_test
 
 import (
+       "fmt"
        "os"
+       "runtime"
        "testing"
+       "time"
 )
 
-// This is exactly what a test would do without a TestMain.
-// It's here only so that there is at least one package in the
-// standard library with a TestMain, so that code is executed.
-
 func TestMain(m *testing.M) {
-       os.Exit(m.Run())
+       g0 := runtime.NumGoroutine()
+
+       code := m.Run()
+       if code != 0 {
+               os.Exit(code)
+       }
+
+       // Check that there are no goroutines left behind.
+       t0 := time.Now()
+       stacks := make([]byte, 1<<20)
+       for {
+               g1 := runtime.NumGoroutine()
+               if g1 == g0 {
+                       return
+               }
+               stacks = stacks[:runtime.Stack(stacks, true)]
+               time.Sleep(50 * time.Millisecond)
+               if time.Since(t0) > 2*time.Second {
+                       fmt.Fprintf(os.Stderr, "Unexpected leftover goroutines detected: %v -> %v\n%s\n", g0, g1, stacks)
+                       os.Exit(1)
+               }
+       }
+}
+
+func TestContextCancel(t *testing.T) {
+       ctx := t.Context()
+       // Tests we don't leak this goroutine:
+       go func() {
+               <-ctx.Done()
+       }()
 }