From 11776a39a118742b61510a3ef3bfa2a80e4b3005 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 27 Feb 2013 15:20:13 -0800 Subject: [PATCH] net/http: add Transport.CancelRequest Permits all sorts of custom HTTP timeout policies without adding a new Transport timeout Duration for each combination of HTTP phases. This keeps track internally of which TCP connection a given Request is on, and lets callers forcefully close the TCP connection for a given request, without actually getting the net.Conn directly. Additionally, a future CL will implement res.Body.Close (Issue 3672) in terms of this. Update #3362 Update #3672 R=golang-dev, rsc, adg CC=golang-dev https://golang.org/cl/7372054 --- src/pkg/net/http/export_test.go | 14 +++++--- src/pkg/net/http/transport.go | 58 +++++++++++++++++++++++------- src/pkg/net/http/transport_test.go | 55 +++++++++++++++++++++++++++- 3 files changed, 109 insertions(+), 18 deletions(-) diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index a7a07852d1..a7bca20a07 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -16,10 +16,16 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn { return newLoggingConn(baseName, c) } +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqConn) +} + func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return } @@ -30,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { } func (t *Transport) IdleConnCountForTesting(cacheKey string) int { - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return 0 } diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 984c39154e..685d7d56c4 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -42,9 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleLk sync.Mutex + idleMu sync.Mutex idleConn map[string][]*persistConn - altLk sync.RWMutex + reqMu sync.Mutex + reqConn map[*Request]*persistConn + altMu sync.RWMutex altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // TODO: tunable on global max cached connections @@ -139,12 +141,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.altLk.RLock() + t.altMu.RLock() var rt RoundTripper if t.altProto != nil { rt = t.altProto[req.URL.Scheme] } - t.altLk.RUnlock() + t.altMu.RUnlock() if rt == nil { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } @@ -181,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { if scheme == "http" || scheme == "https" { panic("protocol " + scheme + " already registered") } - t.altLk.Lock() - defer t.altLk.Unlock() + t.altMu.Lock() + defer t.altMu.Unlock() if t.altProto == nil { t.altProto = make(map[string]RoundTripper) } @@ -197,10 +199,10 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.idleLk.Lock() + t.idleMu.Lock() m := t.idleConn t.idleConn = nil - t.idleLk.Unlock() + t.idleMu.Unlock() if m == nil { return } @@ -211,6 +213,17 @@ func (t *Transport) CloseIdleConnections() { } } +// CancelRequest cancels an in-flight request by closing its +// connection. +func (t *Transport) CancelRequest(req *Request) { + t.reqMu.Lock() + pc := t.reqConn[req] + t.reqMu.Unlock() + if pc != nil { + pc.conn.Close() + } +} + // // Private implementation past this point. // @@ -266,12 +279,12 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { if max == 0 { max = DefaultMaxIdleConnsPerHost } - t.idleLk.Lock() + t.idleMu.Lock() if t.idleConn == nil { t.idleConn = make(map[string][]*persistConn) } if len(t.idleConn[key]) >= max { - t.idleLk.Unlock() + t.idleMu.Unlock() pconn.close() return false } @@ -281,14 +294,14 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } } t.idleConn[key] = append(t.idleConn[key], pconn) - t.idleLk.Unlock() + t.idleMu.Unlock() return true } func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { key := cm.String() - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return nil } @@ -313,6 +326,19 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { panic("unreachable") } +func (t *Transport) setReqConn(r *Request, pc *persistConn) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqConn == nil { + t.reqConn = make(map[*Request]*persistConn) + } + if pc != nil { + t.reqConn[r] = pc + } else { + delete(t.reqConn, r) + } +} + func (t *Transport) dial(network, addr string) (c net.Conn, err error) { if t.Dial != nil { return t.Dial(network, addr) @@ -662,6 +688,8 @@ func (pc *persistConn) readLoop() { alive = <-waitForBodyRead } + pc.t.setReqConn(rc.req, nil) + if !alive { pc.close() } @@ -715,6 +743,7 @@ type writeRequest struct { } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { + pc.t.setReqConn(req.Request, pc) pc.lk.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc @@ -793,6 +822,9 @@ WaitResponse: pc.numExpectedResponses-- pc.lk.Unlock() + if re.err != nil { + pc.t.setReqConn(req.Request, nil) + } return re.res, re.err } diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 248e1507a9..68010e68b3 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -1118,7 +1118,6 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } - const debug = false mux := NewServeMux() mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { @@ -1161,6 +1160,60 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } } +func TestTransportCancelRequest(t *testing.T) { + defer checkLeakedTransports(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := ioutil.ReadAll(res.Body) + d := time.Since(t0) + + if err == nil { + t.Error("expected an error reading the body") + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 3; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { -- 2.48.1