From: Brad Fitzpatrick Date: Mon, 3 Mar 2014 04:39:20 +0000 (-0800) Subject: net/http: add Client.Timeout for end-to-end timeouts X-Git-Tag: go1.3beta1~514 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=2ad72ecf341d553fe9b698343f9bae6d26344619;p=gostls13.git net/http: add Client.Timeout for end-to-end timeouts Fixes #3362 LGTM=josharian R=golang-codereviews, josharian CC=adg, dsymonds, golang-codereviews, n13m3y3r https://golang.org/cl/70120045 --- diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 22f2e865cf..952799a1be 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -17,6 +17,8 @@ import ( "log" "net/url" "strings" + "sync" + "time" ) // A Client is an HTTP client. Its zero value (DefaultClient) is a @@ -52,6 +54,21 @@ type Client struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar CookieJar + + // Timeout specifies the end-to-end timeout for requests made + // via this Client. The timeout includes connection time, any + // redirects, and reading the response body. The timeout + // remains running once Get, Head, Post, or Do returns and + // will interrupt the read of the Response.Body if EOF hasn't + // been reached. + // + // A Timeout of zero means no timeout. + // + // The Client's Transport must support the CancelRequest + // method or Client will return errors when attempting to make + // a request with Get, Head, Post, or Do. Client's default + // Transport (DefaultTransport) supports CancelRequest. + Timeout time.Duration } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -97,7 +114,7 @@ func (c *Client) send(req *Request) (*Response, error) { req.AddCookie(cookie) } } - resp, err := send(req, c.Transport) + resp, err := send(req, c.transport()) if err != nil { return nil, err } @@ -134,15 +151,18 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { return c.send(req) } +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + // send issues an HTTP request. // Caller should close resp.Body when done reading from it. func send(req *Request, t RoundTripper) (resp *Response, err error) { if t == nil { - t = DefaultTransport - if t == nil { - err = errors.New("http: no Client.Transport or DefaultTransport") - return - } + return nil, errors.New("http: no Client.Transport or DefaultTransport") } if req.URL == nil { @@ -260,18 +280,36 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo return nil, errors.New("http: nil Request.URL") } + var reqmu sync.Mutex // guards req req := ireq + + var timer *time.Timer + if c.Timeout > 0 { + type canceler interface { + CancelRequest(*Request) + } + tr, ok := c.transport().(canceler) + if !ok { + return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport()) + } + timer = time.AfterFunc(c.Timeout, func() { + reqmu.Lock() + defer reqmu.Unlock() + tr.CancelRequest(req) + }) + } + urlStr := "" // next relative or absolute URL to fetch (after first request) redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { - req = new(Request) - req.Method = ireq.Method + nreq := new(Request) + nreq.Method = ireq.Method if ireq.Method == "POST" || ireq.Method == "PUT" { - req.Method = "GET" + nreq.Method = "GET" } - req.Header = make(Header) - req.URL, err = base.Parse(urlStr) + nreq.Header = make(Header) + nreq.URL, err = base.Parse(urlStr) if err != nil { break } @@ -279,15 +317,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo // Add the Referer header. lastReq := via[len(via)-1] if lastReq.URL.Scheme != "https" { - req.Header.Set("Referer", lastReq.URL.String()) + nreq.Header.Set("Referer", lastReq.URL.String()) } - err = redirectChecker(req, via) + err = redirectChecker(nreq, via) if err != nil { redirectFailed = true break } } + reqmu.Lock() + req = nreq + reqmu.Unlock() } urlStr = req.URL.String() @@ -305,7 +346,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo via = append(via, req) continue } - return + if timer != nil { + resp.Body = &cancelTimerBody{timer, resp.Body} + } + return resp, nil } method := ireq.Method @@ -408,3 +452,22 @@ func (c *Client) Head(url string) (resp *Response, err error) { } return c.doFollowingRedirects(req, shouldRedirectGet) } + +type cancelTimerBody struct { + t *time.Timer + rc io.ReadCloser +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == io.EOF { + b.t.Stop() + } + return +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.t.Stop() + return err +} diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index b81af1a479..f44fb199dc 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -812,3 +812,70 @@ func TestBasicAuth(t *testing.T) { t.Errorf("Invalid auth %q", auth) } } + +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer afterTest(t) + sawRoot := make(chan bool, 1) + sawSlow := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/" { + sawRoot <- true + Redirect(w, r, "/slow", StatusFound) + return + } + if r.URL.Path == "/slow" { + w.Write([]byte("Hello")) + w.(Flusher).Flush() + sawSlow <- true + time.Sleep(2 * time.Second) + return + } + })) + defer ts.Close() + const timeout = 500 * time.Millisecond + c := &Client{ + Timeout: timeout, + } + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + select { + case <-sawRoot: + // good. + default: + t.Fatal("handler never got / request") + } + + select { + case <-sawSlow: + // good. + default: + t.Fatal("handler never got /slow request") + } + + var all []byte + errc := make(chan error, 1) + go func() { + var err error + all, err = ioutil.ReadAll(res.Body) + errc <- err + res.Body.Close() + }() + + const failTime = timeout * 2 + select { + case err := <-errc: + if err == nil { + t.Error("expected error from ReadAll") + } + t.Logf("Got expected ReadAll error of %v after reading body %q", err, all) + case <-time.After(failTime): + t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + } +}