"log"
"net/url"
"strings"
+ "sync"
+ "time"
)
// A Client is an HTTP client. Its zero value (DefaultClient) is a
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar CookieJar
+
+ // Timeout specifies the end-to-end timeout for requests made
+ // via this Client. The timeout includes connection time, any
+ // redirects, and reading the response body. The timeout
+ // remains running once Get, Head, Post, or Do returns and
+ // will interrupt the read of the Response.Body if EOF hasn't
+ // been reached.
+ //
+ // A Timeout of zero means no timeout.
+ //
+ // The Client's Transport must support the CancelRequest
+ // method or Client will return errors when attempting to make
+ // a request with Get, Head, Post, or Do. Client's default
+ // Transport (DefaultTransport) supports CancelRequest.
+ Timeout time.Duration
}
// DefaultClient is the default Client and is used by Get, Head, and Post.
req.AddCookie(cookie)
}
}
- resp, err := send(req, c.Transport)
+ resp, err := send(req, c.transport())
if err != nil {
return nil, err
}
return c.send(req)
}
+func (c *Client) transport() RoundTripper {
+ if c.Transport != nil {
+ return c.Transport
+ }
+ return DefaultTransport
+}
+
// send issues an HTTP request.
// Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil {
- t = DefaultTransport
- if t == nil {
- err = errors.New("http: no Client.Transport or DefaultTransport")
- return
- }
+ return nil, errors.New("http: no Client.Transport or DefaultTransport")
}
if req.URL == nil {
return nil, errors.New("http: nil Request.URL")
}
+ var reqmu sync.Mutex // guards req
req := ireq
+
+ var timer *time.Timer
+ if c.Timeout > 0 {
+ type canceler interface {
+ CancelRequest(*Request)
+ }
+ tr, ok := c.transport().(canceler)
+ if !ok {
+ return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
+ }
+ timer = time.AfterFunc(c.Timeout, func() {
+ reqmu.Lock()
+ defer reqmu.Unlock()
+ tr.CancelRequest(req)
+ })
+ }
+
urlStr := "" // next relative or absolute URL to fetch (after first request)
redirectFailed := false
for redirect := 0; ; redirect++ {
if redirect != 0 {
- req = new(Request)
- req.Method = ireq.Method
+ nreq := new(Request)
+ nreq.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" {
- req.Method = "GET"
+ nreq.Method = "GET"
}
- req.Header = make(Header)
- req.URL, err = base.Parse(urlStr)
+ nreq.Header = make(Header)
+ nreq.URL, err = base.Parse(urlStr)
if err != nil {
break
}
// Add the Referer header.
lastReq := via[len(via)-1]
if lastReq.URL.Scheme != "https" {
- req.Header.Set("Referer", lastReq.URL.String())
+ nreq.Header.Set("Referer", lastReq.URL.String())
}
- err = redirectChecker(req, via)
+ err = redirectChecker(nreq, via)
if err != nil {
redirectFailed = true
break
}
}
+ reqmu.Lock()
+ req = nreq
+ reqmu.Unlock()
}
urlStr = req.URL.String()
via = append(via, req)
continue
}
- return
+ if timer != nil {
+ resp.Body = &cancelTimerBody{timer, resp.Body}
+ }
+ return resp, nil
}
method := ireq.Method
}
return c.doFollowingRedirects(req, shouldRedirectGet)
}
+
+type cancelTimerBody struct {
+ t *time.Timer
+ rc io.ReadCloser
+}
+
+func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
+ n, err = b.rc.Read(p)
+ if err == io.EOF {
+ b.t.Stop()
+ }
+ return
+}
+
+func (b *cancelTimerBody) Close() error {
+ err := b.rc.Close()
+ b.t.Stop()
+ return err
+}
t.Errorf("Invalid auth %q", auth)
}
}
+
+func TestClientTimeout(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer afterTest(t)
+ sawRoot := make(chan bool, 1)
+ sawSlow := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/" {
+ sawRoot <- true
+ Redirect(w, r, "/slow", StatusFound)
+ return
+ }
+ if r.URL.Path == "/slow" {
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ sawSlow <- true
+ time.Sleep(2 * time.Second)
+ return
+ }
+ }))
+ defer ts.Close()
+ const timeout = 500 * time.Millisecond
+ c := &Client{
+ Timeout: timeout,
+ }
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ select {
+ case <-sawRoot:
+ // good.
+ default:
+ t.Fatal("handler never got / request")
+ }
+
+ select {
+ case <-sawSlow:
+ // good.
+ default:
+ t.Fatal("handler never got /slow request")
+ }
+
+ var all []byte
+ errc := make(chan error, 1)
+ go func() {
+ var err error
+ all, err = ioutil.ReadAll(res.Body)
+ errc <- err
+ res.Body.Close()
+ }()
+
+ const failTime = timeout * 2
+ select {
+ case err := <-errc:
+ if err == nil {
+ t.Error("expected error from ReadAll")
+ }
+ t.Logf("Got expected ReadAll error of %v after reading body %q", err, all)
+ case <-time.After(failTime):
+ t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
+ }
+}