]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't cancel Dials when requests are canceled
authorDamien Neil <dneil@google.com>
Thu, 4 Apr 2024 18:01:28 +0000 (11:01 -0700)
committerDamien Neil <dneil@google.com>
Wed, 17 Apr 2024 21:11:57 +0000 (21:11 +0000)
Currently, when a Transport creates a new connection for a request,
it uses the request's Context to make the Dial. If a request
times out or is canceled before a Dial completes, the Dial is
canceled.

Change this so that the lifetime of a Dial call is not bound
by the request that originated it.

This change avoids a scenario where a Transport can start and
then cancel many Dial calls in rapid succession:

  - Request starts a Dial.
  - A previous request completes, making its connection available.
  - The new request uses the now-idle connection, and completes.
  - The request Context is canceled, and the Dial is aborted.

Fixes #59017

Change-Id: I996ffabc56d3b1b43129cbfd9b3e9ea7d53d263c
Reviewed-on: https://go-review.googlesource.com/c/go/+/576555
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
src/net/http/client_test.go
src/net/http/export_test.go
src/net/http/transport.go
src/net/http/transport_dial_test.go [new file with mode: 0644]
src/net/http/transport_test.go

index 569b58ca6225c9de7d13ac3d0d1bfb0664828a93..33e69467c6a3f467e8c33231d0bdc4600b807b56 100644 (file)
@@ -1938,21 +1938,25 @@ func TestClientCloseIdleConnections(t *testing.T) {
        }
 }
 
+type testRoundTripper func(*Request) (*Response, error)
+
+func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) {
+       return t(req)
+}
+
 func TestClientPropagatesTimeoutToContext(t *testing.T) {
-       errDial := errors.New("not actually dialing")
        c := &Client{
                Timeout: 5 * time.Second,
-               Transport: &Transport{
-                       DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
-                               deadline, ok := ctx.Deadline()
-                               if !ok {
-                                       t.Error("no deadline")
-                               } else {
-                                       t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
-                               }
-                               return nil, errDial
-                       },
-               },
+               Transport: testRoundTripper(func(req *Request) (*Response, error) {
+                       ctx := req.Context()
+                       deadline, ok := ctx.Deadline()
+                       if !ok {
+                               t.Error("no deadline")
+                       } else {
+                               t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
+                       }
+                       return nil, errors.New("not actually making a request")
+               }),
        }
        c.Get("https://example.tld/")
 }
index 8a6f4f192fb2df4b5d324a0c886f26820a90a690..56ebda180bb08516602a22466f3945cf37c807ec 100644 (file)
@@ -86,6 +86,14 @@ func SetPendingDialHooks(before, after func()) {
 
 func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
 
+func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
+       orig := testHookProxyConnectTimeout
+       t.Cleanup(func() {
+               testHookProxyConnectTimeout = orig
+       })
+       testHookProxyConnectTimeout = f
+}
+
 func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
        return &timeoutHandler{
                handler:     handler,
index d97298ecd954dd83d1b17d9c923a5c2249fbfd7b..e6a97a00c63392c2c58f73c56aa7bfbf0db06a31 100644 (file)
@@ -108,6 +108,7 @@ type Transport struct {
        connsPerHostMu   sync.Mutex
        connsPerHost     map[connectMethodKey]int
        connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
+       dialsInProgress  wantConnQueue
 
        // Proxy specifies a function to return a proxy for a given
        // Request. If the function returns a non-nil error, the
@@ -807,6 +808,13 @@ func (t *Transport) CloseIdleConnections() {
                        pconn.close(errCloseIdleConns)
                }
        }
+       t.connsPerHostMu.Lock()
+       t.dialsInProgress.all(func(w *wantConn) {
+               if w.cancelCtx != nil && !w.waiting() {
+                       w.cancelCtx()
+               }
+       })
+       t.connsPerHostMu.Unlock()
        if t2 := t.h2transport; t2 != nil {
                t2.CloseIdleConnections()
        }
@@ -1116,7 +1124,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
                t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
        }
        q := t.idleConnWait[w.key]
-       q.cleanFront()
+       q.cleanFrontNotWaiting()
        q.pushBack(w)
        t.idleConnWait[w.key] = q
        return false
@@ -1230,10 +1238,11 @@ type wantConn struct {
        beforeDial func()
        afterDial  func()
 
-       mu     sync.Mutex       // protects ctx, done and sending of the result
-       ctx    context.Context  // context for dial, cleared after delivered or canceled
-       done   bool             // true after delivered or canceled
-       result chan connOrError // channel to deliver connection or error
+       mu        sync.Mutex      // protects ctx, done and sending of the result
+       ctx       context.Context // context for dial, cleared after delivered or canceled
+       cancelCtx context.CancelFunc
+       done      bool             // true after delivered or canceled
+       result    chan connOrError // channel to deliver connection or error
 }
 
 type connOrError struct {
@@ -1352,9 +1361,9 @@ func (q *wantConnQueue) peekFront() *wantConn {
        return nil
 }
 
-// cleanFront pops any wantConns that are no longer waiting from the head of the
+// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the
 // queue, reporting whether any were popped.
-func (q *wantConnQueue) cleanFront() (cleaned bool) {
+func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) {
        for {
                w := q.peekFront()
                if w == nil || w.waiting() {
@@ -1365,6 +1374,28 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
        }
 }
 
+// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue.
+func (q *wantConnQueue) cleanFrontCanceled() {
+       for {
+               w := q.peekFront()
+               if w == nil || w.cancelCtx != nil {
+                       return
+               }
+               q.popFront()
+       }
+}
+
+// all iterates over all wantConns in the queue.
+// The caller must not modify the queue while iterating.
+func (q *wantConnQueue) all(f func(*wantConn)) {
+       for _, w := range q.head[q.headPos:] {
+               f(w)
+       }
+       for _, w := range q.tail {
+               f(w)
+       }
+}
+
 func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
        if t.DialTLSContext != nil {
                conn, err = t.DialTLSContext(ctx, network, addr)
@@ -1389,10 +1420,18 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
                trace.GetConn(cm.addr())
        }
 
+       // Detach from the request context's cancellation signal.
+       // The dial should proceed even if the request is canceled,
+       // because a future request may be able to make use of the connection.
+       //
+       // We retain the request context's values.
+       dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx))
+
        w := &wantConn{
                cm:         cm,
                key:        cm.key(),
-               ctx:        ctx,
+               ctx:        dialCtx,
+               cancelCtx:  dialCancel,
                result:     make(chan connOrError, 1),
                beforeDial: testHookPrePendingDial,
                afterDial:  testHookPostPendingDial,
@@ -1470,20 +1509,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
 // Once w receives permission to dial, it will do so in a separate goroutine.
 func (t *Transport) queueForDial(w *wantConn) {
        w.beforeDial()
-       if t.MaxConnsPerHost <= 0 {
-               go t.dialConnFor(w)
-               return
-       }
 
        t.connsPerHostMu.Lock()
        defer t.connsPerHostMu.Unlock()
 
+       if t.MaxConnsPerHost <= 0 {
+               t.startDialConnForLocked(w)
+               return
+       }
+
        if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
                if t.connsPerHost == nil {
                        t.connsPerHost = make(map[connectMethodKey]int)
                }
                t.connsPerHost[w.key] = n + 1
-               go t.dialConnFor(w)
+               t.startDialConnForLocked(w)
                return
        }
 
@@ -1491,11 +1531,24 @@ func (t *Transport) queueForDial(w *wantConn) {
                t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
        }
        q := t.connsPerHostWait[w.key]
-       q.cleanFront()
+       q.cleanFrontNotWaiting()
        q.pushBack(w)
        t.connsPerHostWait[w.key] = q
 }
 
+// startDialConnFor calls dialConn in a new goroutine.
+// t.connsPerHostMu must be held.
+func (t *Transport) startDialConnForLocked(w *wantConn) {
+       t.dialsInProgress.cleanFrontCanceled()
+       t.dialsInProgress.pushBack(w)
+       go func() {
+               t.dialConnFor(w)
+               t.connsPerHostMu.Lock()
+               defer t.connsPerHostMu.Unlock()
+               w.cancelCtx = nil
+       }()
+}
+
 // dialConnFor dials on behalf of w and delivers the result to w.
 // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
 // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
