]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add more tests of transport connection pool
authorDamien Neil <dneil@google.com>
Wed, 25 Sep 2024 18:48:17 +0000 (11:48 -0700)
committerGopher Robot <gobot@golang.org>
Fri, 26 Sep 2025 22:06:09 +0000 (15:06 -0700)
Add a variety of addtional tests exercising client connection pooling,
in particular HTTP/2 connection behavior.

Change-Id: I7609d36db5865f1b95c903cfadb0c3233e046c09
Reviewed-on: https://go-review.googlesource.com/c/go/+/615896
Reviewed-by: Nicholas Husin <husin@google.com>
Reviewed-by: Nicholas Husin <nsh@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Damien Neil <dneil@google.com>

src/net/http/clientserver_test.go
src/net/http/http2_test.go [new file with mode: 0644]
src/net/http/transport_dial_test.go

index c3cf3984ef1f6a8e5dfc0d624a327048a69f12ac..8665bae38ad0f26ccfd622fe9dbbce45b23847cf 100644 (file)
@@ -207,6 +207,8 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c
                        transportFuncs = append(transportFuncs, opt)
                case func(*httptest.Server):
                        opt(cst.ts)
+               case func(*Server):
+                       opt(cst.ts.Config)
                default:
                        t.Fatalf("unhandled option type %T", opt)
                }
diff --git a/src/net/http/http2_test.go b/src/net/http/http2_test.go
new file mode 100644 (file)
index 0000000..7841b05
--- /dev/null
@@ -0,0 +1,12 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !nethttpomithttp2
+
+package http
+
+func init() {
+       // Disable HTTP/2 internal channel pooling which interferes with synctest.
+       http2inTests = true
+}
index 39e35cec55620e3edb9fcd3ad57fb6f0dbb4384a..d7f7a3d539434fa5faa150bd494ae44bc6b75abe 100644 (file)
@@ -6,73 +6,323 @@ package http_test
 
 import (
        "context"
+       "crypto/tls"
+       "errors"
        "io"
        "net"
        "net/http"
        "net/http/httptrace"
+       "strings"
+       "sync"
        "testing"
+       "testing/synctest"
 )
 
+// Successive requests use the same HTTP/1 connection.
 func TestTransportPoolConnReusePriorConnection(t *testing.T) {
-       dt := newTransportDialTester(t, http1Mode)
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode)
 
-       // First request creates a new connection.
-       rt1 := dt.roundTrip()
-       c1 := dt.wantDial()
-       c1.finish(nil)
-       rt1.wantDone(c1)
-       rt1.finish()
+               // First request creates a new connection.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/1.1")
+               rt1.finish()
 
-       // Second request reuses the first connection.
-       rt2 := dt.roundTrip()
-       rt2.wantDone(c1)
-       rt2.finish()
+               // Second request reuses the first connection.
+               rt2 := dt.roundTrip()
+               rt2.wantDone(c1, "HTTP/1.1")
+               rt2.finish()
+       })
 }
 
+// Two HTTP/1 requests made at the same time use different connections.
 func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
-       dt := newTransportDialTester(t, http1Mode)
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode)
+
+               // First request creates a new connection.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/1.1")
+
+               // Second request is made while the first request is still using its connection,
+               // so it goes on a new connection.
+               rt2 := dt.roundTrip()
+               c2 := dt.wantDial()
+               c2.finish(nil)
+               rt2.wantDone(c2, "HTTP/1.1")
+       })
+}
+
+// When an HTTP/2 connection is at its stream limit
+// a new request is made on a new connection.
+func TestTransportPoolConnHTTP2OverStreamLimit(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {
+                       srv.HTTP2 = &http.HTTP2Config{
+                               MaxConcurrentStreams: 2,
+                       }
+               })
+
+               // First request dials an HTTP/2 connection.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/2.0")
+
+               // Second request uses the existing connection.
+               rt2 := dt.roundTrip()
+               rt2.wantDone(c1, "HTTP/2.0")
+
+               // Third request creates a new connection
+               rt3 := dt.roundTrip()
+               c2 := dt.wantDial()
+               c2.finish(nil)
+               rt3.wantDone(c2, "HTTP/2.0")
 
-       // First request creates a new connection.
-       rt1 := dt.roundTrip()
-       c1 := dt.wantDial()
-       c1.finish(nil)
-       rt1.wantDone(c1)
+               rt1.finish()
+               rt2.finish()
+               rt3.finish()
 
