]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: simplify HTTP/1 request cancelation
authorDamien Neil <dneil@google.com>
Sat, 2 Dec 2023 00:26:14 +0000 (16:26 -0800)
committerDamien Neil <dneil@google.com>
Thu, 16 May 2024 15:57:17 +0000 (15:57 +0000)
HTTP requests have three separate user cancelation signals:

Transport.CancelRequest
Request.Cancel
Request.Context()

In addition, a request can be canceled due to errors.

The Transport keeps a map of all in-flight requests,
with an associated func to run if CancelRequest is
called. Confusingly, this func is *not* run if
Request.Cancel is closed or the request context expires.

The map of in-flight requests is also used to communicate
between roundTrip and readLoop. In particular, if readLoop
reads a response immediately followed by an EOF, it may
send racing signals to roundTrip: The connection has
closed, but also there is a response available.
This race is resolved by readLoop communicating through
the request map that this request has successfully
completed.

This CL refactors all of this.

In-flight requests now have a context which is canceled
when any of the above cancelation events occurs.

The map of requests to cancel funcs remains, but is
used strictly for implementing Transport.CancelRequest.
It is not used to communicate information about the
state of a request.

Change-Id: Ie157edc0ce35f719866a0a2cb0e70514fd119ff8
Reviewed-on: https://go-review.googlesource.com/c/go/+/546676
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
src/net/http/transport.go
src/net/http/transport_internal_test.go

index e6a97a00c63392c2c58f73c56aa7bfbf0db06a31..f7a7092ef749fe7d6afdeaa3bc1ee809a320c851 100644 (file)
@@ -100,7 +100,7 @@ type Transport struct {
        idleLRU      connLRU
 
        reqMu       sync.Mutex
-       reqCanceler map[cancelKey]func(error)
+       reqCanceler map[*Request]context.CancelCauseFunc
 
        altMu    sync.Mutex   // guards changing altProto only
        altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@@ -294,13 +294,6 @@ type Transport struct {
        ForceAttemptHTTP2 bool
 }
 
-// A cancelKey is the key of the reqCanceler map.
-// We wrap the *Request in this type since we want to use the original request,
-// not any transient one created by roundTrip.
-type cancelKey struct {
-       req *Request
-}
-
 func (t *Transport) writeBufferSize() int {
        if t.WriteBufferSize > 0 {
                return t.WriteBufferSize
@@ -466,10 +459,12 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
 // optional extra headers to write and stores any error to return
 // from roundTrip.
 type transportRequest struct {
-       *Request                         // original request, not to be mutated
-       extra     Header                 // extra headers to write, or nil
-       trace     *httptrace.ClientTrace // optional
-       cancelKey cancelKey
+       *Request                        // original request, not to be mutated
+       extra    Header                 // extra headers to write, or nil
+       trace    *httptrace.ClientTrace // optional
+
+       ctx    context.Context // canceled when we are done with the request
+       cancel context.CancelCauseFunc
 
        mu  sync.Mutex // guards err
        err error      // first setError value for mapRoundTripError to consider
@@ -531,7 +526,7 @@ func validateHeaders(hdrs Header) string {
 }
 
 // roundTrip implements a RoundTripper over HTTP.
-func (t *Transport) roundTrip(req *Request) (*Response, error) {
+func (t *Transport) roundTrip(req *Request) (_ *Response, err error) {
        t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
        ctx := req.Context()
        trace := httptrace.ContextClientTrace(ctx)
@@ -561,7 +556,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
        }
 
        origReq := req
-       cancelKey := cancelKey{origReq}
        req = setupRewindBody(req)
 
        if altRT := t.alternateRoundTripper(req); altRT != nil {
@@ -587,16 +581,44 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                return nil, errors.New("http: no Host in request URL")
        }
 
+       // Transport request context.
+       //
+       // If RoundTrip returns an error, it cancels this context before returning.
+       //
+       // If RoundTrip returns no error:
+       //   - For an HTTP/1 request, persistConn.readLoop cancels this context
+       //     after reading the request body.
+       //   - For an HTTP/2 request, RoundTrip cancels this context after the HTTP/2
+       //     RoundTripper returns.
+       ctx, cancel := context.WithCancelCause(req.Context())
+
+       // Convert Request.Cancel into context cancelation.
+       if origReq.Cancel != nil {
+               go awaitLegacyCancel(ctx, cancel, origReq)
+       }
+
+       // Convert Transport.CancelRequest into context cancelation.
+       //
+       // This is lamentably expensive. CancelRequest has been deprecated for a long time
+       // and doesn't work on HTTP/2 requests. Perhaps we should drop support for it entirely.
+       cancel = t.prepareTransportCancel(origReq, cancel)
+
+       defer func() {
+               if err != nil {
+                       cancel(err)
+               }
+       }()
+
        for {
                select {
                case <-ctx.Done():
                        req.closeBody()
-                       return nil, ctx.Err()
+                       return nil, context.Cause(ctx)
                default:
                }
 
                // treq gets modified by roundTrip, so we need to recreate for each retry.
-               treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
+               treq := &transportRequest{Request: req, trace: trace, ctx: ctx, cancel: cancel}
                cm, err := t.connectMethodForRequest(treq)
                if err != nil {
                        req.closeBody()
@@ -609,7 +631,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                // to send it requests.
                pconn, err := t.getConn(treq, cm)
                if err != nil {
-                       t.setReqCanceler(cancelKey, nil)
                        req.closeBody()
                        return nil, err
                }
@@ -617,12 +638,19 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                var resp *Response
                if pconn.alt != nil {
                        // HTTP/2 path.
-                       t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
                        resp, err = pconn.alt.RoundTrip(req)
                } else {
                        resp, err = pconn.roundTrip(treq)
                }
                if err == nil {
+                       if pconn.alt != nil {
+                               // HTTP/2 requests are not cancelable with CancelRequest,
+                               // so we have no further need for the request context.
+                               //
+                               // On the HTTP/1 path, roundTrip takes responsibility for
+                               // canceling the context after the response body is read.
+                               cancel(errRequestDone)
+                       }
                        resp.Request = origReq
                        return resp, nil
                }
@@ -659,6 +687,14 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
        }
 }
 
+func awaitLegacyCancel(ctx context.Context, cancel context.CancelCauseFunc, req *Request) {
+       select {
+       case <-req.Cancel:
+               cancel(errRequestCanceled)
+       case <-ctx.Done():
+       }
+}
+
 var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
 
 type readTrackingBody struct {
@@ -820,30 +856,42 @@ func (t *Transport) CloseIdleConnections() {
        }
 }
 
+// prepareTransportCancel sets up state to convert Transport.CancelRequest into context cancelation.
+func (t *Transport) prepareTransportCancel(req *Request, origCancel context.CancelCauseFunc) context.CancelCauseFunc {
+       // Historically, RoundTrip has not modified the Request in any way.
+       // We could avoid the need to keep a map of all in-flight requests by adding
+       // a field to the Request containing its cancel func, and setting that field
+       // while the request is in-flight. Callers aren't supposed to reuse a Request
+       // until after the response body is closed, so this wouldn't violate any
+       // concurrency guarantees.
+       cancel := func(err error) {
+               origCancel(err)
+               t.reqMu.Lock()
+               delete(t.reqCanceler, req)
+               t.reqMu.Unlock()
+       }
+       t.reqMu.Lock()
+       if t.reqCanceler == nil {
+               t.reqCanceler = make(map[*Request]context.CancelCauseFunc)
+       }
+       t.reqCanceler[req] = cancel
+       t.reqMu.Unlock()
+       return cancel
+}
+
 // CancelRequest cancels an in-flight request by closing its connection.
 // CancelRequest should only be called after [Transport.RoundTrip] has returned.
 //
 // Deprecated: Use [Request.WithContext] to create a request with a
 // cancelable context instead. CancelRequest cannot cancel HTTP/2
-// requests.
+// requests. This may become a no-op in a future release of Go.
 func (t *Transport) CancelRequest(req *Request) {
-       t.cancelRequest(cancelKey{req}, errRequestCanceled)
-}
-
-// Cancel an in-flight request, recording the error value.
-// Returns whether the request was canceled.
-func (t *Transport) cancelRequest(key cancelKey, err error) bool {
-       // This function must not return until the cancel func has completed.
-       // See: https://golang.org/issue/34658
        t.reqMu.Lock()
-       defer t.reqMu.Unlock()
-       cancel := t.reqCanceler[key]
-       delete(t.reqCanceler, key)
+       cancel := t.reqCanceler[req]
+       t.reqMu.Unlock()
        if cancel != nil {
-               cancel(err)
+               cancel(errRequestCanceled)
        }
-
-       return cancel != nil
 }
 
 //
@@ -1170,38 +1218,6 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
        return removed
 }
 
-func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
-       t.reqMu.Lock()
-       defer t.reqMu.Unlock()
-       if t.reqCanceler == nil {
-               t.reqCanceler = make(map[cancelKey]func(error))
-       }
-       if fn != nil {
-               t.reqCanceler[key] = fn
-       } else {
-               delete(t.reqCanceler, key)
-       }
-}
-
-// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
-// for the request, we don't set the function and return false.
-// Since CancelRequest will clear the canceler, we can use the return value to detect if
-// the request was canceled since the last setReqCancel call.
-func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
-       t.reqMu.Lock()
-       defer t.reqMu.Unlock()
-       _, ok := t.reqCanceler[key]
-       if !ok {
-               return false
-       }
-       if fn != nil {
-               t.reqCanceler[key] = fn
-       } else {
-               delete(t.reqCanceler, key)
-       }
-       return true
-}
-
 var zeroDialer net.Dialer
 
 func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -1442,19 +1458,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
                }
        }()
 
