]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Client.Timeout for end-to-end timeouts
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 3 Mar 2014 04:39:20 +0000 (20:39 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 3 Mar 2014 04:39:20 +0000 (20:39 -0800)
Fixes #3362

LGTM=josharian
R=golang-codereviews, josharian
CC=adg, dsymonds, golang-codereviews, n13m3y3r
https://golang.org/cl/70120045

src/pkg/net/http/client.go
src/pkg/net/http/client_test.go

index 22f2e865cf7d69ef3fc0e10f9bd9e621d517c692..952799a1be613fc80206d976b594bf963b6b4c37 100644 (file)
@@ -17,6 +17,8 @@ import (
        "log"
        "net/url"
        "strings"
+       "sync"
+       "time"
 )
 
 // A Client is an HTTP client. Its zero value (DefaultClient) is a
@@ -52,6 +54,21 @@ type Client struct {
        // If Jar is nil, cookies are not sent in requests and ignored
        // in responses.
        Jar CookieJar
+
+       // Timeout specifies the end-to-end timeout for requests made
+       // via this Client. The timeout includes connection time, any
+       // redirects, and reading the response body. The timeout
+       // remains running once Get, Head, Post, or Do returns and
+       // will interrupt the read of the Response.Body if EOF hasn't
+       // been reached.
+       //
+       // A Timeout of zero means no timeout.
+       //
+       // The Client's Transport must support the CancelRequest
+       // method or Client will return errors when attempting to make
+       // a request with Get, Head, Post, or Do. Client's default
+       // Transport (DefaultTransport) supports CancelRequest.
+       Timeout time.Duration
 }
 
 // DefaultClient is the default Client and is used by Get, Head, and Post.
@@ -97,7 +114,7 @@ func (c *Client) send(req *Request) (*Response, error) {
                        req.AddCookie(cookie)
                }
        }
-       resp, err := send(req, c.Transport)
+       resp, err := send(req, c.transport())
        if err != nil {
                return nil, err
        }
@@ -134,15 +151,18 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
        return c.send(req)
 }
 
+func (c *Client) transport() RoundTripper {
+       if c.Transport != nil {
+               return c.Transport
+       }
+       return DefaultTransport
+}
+
 // send issues an HTTP request.
 // Caller should close resp.Body when done reading from it.
 func send(req *Request, t RoundTripper) (resp *Response, err error) {
        if t == nil {
-               t = DefaultTransport
-               if t == nil {
-                       err = errors.New("http: no Client.Transport or DefaultTransport")
-                       return
-               }
+               return nil, errors.New("http: no Client.Transport or DefaultTransport")
        }
 
        if req.URL == nil {
@@ -260,18 +280,36 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
                return nil, errors.New("http: nil Request.URL")
        }
 
+       var reqmu sync.Mutex // guards req
        req := ireq
+
+       var timer *time.Timer
+       if c.Timeout > 0 {
+               type canceler interface {
+                       CancelRequest(*Request)
+               }
+               tr, ok := c.transport().(canceler)
+               if !ok {
+                       return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
+               }
+               timer = time.AfterFunc(c.Timeout, func() {
+                       reqmu.Lock()
+                       defer reqmu.Unlock()
+                       tr.CancelRequest(req)
+               })
+       }
+
        urlStr := "" // next relative or absolute URL to fetch (after first request)
        redirectFailed := false
        for redirect := 0; ; redirect++ {
                if redirect != 0 {
-                       req = new(Request)
-                       req.Method = ireq.Method
+                       nreq := new(Request)
+                       nreq.Method = ireq.Method
                        if ireq.Method == "POST" || ireq.Method == "PUT" {
-                               req.Method = "GET"
+                               nreq.Method = "GET"
                        }
-                       req.Header = make(Header)
-                       req.URL, err = base.Parse(urlStr)
+                       nreq.Header = make(Header)
+                       nreq.URL, err = base.Parse(urlStr)
                        if err != nil {
                                break
                        }
@@ -279,15 +317,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
                                // Add the Referer header.
                                lastReq := via[len(via)-1]
                                if lastReq.URL.Scheme != "https" {
-                                       req.Header.Set("Referer", lastReq.URL.String())
+                                       nreq.Header.Set("Referer", lastReq.URL.String())
                                }
 
-                               err = redirectChecker(req, via)
+                               err = redirectChecker(nreq, via)
                                if err != nil {
                                        redirectFailed = true
                                        break
                                }
                        }
+                       reqmu.Lock()
+                       req = nreq
+                       reqmu.Unlock()
                }
 
                urlStr = req.URL.String()
@@ -305,7 +346,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
                        via = append(via, req)
                        continue
                }
-               return
+               if timer != nil {
+                       resp.Body = &cancelTimerBody{timer, resp.Body}
+               }
+               return resp, nil
        }
 
        method := ireq.Method
@@ -408,3 +452,22 @@ func (c *Client) Head(url string) (resp *Response, err error) {
        }
        return c.doFollowingRedirects(req, shouldRedirectGet)
 }
+
+type cancelTimerBody struct {
+       t  *time.Timer
+       rc io.ReadCloser
+}
+
+func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
+       n, err = b.rc.Read(p)
+       if err == io.EOF {
+               b.t.Stop()
+       }
+       return
+}
+
+func (b *cancelTimerBody) Close() error {
+       err := b.rc.Close()
+       b.t.Stop()
+       return err
+}
index b81af1a479feede344953fb0da09121785c706bd..f44fb199dceccc4231508042d8fec926c48abccf 100644 (file)
@@ -812,3 +812,70 @@ func TestBasicAuth(t *testing.T) {
                t.Errorf("Invalid auth %q", auth)
        }
 }
+
+func TestClientTimeout(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+       defer afterTest(t)
+       sawRoot := make(chan bool, 1)
+       sawSlow := make(chan bool, 1)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.URL.Path == "/" {
+                       sawRoot <- true
+                       Redirect(w, r, "/slow", StatusFound)
+                       return
+               }
+               if r.URL.Path == "/slow" {
+                       w.Write([]byte("Hello"))
+                       w.(Flusher).Flush()
+                       sawSlow <- true
+                       time.Sleep(2 * time.Second)
+                       return
+               }
+       }))
+       defer ts.Close()
+       const timeout = 500 * time.Millisecond
+       c := &Client{
+               Timeout: timeout,
+       }
+
+       res, err := c.Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       select {
+       case <-sawRoot:
+               // good.
+       default:
+               t.Fatal("handler never got / request")
+       }
+
+       select {
+       case <-sawSlow:
+               // good.
+       default:
+               t.Fatal("handler never got /slow request")
+       }
+
+       var all []byte
+       errc := make(chan error, 1)
+       go func() {
+               var err error
+               all, err = ioutil.ReadAll(res.Body)
+               errc <- err
+               res.Body.Close()
+       }()
+
+       const failTime = timeout * 2
+       select {
+       case err := <-errc:
+               if err == nil {
+                       t.Error("expected error from ReadAll")
+               }
+               t.Logf("Got expected ReadAll error of %v after reading body %q", err, all)
+       case <-time.After(failTime):
+               t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
+       }
+}