]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make Transport.CancelRequest work for requests blocked in a dial
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 27 Feb 2014 21:32:40 +0000 (13:32 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 27 Feb 2014 21:32:40 +0000 (13:32 -0800)
Fixes #6951

LGTM=josharian
R=golang-codereviews, josharian
CC=golang-codereviews
https://golang.org/cl/69280043

src/pkg/net/http/export_test.go
src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index 8074df5bbde17f3ffad29dd1f1fe51be3c772d18..960563b240945d9547cb4d4c6e87d9a09708b6ff 100644 (file)
@@ -21,7 +21,7 @@ var ExportAppendTime = appendTime
 func (t *Transport) NumPendingRequestsForTesting() int {
        t.reqMu.Lock()
        defer t.reqMu.Unlock()
-       return len(t.reqConn)
+       return len(t.reqCanceler)
 }
 
 func (t *Transport) IdleConnKeysForTesting() (keys []string) {
index 1a7b459fe1bbc18050dba434f567d3a4c080ee43..9eb40a3e24ae566e5594b8ccafc29c6bd6a060af 100644 (file)
@@ -47,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2
 // https, and http proxies (for either http or https with CONNECT).
 // Transport can also cache connections for future re-use.
 type Transport struct {
-       idleMu     sync.Mutex
-       idleConn   map[connectMethodKey][]*persistConn
-       idleConnCh map[connectMethodKey]chan *persistConn
-       reqMu      sync.Mutex
-       reqConn    map[*Request]*persistConn
-       altMu      sync.RWMutex
-       altProto   map[string]RoundTripper // nil or map of URI scheme => RoundTripper
+       idleMu      sync.Mutex
+       idleConn    map[connectMethodKey][]*persistConn
+       idleConnCh  map[connectMethodKey]chan *persistConn
+       reqMu       sync.Mutex
+       reqCanceler map[*Request]func()
+       altMu       sync.RWMutex
+       altProto    map[string]RoundTripper // nil or map of URI scheme => RoundTripper
 
        // Proxy specifies a function to return a proxy for a given
        // Request. If the function returns a non-nil error, the
@@ -190,8 +190,9 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
        // host (for http or https), the http proxy, or the http proxy
        // pre-CONNECTed to https server.  In any case, we'll be ready
        // to send it requests.
-       pconn, err := t.getConn(cm)
+       pconn, err := t.getConn(req, cm)
        if err != nil {
+               t.setReqCanceler(req, nil)
                return nil, err
        }
 
@@ -243,10 +244,10 @@ func (t *Transport) CloseIdleConnections() {
 // connection.
 func (t *Transport) CancelRequest(req *Request) {
        t.reqMu.Lock()
-       pc := t.reqConn[req]
+       cancel := t.reqCanceler[req]
        t.reqMu.Unlock()
-       if pc != nil {
-               pc.conn.Close()
+       if cancel != nil {
+               cancel()
        }
 }
 
@@ -417,16 +418,16 @@ func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) {
        }
 }
 
-func (t *Transport) setReqConn(r *Request, pc *persistConn) {
+func (t *Transport) setReqCanceler(r *Request, fn func()) {
        t.reqMu.Lock()
        defer t.reqMu.Unlock()
-       if t.reqConn == nil {
-               t.reqConn = make(map[*Request]*persistConn)
+       if t.reqCanceler == nil {
+               t.reqCanceler = make(map[*Request]func())
        }
-       if pc != nil {
-               t.reqConn[r] = pc
+       if fn != nil {
+               t.reqCanceler[r] = fn
        } else {
-               delete(t.reqConn, r)
+               delete(t.reqCanceler, r)
        }
 }
 
@@ -441,7 +442,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
 // specified in the connectMethod.  This includes doing a proxy CONNECT
 // and/or setting up TLS.  If this doesn't return an error, the persistConn
 // is ready to write requests to.
-func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
+func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
        if pc := t.getIdleConn(cm); pc != nil {
                return pc, nil
        }
@@ -451,6 +452,16 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
                err error
        }
        dialc := make(chan dialRes)
+
+       handlePendingDial := func() {
+               if v := <-dialc; v.err == nil {
+                       t.putIdleConn(v.pc)
+               }
+       }
+
+       cancelc := make(chan struct{})
+       t.setReqCanceler(req, func() { close(cancelc) })
+
        go func() {
                pc, err := t.dialConn(cm)
                dialc <- dialRes{pc, err}
@@ -467,12 +478,11 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
                // else's dial that they didn't use.
                // But our dial is still going, so give it away
                // when it finishes:
-               go func() {
-                       if v := <-dialc; v.err == nil {
-                               t.putIdleConn(v.pc)
-                       }
-               }()
+               go handlePendingDial()
                return pc, nil
+       case <-cancelc:
+               go handlePendingDial()
+               return nil, errors.New("net/http: request canceled while waiting for connection")
        }
 }
 
@@ -732,6 +742,10 @@ func (pc *persistConn) isBroken() bool {
        return b
 }
 
+func (pc *persistConn) cancelRequest() {
+       pc.conn.Close()
+}
+
 var remoteSideClosedFunc func(error) bool // or nil to use default
 
 func remoteSideClosed(err error) bool {
@@ -843,7 +857,7 @@ func (pc *persistConn) readLoop() {
                        alive = <-waitForBodyRead
                }
 
-               pc.t.setReqConn(rc.req, nil)
+               pc.t.setReqCanceler(rc.req, nil)
 
                if !alive {
                        pc.close()
@@ -910,7 +924,7 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head
 var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
 
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
-       pc.t.setReqConn(req.Request, pc)
+       pc.t.setReqCanceler(req.Request, pc.cancelRequest)
        pc.lk.Lock()
        pc.numExpectedResponses++
        headerFn := pc.mutateHeaderFunc
@@ -995,7 +1009,7 @@ WaitResponse:
        pc.lk.Unlock()
 
        if re.err != nil {
-               pc.t.setReqConn(req.Request, nil)
+               pc.t.setReqCanceler(req.Request, nil)
        }
        return re.res, re.err
 }
index 510679e53b24ce211cfd114f48650d97d0040e8d..1d73633ea4bc4946bdcf3cad6ed1cacdfb5a751f 100644 (file)
@@ -11,9 +11,11 @@ import (
        "bytes"
        "compress/gzip"
        "crypto/rand"
+       "errors"
        "fmt"
        "io"
        "io/ioutil"
+       "log"
        "net"
        "net/http"
        . "net/http"
@@ -1321,6 +1323,61 @@ func TestTransportCancelRequest(t *testing.T) {
        }
 }
 
+func TestTransportCancelRequestInDial(t *testing.T) {
+       defer afterTest(t)
+       if testing.Short() {
+               t.Skip("skipping test in -short mode")
+       }
+       var logbuf bytes.Buffer
+       eventLog := log.New(&logbuf, "", 0)
+
+       unblockDial := make(chan bool)
+       defer close(unblockDial)
+
+       inDial := make(chan bool)
+       tr := &Transport{
+               Dial: func(network, addr string) (net.Conn, error) {
+                       eventLog.Println("dial: blocking")
+                       inDial <- true
+                       <-unblockDial
+                       return nil, errors.New("nope")
+               },
+       }
+       cl := &Client{Transport: tr}
+       gotres := make(chan bool)
+       req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+       go func() {
+               _, err := cl.Do(req)
+               eventLog.Printf("Get = %v", err)
+               gotres <- true
+       }()
+
+       select {
+       case <-inDial:
+       case <-time.After(5 * time.Second):
+               t.Fatal("timeout; never saw blocking dial")
+       }
+
+       eventLog.Printf("canceling")
+       tr.CancelRequest(req)
+
+       select {
+       case <-gotres:
+       case <-time.After(5 * time.Second):
+               panic("hang. events are: " + logbuf.String())
+               t.Fatal("timeout; cancel didn't work?")
+       }
+
+       got := logbuf.String()
+       want := `dial: blocking
+canceling
+Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection
+`
+       if got != want {
+               t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
+       }
+}
+
 // golang.org/issue/3672 -- Client can't close HTTP stream
 // Calling Close on a Response.Body used to just read until EOF.
 // Now it actually closes the TCP connection.