From 0b5f2f0d1149bcff3c6b08458d7ffdd96970235c Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Fri, 15 Jul 2016 15:30:53 -0700 Subject: [PATCH] net/http: if context is canceled, return its error This permits the error message to distinguish between a context that was canceled and a context that timed out. Updates #16381. Change-Id: I3994b98e32952abcd7ddb5fee08fa1535999be6d Reviewed-on: https://go-review.googlesource.com/24978 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/net/http/client_test.go | 4 +- src/net/http/transport.go | 73 ++++++++++++++++++++-------------- src/net/http/transport_test.go | 13 +++++- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index a9b1948005..f5500b6d88 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -313,8 +313,8 @@ func TestClientRedirectContext(t *testing.T) { if !ok { t.Fatalf("got error %T; want *url.Error", err) } - if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn { - t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err) + if ue.Err != context.Canceled { + t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled) } } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 35cee82235..878b925a53 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -76,7 +76,7 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[*Request]func() + reqCanceler map[*Request]func(error) altMu sync.RWMutex altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper @@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() { // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *Request) { + t.cancelRequest(req, errRequestCanceled) +} + +// Cancel an in-flight request, recording the error value. +func (t *Transport) cancelRequest(req *Request, err error) { t.reqMu.Lock() cancel := t.reqCanceler[req] delete(t.reqCanceler, req) t.reqMu.Unlock() if cancel != nil { - cancel() + cancel(err) } } @@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { } } -func (t *Transport) setReqCanceler(r *Request, fn func()) { +func (t *Transport) setReqCanceler(r *Request, fn func(error)) { t.reqMu.Lock() defer t.reqMu.Unlock() if t.reqCanceler == nil { - t.reqCanceler = make(map[*Request]func()) + t.reqCanceler = make(map[*Request]func(error)) } if fn != nil { t.reqCanceler[r] = fn @@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) { // for the request, we don't set the function and return false. // Since CancelRequest will clear the canceler, we can use the return value to detect if // the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { +func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool { t.reqMu.Lock() defer t.reqMu.Unlock() _, ok := t.reqCanceler[r] @@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC // set request canceler to some non-nil function so we // can detect whether it was cleared between now and when // we enter roundTrip - t.setReqCanceler(req, func() {}) + t.setReqCanceler(req, func(error) {}) return pc, nil } @@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC }() } - cancelc := make(chan struct{}) - t.setReqCanceler(req, func() { close(cancelc) }) + cancelc := make(chan error, 1) + t.setReqCanceler(req, func(err error) { cancelc <- err }) go func() { pc, err := t.dialConn(ctx, cm) @@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC select { case <-req.Cancel: case <-req.Context().Done(): - case <-cancelc: + return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err default: // It wasn't an error due to cancelation, so // return the original error message: @@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC return nil, errRequestCanceledConn case <-req.Context().Done(): handlePendingDial() - return nil, errRequestCanceledConn - case <-cancelc: + return nil, req.Context().Err() + case err := <-cancelc: handlePendingDial() - return nil, errRequestCanceledConn + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err } } @@ -1231,8 +1244,8 @@ type persistConn struct { mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled broken bool // an error has happened on this connection; marked broken so it's not reused. - canceled bool // whether this conn was broken due a CancelRequest reused bool // whether conn has had successful request/response and is being reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the @@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool { return b } -// isCanceled reports whether this connection was closed due to CancelRequest. -func (pc *persistConn) isCanceled() bool { +// canceled returns non-nil if the connection was closed due to +// CancelRequest or due to context cancelation. +func (pc *persistConn) canceled() error { pc.mu.Lock() defer pc.mu.Unlock() - return pc.canceled + return pc.canceledErr } // isReused reports whether this connection is in a known broken state. @@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn return } -func (pc *persistConn) cancelRequest() { +func (pc *persistConn) cancelRequest(err error) { pc.mu.Lock() defer pc.mu.Unlock() - pc.canceled = true + pc.canceledErr = err pc.closeLocked(errRequestCanceled) } @@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er if err == nil { return nil } - if pc.isCanceled() { - return errRequestCanceled + if err := pc.canceled(); err != nil { + return err } if err == errServerClosedIdle { return err @@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er // its pc.closech channel close, indicating the persistConn is dead. // (after closech is closed, pc.closed is valid). func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error { - if pc.isCanceled() { - return errRequestCanceled + if err := pc.canceled(); err != nil { + return err } err := pc.closed if err == errServerClosedIdle { @@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- isEOF if isEOF { <-eofc // see comment above eofc declaration - } else if err != nil && pc.isCanceled() { - return errRequestCanceled + } else if err != nil { + if cerr := pc.canceled(); cerr != nil { + return cerr + } } return err }, @@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() { pc.t.CancelRequest(rc.req) case <-rc.req.Context().Done(): alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.req, rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -1836,8 +1852,8 @@ WaitResponse: select { case err := <-writeErrCh: if err != nil { - if pc.isCanceled() { - err = errRequestCanceled + if cerr := pc.canceled(); cerr != nil { + err = cerr } re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) @@ -1861,9 +1877,8 @@ WaitResponse: case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil - ctxDoneChan = nil case <-ctxDoneChan: - pc.t.CancelRequest(req.Request) + pc.t.cancelRequest(req.Request, req.Context().Err()) cancelChan = nil ctxDoneChan = nil } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 298682d04d..daf943e250 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -1718,8 +1718,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { } _, err := c.Do(req) - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancelation", err) + if ue, ok := err.(*url.Error); ok { + err = ue.Err + } + if withCtx { + if err != context.Canceled { + t.Errorf("Do error = %v; want %v", err, context.Canceled) + } + } else { + if err == nil || !strings.Contains(err.Error(), "canceled") { + t.Errorf("Do error = %v; want cancelation", err) + } } } -- 2.48.1