-       // Second request is made while the first request is still using its connection,
-       // so it goes on a new connection.
-       rt2 := dt.roundTrip()
-       c2 := dt.wantDial()
-       c2.finish(nil)
-       rt2.wantDone(c2)
+               // With slots available on both connections, we prefer the oldest.
+               rt4 := dt.roundTrip()
+               rt4.wantDone(c1, "HTTP/2.0")
+               rt5 := dt.roundTrip()
+               rt5.wantDone(c1, "HTTP/2.0")
+               rt6 := dt.roundTrip()
+               rt6.wantDone(c2, "HTTP/2.0")
+               rt4.finish()
+               rt5.finish()
+               rt6.finish()
+       })
 }
 
+// A new request made while an HTTP/2 dial is in progress will start a second dial.
+func TestTransportPoolConnHTTP2Startup(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {})
+
+               // Two requests start.
+               // Since the second request starts before the first dial finishes, it starts a second dial.
+               rt1 := dt.roundTrip()
+               rt2 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c2 := dt.wantDial()
+
+               // Both requests use the conn of the first dial to complete.
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/2.0")
+               rt2.wantDone(c1, "HTTP/2.0")
+
+               rt1.finish()
+               rt2.finish()
+               c2.finish(nil)
+       })
+}
+
+// When a request finishes using an HTTP/1 connection,
+// a pending request attempting to dial a new connection will use the newly-available one.
 func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
