]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix cancelation of requests with a readTrackingBody wrapper
authorDamien Neil <dneil@google.com>
Tue, 28 Jul 2020 19:49:52 +0000 (12:49 -0700)
committerDamien Neil <dneil@google.com>
Tue, 4 Aug 2020 19:27:13 +0000 (19:27 +0000)
Use the original *Request in the reqCanceler map, not the transient
wrapper created to handle body rewinding.

Change the key of reqCanceler to a struct{*Request}, to make it more
difficult to accidentally use the wrong request as the key.

Fixes #40453.

Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757
Reviewed-on: https://go-review.googlesource.com/c/go/+/245357
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
src/net/http/transport.go
src/net/http/transport_test.go

index a41e732d983da433c1c7c103d06e115eb8b3fc0a..d37b52b13d06006619d341692680bf13cc1a4d52 100644 (file)
@@ -100,7 +100,7 @@ type Transport struct {
        idleLRU      connLRU
 
        reqMu       sync.Mutex
-       reqCanceler map[*Request]func(error)
+       reqCanceler map[cancelKey]func(error)
 
        altMu    sync.Mutex   // guards changing altProto only
        altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@@ -273,6 +273,13 @@ type Transport struct {
        ForceAttemptHTTP2 bool
 }
 
+// A cancelKey is the key of the reqCanceler map.
+// We wrap the *Request in this type since we want to use the original request,
+// not any transient one created by roundTrip.
+type cancelKey struct {
+       req *Request
+}
+
 func (t *Transport) writeBufferSize() int {
        if t.WriteBufferSize > 0 {
                return t.WriteBufferSize
@@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
 // optional extra headers to write and stores any error to return
 // from roundTrip.
 type transportRequest struct {
-       *Request                        // original request, not to be mutated
-       extra    Header                 // extra headers to write, or nil
-       trace    *httptrace.ClientTrace // optional
+       *Request                         // original request, not to be mutated
+       extra     Header                 // extra headers to write, or nil
+       trace     *httptrace.ClientTrace // optional
+       cancelKey cancelKey
 
        mu  sync.Mutex // guards err
        err error      // first setError value for mapRoundTripError to consider
@@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
        }
 
        origReq := req
+       cancelKey := cancelKey{origReq}
        req = setupRewindBody(req)
 
        if altRT := t.alternateRoundTripper(req); altRT != nil {
@@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                }
 
                // treq gets modified by roundTrip, so we need to recreate for each retry.
-               treq := &transportRequest{Request: req, trace: trace}
+               treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
                cm, err := t.connectMethodForRequest(treq)
                if err != nil {
                        req.closeBody()
@@ -559,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                // to send it requests.
                pconn, err := t.getConn(treq, cm)
                if err != nil {
-                       t.setReqCanceler(req, nil)
+                       t.setReqCanceler(cancelKey, nil)
                        req.closeBody()
                        return nil, err
                }
@@ -567,7 +576,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                var resp *Response
                if pconn.alt != nil {
                        // HTTP/2 path.
-                       t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+                       t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
                        resp, err = pconn.alt.RoundTrip(req)
                } else {
                        resp, err = pconn.roundTrip(treq)
@@ -753,14 +762,14 @@ func (t *Transport) CloseIdleConnections() {
 // cancelable context instead. CancelRequest cannot cancel HTTP/2
 // requests.
 func (t *Transport) CancelRequest(req *Request) {
-       t.cancelRequest(req, errRequestCanceled)
+       t.cancelRequest(cancelKey{req}, errRequestCanceled)
 }
 
 // Cancel an in-flight request, recording the error value.
-func (t *Transport) cancelRequest(req *Request, err error) {
+func (t *Transport) cancelRequest(key cancelKey, err error) {
        t.reqMu.Lock()
-       cancel := t.reqCanceler[req]
-       delete(t.reqCanceler, req)
+       cancel := t.reqCanceler[key]
+       delete(t.reqCanceler, key)
        t.reqMu.Unlock()
        if cancel != nil {
                cancel(err)
@@ -1093,16 +1102,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
        return removed
 }
 
-func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
+func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
        t.reqMu.Lock()
        defer t.reqMu.Unlock()
        if t.reqCanceler == nil {
-               t.reqCanceler = make(map[*Request]func(error))
+               t.reqCanceler = make(map[cancelKey]func(error))
        }
        if fn != nil {
-               t.reqCanceler[r] = fn
+               t.reqCanceler[key] = fn
        } else {
-               delete(t.reqCanceler, r)
+               delete(t.reqCanceler, key)
        }
 }
 
@@ -1110,17 +1119,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
 // for the request, we don't set the function and return false.
 // Since CancelRequest will clear the canceler, we can use the return value to detect if
 // the request was canceled since the last setReqCancel call.
-func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
+func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
        t.reqMu.Lock()
        defer t.reqMu.Unlock()
-       _, ok := t.reqCanceler[r]
+       _, ok := t.reqCanceler[key]
        if !ok {
                return false
        }
        if fn != nil {
-               t.reqCanceler[r] = fn
+               t.reqCanceler[key] = fn
        } else {
-               delete(t.reqCanceler, r)
+               delete(t.reqCanceler, key)
        }
        return true
 }
@@ -1324,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
                // set request canceler to some non-nil function so we
                // can detect whether it was cleared between now and when
                // we enter roundTrip
-               t.setReqCanceler(req, func(error) {})
+               t.setReqCanceler(treq.cancelKey, func(error) {})
                return pc, nil
        }
 
        cancelc := make(chan error, 1)
-       t.setReqCanceler(req, func(err error) { cancelc <- err })
+       t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
 
        // Queue for permission to dial.
        t.queueForDial(w)
@@ -2078,7 +2087,7 @@ func (pc *persistConn) readLoop() {
                }
 
                if !hasBody || bodyWritable {
-                       pc.t.setReqCanceler(rc.req, nil)
+                       pc.t.setReqCanceler(rc.cancelKey, nil)
 
                        // Put the idle conn back into the pool before we send the response
                        // so if they process it quickly and make another request, they'll
@@ -2151,7 +2160,7 @@ func (pc *persistConn) readLoop() {
                // reading the response body. (or for cancellation or death)
                select {
                case bodyEOF := <-waitForBodyRead:
-                       pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
+                       pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
                        alive = alive &&
                                bodyEOF &&
                                !pc.sawEOF &&
@@ -2165,7 +2174,7 @@ func (pc *persistConn) readLoop() {
                        pc.t.CancelRequest(rc.req)
                case <-rc.req.Context().Done():
                        alive = false
-                       pc.t.cancelRequest(rc.req, rc.req.Context().Err())
+                       pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
                case <-pc.closech:
                        alive = false
                }
@@ -2408,9 +2417,10 @@ type responseAndError struct {
 }
 
 type requestAndChan struct {
-       _   incomparable
-       req *Request
-       ch  chan responseAndError // unbuffered; always send in select on callerGone
+       _         incomparable
+       req       *Request
+       cancelKey cancelKey
+       ch        chan responseAndError // unbuffered; always send in select on callerGone
 
        // whether the Transport (as opposed to the user client code)
        // added the Accept-Encoding gzip header. If the Transport
@@ -2472,7 +2482,7 @@ var (
 
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
        testHookEnterRoundTrip()
-       if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
+       if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
                pc.t.putOrCloseIdleConn(pc)
                return nil, errRequestCanceled
        }
@@ -2524,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
 
        defer func() {
                if err != nil {
-                       pc.t.setReqCanceler(req.Request, nil)
+                       pc.t.setReqCanceler(req.cancelKey, nil)
                }
        }()
 
@@ -2540,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        resc := make(chan responseAndError)
        pc.reqch <- requestAndChan{
                req:        req.Request,
+               cancelKey:  req.cancelKey,
                ch:         resc,
                addedGzip:  requestedGzip,
                continueCh: continueCh,
@@ -2591,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
                        }
                        return re.res, nil
                case <-cancelChan:
-                       pc.t.CancelRequest(req.Request)
+                       pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
                        cancelChan = nil
                case <-ctxDoneChan:
-                       pc.t.cancelRequest(req.Request, req.Context().Err())
+                       pc.t.cancelRequest(req.cancelKey, req.Context().Err())
                        cancelChan = nil
                        ctxDoneChan = nil
                }
index 31a41f535115236817110fcd0bf6374cddcb3772..0a47687d9afbacac2ac7411c2e91e41b8b0915a9 100644 (file)
@@ -2364,6 +2364,50 @@ func TestTransportCancelRequest(t *testing.T) {
        }
 }
 
+func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
+       setParallel(t)
+       defer afterTest(t)
+       if testing.Short() {
+               t.Skip("skipping test in -short mode")
+       }
+       unblockc := make(chan bool)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               <-unblockc
+       }))
+       defer ts.Close()
+       defer close(unblockc)
+
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
+
+       donec := make(chan bool)
+       req, _ := NewRequest("GET", ts.URL, body)
+       go func() {
+               defer close(donec)
+               c.Do(req)
+       }()
+       start := time.Now()
+       timeout := 10 * time.Second
+       for time.Since(start) < timeout {
+               time.Sleep(100 * time.Millisecond)
+               tr.CancelRequest(req)
+               select {
+               case <-donec:
+                       return
+               default:
+               }
+       }
+       t.Errorf("Do of canceled request has not returned after %v", timeout)
+}
+
+func TestTransportCancelRequestInDo(t *testing.T) {
+       testTransportCancelRequestInDo(t, nil)
+}
+
+func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
+       testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
+}
+
 func TestTransportCancelRequestInDial(t *testing.T) {
        defer afterTest(t)
        if testing.Short() {