]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.CancelRequest
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 27 Feb 2013 23:20:13 +0000 (15:20 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 27 Feb 2013 23:20:13 +0000 (15:20 -0800)
Permits all sorts of custom HTTP timeout policies without
adding a new Transport timeout Duration for each combination
of HTTP phases.

This keeps track internally of which TCP connection a given
Request is on, and lets callers forcefully close the TCP
connection for a given request, without actually getting
the net.Conn directly.

Additionally, a future CL will implement res.Body.Close (Issue
3672) in terms of this.

Update #3362
Update #3672

R=golang-dev, rsc, adg
CC=golang-dev
https://golang.org/cl/7372054

src/pkg/net/http/export_test.go
src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index a7a07852d18b7dac221a694e1be0991292b9f484..a7bca20a07dfa80d2de4a65c0f4760bcff4ed62d 100644 (file)
@@ -16,10 +16,16 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn {
        return newLoggingConn(baseName, c)
 }
 
+func (t *Transport) NumPendingRequestsForTesting() int {
+       t.reqMu.Lock()
+       defer t.reqMu.Unlock()
+       return len(t.reqConn)
+}
+
 func (t *Transport) IdleConnKeysForTesting() (keys []string) {
        keys = make([]string, 0)
-       t.idleLk.Lock()
-       defer t.idleLk.Unlock()
+       t.idleMu.Lock()
+       defer t.idleMu.Unlock()
        if t.idleConn == nil {
                return
        }
@@ -30,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
 }
 
 func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
-       t.idleLk.Lock()
-       defer t.idleLk.Unlock()
+       t.idleMu.Lock()
+       defer t.idleMu.Unlock()
        if t.idleConn == nil {
                return 0
        }
index 984c39154e4b3e23cceec6f1b7a169e0b20adc6e..685d7d56c45afc70a155239f5399407041386909 100644 (file)
@@ -42,9 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2
 // https, and http proxies (for either http or https with CONNECT).
 // Transport can also cache connections for future re-use.
 type Transport struct {
-       idleLk   sync.Mutex
+       idleMu   sync.Mutex
        idleConn map[string][]*persistConn
-       altLk    sync.RWMutex
+       reqMu    sync.Mutex
+       reqConn  map[*Request]*persistConn
+       altMu    sync.RWMutex
        altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
 
        // TODO: tunable on global max cached connections
@@ -139,12 +141,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
                return nil, errors.New("http: nil Request.Header")
        }
        if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
-               t.altLk.RLock()
+               t.altMu.RLock()
                var rt RoundTripper
                if t.altProto != nil {
                        rt = t.altProto[req.URL.Scheme]
                }
-               t.altLk.RUnlock()
+               t.altMu.RUnlock()
                if rt == nil {
                        return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
                }
@@ -181,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
        if scheme == "http" || scheme == "https" {
                panic("protocol " + scheme + " already registered")
        }
-       t.altLk.Lock()
-       defer t.altLk.Unlock()
+       t.altMu.Lock()
+       defer t.altMu.Unlock()
        if t.altProto == nil {
                t.altProto = make(map[string]RoundTripper)
        }
@@ -197,10 +199,10 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
 // a "keep-alive" state. It does not interrupt any connections currently
 // in use.
 func (t *Transport) CloseIdleConnections() {
-       t.idleLk.Lock()
+       t.idleMu.Lock()
        m := t.idleConn
        t.idleConn = nil
-       t.idleLk.Unlock()
+       t.idleMu.Unlock()
        if m == nil {
                return
        }
@@ -211,6 +213,17 @@ func (t *Transport) CloseIdleConnections() {
        }
 }
 
+// CancelRequest cancels an in-flight request by closing its
+// connection.
+func (t *Transport) CancelRequest(req *Request) {
+       t.reqMu.Lock()
+       pc := t.reqConn[req]
+       t.reqMu.Unlock()
+       if pc != nil {
+               pc.conn.Close()
+       }
+}
+
 //
 // Private implementation past this point.
 //
@@ -266,12 +279,12 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
        if max == 0 {
                max = DefaultMaxIdleConnsPerHost
        }
-       t.idleLk.Lock()
+       t.idleMu.Lock()
        if t.idleConn == nil {
                t.idleConn = make(map[string][]*persistConn)
        }
        if len(t.idleConn[key]) >= max {
-               t.idleLk.Unlock()
+               t.idleMu.Unlock()
                pconn.close()
                return false
        }
@@ -281,14 +294,14 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
                }
        }
        t.idleConn[key] = append(t.idleConn[key], pconn)