-       dt := newTransportDialTester(t, http1Mode)
-
-       // First request creates a new connection.
-       rt1 := dt.roundTrip()
-       c1 := dt.wantDial()
-       c1.finish(nil)
-       rt1.wantDone(c1)
-
-       // Second request is made while the first request is still using its connection.
-       // The first connection completes while the second Dial is in progress, so the
-       // second request uses the first connection.
-       rt2 := dt.roundTrip()
-       c2 := dt.wantDial()
-       rt1.finish()
-       rt2.wantDone(c1)
-
-       // This section is a bit overfitted to the current Transport implementation:
-       // A third request starts. We have an in-progress dial that was started by rt2,
-       // but this new request (rt3) is going to ignore it and make a dial of its own.
-       // rt3 will use the first of these dials that completes.
-       rt3 := dt.roundTrip()
-       c3 := dt.wantDial()
-       c2.finish(nil)
-       rt3.wantDone(c2)
-
-       c3.finish(nil)
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode)
+
+               // First request creates a new connection.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/1.1")
+
+               // Second request is made while the first request is still using its connection.
+               // The first connection completes while the second Dial is in progress, so the
+               // second request uses the first connection.
+               rt2 := dt.roundTrip()
+               c2 := dt.wantDial()
+               rt1.finish()
+               rt2.wantDone(c1, "HTTP/1.1")
+
+               // This section is a bit overfitted to the current Transport implementation:
+               // A third request starts. We have an in-progress dial that was started by rt2,
+               // but this new request (rt3) is going to ignore it and make a dial of its own.
+               // rt3 will use the first of these dials that completes.
+               rt3 := dt.roundTrip()
+               c3 := dt.wantDial()
+               c2.finish(nil)
+               rt3.wantDone(c2, "HTTP/1.1")
+
+               c3.finish(nil)
+       })
+}
+
+// Connections are not reused when DisableKeepAlives = true.
+func TestTransportPoolDisableKeepAlives(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode, func(tr *http.Transport) {
+                       tr.DisableKeepAlives = true
+               })
+
+               // Two requests, each uses a separate connection.
+               for range 2 {
+                       rt := dt.roundTrip()
+                       c := dt.wantDial()
+                       c.finish(nil)
+                       rt.wantDone(c, "HTTP/1.1")
+                       rt.finish()
+               }
+       })
+}
+
+// Canceling a request before its connection is created returns the conn to the pool.
+func TestTransportPoolCancelRequestReusesConn(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode)
+
+               // First request is canceled before its connection is created.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               rt1.cancel()
+               rt1.wantError()
+
+               // Second request uses the first connection.
+               rt2 := dt.roundTrip()
+               c2 := dt.wantDial()
+               c1.finish(nil) // first dial finishes
+               rt2.wantDone(c1, "HTTP/1.1")
+               rt2.finish()
+
+               c2.finish(nil) // second dial finishes
+       })
+}
+
+// Connections are not reused when DisableKeepAlives = true.
+func TestTransportPoolCancelRequestWithDisableKeepAlives(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode, func(tr *http.Transport) {
+                       tr.DisableKeepAlives = true
+               })
+
+               // First request is canceled before its connection is created.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               rt1.cancel()
+               rt1.wantError()
+
+               // Dial finishes. DisableKeepAlives = true, so we discard the connection.
+               c1.finish(nil)
+
+               // Second request is made on a new connection.
+               rt2 := dt.roundTrip()
+               c2 := dt.wantDial()
+               c2.finish(nil)
+               rt2.wantDone(c2, "HTTP/1.1")
+               rt2.finish()
+       })
+}
+
+// Connections are not reused after an error.
+func TestTransportPoolConnectionBroken(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode)
+
+               // First request creates a new connection.
+               // The connection breaks while sending the response.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/1.1")
+               c1.fakeNetConn.Close() // break the connection
+               rt1.finish()
+
+               // Second request is made on a new connection, since the first is broken.
+               rt2 := dt.roundTrip()
+               c2 := dt.wantDial()
+               c2.finish(nil)
+               rt2.wantDone(c2, "HTTP/1.1")
+               c2.fakeNetConn.Close()
+               rt2.finish()
+       })
+}
+
+// MaxIdleConnsPerHost limits the number of idle connections.
+func TestTransportPoolClosesConnsPastMaxIdleConnsPerHost(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               dt := newTransportDialTester(t, http1Mode, func(tr *http.Transport) {
+                       tr.MaxIdleConnsPerHost = 1
+               })
+
+               // First request creates a new connection.
+               rt1 := dt.roundTrip("host1.fake.tld")
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/1.1")
+
+               // Second request also creates a new connection.
+               rt2 := dt.roundTrip("host1.fake.tld")
+               c2 := dt.wantDial()
+               c2.finish(nil)
+               rt2.wantDone(c2, "HTTP/1.1")
+
+               // Third request is to a different host.
+               rt3 := dt.roundTrip("host2.fake.tld")
+               c3 := dt.wantDial()
+               c3.finish(nil)
+               rt3.wantDone(c3, "HTTP/1.1")
+
+               // All requests finish. One conn is in excess of MaxIdleConnsPerHost, and is closed.
+               rt3.finish()
+               rt2.finish()
+               rt1.finish()
+               c1.wantClosed()
+
+               // Additional requests reuse the remaining connections.
+               rt4 := dt.roundTrip("host1.fake.tld")
+               rt4.wantDone(c2, "HTTP/1.1")
+               rt4.finish()
+               rt5 := dt.roundTrip("host2.fake.tld")
+               rt5.wantDone(c3, "HTTP/1.1")
+               rt5.finish()
+       })
+}
+
+// Current (but probably wrong) behavior:
+// MaxIdleConnsPerHost doesn't apply to HTTP/2 connections.
+func TestTransportPoolMaxIdleConnsPerHostHTTP2(t *testing.T) {
+       synctest.Test(t, func(t *testing.T) {
+               t.Skip("skipped until h2_bundle.go includes support for MaxConcurrentStreams")
+
+               dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {
+                       srv.HTTP2 = &http.HTTP2Config{
+                               MaxConcurrentStreams: 1,
+                       }
+               }, func(tr *http.Transport) {
+                       tr.MaxIdleConnsPerHost = 1
+               })
+
+               // First request creates a new connection.
+               rt1 := dt.roundTrip()
+               c1 := dt.wantDial()
+               c1.finish(nil)
+               rt1.wantDone(c1, "HTTP/2.0")
+
+               // Second request also creates a new connection.
+               rt2 := dt.roundTrip()
+               c2 := dt.wantDial()
+               c2.finish(nil)
+               rt2.wantDone(c2, "HTTP/2.0")
+
+               // Both requests finish.
+               // We have two idle conns for this host, but we keep them both.
+               rt1.finish()
+               rt2.finish()
+
+               // Two new requests use the existing connections.
+               rt3 := dt.roundTrip()
+               rt3.wantDone(c1, "HTTP/2.0")
+               rt4 := dt.roundTrip()
+               rt4.wantDone(c2, "HTTP/2.0")
+       })
 }
 
 // A transportDialTester manages a test of a connection's Dials.
