From: Brad Fitzpatrick Date: Thu, 27 Feb 2014 21:32:40 +0000 (-0800) Subject: net/http: make Transport.CancelRequest work for requests blocked in a dial X-Git-Tag: go1.3beta1~541 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=dc6bf295b95f3b1141e81fea3e128d22e4282962;p=gostls13.git net/http: make Transport.CancelRequest work for requests blocked in a dial Fixes #6951 LGTM=josharian R=golang-codereviews, josharian CC=golang-codereviews https://golang.org/cl/69280043 --- diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index 8074df5bbd..960563b240 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -21,7 +21,7 @@ var ExportAppendTime = appendTime func (t *Transport) NumPendingRequestsForTesting() int { t.reqMu.Lock() defer t.reqMu.Unlock() - return len(t.reqConn) + return len(t.reqCanceler) } func (t *Transport) IdleConnKeysForTesting() (keys []string) { diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 1a7b459fe1..9eb40a3e24 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -47,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleMu sync.Mutex - idleConn map[connectMethodKey][]*persistConn - idleConnCh map[connectMethodKey]chan *persistConn - reqMu sync.Mutex - reqConn map[*Request]*persistConn - altMu sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + idleMu sync.Mutex + idleConn map[connectMethodKey][]*persistConn + idleConnCh map[connectMethodKey]chan *persistConn + reqMu sync.Mutex + reqCanceler map[*Request]func() + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -190,8 +190,9 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { // host (for http or https), the http proxy, or the http proxy // pre-CONNECTed to https server. In any case, we'll be ready // to send it requests. - pconn, err := t.getConn(cm) + pconn, err := t.getConn(req, cm) if err != nil { + t.setReqCanceler(req, nil) return nil, err } @@ -243,10 +244,10 @@ func (t *Transport) CloseIdleConnections() { // connection. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() - pc := t.reqConn[req] + cancel := t.reqCanceler[req] t.reqMu.Unlock() - if pc != nil { - pc.conn.Close() + if cancel != nil { + cancel() } } @@ -417,16 +418,16 @@ func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) { } } -func (t *Transport) setReqConn(r *Request, pc *persistConn) { +func (t *Transport) setReqCanceler(r *Request, fn func()) { t.reqMu.Lock() defer t.reqMu.Unlock() - if t.reqConn == nil { - t.reqConn = make(map[*Request]*persistConn) + if t.reqCanceler == nil { + t.reqCanceler = make(map[*Request]func()) } - if pc != nil { - t.reqConn[r] = pc + if fn != nil { + t.reqCanceler[r] = fn } else { - delete(t.reqConn, r) + delete(t.reqCanceler, r) } } @@ -441,7 +442,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) { // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn // is ready to write requests to. -func (t *Transport) getConn(cm connectMethod) (*persistConn, error) { +func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) { if pc := t.getIdleConn(cm); pc != nil { return pc, nil } @@ -451,6 +452,16 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) { err error } dialc := make(chan dialRes) + + handlePendingDial := func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + } + + cancelc := make(chan struct{}) + t.setReqCanceler(req, func() { close(cancelc) }) + go func() { pc, err := t.dialConn(cm) dialc <- dialRes{pc, err} @@ -467,12 +478,11 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) { // else's dial that they didn't use. // But our dial is still going, so give it away // when it finishes: - go func() { - if v := <-dialc; v.err == nil { - t.putIdleConn(v.pc) - } - }() + go handlePendingDial() return pc, nil + case <-cancelc: + go handlePendingDial() + return nil, errors.New("net/http: request canceled while waiting for connection") } } @@ -732,6 +742,10 @@ func (pc *persistConn) isBroken() bool { return b } +func (pc *persistConn) cancelRequest() { + pc.conn.Close() +} + var remoteSideClosedFunc func(error) bool // or nil to use default func remoteSideClosed(err error) bool { @@ -843,7 +857,7 @@ func (pc *persistConn) readLoop() { alive = <-waitForBodyRead } - pc.t.setReqConn(rc.req, nil) + pc.t.setReqCanceler(rc.req, nil) if !alive { pc.close() @@ -910,7 +924,7 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head var errClosed error = &httpError{err: "net/http: transport closed before response was received"} func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - pc.t.setReqConn(req.Request, pc) + pc.t.setReqCanceler(req.Request, pc.cancelRequest) pc.lk.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc @@ -995,7 +1009,7 @@ WaitResponse: pc.lk.Unlock() if re.err != nil { - pc.t.setReqConn(req.Request, nil) + pc.t.setReqCanceler(req.Request, nil) } return re.res, re.err } diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 510679e53b..1d73633ea4 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -11,9 +11,11 @@ import ( "bytes" "compress/gzip" "crypto/rand" + "errors" "fmt" "io" "io/ioutil" + "log" "net" "net/http" . "net/http" @@ -1321,6 +1323,61 @@ func TestTransportCancelRequest(t *testing.T) { } } +func TestTransportCancelRequestInDial(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + var logbuf bytes.Buffer + eventLog := log.New(&logbuf, "", 0) + + unblockDial := make(chan bool) + defer close(unblockDial) + + inDial := make(chan bool) + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + eventLog.Println("dial: blocking") + inDial <- true + <-unblockDial + return nil, errors.New("nope") + }, + } + cl := &Client{Transport: tr} + gotres := make(chan bool) + req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + go func() { + _, err := cl.Do(req) + eventLog.Printf("Get = %v", err) + gotres <- true + }() + + select { + case <-inDial: + case <-time.After(5 * time.Second): + t.Fatal("timeout; never saw blocking dial") + } + + eventLog.Printf("canceling") + tr.CancelRequest(req) + + select { + case <-gotres: + case <-time.After(5 * time.Second): + panic("hang. events are: " + logbuf.String()) + t.Fatal("timeout; cancel didn't work?") + } + + got := logbuf.String() + want := `dial: blocking +canceling +Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection +` + if got != want { + t.Errorf("Got events:\n%s\nWant:\n%s", got, want) + } +} + // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection.