]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't wrap request cancellation errors in timeouts
authorEmmanuel Odeke <emm.odeke@gmail.com>
Tue, 1 Nov 2016 07:44:48 +0000 (00:44 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 10 Nov 2016 20:42:55 +0000 (20:42 +0000)
Based on Filippo Valsorda's https://golang.org/cl/24230

Fixes #16094

Change-Id: Ie39b0834e220f0a0f4fbfb3bbb271e70837718c3
Reviewed-on: https://go-review.googlesource.com/32478
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/client.go
src/net/http/client_test.go

index 1af33af93778621eec3f2aef65eee9928248b2e1..9e7c15fe86728d7f0bc4f70d1fadc76351aa5418 100644 (file)
@@ -218,7 +218,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
        if !deadline.IsZero() {
                forkReq()
        }
-       stopTimer, wasCanceled := setRequestCancel(req, rt, deadline)
+       stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
 
        resp, err := rt.RoundTrip(req)
        if err != nil {
@@ -238,9 +238,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
        }
        if !deadline.IsZero() {
                resp.Body = &cancelTimerBody{
-                       stop:           stopTimer,
-                       rc:             resp.Body,
-                       reqWasCanceled: wasCanceled,
+                       stop:          stopTimer,
+                       rc:            resp.Body,
+                       reqDidTimeout: didTimeout,
                }
        }
        return resp, nil
@@ -249,7 +249,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
 // setRequestCancel sets the Cancel field of req, if deadline is
 // non-zero. The RoundTripper's type is used to determine whether the legacy
 // CancelRequest behavior should be used.
-func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) {
+func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) {
        if deadline.IsZero() {
                return nop, alwaysFalse
        }
@@ -259,15 +259,6 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
        cancel := make(chan struct{})
        req.Cancel = cancel
 
-       wasCanceled = func() bool {
-               select {
-               case <-cancel:
-                       return true
-               default:
-                       return false
-               }
-       }
-
        doCancel := func() {
                // The new way:
                close(cancel)
@@ -292,19 +283,22 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
        stopTimer = func() { once.Do(func() { close(stopTimerCh) }) }
 
        timer := time.NewTimer(time.Until(deadline))
+       var timedOut atomicBool
+
        go func() {
                select {
                case <-initialReqCancel:
                        doCancel()
                        timer.Stop()
                case <-timer.C:
+                       timedOut.setTrue()
                        doCancel()
                case <-stopTimerCh:
                        timer.Stop()
                }
        }()
 
-       return stopTimer, wasCanceled
+       return stopTimer, timedOut.isSet
 }
 
 // See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
@@ -728,12 +722,12 @@ func (c *Client) Head(url string) (resp *Response, err error) {
 
 // cancelTimerBody is an io.ReadCloser that wraps rc with two features:
 // 1) on Read error or close, the stop func is called.
-// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and
+// 2) On Read failure, if reqDidTimeout is true, the error is wrapped and
 //    marked as net.Error that hit its timeout.
 type cancelTimerBody struct {
-       stop           func() // stops the time.Timer waiting to cancel the request
-       rc             io.ReadCloser
-       reqWasCanceled func() bool
+       stop          func() // stops the time.Timer waiting to cancel the request
+       rc            io.ReadCloser
+       reqDidTimeout func() bool
 }
 
 func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
@@ -745,7 +739,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
        if err == io.EOF {
                return n, err
        }
-       if b.reqWasCanceled() {
+       if b.reqDidTimeout() {
                err = &httpError{
                        err:     err.Error() + " (Client.Timeout exceeded while reading body)",
                        timeout: true,
index 2fe6e2164fcf3bdb41150e270908440cd7754528..dc6d339264366864ea7e0df752813d26994509f0 100644 (file)
@@ -1185,10 +1185,12 @@ func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
 func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
 
 func testClientTimeout(t *testing.T, h2 bool) {
-       if testing.Short() {
-               t.Skip("skipping in short mode")
-       }
+       setParallel(t)
        defer afterTest(t)
+       testDone := make(chan struct{})
+
+       const timeout = 100 * time.Millisecond
+
        sawRoot := make(chan bool, 1)
        sawSlow := make(chan bool, 1)
        cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1201,16 +1203,22 @@ func testClientTimeout(t *testing.T, h2 bool) {
                        w.Write([]byte("Hello"))
                        w.(Flusher).Flush()
                        sawSlow <- true
-                       time.Sleep(2 * time.Second)
+                       select {
+                       case <-testDone:
+                       case <-time.After(timeout * 10):
+                       }
                        return
                }
        }))
        defer cst.close()
-       const timeout = 500 * time.Millisecond
+       defer close(testDone)
        cst.c.Timeout = timeout
 
        res, err := cst.c.Get(cst.ts.URL)
        if err != nil {
+               if strings.Contains(err.Error(), "Client.Timeout") {
+                       t.Skip("host too slow to get fast resource in 100ms")
+               }
                t.Fatal(err)
        }
 
@@ -1260,11 +1268,9 @@ func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h
 
 // Client.Timeout firing before getting to the body
 func testClientTimeout_Headers(t *testing.T, h2 bool) {
-       if testing.Short() {
-               t.Skip("skipping in short mode")
-       }
+       setParallel(t)
        defer afterTest(t)
-       donec := make(chan bool)
+       donec := make(chan bool, 1)
        cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
                <-donec
        }))
@@ -1278,9 +1284,10 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
        // doesn't know this, so synchronize explicitly.
        defer func() { donec <- true }()
 
-       cst.c.Timeout = 500 * time.Millisecond
-       _, err := cst.c.Get(cst.ts.URL)
+       cst.c.Timeout = 5 * time.Millisecond
+       res, err := cst.c.Get(cst.ts.URL)
        if err == nil {
+               res.Body.Close()
                t.Fatal("got response from Get; expected error")
        }
        if _, ok := err.(*url.Error); !ok {
@@ -1298,6 +1305,36 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
        }
 }
 
+// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be
+// returned.
+func TestClientTimeoutCancel(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+
+       testDone := make(chan struct{})
+       ctx, cancel := context.WithCancel(context.Background())
+
+       cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.(Flusher).Flush()
+               <-testDone
+       }))
+       defer cst.close()
+       defer close(testDone)
+
+       cst.c.Timeout = 1 * time.Hour
+       req, _ := NewRequest("GET", cst.ts.URL, nil)
+       req.Cancel = ctx.Done()
+       res, err := cst.c.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       cancel()
+       _, err = io.Copy(ioutil.Discard, res.Body)
+       if err != ExportErrRequestCanceled {
+               t.Fatal("error = %v; want errRequestCanceled")
+       }
+}
+
 func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
 func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
 func testClientRedirectEatsBody(t *testing.T, h2 bool) {
@@ -1317,10 +1354,10 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) {
                t.Fatal(err)
        }
        _, err = ioutil.ReadAll(res.Body)
+       res.Body.Close()
        if err != nil {
                t.Fatal(err)
        }
-       res.Body.Close()
 
        var first string
        select {