}
}
+// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
+// for the request, we don't set the function and return false.
+// Since CancelRequest will clear the canceler, we can use the return value to detect if
+// the request was canceled since the last setReqCancel call.
+func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ _, ok := t.reqCanceler[r]
+ if !ok {
+ return false
+ }
+ if fn != nil {
+ t.reqCanceler[r] = fn
+ } else {
+ delete(t.reqCanceler, r)
+ }
+ return true
+}
+
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
if t.Dial != nil {
return t.Dial(network, addr)
// is ready to write requests to.
func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
if pc := t.getIdleConn(cm); pc != nil {
+ // set request canceler to some non-nil function so we
+ // can detect whether it was cleared between now and when
+ // we enter roundTrip
+ t.setReqCanceler(req, func() {})
return pc, nil
}
var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled")
-var testHookPersistConnClosedGotRes func() // nil except for tests
+// nil except for tests
+var (
+ testHookPersistConnClosedGotRes func()
+ testHookEnterRoundTrip func()
+)
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
- pc.t.setReqCanceler(req.Request, pc.cancelRequest)
+ if hook := testHookEnterRoundTrip; hook != nil {
+ hook()
+ }
+ if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
+ pc.t.putIdleConn(pc)
+ return nil, errRequestCanceled
+ }
pc.lk.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
res.Body.Close()
}
+func TestTransportDialCancelRace(t *testing.T) {
+ defer afterTest(t)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ SetEnterRoundTripHook(func() {
+ tr.CancelRequest(req)
+ })
+ defer SetEnterRoundTripHook(nil)
+ res, err := tr.RoundTrip(req)
+ if err != ExportErrRequestCanceled {
+ t.Errorf("expected canceled request error; got %v", err)
+ if err == nil {
+ res.Body.Close()
+ }
+ }
+}
+
func wantBody(res *http.Response, err error, want string) error {
if err != nil {
return err