-       t.idleLk.Unlock()
+       t.idleMu.Unlock()
        return true
 }
 
 func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
        key := cm.String()
-       t.idleLk.Lock()
-       defer t.idleLk.Unlock()
+       t.idleMu.Lock()
+       defer t.idleMu.Unlock()
        if t.idleConn == nil {
                return nil
        }
@@ -313,6 +326,19 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
        panic("unreachable")
 }
 
+func (t *Transport) setReqConn(r *Request, pc *persistConn) {
+       t.reqMu.Lock()
+       defer t.reqMu.Unlock()
+       if t.reqConn == nil {
+               t.reqConn = make(map[*Request]*persistConn)
+       }
+       if pc != nil {
+               t.reqConn[r] = pc
+       } else {
+               delete(t.reqConn, r)
+       }
+}
+
 func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
        if t.Dial != nil {
                return t.Dial(network, addr)
@@ -662,6 +688,8 @@ func (pc *persistConn) readLoop() {
                        alive = <-waitForBodyRead
                }
 
+               pc.t.setReqConn(rc.req, nil)
+
                if !alive {
                        pc.close()
                }
@@ -715,6 +743,7 @@ type writeRequest struct {
 }
 
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
+       pc.t.setReqConn(req.Request, pc)
        pc.lk.Lock()
        pc.numExpectedResponses++
        headerFn := pc.mutateHeaderFunc
@@ -793,6 +822,9 @@ WaitResponse:
        pc.numExpectedResponses--
        pc.lk.Unlock()
 
+       if re.err != nil {
+               pc.t.setReqConn(req.Request, nil)
+       }
        return re.res, re.err
 }
 
index 248e1507a9d09270cddc805bdc9478c5e4188c08..68010e68b33003628d97d90b59b26460fca08d37 100644 (file)
@@ -1118,7 +1118,6 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
        if testing.Short() {
                t.Skip("skipping timeout test in -short mode")
        }
-       const debug = false
        mux := NewServeMux()
        mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
        mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
@@ -1161,6 +1160,60 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
        }
 }
 
+func TestTransportCancelRequest(t *testing.T) {
+       defer checkLeakedTransports(t)
+       if testing.Short() {
+               t.Skip("skipping test in -short mode")
+       }
+       unblockc := make(chan bool)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               fmt.Fprintf(w, "Hello")
+               w.(Flusher).Flush() // send headers and some body
+               <-unblockc
+       }))
+       defer ts.Close()
+       defer close(unblockc)
+
+       tr := &Transport{}
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+
+       req, _ := NewRequest("GET", ts.URL, nil)
+       res, err := c.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       go func() {
+               time.Sleep(1 * time.Second)
+               tr.CancelRequest(req)
+       }()
+       t0 := time.Now()
+       body, err := ioutil.ReadAll(res.Body)
+       d := time.Since(t0)
+
+       if err == nil {
+               t.Error("expected an error reading the body")
+       }
+       if string(body) != "Hello" {
+               t.Errorf("Body = %q; want Hello", body)
+       }
+       if d < 500*time.Millisecond {
+               t.Errorf("expected ~1 second delay; got %v", d)
+       }
+       // Verify no outstanding requests after readLoop/writeLoop
+       // goroutines shut down.
+       for tries := 3; tries > 0; tries-- {
+               n := tr.NumPendingRequestsForTesting()
+               if n == 0 {
+                       break
+               }
+               time.Sleep(100 * time.Millisecond)
+               if tries == 1 {
+                       t.Errorf("pending requests = %d; want 0", n)
+               }
+       }
+}
+
 type fooProto struct{}
 
 func (fooProto) RoundTrip(req *Request) (*Response, error) {