@@ -80,7 +330,8 @@ type transportDialTester struct {
        t   *testing.T
        cst *clientServerTest
 
-       dials chan *transportDialTesterConn // each new conn is sent to this channel
+       dialsMu sync.Mutex
+       dials   []*transportDialTesterConn
 
        roundTripCount int
        dialCount      int
@@ -90,12 +341,12 @@ type transportDialTester struct {
 type transportDialTesterRoundTrip struct {
        t *testing.T
 
-       roundTripID int                // distinguishes RoundTrips in logs
-       cancel      context.CancelFunc // cancels the Request context
-       reqBody     io.WriteCloser     // write half of the Request.Body
-       finished    bool
+       roundTripID    int                // distinguishes RoundTrips in logs
+       cancel         context.CancelFunc // cancels the Request context
+       reqBody        io.WriteCloser     // write half of the Request.Body
+       respBodyClosed bool               // set when the user calls Response.Body.Close
+       returned       bool               // set when RoundTrip returns
 
-       done chan struct{} // closed when RoundTrip returns:w
        res  *http.Response
        err  error
        conn *transportDialTesterConn
@@ -108,15 +359,46 @@ type transportDialTesterConn struct {
 
        connID int        // distinguished Dials in logs
        ready  chan error // sent on to complete the Dial
+       protos []string
+       closed chan struct{}
 
-       net.Conn
+       *fakeNetConn
 }
 
-func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
+func newTransportDialTester(t *testing.T, mode testMode, opts ...any) *transportDialTester {
        t.Helper()
        dt := &transportDialTester{
-               t:     t,
-               dials: make(chan *transportDialTesterConn),
+               t: t,
+       }
+       dialContext := func(_ context.Context, network, address string) (*transportDialTesterConn, error) {
+               c := &transportDialTesterConn{
+                       t:      t,
+                       ready:  make(chan error),
+                       closed: make(chan struct{}),
+               }
+               // Notify the test that a Dial has started,
+               // and wait for the test to notify us that it should complete.
+               dt.dialsMu.Lock()
+               dt.dials = append(dt.dials, c)
+               dt.dialsMu.Unlock()
+
+               select {
+               case err := <-c.ready:
+                       if err != nil {
+                               return nil, err
+                       }
+               case <-t.Context().Done():
+                       t.Errorf("test finished with dial in progress")
+                       return nil, errors.New("test finished")
+               }
+
+               c.fakeNetConn = dt.cst.li.connect()
+               t.Cleanup(func() {
+                       c.fakeNetConn.Close()
+               })
+               // Use the *transportDialTesterConn as the net.Conn,
+               // to let tests associate requests with connections.
+               return c, nil
        }
        dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                // Write response headers when we receive a request.
@@ -126,45 +408,54 @@ func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
                // Wait for the client to send the request body,
                // to synchronize with the rest of the test.
                io.ReadAll(r.Body)
-       }), func(tr *http.Transport) {
-               tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
-                       c := &transportDialTesterConn{
-                               t:     t,
-                               ready: make(chan error),
-                       }
-                       // Notify the test that a Dial has started,
-                       // and wait for the test to notify us that it should complete.
-                       dt.dials <- c
-                       if err := <-c.ready; err != nil {
+       }), append([]any{optFakeNet, func(tr *http.Transport) {
+               tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+                       return dialContext(ctx, network, dt.cst.ts.Listener.Addr().String())
+               }
+               tr.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+                       conn, err := dialContext(ctx, network, dt.cst.ts.Listener.Addr().String())
+                       if err != nil {
                                return nil, err
                        }
-                       nc, err := net.Dial(network, address)
-                       if err != nil {
+                       config := &tls.Config{
+                               InsecureSkipVerify: true,
+                               NextProtos:         []string{"h2", "http/1.1"},
+                       }
+                       if conn.protos != nil {
+                               config.NextProtos = conn.protos
+                       }
+                       tc := tls.Client(conn, config)
+                       if err := tc.Handshake(); err != nil {
                                return nil, err
                        }
-                       // Use the *transportDialTesterConn as the net.Conn,
-                       // to let tests associate requests with connections.
-                       c.Conn = nc
-                       return c, err
+                       return tc, nil
                }
-       })
+       }}, opts...)...)
        return dt
 }
 
 // roundTrip starts a RoundTrip.
 // It returns immediately, without waiting for the RoundTrip call to complete.
