]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: if context is canceled, return its error
authorIan Lance Taylor <iant@golang.org>
Fri, 15 Jul 2016 22:30:53 +0000 (15:30 -0700)
committerIan Lance Taylor <iant@golang.org>
Tue, 23 Aug 2016 05:31:45 +0000 (05:31 +0000)
This permits the error message to distinguish between a context that was
canceled and a context that timed out.

Updates #16381.

Change-Id: I3994b98e32952abcd7ddb5fee08fa1535999be6d
Reviewed-on: https://go-review.googlesource.com/24978
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/client_test.go
src/net/http/transport.go
src/net/http/transport_test.go

index a9b1948005cb66c1e7a233268603756ee9b6ca5b..f5500b6d887a9f6311becd99e6525bb04080ff8a 100644 (file)
@@ -313,8 +313,8 @@ func TestClientRedirectContext(t *testing.T) {
        if !ok {
                t.Fatalf("got error %T; want *url.Error", err)
        }
-       if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn {
-               t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err)
+       if ue.Err != context.Canceled {
+               t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
        }
 }
 
index 35cee822351b749f8ba77e5fde81b8ad53c197c5..878b925a533347cc4e8bcfcfb7f44ceb86abb224 100644 (file)
@@ -76,7 +76,7 @@ type Transport struct {
        idleLRU    connLRU
 
        reqMu       sync.Mutex
-       reqCanceler map[*Request]func()
+       reqCanceler map[*Request]func(error)
 
        altMu    sync.RWMutex
        altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
@@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() {
 // cancelable context instead. CancelRequest cannot cancel HTTP/2
 // requests.
 func (t *Transport) CancelRequest(req *Request) {
+       t.cancelRequest(req, errRequestCanceled)
+}
+
+// Cancel an in-flight request, recording the error value.
+func (t *Transport) cancelRequest(req *Request, err error) {
        t.reqMu.Lock()
        cancel := t.reqCanceler[req]
        delete(t.reqCanceler, req)
        t.reqMu.Unlock()
        if cancel != nil {
-               cancel()
+               cancel(err)
        }
 }
 
@@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
        }
 }
 
-func (t *Transport) setReqCanceler(r *Request, fn func()) {
+func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
        t.reqMu.Lock()
        defer t.reqMu.Unlock()
        if t.reqCanceler == nil {
-               t.reqCanceler = make(map[*Request]func())
+               t.reqCanceler = make(map[*Request]func(error))
        }
        if fn != nil {
                t.reqCanceler[r] = fn
@@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
 // for the request, we don't set the function and return false.
 // Since CancelRequest will clear the canceler, we can use the return value to detect if
 // the request was canceled since the last setReqCancel call.
-func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
+func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
        t.reqMu.Lock()
        defer t.reqMu.Unlock()
        _, ok := t.reqCanceler[r]
@@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                // set request canceler to some non-nil function so we
                // can detect whether it was cleared between now and when
                // we enter roundTrip
-               t.setReqCanceler(req, func() {})
+               t.setReqCanceler(req, func(error) {})
                return pc, nil
        }
 
@@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                }()
        }
 
-       cancelc := make(chan struct{})
-       t.setReqCanceler(req, func() { close(cancelc) })
+       cancelc := make(chan error, 1)
+       t.setReqCanceler(req, func(err error) { cancelc <- err })
 
        go func() {
                pc, err := t.dialConn(ctx, cm)
@@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                select {
                case <-req.Cancel:
                case <-req.Context().Done():
-               case <-cancelc:
+                       return nil, req.Context().Err()
+               case err := <-cancelc:
+                       if err == errRequestCanceled {
+                               err = errRequestCanceledConn
+                       }
+                       return nil, err
                default:
                        // It wasn't an error due to cancelation, so
                        // return the original error message:
@@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                return nil, errRequestCanceledConn
        case <-req.Context().Done():
                handlePendingDial()
-               return nil, errRequestCanceledConn
-       case <-cancelc:
+               return nil, req.Context().Err()
+       case err := <-cancelc:
                handlePendingDial()
-               return nil, errRequestCanceledConn
+               if err == errRequestCanceled {
+                       err = errRequestCanceledConn
+               }
+               return nil, err
        }
 }
 