-       var cancelc chan error
-
        // Queue for idle connection.
-       if delivered := t.queueForIdleConn(w); delivered {
-               // set request canceler to some non-nil function so we
-               // can detect whether it was cleared between now and when
-               // we enter roundTrip
-               t.setReqCanceler(treq.cancelKey, func(error) {})
-       } else {
-               cancelc = make(chan error, 1)
-               t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
-
-               // Queue for permission to dial.
+       if delivered := t.queueForIdleConn(w); !delivered {
                t.queueForDial(w)
        }
 
@@ -1479,11 +1484,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
                        // what caused r.err; if so, prefer to return the
                        // cancellation error (see golang.org/issue/16049).
                        select {
-                       case <-req.Cancel:
-                               return nil, errRequestCanceledConn
-                       case <-req.Context().Done():
-                               return nil, req.Context().Err()
-                       case err := <-cancelc:
+                       case <-treq.ctx.Done():
+                               err := context.Cause(treq.ctx)
                                if err == errRequestCanceled {
                                        err = errRequestCanceledConn
                                }
@@ -1493,11 +1495,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
                        }
                }
                return r.pc, r.err
-       case <-req.Cancel:
-               return nil, errRequestCanceledConn
-       case <-req.Context().Done():
-               return nil, req.Context().Err()
-       case err := <-cancelc:
+       case <-treq.ctx.Done():
+               err := context.Cause(treq.ctx)
                if err == errRequestCanceled {
                        err = errRequestCanceledConn
                }
@@ -2173,7 +2172,8 @@ func (pc *persistConn) readLoop() {
                pc.t.removeIdleConn(pc)
        }()
 
-       tryPutIdleConn := func(trace *httptrace.ClientTrace) bool {
+       tryPutIdleConn := func(treq *transportRequest) bool {
+               trace := treq.trace
                if err := pc.t.tryPutIdleConn(pc); err != nil {
                        closeErr = err
                        if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled {
@@ -2212,7 +2212,7 @@ func (pc *persistConn) readLoop() {
                pc.mu.Unlock()
 
                rc := <-pc.reqch
-               trace := httptrace.ContextClientTrace(rc.req.Context())
+               trace := rc.treq.trace
 
                var resp *Response
                if err == nil {
@@ -2241,9 +2241,9 @@ func (pc *persistConn) readLoop() {
                pc.mu.Unlock()
 
                bodyWritable := resp.bodyIsWritable()
-               hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
+               hasBody := rc.treq.Request.Method != "HEAD" && resp.ContentLength != 0
 
-               if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
+               if resp.Close || rc.treq.Request.Close || resp.StatusCode <= 199 || bodyWritable {
                        // Don't do keep-alive on error if either party requested a close
                        // or we get an unexpected informational (1xx) response.
                        // StatusCode 100 is already handled above.
@@ -2251,8 +2251,6 @@ func (pc *persistConn) readLoop() {
                }
 
                if !hasBody || bodyWritable {
-                       replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
-
                        // Put the idle conn back into the pool before we send the response
                        // so if they process it quickly and make another request, they'll
                        // get this same conn. But we use the unbuffered channel 'rc'
@@ -2261,7 +2259,7 @@ func (pc *persistConn) readLoop() {
                        alive = alive &&
                                !pc.sawEOF &&
                                pc.wroteRequest() &&
-                               replaced && tryPutIdleConn(trace)
+                               tryPutIdleConn(rc.treq)
 
                        if bodyWritable {
                                closeErr = errCallerOwnsConn
@@ -2273,6 +2271,8 @@ func (pc *persistConn) readLoop() {
                                return
                        }
 
+                       rc.treq.cancel(errRequestDone)
+
                        // Now that they've read from the unbuffered channel, they're safely
                        // out of the select that also waits on this goroutine to die, so
                        // we're allowed to exit now if needed (if alive is false)
@@ -2323,26 +2323,22 @@ func (pc *persistConn) readLoop() {
                // reading the response body. (or for cancellation or death)
                select {
                case bodyEOF := <-waitForBodyRead:
-                       replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
                        alive = alive &&
                                bodyEOF &&
                                !pc.sawEOF &&
                                pc.wroteRequest() &&
-                               replaced && tryPutIdleConn(trace)
+                               tryPutIdleConn(rc.treq)
                        if bodyEOF {
                                eofc <- struct{}{}
                        }
-               case <-rc.req.Cancel:
-                       alive = false
-                       pc.t.cancelRequest(rc.cancelKey, errRequestCanceled)
-               case <-rc.req.Context().Done():
+               case <-rc.treq.ctx.Done():
                        alive = false
-                       pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
+                       pc.cancelRequest(errRequestCanceled)
                case <-pc.closech:
                        alive = false
-                       pc.t.setReqCanceler(rc.cancelKey, nil)
                }
 
+               rc.treq.cancel(errRequestDone)
                testHookReadLoopBeforeNextRead()
        }
 }
@@ -2395,7 +2391,7 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
 
        continueCh := rc.continueCh
        for {
-               resp, err = ReadResponse(pc.br, rc.req)
+               resp, err = ReadResponse(pc.br, rc.treq.Request)
                if err != nil {
                        return
                }
@@ -2587,10 +2583,9 @@ type responseAndError struct {
 }
 
 type requestAndChan struct {
-       _         incomparable
-       req       *Request
-       cancelKey cancelKey
-       ch        chan responseAndError // unbuffered; always send in select on callerGone
+       _    incomparable
+       treq *transportRequest
+       ch   chan responseAndError // unbuffered; always send in select on callerGone
 
        // whether the Transport (as opposed to the user client code)
        // added the Accept-Encoding gzip header. If the Transport
@@ -2638,6 +2633,10 @@ var errTimeout error = &timeoutError{"net/http: timeout awaiting response header
 var errRequestCanceled = http2errRequestCanceled
 var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
 
+// errRequestDone is used to cancel the round trip Context after a request is successfully done.
+// It should not be seen by the user.
+var errRequestDone = errors.New("net/http: request completed")
+
 func nop() {}
 
 // testHooks. Always non-nil.
@@ -2654,10 +2653,6 @@ var (
 
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
        testHookEnterRoundTrip()
-       if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
-               pc.t.putOrCloseIdleConn(pc)
-               return nil, errRequestCanceled
-       }
        pc.mu.Lock()
        pc.numExpectedResponses++
        headerFn := pc.mutateHeaderFunc
@@ -2706,12 +2701,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        gone := make(chan struct{})
        defer close(gone)
 
-       defer func() {
-               if err != nil {
-                       pc.t.setReqCanceler(req.cancelKey, nil)
-               }
-       }()
-
        const debugRoundTrip = false
 
        // Write the request concurrently with waiting for a response,
@@ -2723,19 +2712,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
 
        resc := make(chan responseAndError)
        pc.reqch <- requestAndChan{
-               req:        req.Request,
-               cancelKey:  req.cancelKey,
+               treq:       req,
                ch:         resc,
                addedGzip:  requestedGzip,
                continueCh: continueCh,
                callerGone: gone,
        }
 
+       handleResponse := func(re responseAndError) (*Response, error) {
+               if (re.res == nil) == (re.err == nil) {
+                       panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
+               }
+               if debugRoundTrip {
+                       req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
+               }
+               if re.err != nil {
+                       return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
+               }
+               return re.res, nil
+       }
+
        var respHeaderTimer <-chan time.Time
-       cancelChan := req.Request.Cancel
-       ctxDoneChan := req.Context().Done()
+       ctxDoneChan := req.ctx.Done()
        pcClosed := pc.closech
-       canceled := false
        for {
                testHookWaitResLoop()
                select {
@@ -2756,13 +2755,18 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
                                respHeaderTimer = timer.C
                        }
                case <-pcClosed:
-                       pcClosed = nil
-                       if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
-                               if debugRoundTrip {
-                                       req.logf("closech recv: %T %#v", pc.closed, pc.closed)
-                               }
-                               return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
+                       select {
+                       case re := <-resc:
+                               // The pconn closing raced with the response to the request,
+                               // probably after the server wrote a response and immediately
+                               // closed the connection. Use the response.
+                               return handleResponse(re)
+                       default:
                        }
+                       if debugRoundTrip {
+                               req.logf("closech recv: %T %#v", pc.closed, pc.closed)
+                       }
+                       return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
                case <-respHeaderTimer:
                        if debugRoundTrip {
                                req.logf("timeout waiting for response headers.")
@@ -2770,23 +2774,17 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
                        pc.close(errTimeout)
                        return nil, errTimeout
                case re := <-resc:
-                       if (re.res == nil) == (re.err == nil) {
-                               panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
-                       }
-                       if debugRoundTrip {
-                               req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
-                       }
-                       if re.err != nil {
-                               return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
-                       }
-                       return re.res, nil
-               case <-cancelChan:
-                       canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
-                       cancelChan = nil
+                       return handleResponse(re)
                case <-ctxDoneChan:
-                       canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
-                       cancelChan = nil
-                       ctxDoneChan = nil
+                       select {
+                       case re := <-resc:
+                               // readLoop is responsible for canceling req.ctx after
+                               // it reads the response body. Check for a response racing
+                               // the context close, and use the response if available.
+                               return handleResponse(re)
+                       default:
+                       }
+                       pc.cancelRequest(context.Cause(req.ctx))
                }
        }
 }
index dc3259fadf90b67ee24bcdf2ccda813afe5d2953..f86970b2480f946cfc8c9bd81f90aa6e7ef3fc9a 100644 (file)
@@ -8,6 +8,7 @@ package http
 
 import (
        "bytes"
+       "context"
        "crypto/tls"
        "errors"
        "io"
@@ -36,7 +37,8 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
        tr := new(Transport)
        req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
        req = req.WithT(t)
-       treq := &transportRequest{Request: req}
+       ctx, cancel := context.WithCancelCause(context.Background())
+       treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
        cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
        pc, err := tr.getConn(treq, cm)
        if err != nil {