-func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
+func (dt *transportDialTester) roundTrip(opts ...any) *transportDialTesterRoundTrip {
        dt.t.Helper()
+       host := "fake.tld"
+       for _, o := range opts {
+               switch o := o.(type) {
+               case string:
+                       host = o
+               default:
+                       dt.t.Fatalf("unknown option type %T", o)
+               }
+       }
        ctx, cancel := context.WithCancel(context.Background())
        pr, pw := io.Pipe()
+       dt.roundTripCount++
        rt := &transportDialTesterRoundTrip{
                t:           dt.t,
                roundTripID: dt.roundTripCount,
-               done:        make(chan struct{}),
                reqBody:     pw,
                cancel:      cancel,
        }
-       dt.roundTripCount++
        dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
        dt.t.Cleanup(func() {
                rt.cancel()
@@ -173,28 +464,54 @@ func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
        go func() {
                ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
                        GotConn: func(info httptrace.GotConnInfo) {
-                               rt.conn = info.Conn.(*transportDialTesterConn)
+                               c := info.Conn
+                               if tlsConn, ok := c.(*tls.Conn); ok {
+                                       c = tlsConn.NetConn()
+                               }
+                               rt.conn = c.(*transportDialTesterConn)
                        },
                })
-               req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
+               proto, _, _ := strings.Cut(dt.cst.ts.URL, ":")
+               req, _ := http.NewRequestWithContext(ctx, "POST", proto+"://"+host, pr)
                req.Header.Set("Content-Type", "text/plain")
                rt.res, rt.err = dt.cst.tr.RoundTrip(req)
                dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
-               close(rt.done)
+               rt.returned = true
        }()
        return rt
 }
 
 // wantDone indicates that a RoundTrip should have returned.
-func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
+func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn, wantProto string) {
        rt.t.Helper()
-       <-rt.done
+       synctest.Wait()
+       if !rt.returned {
+               rt.t.Fatalf("RoundTrip %v: still running, want to have returned", rt.roundTripID)
+       }
        if rt.err != nil {
                rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
        }
        if rt.conn != c {
                rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
        }
+       if got, want := rt.conn, c; got != want {
+               rt.t.Fatalf("RoundTrip %v: sent on conn %v, want conn %v", rt.roundTripID, got.connID, want.connID)
+       }
+       if got, want := rt.res.Proto, wantProto; got != want {
+               rt.t.Fatalf("RoundTrip %v: got protocol %q, want %q", rt.roundTripID, got, want)
+       }
+}
+
+// wantError indicates that a RoundTrip should have returned with an error.
+func (rt *transportDialTesterRoundTrip) wantError() {
+       rt.t.Helper()
+       synctest.Wait()
+       if !rt.returned {
+               rt.t.Fatalf("RoundTrip %v: still running, want to have returned", rt.roundTripID)
+       }
+       if rt.err == nil {
+               rt.t.Fatalf("RoundTrip %v: success, want error", rt.roundTripID)
+       }
 }
 
 // finish completes a RoundTrip by sending the request body, consuming the response body,
@@ -202,16 +519,18 @@ func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
 func (rt *transportDialTesterRoundTrip) finish() {
        rt.t.Helper()
 
-       if rt.finished {
+       synctest.Wait()
+       if !rt.returned {
+               rt.t.Fatalf("RoundTrip %v: still running, want to have returned", rt.roundTripID)
+       }
+       if rt.err != nil {
                return
        }
-       rt.finished = true
 
-       <-rt.done
-
-       if rt.err != nil {
+       if rt.respBodyClosed {
                return
        }
+       rt.respBodyClosed = true
        rt.reqBody.Close()
        io.ReadAll(rt.res.Body)
        rt.res.Body.Close()
@@ -220,16 +539,40 @@ func (rt *transportDialTesterRoundTrip) finish() {
 
 // wantDial waits for the Transport to start a Dial.
 func (dt *transportDialTester) wantDial() *transportDialTesterConn {
-       c := <-dt.dials
-       c.connID = dt.dialCount
+       dt.t.Helper()
+       synctest.Wait()
+       dt.dialsMu.Lock()
+       defer dt.dialsMu.Unlock()
+       if len(dt.dials) == 0 {
+               dt.t.Fatalf("no dial started, want one")
+       }
+       c := dt.dials[0]
+       dt.dials = dt.dials[1:]
        dt.dialCount++
+       c.connID = dt.dialCount
        dt.t.Logf("Dial %v: started", c.connID)
        return c
 }
 
 // finish completes a Dial.
 func (c *transportDialTesterConn) finish(err error) {
+       c.t.Helper()
        c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
        c.ready <- err
        close(c.ready)
 }
+
+func (c *transportDialTesterConn) wantClosed() {
+       c.t.Helper()
+       <-c.closed
+}
+
+func (c *transportDialTesterConn) Close() error {
+       select {
+       case <-c.closed:
+       default:
+               c.t.Logf("Conn %v: closed", c.connID)
+               close(c.closed)
+       }
+       return nil
+}