@@ -1545,7 +1598,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) {
                for q.len() > 0 {
                        w := q.popFront()
                        if w.waiting() {
-                               go t.dialConnFor(w)
+                               t.startDialConnForLocked(w)
                                done = true
                                break
                        }
@@ -1626,6 +1679,8 @@ type erringRoundTripper interface {
        RoundTripErr() error
 }
 
+var testHookProxyConnectTimeout = context.WithTimeout
+
 func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
        pconn = &persistConn{
                t:             t,
@@ -1742,17 +1797,11 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                        Header: hdr,
                }
 
-               // If there's no done channel (no deadline or cancellation
-               // from the caller possible), at least set some (long)
-               // timeout here. This will make sure we don't block forever
-               // and leak a goroutine if the connection stops replying
-               // after the TCP connect.
-               connectCtx := ctx
-               if ctx.Done() == nil {
-                       newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
-                       defer cancel()
-                       connectCtx = newCtx
-               }
+               // Set a (long) timeout here to make sure we don't block forever
+               // and leak a goroutine if the connection stops replying after
+               // the TCP connect.
+               connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute)
+               defer cancel()
 
                didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
                var (
diff --git a/src/net/http/transport_dial_test.go b/src/net/http/transport_dial_test.go
new file mode 100644 (file)
index 0000000..39e35ce
--- /dev/null
@@ -0,0 +1,235 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "context"
+       "io"
+       "net"
+       "net/http"
+       "net/http/httptrace"
+       "testing"
+)
+
+func TestTransportPoolConnReusePriorConnection(t *testing.T) {
+       dt := newTransportDialTester(t, http1Mode)
+
+       // First request creates a new connection.
+       rt1 := dt.roundTrip()
+       c1 := dt.wantDial()
+       c1.finish(nil)
+       rt1.wantDone(c1)
+       rt1.finish()
+
+       // Second request reuses the first connection.
+       rt2 := dt.roundTrip()
+       rt2.wantDone(c1)
+       rt2.finish()
+}
+
+func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
+       dt := newTransportDialTester(t, http1Mode)
+
+       // First request creates a new connection.
+       rt1 := dt.roundTrip()
+       c1 := dt.wantDial()
+       c1.finish(nil)
+       rt1.wantDone(c1)
+
+       // Second request is made while the first request is still using its connection,
+       // so it goes on a new connection.
+       rt2 := dt.roundTrip()
+       c2 := dt.wantDial()
+       c2.finish(nil)
+       rt2.wantDone(c2)
+}
+
+func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
+       dt := newTransportDialTester(t, http1Mode)
+
+       // First request creates a new connection.
+       rt1 := dt.roundTrip()
+       c1 := dt.wantDial()
+       c1.finish(nil)
+       rt1.wantDone(c1)
+
+       // Second request is made while the first request is still using its connection.
+       // The first connection completes while the second Dial is in progress, so the
+       // second request uses the first connection.
+       rt2 := dt.roundTrip()
+       c2 := dt.wantDial()
+       rt1.finish()
+       rt2.wantDone(c1)
+
+       // This section is a bit overfitted to the current Transport implementation:
+       // A third request starts. We have an in-progress dial that was started by rt2,
+       // but this new request (rt3) is going to ignore it and make a dial of its own.
+       // rt3 will use the first of these dials that completes.
+       rt3 := dt.roundTrip()
+       c3 := dt.wantDial()
+       c2.finish(nil)
+       rt3.wantDone(c2)
+
+       c3.finish(nil)
+}
+
+// A transportDialTester manages a test of a connection's Dials.
+type transportDialTester struct {
+       t   *testing.T
+       cst *clientServerTest
+
+       dials chan *transportDialTesterConn // each new conn is sent to this channel
+
+       roundTripCount int
+       dialCount      int
+}
+
+// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
+type transportDialTesterRoundTrip struct {
+       t *testing.T
+
+       roundTripID int                // distinguishes RoundTrips in logs
+       cancel      context.CancelFunc // cancels the Request context
+       reqBody     io.WriteCloser     // write half of the Request.Body
+       finished    bool
+
+       done chan struct{} // closed when RoundTrip returns:w
+       res  *http.Response
+       err  error
+       conn *transportDialTesterConn
+}
+
+// A transportDialTesterConn is a client connection created by the Transport as
+// part of a dial test.
+type transportDialTesterConn struct {
+       t *testing.T
+
+       connID int        // distinguished Dials in logs
+       ready  chan error // sent on to complete the Dial
+
+       net.Conn
+}
+
+func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
+       t.Helper()
+       dt := &transportDialTester{
+               t:     t,
+               dials: make(chan *transportDialTesterConn),
+       }
+       dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               // Write response headers when we receive a request.
+               http.NewResponseController(w).EnableFullDuplex()
+               w.WriteHeader(200)
+               http.NewResponseController(w).Flush()
+               // Wait for the client to send the request body,
+               // to synchronize with the rest of the test.
+               io.ReadAll(r.Body)
+       }), func(tr *http.Transport) {
+               tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
+                       c := &transportDialTesterConn{
+                               t:     t,
+                               ready: make(chan error),
+                       }
+                       // Notify the test that a Dial has started,
+                       // and wait for the test to notify us that it should complete.
+                       dt.dials <- c
+                       if err := <-c.ready; err != nil {
+                               return nil, err
+                       }
+                       nc, err := net.Dial(network, address)
+                       if err != nil {
+                               return nil, err
+                       }
+                       // Use the *transportDialTesterConn as the net.Conn,
+                       // to let tests associate requests with connections.
+                       c.Conn = nc
+                       return c, err
+               }
+       })
+       return dt
+}
+
+// roundTrip starts a RoundTrip.
+// It returns immediately, without waiting for the RoundTrip call to complete.
+func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
+       dt.t.Helper()
+       ctx, cancel := context.WithCancel(context.Background())
+       pr, pw := io.Pipe()
+       rt := &transportDialTesterRoundTrip{
+               t:           dt.t,
+               roundTripID: dt.roundTripCount,
+               done:        make(chan struct{}),
+               reqBody:     pw,
+               cancel:      cancel,
+       }
+       dt.roundTripCount++
+       dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
+       dt.t.Cleanup(func() {
+               rt.cancel()
+               rt.finish()
+       })
+       go func() {
+               ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+                       GotConn: func(info httptrace.GotConnInfo) {
+                               rt.conn = info.Conn.(*transportDialTesterConn)
+                       },
+               })
+               req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
+               req.Header.Set("Content-Type", "text/plain")
+               rt.res, rt.err = dt.cst.tr.RoundTrip(req)
+               dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
+               close(rt.done)
+       }()
+       return rt
+}
+
+// wantDone indicates that a RoundTrip should have returned.
+func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
+       rt.t.Helper()
+       <-rt.done
+       if rt.err != nil {
+               rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
+       }
+       if rt.conn != c {
+               rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
+       }
+}
+
+// finish completes a RoundTrip by sending the request body, consuming the response body,
+// and closing the response body.
+func (rt *transportDialTesterRoundTrip) finish() {
+       rt.t.Helper()
+
+       if rt.finished {
+               return
+       }
+       rt.finished = true
+
+       <-rt.done
+
+       if rt.err != nil {
+               return
+       }
+       rt.reqBody.Close()
+       io.ReadAll(rt.res.Body)
+       rt.res.Body.Close()
+       rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
+}
+
+// wantDial waits for the Transport to start a Dial.
+func (dt *transportDialTester) wantDial() *transportDialTesterConn {
+       c := <-dt.dials
+       c.connID = dt.dialCount
+       dt.dialCount++
+       dt.t.Logf("Dial %v: started", c.connID)
+       return c
+}
+
+// finish completes a Dial.
+func (c *transportDialTesterConn) finish(err error) {
+       c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
+       c.ready <- err
+       close(c.ready)
+}
index e8baa486a4fb1edc51c06e0612f6ac49ed42574b..fa147e164ed843181980cbcde07e2e4b6144dff7 100644 (file)
@@ -1626,11 +1626,20 @@ func TestOnProxyConnectResponse(t *testing.T) {
 // Issue 28012: verify that the Transport closes its TCP connection to http proxies
 // when they're slow to reply to HTTPS CONNECT responses.
 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
-       setParallel(t)
-       defer afterTest(t)
+       cancelc := make(chan struct{})
+       SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
+               ctx, cancel := context.WithCancel(ctx)
+               go func() {
+                       select {
+                       case <-cancelc:
+                       case <-ctx.Done():
+                       }
+                       cancel()
+               }()
+               return ctx, cancel
+       })
 
