]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix race between dialing and canceling
authorDaniel Morsing <daniel.morsing@gmail.com>
Mon, 20 Apr 2015 22:02:07 +0000 (23:02 +0100)
committerDaniel Morsing <daniel.morsing@gmail.com>
Wed, 22 Apr 2015 12:23:55 +0000 (12:23 +0000)
In the brief window between getConn and persistConn.roundTrip,
a cancel could end up going missing.

Fix by making it possible to inspect if a cancel function was cleared
and checking if we were canceled before entering roundTrip.

Fixes #10511

Change-Id: If6513e63fbc2edb703e36d6356ccc95a1dc33144
Reviewed-on: https://go-review.googlesource.com/9181
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/export_test.go
src/net/http/transport.go
src/net/http/transport_test.go

index 69757bdca6518a6f649fd03753d975df3d700ae8..b656aa973184bf284d27a937d359ee17cb3a6f24 100644 (file)
@@ -82,6 +82,10 @@ func SetInstallConnClosedHook(f func()) {
        testHookPersistConnClosedGotRes = f
 }
 
+func SetEnterRoundTripHook(f func()) {
+       testHookEnterRoundTrip = f
+}
+
 func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
        f := func() <-chan time.Time {
                return ch
index b754472be697f1cac4222bd88e621b61b92d98bf..e31ae93e2a8d52316349aaf379b61b716640d160 100644 (file)
@@ -475,6 +475,25 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
        }
 }
 
+// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
+// for the request, we don't set the function and return false.
+// Since CancelRequest will clear the canceler, we can use the return value to detect if
+// the request was canceled since the last setReqCancel call.
+func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
+       t.reqMu.Lock()
+       defer t.reqMu.Unlock()
+       _, ok := t.reqCanceler[r]
+       if !ok {
+               return false
+       }
+       if fn != nil {
+               t.reqCanceler[r] = fn
+       } else {
+               delete(t.reqCanceler, r)
+       }
+       return true
+}
+
 func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
        if t.Dial != nil {
                return t.Dial(network, addr)
@@ -491,6 +510,10 @@ var prePendingDial, postPendingDial func()
 // is ready to write requests to.
 func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
        if pc := t.getIdleConn(cm); pc != nil {
+               // set request canceler to some non-nil function so we
+               // can detect whether it was cleared between now and when
+               // we enter roundTrip
+               t.setReqCanceler(req, func() {})
                return pc, nil
        }
 
@@ -1063,10 +1086,20 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head
 var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
 var errRequestCanceled = errors.New("net/http: request canceled")
 
-var testHookPersistConnClosedGotRes func() // nil except for tests
+// nil except for tests
+var (
+       testHookPersistConnClosedGotRes func()
+       testHookEnterRoundTrip          func()
+)
 
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
-       pc.t.setReqCanceler(req.Request, pc.cancelRequest)
+       if hook := testHookEnterRoundTrip; hook != nil {
+               hook()
+       }
+       if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
+               pc.t.putIdleConn(pc)
+               return nil, errRequestCanceled
+       }
        pc.lk.Lock()
        pc.numExpectedResponses++
        headerFn := pc.mutateHeaderFunc
index 2d52f17721d84f4ab1971f73e3873db930948c77..d20ba13208b4c53004cf6516f23b7f6c37efef5a 100644 (file)
@@ -2400,6 +2400,32 @@ func TestTransportResponseCancelRace(t *testing.T) {
        res.Body.Close()
 }
 
+func TestTransportDialCancelRace(t *testing.T) {
+       defer afterTest(t)
+
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+       defer ts.Close()
+
+       tr := &Transport{}
+       defer tr.CloseIdleConnections()
+
+       req, err := NewRequest("GET", ts.URL, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       SetEnterRoundTripHook(func() {
+               tr.CancelRequest(req)
+       })
+       defer SetEnterRoundTripHook(nil)
+       res, err := tr.RoundTrip(req)
+       if err != ExportErrRequestCanceled {
+               t.Errorf("expected canceled request error; got %v", err)
+               if err == nil {
+                       res.Body.Close()
+               }
+       }
+}
+
 func wantBody(res *http.Response, err error, want string) error {
        if err != nil {
                return err