@@ -1231,8 +1244,8 @@ type persistConn struct {
        mu                   sync.Mutex // guards following fields
        numExpectedResponses int
        closed               error // set non-nil when conn is closed, before closech is closed
+       canceledErr          error // set non-nil if conn is canceled
        broken               bool  // an error has happened on this connection; marked broken so it's not reused.
-       canceled             bool  // whether this conn was broken due a CancelRequest
        reused               bool  // whether conn has had successful request/response and is being reused.
        // mutateHeaderFunc is an optional func to modify extra
        // headers on each outbound request before it's written. (the
@@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool {
        return b
 }
 
-// isCanceled reports whether this connection was closed due to CancelRequest.
-func (pc *persistConn) isCanceled() bool {
+// canceled returns non-nil if the connection was closed due to
+// CancelRequest or due to context cancelation.
+func (pc *persistConn) canceled() error {
        pc.mu.Lock()
        defer pc.mu.Unlock()
-       return pc.canceled
+       return pc.canceledErr
 }
 
 // isReused reports whether this connection is in a known broken state.
@@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn
        return
 }
 
-func (pc *persistConn) cancelRequest() {
+func (pc *persistConn) cancelRequest(err error) {
        pc.mu.Lock()
        defer pc.mu.Unlock()
-       pc.canceled = true
+       pc.canceledErr = err
        pc.closeLocked(errRequestCanceled)
 }
 
@@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
        if err == nil {
                return nil
        }
-       if pc.isCanceled() {
-               return errRequestCanceled
+       if err := pc.canceled(); err != nil {
+               return err
        }
        if err == errServerClosedIdle {
                return err
@@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
 // its pc.closech channel close, indicating the persistConn is dead.
 // (after closech is closed, pc.closed is valid).
 func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error {
-       if pc.isCanceled() {
-               return errRequestCanceled
+       if err := pc.canceled(); err != nil {
+               return err
        }
        err := pc.closed
        if err == errServerClosedIdle {
@@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() {
                                waitForBodyRead <- isEOF
                                if isEOF {
                                        <-eofc // see comment above eofc declaration
-                               } else if err != nil && pc.isCanceled() {
-                                       return errRequestCanceled
+                               } else if err != nil {
+                                       if cerr := pc.canceled(); cerr != nil {
+                                               return cerr
+                                       }
                                }
                                return err
                        },
@@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() {
                        pc.t.CancelRequest(rc.req)
                case <-rc.req.Context().Done():
                        alive = false
-                       pc.t.CancelRequest(rc.req)
+                       pc.t.cancelRequest(rc.req, rc.req.Context().Err())
                case <-pc.closech:
                        alive = false
                }
@@ -1836,8 +1852,8 @@ WaitResponse:
                select {
                case err := <-writeErrCh:
                        if err != nil {
-                               if pc.isCanceled() {
-                                       err = errRequestCanceled
+                               if cerr := pc.canceled(); cerr != nil {
+                                       err = cerr
                                }
                                re = responseAndError{err: err}
                                pc.close(fmt.Errorf("write error: %v", err))
@@ -1861,9 +1877,8 @@ WaitResponse:
                case <-cancelChan:
                        pc.t.CancelRequest(req.Request)
                        cancelChan = nil
-                       ctxDoneChan = nil
                case <-ctxDoneChan:
-                       pc.t.CancelRequest(req.Request)
+                       pc.t.cancelRequest(req.Request, req.Context().Err())
                        cancelChan = nil
                        ctxDoneChan = nil
                }
index 298682d04de93a4f727063d5dc20ad509dd586fd..daf943e250dbbdbbc3a98a420434476b032be15f 100644 (file)
@@ -1718,8 +1718,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
        }
 
        _, err := c.Do(req)
-       if err == nil || !strings.Contains(err.Error(), "canceled") {
-               t.Errorf("Do error = %v; want cancelation", err)
+       if ue, ok := err.(*url.Error); ok {
+               err = ue.Err
+       }
+       if withCtx {
+               if err != context.Canceled {
+                       t.Errorf("Do error = %v; want %v", err, context.Canceled)
+               }
+       } else {
+               if err == nil || !strings.Contains(err.Error(), "canceled") {
+                       t.Errorf("Do error = %v; want cancelation", err)
+               }
        }
 }