-       ctx, cancel := context.WithCancel(context.Background())
-       defer cancel()
+       defer afterTest(t)
 
        ln := newLocalListener(t)
        defer ln.Close()
@@ -1658,7 +1667,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
                // Now hang and never write a response; instead, cancel the request and wait
                // for the client to close.
                // (Prior to Issue 28012 being fixed, we never closed.)
-               cancel()
+               close(cancelc)
                var buf [1]byte
                _, err = br.Read(buf[:])
                if err != io.EOF {
@@ -1674,7 +1683,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
                        },
                },
        }
-       req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
+       req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
        if err != nil {
                t.Fatal(err)
        }
@@ -3927,9 +3936,13 @@ func testTransportDialTLS(t *testing.T, mode testMode) {
 
 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
 func testTransportDialContext(t *testing.T, mode testMode) {
-       var mu sync.Mutex // guards following
-       var gotReq bool
-       var receivedContext context.Context
+       ctxKey := "some-key"
+       ctxValue := "some-value"
+       var (
+               mu          sync.Mutex // guards following
+               gotReq      bool
+               gotCtxValue any
+       )
 
        ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                mu.Lock()
@@ -3939,7 +3952,7 @@ func testTransportDialContext(t *testing.T, mode testMode) {
        c := ts.Client()
        c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
                mu.Lock()
-               receivedContext = ctx
+               gotCtxValue = ctx.Value(ctxKey)
                mu.Unlock()
                return net.Dial(netw, addr)
        }
@@ -3948,7 +3961,7 @@ func testTransportDialContext(t *testing.T, mode testMode) {
        if err != nil {
                t.Fatal(err)
        }
-       ctx := context.WithValue(context.Background(), "some-key", "some-value")
+       ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
        res, err := c.Do(req.WithContext(ctx))
        if err != nil {
                t.Fatal(err)
@@ -3958,8 +3971,8 @@ func testTransportDialContext(t *testing.T, mode testMode) {
        if !gotReq {
                t.Error("didn't get request")
        }
-       if receivedContext != ctx {
-               t.Error("didn't receive correct context")
+       if got, want := gotCtxValue, ctxValue; got != want {
+               t.Errorf("got context with value %v, want %v", got, want)
        }
 }
 
@@ -3967,9 +3980,13 @@ func TestTransportDialTLSContext(t *testing.T) {
        run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
 }
 func testTransportDialTLSContext(t *testing.T, mode testMode) {
-       var mu sync.Mutex // guards following
-       var gotReq bool
-       var receivedContext context.Context
+       ctxKey := "some-key"
+       ctxValue := "some-value"
+       var (
+               mu          sync.Mutex // guards following
+               gotReq      bool
+               gotCtxValue any
+       )
 
        ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                mu.Lock()
@@ -3979,7 +3996,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
        c := ts.Client()
        c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
                mu.Lock()
-               receivedContext = ctx
+               gotCtxValue = ctx.Value(ctxKey)
                mu.Unlock()
                c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
                if err != nil {
@@ -3992,7 +4009,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
        if err != nil {
                t.Fatal(err)
        }
-       ctx := context.WithValue(context.Background(), "some-key", "some-value")
+       ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
        res, err := c.Do(req.WithContext(ctx))
        if err != nil {
                t.Fatal(err)
@@ -4002,8 +4019,8 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
        if !gotReq {
                t.Error("didn't get request")
        }
-       if receivedContext != ctx {
-               t.Error("didn't receive correct context")
+       if got, want := gotCtxValue, ctxValue; got != want {
+               t.Errorf("got context with value %v, want %v", got, want)
        }
 }