]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: use httptest.Server Client in tests
authorJohan Brandhorst <johan.brandhorst@gmail.com>
Sat, 4 Mar 2017 18:24:44 +0000 (18:24 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 8 Mar 2017 15:51:48 +0000 (15:51 +0000)
After merging https://go-review.googlesource.com/c/34639/,
it was pointed out to me that a lot of tests under net/http
could use the new functionality to simplify and unify testing.

Using the httptest.Server provided Client removes the need to
call CloseIdleConnections() on all Transports created, as it
is automatically called on the Transport associated with the
client when Server.Close() is called.

Change the transport used by the non-TLS
httptest.Server to a new *http.Transport rather than using
http.DefaultTransport implicitly. The TLS version already
used its own *http.Transport. This change is to prevent
concurrency problems with using DefaultTransport implicitly
across several httptest.Server's.

Add tests to ensure the httptest.Server.Client().Transport
RoundTripper interface is implemented by a *http.Transport,
as is now assumed across large parts of net/http tests.

Change-Id: I9f9d15f59d72893deead5678d314388718c91821
Reviewed-on: https://go-review.googlesource.com/37771
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/client_test.go
src/net/http/fs_test.go
src/net/http/httptest/server.go
src/net/http/httptest/server_test.go
src/net/http/httputil/reverseproxy_test.go
src/net/http/main_test.go
src/net/http/npn_test.go
src/net/http/serve_test.go
src/net/http/transport_test.go

index c75456ae5378dc894ebe1b5c83421ca9d8029895..73f22212f6ede79e33d5385aec24b1bea517e35a 100644 (file)
@@ -10,7 +10,6 @@ import (
        "bytes"
        "context"
        "crypto/tls"
-       "crypto/x509"
        "encoding/base64"
        "errors"
        "fmt"
@@ -73,7 +72,7 @@ func TestClient(t *testing.T) {
        ts := httptest.NewServer(robotsTxtHandler)
        defer ts.Close()
 
-       c := &Client{Transport: &Transport{DisableKeepAlives: true}}
+       c := ts.Client()
        r, err := c.Get(ts.URL)
        var b []byte
        if err == nil {
@@ -220,10 +219,7 @@ func TestClientRedirects(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-
-       c := &Client{Transport: tr}
+       c := ts.Client()
        _, err := c.Get(ts.URL)
        if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
                t.Errorf("with default client Get, expected error %q, got %q", e, g)
@@ -252,13 +248,10 @@ func TestClientRedirects(t *testing.T) {
        var checkErr error
        var lastVia []*Request
        var lastReq *Request
-       c = &Client{
-               Transport: tr,
-               CheckRedirect: func(req *Request, via []*Request) error {
-                       lastReq = req
-                       lastVia = via
-                       return checkErr
-               },
+       c.CheckRedirect = func(req *Request, via []*Request) error {
+               lastReq = req
+               lastVia = via
+               return checkErr
        }
        res, err := c.Get(ts.URL)
        if err != nil {
@@ -313,21 +306,16 @@ func TestClientRedirectContext(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-
        ctx, cancel := context.WithCancel(context.Background())
-       c := &Client{
-               Transport: tr,
-               CheckRedirect: func(req *Request, via []*Request) error {
-                       cancel()
-                       select {
-                       case <-req.Context().Done():
-                               return nil
-                       case <-time.After(5 * time.Second):
-                               return errors.New("redirected request's context never expired after root request canceled")
-                       }
-               },
+       c := ts.Client()
+       c.CheckRedirect = func(req *Request, via []*Request) error {
+               cancel()
+               select {
+               case <-req.Context().Done():
+                       return nil
+               case <-time.After(5 * time.Second):
+                       return errors.New("redirected request's context never expired after root request canceled")
+               }
        }
        req, _ := NewRequest("GET", ts.URL, nil)
        req = req.WithContext(ctx)
@@ -461,11 +449,12 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa
        }))
        defer ts.Close()
 
+       c := ts.Client()
        for _, tt := range table {
                content := tt.redirectBody
                req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content))
                req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil }
-               res, err := DefaultClient.Do(req)
+               res, err := c.Do(req)
 
                if err != nil {
                        t.Fatal(err)
@@ -519,17 +508,12 @@ func TestClientRedirectUseResponse(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-
-       c := &Client{
-               Transport: tr,
-               CheckRedirect: func(req *Request, via []*Request) error {
-                       if req.Response == nil {
-                               t.Error("expected non-nil Request.Response")
-                       }
-                       return ErrUseLastResponse
-               },
+       c := ts.Client()
+       c.CheckRedirect = func(req *Request, via []*Request) error {
+               if req.Response == nil {
+                       t.Error("expected non-nil Request.Response")
+               }
+               return ErrUseLastResponse
        }
        res, err := c.Get(ts.URL)
        if err != nil {
@@ -558,7 +542,7 @@ func TestClientRedirect308NoLocation(t *testing.T) {
                w.WriteHeader(308)
        }))
        defer ts.Close()
-       c := &Client{Transport: &Transport{DisableKeepAlives: true}}
+       c := ts.Client()
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
@@ -586,7 +570,7 @@ func TestClientRedirect308NoGetBody(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       c := &Client{Transport: &Transport{DisableKeepAlives: true}}
+       c := ts.Client()
        req.GetBody = nil // so it can't rewind.
        res, err := c.Do(req)
        if err != nil {
@@ -678,12 +662,8 @@ func TestRedirectCookiesJar(t *testing.T) {
        var ts *httptest.Server
        ts = httptest.NewServer(echoCookiesRedirectHandler)
        defer ts.Close()
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{
-               Transport: tr,
-               Jar:       new(TestJar),
-       }
+       c := ts.Client()
+       c.Jar = new(TestJar)
        u, _ := url.Parse(ts.URL)
        c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
        resp, err := c.Get(ts.URL)
@@ -727,13 +707,10 @@ func TestJarCalls(t *testing.T) {
        }))
        defer ts.Close()
        jar := new(RecordingJar)
-       c := &Client{
-               Jar: jar,
-               Transport: &Transport{
-                       Dial: func(_ string, _ string) (net.Conn, error) {
-                               return net.Dial("tcp", ts.Listener.Addr().String())
-                       },
-               },
+       c := ts.Client()
+       c.Jar = jar
+       c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) {
+               return net.Dial("tcp", ts.Listener.Addr().String())
        }
        _, err := c.Get("http://firsthost.fake/")
        if err != nil {
@@ -845,7 +822,8 @@ func TestClientWrites(t *testing.T) {
                }
                return c, err
        }
-       c := &Client{Transport: &Transport{Dial: dialer}}
+       c := ts.Client()
+       c.Transport.(*Transport).Dial = dialer
 
        _, err := c.Get(ts.URL)
        if err != nil {
@@ -878,14 +856,11 @@ func TestClientInsecureTransport(t *testing.T) {
        // TODO(bradfitz): add tests for skipping hostname checks too?
        // would require a new cert for testing, and probably
        // redundant with these tests.
+       c := ts.Client()
        for _, insecure := range []bool{true, false} {
-               tr := &Transport{
-                       TLSClientConfig: &tls.Config{
-                               InsecureSkipVerify: insecure,
-                       },
+               c.Transport.(*Transport).TLSClientConfig = &tls.Config{
+                       InsecureSkipVerify: insecure,
                }
-               defer tr.CloseIdleConnections()
-               c := &Client{Transport: tr}
                res, err := c.Get(ts.URL)
                if (err == nil) != insecure {
                        t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
@@ -919,22 +894,6 @@ func TestClientErrorWithRequestURI(t *testing.T) {
        }
 }
 
-func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
-       certs := x509.NewCertPool()
-       for _, c := range ts.TLS.Certificates {
-               roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
-               if err != nil {
-                       t.Fatalf("error parsing server's root cert: %v", err)
-               }
-               for _, root := range roots {
-                       certs.AddCert(root)
-               }
-       }
-       return &Transport{
-               TLSClientConfig: &tls.Config{RootCAs: certs},
-       }
-}
-
 func TestClientWithCorrectTLSServerName(t *testing.T) {
        defer afterTest(t)
 
@@ -946,9 +905,8 @@ func TestClientWithCorrectTLSServerName(t *testing.T) {
        }))
        defer ts.Close()
 
-       trans := newTLSTransport(t, ts)
-       trans.TLSClientConfig.ServerName = serverName
-       c := &Client{Transport: trans}
+       c := ts.Client()
+       c.Transport.(*Transport).TLSClientConfig.ServerName = serverName
        if _, err := c.Get(ts.URL); err != nil {
                t.Fatalf("expected successful TLS connection, got error: %v", err)
        }
@@ -961,9 +919,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
        errc := make(chanWriter, 10) // but only expecting 1
        ts.Config.ErrorLog = log.New(errc, "", 0)
 
-       trans := newTLSTransport(t, ts)
-       trans.TLSClientConfig.ServerName = "badserver"
-       c := &Client{Transport: trans}
+       c := ts.Client()
+       c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver"
        _, err := c.Get(ts.URL)
        if err == nil {
                t.Fatalf("expected an error")
@@ -997,13 +954,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := newTLSTransport(t, ts)
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
        tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
        tr.Dial = func(netw, addr string) (net.Conn, error) {
                return net.Dial(netw, ts.Listener.Addr().String())
        }
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
        res, err := c.Get("https://some-other-host.tld/")
        if err != nil {
                t.Fatal(err)
@@ -1018,13 +974,12 @@ func TestResponseSetsTLSConnectionState(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := newTLSTransport(t, ts)
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
        tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
        tr.Dial = func(netw, addr string) (net.Conn, error) {
                return net.Dial(netw, ts.Listener.Addr().String())
        }
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
        res, err := c.Get("https://example.com/")
        if err != nil {
                t.Fatal(err)
@@ -1119,14 +1074,12 @@ func TestEmptyPasswordAuth(t *testing.T) {
                }
        }))
        defer ts.Close()
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
        req, err := NewRequest("GET", ts.URL, nil)
        if err != nil {
                t.Fatal(err)
        }
        req.URL.User = url.User(gopher)
+       c := ts.Client()
        resp, err := c.Do(req)
        if err != nil {
                t.Fatal(err)
@@ -1503,21 +1456,17 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) {
        defer ts2.Close()
        ts2URL = ts2.URL
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{
-               Transport: tr,
-               CheckRedirect: func(r *Request, via []*Request) error {
-                       want := Header{
-                               "User-Agent": []string{ua},
-                               "X-Foo":      []string{xfoo},
-                               "Referer":    []string{ts2URL},
-                       }
-                       if !reflect.DeepEqual(r.Header, want) {
-                               t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
-                       }
-                       return nil
-               },
+       c := ts1.Client()
+       c.CheckRedirect = func(r *Request, via []*Request) error {
+               want := Header{
+                       "User-Agent": []string{ua},
+                       "X-Foo":      []string{xfoo},
+                       "Referer":    []string{ts2URL},
+               }
+               if !reflect.DeepEqual(r.Header, want) {
+                       t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
+               }
+               return nil
        }
 
        req, _ := NewRequest("GET", ts2.URL, nil)
@@ -1606,13 +1555,9 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
        jar, _ := cookiejar.New(nil)
-       c := &Client{
-               Transport: tr,
-               Jar:       jar,
-       }
+       c := ts.Client()
+       c.Jar = jar
 
        u, _ := url.Parse(ts.URL)
        req, _ := NewRequest("GET", ts.URL, nil)
@@ -1730,9 +1675,7 @@ func TestClientRedirectTypes(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-
+       c := ts.Client()
        for i, tt := range tests {
                handlerc <- func(w ResponseWriter, r *Request) {
                        w.Header().Set("Location", ts.URL)
@@ -1745,7 +1688,6 @@ func TestClientRedirectTypes(t *testing.T) {
                        continue
                }
 
-               c := &Client{Transport: tr}
                c.CheckRedirect = func(req *Request, via []*Request) error {
                        if got, want := req.Method, tt.wantMethod; got != want {
                                return fmt.Errorf("#%d: got next method %q; want %q", i, got, want)
@@ -1799,9 +1741,8 @@ func TestTransportBodyReadError(t *testing.T) {
                w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err))
        }))
        defer ts.Close()
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        // Do one initial successful request to create an idle TCP connection
        // for the subsequent request to reuse. (The Transport only retries
index 1de1cd53d098b5397a893e86661f6d31724520d5..e12350efd735d8283449e558dcd717130a40ef87 100644 (file)
@@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) {
                ServeFile(w, r, "testdata/file")
        }))
        defer ts.Close()
+       c := ts.Client()
 
        var err error
 
@@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) {
        req.Method = "GET"
 
        // straight GET
-       _, body := getBody(t, "straight get", req)
+       _, body := getBody(t, "straight get", req, c)
        if !bytes.Equal(body, file) {
                t.Fatalf("body mismatch: got %q, want %q", body, file)
        }
@@ -102,7 +103,7 @@ Cases:
                if rt.r != "" {
                        req.Header.Set("Range", rt.r)
                }
-               resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req)
+               resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c)
                if resp.StatusCode != rt.code {
                        t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
                }
@@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) {
        req, _ := NewRequest("GET", ts.URL, nil)
        req.Header.Set("If-Modified-Since", lastMod)
 
-       res, err = DefaultClient.Do(req)
+       c := ts.Client()
+       res, err = c.Do(req)
        if err != nil {
                t.Fatal(err)
        }
@@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) {
        // Advance the index.html file's modtime, but not the directory's.
        indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
 
-       res, err = DefaultClient.Do(req)
+       res, err = c.Do(req)
        if err != nil {
                t.Fatal(err)
        }
@@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) {
                for k, v := range tt.reqHeader {
                        req.Header.Set(k, v)
                }
-               res, err := DefaultClient.Do(req)
+
+               c := ts.Client()
+               res, err := c.Do(req)
                if err != nil {
                        t.Fatal(err)
                }
@@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) {
        }
        ts := httptest.NewServer(FileServer(fs))
        defer ts.Close()
+       c := ts.Client()
        for _, code := range []int{403, 404, 500} {
-               res, err := DefaultClient.Get(fmt.Sprintf("%s/%d", ts.URL, code))
+               res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code))
                if err != nil {
                        t.Errorf("Error fetching /%d: %v", code, err)
                        continue
@@ -1125,8 +1130,8 @@ func TestLinuxSendfile(t *testing.T) {
        }
 }
 
-func getBody(t *testing.T, testName string, req Request) (*Response, []byte) {
-       r, err := DefaultClient.Do(&req)
+func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) {
+       r, err := client.Do(&req)
        if err != nil {
                t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err)
        }
index 56ad18ee9ba8b8cbac8381b4c929a02af287c4ed..b5b18c747d9b2a1597e9979b73af7fe21a0e9fc4 100644 (file)
@@ -93,7 +93,9 @@ func NewUnstartedServer(handler http.Handler) *Server {
        return &Server{
                Listener: newLocalListener(),
                Config:   &http.Server{Handler: handler},
-               client:   &http.Client{},
+               client: &http.Client{
+                       Transport: &http.Transport{},
+               },
        }
 }
 
index 7d80fa15dd8cd21e50ac6c7fd59112c71fd62b07..62846de02cd582a329b267c9655bc5a293be6f1f 100644 (file)
@@ -121,3 +121,27 @@ func TestServerClient(t *testing.T) {
                t.Errorf("got %q, want hello", string(got))
        }
 }
+
+// Tests that the Server.Client.Transport interface is implemented
+// by a *http.Transport.
+func TestServerClientTransportType(t *testing.T) {
+       ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+       }))
+       defer ts.Close()
+       client := ts.Client()
+       if _, ok := client.Transport.(*http.Transport); !ok {
+               t.Errorf("got %T, want *http.Transport", client.Transport)
+       }
+}
+
+// Tests that the TLS Server.Client.Transport interface is implemented
+// by a *http.Transport.
+func TestTLSServerClientTransportType(t *testing.T) {
+       ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+       }))
+       defer ts.Close()
+       client := ts.Client()
+       if _, ok := client.Transport.(*http.Transport); !ok {
+               t.Errorf("got %T, want *http.Transport", client.Transport)
+       }
+}
index 9153508ef4b2f954c18e2aae6026dbda9d1baa87..008e4e717fbad0a2eac460514f4325a22970b04b 100644 (file)
@@ -79,6 +79,7 @@ func TestReverseProxy(t *testing.T) {
        proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
        frontend := httptest.NewServer(proxyHandler)
        defer frontend.Close()
+       frontendClient := frontend.Client()
 
        getReq, _ := http.NewRequest("GET", frontend.URL, nil)
        getReq.Host = "some-name"
@@ -86,7 +87,7 @@ func TestReverseProxy(t *testing.T) {
        getReq.Header.Set("Proxy-Connection", "should be deleted")
        getReq.Header.Set("Upgrade", "foo")
        getReq.Close = true
-       res, err := http.DefaultClient.Do(getReq)
+       res, err := frontendClient.Do(getReq)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -126,7 +127,7 @@ func TestReverseProxy(t *testing.T) {
        // a response results in a StatusBadGateway.
        getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
        getReq.Close = true
-       res, err = http.DefaultClient.Do(getReq)
+       res, err = frontendClient.Do(getReq)
        if err != nil {
                t.Fatal(err)
        }
@@ -172,7 +173,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
        getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
        getReq.Header.Set("Upgrade", "original value")
        getReq.Header.Set(fakeConnectionToken, "should be deleted")
-       res, err := http.DefaultClient.Do(getReq)
+       res, err := frontend.Client().Do(getReq)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -220,7 +221,7 @@ func TestXForwardedFor(t *testing.T) {
        getReq.Header.Set("Connection", "close")
        getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
        getReq.Close = true
-       res, err := http.DefaultClient.Do(getReq)
+       res, err := frontend.Client().Do(getReq)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -259,7 +260,7 @@ func TestReverseProxyQuery(t *testing.T) {
                frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
                req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
                req.Close = true
-               res, err := http.DefaultClient.Do(req)
+               res, err := frontend.Client().Do(req)
                if err != nil {
                        t.Fatalf("%d. Get: %v", i, err)
                }
@@ -295,7 +296,7 @@ func TestReverseProxyFlushInterval(t *testing.T) {
 
        req, _ := http.NewRequest("GET", frontend.URL, nil)
        req.Close = true
-       res, err := http.DefaultClient.Do(req)
+       res, err := frontend.Client().Do(req)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -349,13 +350,14 @@ func TestReverseProxyCancelation(t *testing.T) {
 
        frontend := httptest.NewServer(proxyHandler)
        defer frontend.Close()
+       frontendClient := frontend.Client()
 
        getReq, _ := http.NewRequest("GET", frontend.URL, nil)
        go func() {
                <-reqInFlight
-               http.DefaultTransport.(*http.Transport).CancelRequest(getReq)
+               frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
        }()
-       res, err := http.DefaultClient.Do(getReq)
+       res, err := frontendClient.Do(getReq)
        if res != nil {
                t.Errorf("got response %v; want nil", res.Status)
        }
@@ -363,7 +365,7 @@ func TestReverseProxyCancelation(t *testing.T) {
                // This should be an error like:
                // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
                //    use of closed network connection
-               t.Error("DefaultClient.Do() returned nil error; want non-nil error")
+               t.Error("Server.Client().Do() returned nil error; want non-nil error")
        }
 }
 
@@ -428,11 +430,12 @@ func TestUserAgentHeader(t *testing.T) {
        proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
        frontend := httptest.NewServer(proxyHandler)
        defer frontend.Close()
+       frontendClient := frontend.Client()
 
        getReq, _ := http.NewRequest("GET", frontend.URL, nil)
        getReq.Header.Set("User-Agent", explicitUA)
        getReq.Close = true
-       res, err := http.DefaultClient.Do(getReq)
+       res, err := frontendClient.Do(getReq)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -441,7 +444,7 @@ func TestUserAgentHeader(t *testing.T) {
        getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
        getReq.Header.Set("User-Agent", "")
        getReq.Close = true
-       res, err = http.DefaultClient.Do(getReq)
+       res, err = frontendClient.Do(getReq)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -493,7 +496,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) {
 
        req, _ := http.NewRequest("GET", frontend.URL, nil)
        req.Close = true
-       res, err := http.DefaultClient.Do(req)
+       res, err := frontend.Client().Do(req)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -540,7 +543,7 @@ func TestReverseProxy_Post(t *testing.T) {
        defer frontend.Close()
 
        postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
-       res, err := http.DefaultClient.Do(postReq)
+       res, err := frontend.Client().Do(postReq)
        if err != nil {
                t.Fatalf("Do: %v", err)
        }
@@ -573,7 +576,7 @@ func TestReverseProxy_NilBody(t *testing.T) {
        frontend := httptest.NewServer(proxyHandler)
        defer frontend.Close()
 
-       res, err := http.DefaultClient.Get(frontend.URL)
+       res, err := frontend.Client().Get(frontend.URL)
        if err != nil {
                t.Fatal(err)
        }
index 438bd2e58fd89e7c5f405c5de88fe772212da464..fc0437e21155b7d8c540c8269b5ea69b44467c39 100644 (file)
@@ -151,7 +151,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error
        }
        return err
 }
-
-func closeClient(c *http.Client) {
-       c.Transport.(*http.Transport).CloseIdleConnections()
-}
index 4c1f6b573dfbeebb319fdb616b6fc531eabba11f..618bdbe54a6926b016fb9a57c5f430f82a30fa83 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bufio"
        "bytes"
        "crypto/tls"
+       "crypto/x509"
        "fmt"
        "io"
        "io/ioutil"
@@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) {
 
        // Normal request, without NPN.
        {
-               tr := newTLSTransport(t, ts)
-               defer tr.CloseIdleConnections()
-               c := &Client{Transport: tr}
-
+               c := ts.Client()
                res, err := c.Get(ts.URL)
                if err != nil {
                        t.Fatal(err)
@@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) {
        // Request to an advertised but unhandled NPN protocol.
        // Server will hang up.
        {
-               tr := newTLSTransport(t, ts)
-               tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
+               certPool := x509.NewCertPool()
+               certPool.AddCert(ts.Certificate())
+               tr := &Transport{
+                       TLSClientConfig: &tls.Config{
+                               RootCAs:    certPool,
+                               NextProtos: []string{"unhandled-proto"},
+                       },
+               }
                defer tr.CloseIdleConnections()
-               c := &Client{Transport: tr}
-
+               c := &Client{
+                       Transport: tr,
+               }
                res, err := c.Get(ts.URL)
                if err == nil {
                        defer res.Body.Close()
@@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) {
        // Request using the "tls-0.9" protocol, which we register here.
        // It is HTTP/0.9 over TLS.
        {
-               tlsConfig := newTLSTransport(t, ts).TLSClientConfig
+               c := ts.Client()
+               tlsConfig := c.Transport.(*Transport).TLSClientConfig
                tlsConfig.NextProtos = []string{"tls-0.9"}
                conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
                if err != nil {
index 8092cc1bcbead9a5ae98c0fb4a8d626c28bf2a9a..d301d15eb1b49bf018bca88bbc583c2f2e01fe4d 100644 (file)
@@ -474,9 +474,7 @@ func TestServerTimeouts(t *testing.T) {
        defer ts.Close()
 
        // Hit the HTTP server successfully.
-       tr := &Transport{DisableKeepAlives: true} // they interfere with this test
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
        r, err := c.Get(ts.URL)
        if err != nil {
                t.Fatalf("http Get #1: %v", err)
@@ -548,12 +546,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
        ts.StartTLS()
        defer ts.Close()
 
-       tr := newTLSTransport(t, ts)
-       defer tr.CloseIdleConnections()
-       if err := ExportHttp2ConfigureTransport(tr); err != nil {
+       c := ts.Client()
+       if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
                t.Fatal(err)
        }
-       c := &Client{Transport: tr}
 
        for i := 1; i <= 3; i++ {
                req, err := NewRequest("GET", ts.URL, nil)
@@ -608,9 +604,7 @@ func TestOnlyWriteTimeout(t *testing.T) {
        ts.Start()
        defer ts.Close()
 
-       tr := &Transport{DisableKeepAlives: false}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
 
        errc := make(chan error)
        go func() {
@@ -671,8 +665,7 @@ func TestIdentityResponse(t *testing.T) {
        ts := httptest.NewServer(handler)
        defer ts.Close()
 
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
+       c := ts.Client()
 
        // Note: this relies on the assumption (which is true) that
        // Get sends HTTP/1.1 or greater requests. Otherwise the
@@ -949,9 +942,8 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
        ts.Start()
        defer ts.Close()
 
-       tr := &Transport{DisableKeepAlives: true}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr, Timeout: time.Second}
+       c := ts.Client()
+       c.Timeout = time.Second
 
        fetch := func(num int, response chan<- string) {
                resp, err := c.Get(ts.URL)
@@ -1022,9 +1014,7 @@ func TestIdentityResponseHeaders(t *testing.T) {
        }))
        defer ts.Close()
 
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
-
+       c := ts.Client()
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatalf("Get error: %v", err)
@@ -1145,12 +1135,7 @@ func TestTLSServer(t *testing.T) {
                        t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
                        return
                }
-               noVerifyTransport := &Transport{
-                       TLSClientConfig: &tls.Config{
-                               InsecureSkipVerify: true,
-                       },
-               }
-               client := &Client{Transport: noVerifyTransport}
+               client := ts.Client()
                res, err := client.Get(ts.URL)
                if err != nil {
                        t.Error(err)
@@ -1967,8 +1952,7 @@ func TestTimeoutHandlerRace(t *testing.T) {
        ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
        defer ts.Close()
 
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
+       c := ts.Client()
 
        var wg sync.WaitGroup
        gate := make(chan bool, 10)
@@ -2011,8 +1995,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
        if testing.Short() {
                n = 10
        }
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
+
+       c := ts.Client()
        for i := 0; i < n; i++ {
                gate <- true
                wg.Add(1)
@@ -2099,8 +2083,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
        ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
        defer ts.Close()
 
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
+       c := ts.Client()
 
        // Issue was caused by the timeout handler starting the timer when
        // was created, not when the request. So wait for more than the timeout
@@ -2127,8 +2110,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) {
        ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
        defer ts.Close()
 
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
+       c := ts.Client()
 
        res, err := c.Get(ts.URL)
        if err != nil {
@@ -2364,9 +2346,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) {
        ts.Start()
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
@@ -2411,8 +2391,7 @@ func TestStripPrefix(t *testing.T) {
        ts := httptest.NewServer(StripPrefix("/foo", h))
        defer ts.Close()
 
-       c := &Client{Transport: new(Transport)}
-       defer closeClient(c)
+       c := ts.Client()
 
        res, err := c.Get(ts.URL + "/foo/bar")
        if err != nil {
@@ -3654,9 +3633,7 @@ func TestServerConnState(t *testing.T) {
        }
        ts.Start()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
 
        mustGet := func(url string, headers ...string) {
                req, err := NewRequest("GET", url, nil)
@@ -4491,15 +4468,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
        b.ResetTimer()
        b.SetParallelism(parallelism)
        b.RunParallel(func(pb *testing.PB) {
-               noVerifyTransport := &Transport{
-                       TLSClientConfig: &tls.Config{
-                               InsecureSkipVerify: true,
-                       },
-               }
-               defer noVerifyTransport.CloseIdleConnections()
-               client := &Client{Transport: noVerifyTransport}
+               c := ts.Client()
                for pb.Next() {
-                       res, err := client.Get(ts.URL)
+                       res, err := c.Get(ts.URL)
                        if err != nil {
                                b.Logf("Get: %v", err)
                                continue
@@ -4934,10 +4905,7 @@ func TestServerIdleTimeout(t *testing.T) {
        ts.Config.IdleTimeout = 2 * time.Second
        ts.Start()
        defer ts.Close()
-
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
 
        get := func() string {
                res, err := c.Get(ts.URL)
@@ -4998,9 +4966,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        get := func() string { return get(t, c, ts.URL) }
 
@@ -5119,9 +5086,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
        ts.Start()
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
 
        res, err := c.Get(ts.URL)
        if err != nil {
index cb315f14f464e32c6a91d8c742172a1559e4975c..09bfef4b10df8d357b224f17ac3c7672b6c12a0a 100644 (file)
@@ -131,11 +131,9 @@ func TestTransportKeepAlives(t *testing.T) {
        ts := httptest.NewServer(hostPortHandler)
        defer ts.Close()
 
+       c := ts.Client()
        for _, disableKeepAlive := range []bool{false, true} {
-               tr := &Transport{DisableKeepAlives: disableKeepAlive}
-               defer tr.CloseIdleConnections()
-               c := &Client{Transport: tr}
-
+               c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
                fetch := func(n int) string {
                        res, err := c.Get(ts.URL)
                        if err != nil {
@@ -166,12 +164,11 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
 
        connSet, testDial := makeTestDial(t)
 
-       for _, connectionClose := range []bool{false, true} {
-               tr := &Transport{
-                       Dial: testDial,
-               }
-               c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
+       tr.Dial = testDial
 
+       for _, connectionClose := range []bool{false, true} {
                fetch := func(n int) string {
                        req := new(Request)
                        var err error
@@ -217,12 +214,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
 
        connSet, testDial := makeTestDial(t)
 
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
+       tr.Dial = testDial
        for _, connectionClose := range []bool{false, true} {
-               tr := &Transport{
-                       Dial: testDial,
-               }
-               c := &Client{Transport: tr}
-
                fetch := func(n int) string {
                        req := new(Request)
                        var err error
@@ -273,10 +268,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
        ts := httptest.NewServer(hostPortHandler)
        defer ts.Close()
 
-       tr := &Transport{
-               DisableKeepAlives: true,
-       }
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       c.Transport.(*Transport).DisableKeepAlives = true
+
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
@@ -291,9 +285,8 @@ func TestTransportIdleCacheKeys(t *testing.T) {
        defer afterTest(t)
        ts := httptest.NewServer(hostPortHandler)
        defer ts.Close()
-
-       tr := &Transport{DisableKeepAlives: false}
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
                t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
@@ -385,9 +378,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
                }
        }))
        defer ts.Close()
+
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
        maxIdleConnsPerHost := 2
-       tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConnsPerHost}
-       c := &Client{Transport: tr}
+       tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
 
        // Start 3 outstanding requests and wait for the server to get them.
        // Their responses will hang until we write to resch, though.
@@ -450,9 +445,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        doReq := func(name string) string {
                // Do a POST instead of a GET to prevent the Transport's
@@ -496,9 +490,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
        defer afterTest(t)
        ts := httptest.NewServer(hostPortHandler)
        defer ts.Close()
-
-       tr := &Transport{}
-       c := &Client{Transport: tr}
+       c := ts.Client()
 
        fetch := func(n, retries int) string {
                condFatalf := func(format string, arg ...interface{}) {
@@ -564,10 +556,7 @@ func TestStressSurpriseServerCloses(t *testing.T) {
                conn.Close()
        }))
        defer ts.Close()
-
-       tr := &Transport{DisableKeepAlives: false}
-       c := &Client{Transport: tr}
-       defer tr.CloseIdleConnections()
+       c := ts.Client()
 
        // Do a bunch of traffic from different goroutines. Send to activityc
        // after each request completes, regardless of whether it failed.
@@ -620,9 +609,8 @@ func TestTransportHeadResponses(t *testing.T) {
                w.WriteHeader(200)
        }))
        defer ts.Close()
+       c := ts.Client()
 
-       tr := &Transport{DisableKeepAlives: false}
-       c := &Client{Transport: tr}
        for i := 0; i < 2; i++ {
                res, err := c.Head(ts.URL)
                if err != nil {
@@ -656,10 +644,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) {
                w.WriteHeader(200)
        }))
        defer ts.Close()
-
-       tr := &Transport{DisableKeepAlives: false}
-       c := &Client{Transport: tr}
-       defer tr.CloseIdleConnections()
+       c := ts.Client()
 
        // Ensure that we wait for the readLoop to complete before
        // calling Head again
@@ -720,6 +705,7 @@ func TestRoundTripGzip(t *testing.T) {
                }
        }))
        defer ts.Close()
+       tr := ts.Client().Transport.(*Transport)
 
        for i, test := range roundTripTests {
                // Test basic request (no accept-encoding)
@@ -727,7 +713,7 @@ func TestRoundTripGzip(t *testing.T) {
                if test.accept != "" {
                        req.Header.Set("Accept-Encoding", test.accept)
                }
-               res, err := DefaultTransport.RoundTrip(req)
+               res, err := tr.RoundTrip(req)
                var body []byte
                if test.compressed {
                        var r *gzip.Reader
@@ -792,10 +778,9 @@ func TestTransportGzip(t *testing.T) {
                gz.Close()
        }))
        defer ts.Close()
+       c := ts.Client()
 
        for _, chunked := range []string{"1", "0"} {
-               c := &Client{Transport: &Transport{}}
-
                // First fetch something large, but only read some of it.
                res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
                if err != nil {
@@ -845,7 +830,6 @@ func TestTransportGzip(t *testing.T) {
        }
 
        // And a HEAD request too, because they're always weird.
-       c := &Client{Transport: &Transport{}}
        res, err := c.Head(ts.URL)
        if err != nil {
                t.Fatalf("Head: %v", err)
@@ -915,11 +899,13 @@ func TestTransportExpect100Continue(t *testing.T) {
                {path: "/timeout", body: []byte("hello"), sent: 5, status: 200},   // Timeout exceeded and entire body is sent.
        }
 
+       c := ts.Client()
        for i, v := range tests {
-               tr := &Transport{ExpectContinueTimeout: 2 * time.Second}
+               tr := &Transport{
+                       ExpectContinueTimeout: 2 * time.Second,
+               }
                defer tr.CloseIdleConnections()
-               c := &Client{Transport: tr}
-
+               c.Transport = tr
                body := bytes.NewReader(v.body)
                req, err := NewRequest("PUT", ts.URL+v.path, body)
                if err != nil {
@@ -1016,7 +1002,8 @@ func TestSocks5Proxy(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}}
+       c := ts.Client()
+       c.Transport.(*Transport).Proxy = ProxyURL(pu)
        if _, err := c.Head(ts.URL); err != nil {
                t.Error(err)
        }
@@ -1052,7 +1039,8 @@ func TestTransportProxy(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}}
+       c := ts.Client()
+       c.Transport.(*Transport).Proxy = ProxyURL(pu)
        if _, err := c.Head(ts.URL); err != nil {
                t.Error(err)
        }
@@ -1122,9 +1110,7 @@ func TestTransportGzipRecursive(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
@@ -1152,9 +1138,7 @@ func TestTransportGzipShort(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
@@ -1195,9 +1179,8 @@ func TestTransportPersistConnLeak(t *testing.T) {
                w.WriteHeader(204)
        }))
        defer ts.Close()
-
-       tr := &Transport{}
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        n0 := runtime.NumGoroutine()
 
@@ -1260,9 +1243,8 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) {
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
        }))
        defer ts.Close()
-
-       tr := &Transport{}
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        n0 := runtime.NumGoroutine()
        body := []byte("Hello")
@@ -1294,8 +1276,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) {
 // This used to crash; https://golang.org/issue/3266
 func TestTransportIdleConnCrash(t *testing.T) {
        defer afterTest(t)
-       tr := &Transport{}
-       c := &Client{Transport: tr}
+       var tr *Transport
 
        unblockCh := make(chan bool, 1)
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1303,6 +1284,8 @@ func TestTransportIdleConnCrash(t *testing.T) {
                tr.CloseIdleConnections()
        }))
        defer ts.Close()
+       c := ts.Client()
+       tr = c.Transport.(*Transport)
 
        didreq := make(chan bool)
        go func() {
@@ -1332,8 +1315,7 @@ func TestIssue3644(t *testing.T) {
                }
        }))
        defer ts.Close()
-       tr := &Transport{}
-       c := &Client{Transport: tr}
+       c := ts.Client()
        res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
@@ -1358,8 +1340,7 @@ func TestIssue3595(t *testing.T) {
                Error(w, deniedMsg, StatusUnauthorized)
        }))
        defer ts.Close()
-       tr := &Transport{}
-       c := &Client{Transport: tr}
+       c := ts.Client()
        res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
        if err != nil {
                t.Errorf("Post: %v", err)
@@ -1383,8 +1364,8 @@ func TestChunkedNoContent(t *testing.T) {
        }))
        defer ts.Close()
 
+       c := ts.Client()
        for _, closeBody := range []bool{true, false} {
-               c := &Client{Transport: &Transport{}}
                const n = 4
                for i := 1; i <= n; i++ {
                        res, err := c.Get(ts.URL)
@@ -1424,10 +1405,7 @@ func TestTransportConcurrency(t *testing.T) {
        SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
        defer SetPendingDialHooks(nil, nil)
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-
-       c := &Client{Transport: tr}
+       c := ts.Client()
        reqs := make(chan string)
        defer close(reqs)
 
@@ -1469,23 +1447,20 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
                io.Copy(w, neverEnding('a'))
        })
        ts := httptest.NewServer(mux)
+       defer ts.Close()
        timeout := 100 * time.Millisecond
 
-       client := &Client{
-               Transport: &Transport{
-                       Dial: func(n, addr string) (net.Conn, error) {
-                               conn, err := net.Dial(n, addr)
-                               if err != nil {
-                                       return nil, err
-                               }
-                               conn.SetDeadline(time.Now().Add(timeout))
-                               if debug {
-                                       conn = NewLoggingConn("client", conn)
-                               }
-                               return conn, nil
-                       },
-                       DisableKeepAlives: true,
-               },
+       c := ts.Client()
+       c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+               conn, err := net.Dial(n, addr)
+               if err != nil {
+                       return nil, err
+               }
+               conn.SetDeadline(time.Now().Add(timeout))
+               if debug {
+                       conn = NewLoggingConn("client", conn)
+               }
+               return conn, nil
        }
 
        getFailed := false
@@ -1497,7 +1472,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
                if debug {
                        println("run", i+1, "of", nRuns)
                }
-               sres, err := client.Get(ts.URL + "/get")
+               sres, err := c.Get(ts.URL + "/get")
                if err != nil {
                        if !getFailed {
                                // Make the timeout longer, once.
@@ -1519,7 +1494,6 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
        if debug {
                println("tests complete; waiting for handlers to finish")
        }
-       ts.Close()
 }
 
 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
@@ -1537,21 +1511,17 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
        ts := httptest.NewServer(mux)
        timeout := 100 * time.Millisecond
 
-       client := &Client{
-               Transport: &Transport{
-                       Dial: func(n, addr string) (net.Conn, error) {
-                               conn, err := net.Dial(n, addr)
-                               if err != nil {
-                                       return nil, err
-                               }
-                               conn.SetDeadline(time.Now().Add(timeout))
-                               if debug {
-                                       conn = NewLoggingConn("client", conn)
-                               }
-                               return conn, nil
-                       },
-                       DisableKeepAlives: true,
-               },
+       c := ts.Client()
+       c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+               conn, err := net.Dial(n, addr)
+               if err != nil {
+                       return nil, err
+               }
+               conn.SetDeadline(time.Now().Add(timeout))
+               if debug {
+                       conn = NewLoggingConn("client", conn)
+               }
+               return conn, nil
        }
 
        getFailed := false
@@ -1563,7 +1533,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
                if debug {
                        println("run", i+1, "of", nRuns)
                }
-               sres, err := client.Get(ts.URL + "/get")
+               sres, err := c.Get(ts.URL + "/get")
                if err != nil {
                        if !getFailed {
                                // Make the timeout longer, once.
@@ -1577,7 +1547,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
                        break
                }
                req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
-               _, err = client.Do(req)
+               _, err = c.Do(req)
                if err == nil {
                        sres.Body.Close()
                        t.Errorf("Unexpected successful PUT")
@@ -1609,11 +1579,8 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
        ts := httptest.NewServer(mux)
        defer ts.Close()
 
-       tr := &Transport{
-               ResponseHeaderTimeout: 500 * time.Millisecond,
-       }
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond
 
        tests := []struct {
                path    string
@@ -1680,9 +1647,8 @@ func TestTransportCancelRequest(t *testing.T) {
        defer ts.Close()
        defer close(unblockc)
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        req, _ := NewRequest("GET", ts.URL, nil)
        res, err := c.Do(req)
@@ -1790,9 +1756,8 @@ func TestCancelRequestWithChannel(t *testing.T) {
        defer ts.Close()
        defer close(unblockc)
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        req, _ := NewRequest("GET", ts.URL, nil)
        ch := make(chan struct{})
@@ -1849,9 +1814,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
        defer ts.Close()
        defer close(unblockc)
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
 
        req, _ := NewRequest("GET", ts.URL, nil)
        if withCtx {
@@ -1939,9 +1902,8 @@ func TestTransportCloseResponseBody(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        req, _ := NewRequest("GET", ts.URL, nil)
        defer tr.CancelRequest(req)
@@ -2061,18 +2023,12 @@ func TestTransportSocketLateBinding(t *testing.T) {
        defer ts.Close()
 
        dialGate := make(chan bool, 1)
-       tr := &Transport{
-               Dial: func(n, addr string) (net.Conn, error) {
-                       if <-dialGate {
-                               return net.Dial(n, addr)
-                       }
-                       return nil, errors.New("manually closed")
-               },
-               DisableKeepAlives: false,
-       }
-       defer tr.CloseIdleConnections()
-       c := &Client{
-               Transport: tr,
+       c := ts.Client()
+       c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+               if <-dialGate {
+                       return net.Dial(n, addr)
+               }
+               return nil, errors.New("manually closed")
        }
 
        dialGate <- true // only allow one dial
@@ -2326,14 +2282,11 @@ func TestIdleConnChannelLeak(t *testing.T) {
        SetReadLoopBeforeNextReadHook(func() { didRead <- true })
        defer SetReadLoopBeforeNextReadHook(nil)
 
-       tr := &Transport{
-               Dial: func(netw, addr string) (net.Conn, error) {
-                       return net.Dial(netw, ts.Listener.Addr().String())
-               },
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
+       tr.Dial = func(netw, addr string) (net.Conn, error) {
+               return net.Dial(netw, ts.Listener.Addr().String())
        }
-       defer tr.CloseIdleConnections()
-
-       c := &Client{Transport: tr}
 
        // First, without keep-alives.
        for _, disableKeep := range []bool{true, false} {
@@ -2376,13 +2329,11 @@ func TestTransportClosesRequestBody(t *testing.T) {
        }))
        defer ts.Close()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       cl := &Client{Transport: tr}
+       c := ts.Client()
 
        closes := 0
 
-       res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+       res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
        if err != nil {
                t.Fatal(err)
        }
@@ -2468,20 +2419,16 @@ func TestTLSServerClosesConnection(t *testing.T) {
                fmt.Fprintf(w, "hello")
        }))
        defer ts.Close()
-       tr := &Transport{
-               TLSClientConfig: &tls.Config{
-                       InsecureSkipVerify: true,
-               },
-       }
-       defer tr.CloseIdleConnections()
-       client := &Client{Transport: tr}
+
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
 
        var nSuccess = 0
        var errs []error
        const trials = 20
        for i := 0; i < trials; i++ {
                tr.CloseIdleConnections()
-               res, err := client.Get(ts.URL + "/keep-alive-then-die")
+               res, err := c.Get(ts.URL + "/keep-alive-then-die")
                if err != nil {
                        t.Fatal(err)
                }
@@ -2496,7 +2443,7 @@ func TestTLSServerClosesConnection(t *testing.T) {
 
                // Now try again and see if we successfully
                // pick a new connection.
-               res, err = client.Get(ts.URL + "/")
+               res, err = c.Get(ts.URL + "/")
                if err != nil {
                        errs = append(errs, err)
                        continue
@@ -2575,22 +2522,20 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
                go io.Copy(ioutil.Discard, conn)
        }))
        defer ts.Close()
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       client := &Client{Transport: tr}
+       c := ts.Client()
 
        const bodySize = 256 << 10
        finalBit := make(byteFromChanReader, 1)
        req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
        req.ContentLength = bodySize
-       res, err := client.Do(req)
+       res, err := c.Do(req)
        if err := wantBody(res, err, "foo"); err != nil {
                t.Errorf("POST response: %v", err)
        }
        donec := make(chan bool)
        go func() {
                defer close(donec)
-               res, err = client.Get(ts.URL)
+               res, err = c.Get(ts.URL)
                if err := wantBody(res, err, "bar"); err != nil {
                        t.Errorf("GET response: %v", err)
                        return
@@ -2622,10 +2567,9 @@ func TestTransportIssue10457(t *testing.T) {
                conn.Close()
        }))
        defer ts.Close()
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       cl := &Client{Transport: tr}
-       res, err := cl.Get(ts.URL)
+       c := ts.Client()
+
+       res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatalf("Get: %v", err)
        }
@@ -2686,29 +2630,26 @@ func TestRetryIdempotentRequestsOnError(t *testing.T) {
        defer ts.Close()
 
        var writeNumAtomic int32
-       tr := &Transport{
-               Dial: func(network, addr string) (net.Conn, error) {
-                       logf("Dial")
-                       c, err := net.Dial(network, ts.Listener.Addr().String())
-                       if err != nil {
-                               logf("Dial error: %v", err)
-                               return nil, err
-                       }
-                       return &writerFuncConn{
-                               Conn: c,
-                               write: func(p []byte) (n int, err error) {
-                                       if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
-                                               logf("intentional write failure")
-                                               return 0, errors.New("second write fails")
-                                       }
-                                       logf("Write(%q)", p)
-                                       return c.Write(p)
-                               },
-                       }, nil
-               },
+       c := ts.Client()
+       c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
+               logf("Dial")
+               c, err := net.Dial(network, ts.Listener.Addr().String())
+               if err != nil {
+                       logf("Dial error: %v", err)
+                       return nil, err
+               }
+               return &writerFuncConn{
+                       Conn: c,
+                       write: func(p []byte) (n int, err error) {
+                               if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
+                                       logf("intentional write failure")
+                                       return 0, errors.New("second write fails")
+                               }
+                               logf("Write(%q)", p)
+                               return c.Write(p)
+                       },
+               }, nil
        }
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
 
        SetRoundTripRetried(func() {
                logf("Retried.")
@@ -2752,6 +2693,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
                readBody <- err
        }))
        defer ts.Close()
+       c := ts.Client()
        fakeErr := errors.New("fake error")
        didClose := make(chan bool, 1)
        req, _ := NewRequest("POST", ts.URL, struct {
@@ -2767,7 +2709,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
                        return nil
                }),
        })
-       res, err := DefaultClient.Do(req)
+       res, err := c.Do(req)
        if res != nil {
                defer res.Body.Close()
        }
@@ -2801,23 +2743,19 @@ func TestTransportDialTLS(t *testing.T) {
                mu.Unlock()
        }))
        defer ts.Close()
-       tr := &Transport{
-               DialTLS: func(netw, addr string) (net.Conn, error) {
-                       mu.Lock()
-                       didDial = true
-                       mu.Unlock()
-                       c, err := tls.Dial(netw, addr, &tls.Config{
-                               InsecureSkipVerify: true,
-                       })
-                       if err != nil {
-                               return nil, err
-                       }
-                       return c, c.Handshake()
-               },
+       c := ts.Client()
+       c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
+               mu.Lock()
+               didDial = true
+               mu.Unlock()
+               c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
+               if err != nil {
+                       return nil, err
+               }
+               return c, c.Handshake()
        }
-       defer tr.CloseIdleConnections()
-       client := &Client{Transport: tr}
-       res, err := client.Get(ts.URL)
+
+       res, err := c.Get(ts.URL)
        if err != nil {
                t.Fatal(err)
        }
@@ -2899,10 +2837,11 @@ func TestTransportRangeAndGzip(t *testing.T) {
                reqc <- r
        }))
        defer ts.Close()
+       c := ts.Client()
 
        req, _ := NewRequest("GET", ts.URL, nil)
        req.Header.Set("Range", "bytes=7-11")
-       res, err := DefaultClient.Do(req)
+       res, err := c.Do(req)
        if err != nil {
                t.Fatal(err)
        }
@@ -2931,9 +2870,7 @@ func TestTransportResponseCancelRace(t *testing.T) {
                w.Write(b[:])
        }))
        defer ts.Close()
-
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
+       tr := ts.Client().Transport.(*Transport)
 
        req, err := NewRequest("GET", ts.URL, nil)
        if err != nil {
@@ -2967,9 +2904,7 @@ func TestTransportDialCancelRace(t *testing.T) {
 
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
        defer ts.Close()
-
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
+       tr := ts.Client().Transport.(*Transport)
 
        req, err := NewRequest("GET", ts.URL, nil)
        if err != nil {
@@ -3096,6 +3031,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) {
                w.WriteHeader(StatusOK)
        }))
        defer ts.Close()
+       c := ts.Client()
 
        fail := 0
        count := 100
@@ -3105,10 +3041,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) {
                if err != nil {
                        t.Fatal(err)
                }
-               tr := new(Transport)
-               defer tr.CloseIdleConnections()
-               client := &Client{Transport: tr}
-               resp, err := client.Do(req)
+               resp, err := c.Do(req)
                if err != nil {
                        fail++
                        t.Logf("%d = %#v", i, err)
@@ -3321,10 +3254,8 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
                w.Write(rgz) // arbitrary gzip response
        }))
        defer ts.Close()
+       c := ts.Client()
 
-       tr := &Transport{}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
        for i := 0; i < 2; i++ {
                res, err := c.Get(ts.URL)
                if err != nil {
@@ -3353,12 +3284,9 @@ func TestTransportResponseHeaderLength(t *testing.T) {
                }
        }))
        defer ts.Close()
+       c := ts.Client()
+       c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
 
-       tr := &Transport{
-               MaxResponseHeaderBytes: 512 << 10,
-       }
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
        if res, err := c.Get(ts.URL); err != nil {
                t.Fatal(err)
        } else {
@@ -3619,8 +3547,8 @@ func TestTransportRejectsAlphaPort(t *testing.T) {
 // connections. The http2 test is done in TestTransportEventTrace_h2
 func TestTLSHandshakeTrace(t *testing.T) {
        defer afterTest(t)
-       s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
-       defer s.Close()
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+       defer ts.Close()
 
        var mu sync.Mutex
        var start, done bool
@@ -3640,10 +3568,8 @@ func TestTLSHandshakeTrace(t *testing.T) {
                },
        }
 
-       tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
-       req, err := NewRequest("GET", s.URL, nil)
+       c := ts.Client()
+       req, err := NewRequest("GET", ts.URL, nil)
        if err != nil {
                t.Fatal("Unable to construct test request:", err)
        }
@@ -3670,16 +3596,14 @@ func TestTransportMaxIdleConns(t *testing.T) {
                // No body for convenience.
        }))
        defer ts.Close()
-       tr := &Transport{
-               MaxIdleConns: 4,
-       }
-       defer tr.CloseIdleConnections()
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
+       tr.MaxIdleConns = 4
 
        ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
        if err != nil {
                t.Fatal(err)
        }
-       c := &Client{Transport: tr}
        ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
                return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
        })
@@ -3975,17 +3899,16 @@ func TestTransportProxyConnectHeader(t *testing.T) {
                c.Close()
        }))
        defer ts.Close()
-       tr := &Transport{
-               ProxyConnectHeader: Header{
-                       "User-Agent": {"foo"},
-                       "Other":      {"bar"},
-               },
-               Proxy: func(r *Request) (*url.URL, error) {
-                       return url.Parse(ts.URL)
-               },
+
+       c := ts.Client()
+       c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+               return url.Parse(ts.URL)
        }
-       defer tr.CloseIdleConnections()
-       c := &Client{Transport: tr}
+       c.Transport.(*Transport).ProxyConnectHeader = Header{
+               "User-Agent": {"foo"},
+               "Other":      {"bar"},
+       }
+
        res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
        if err == nil {
                res.Body.Close()