return newLoggingConn(baseName, c)
}
+func (t *Transport) NumPendingRequestsForTesting() int {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ return len(t.reqConn)
+}
+
func (t *Transport) IdleConnKeysForTesting() (keys []string) {
keys = make([]string, 0)
- t.idleLk.Lock()
- defer t.idleLk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return
}
}
func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
- t.idleLk.Lock()
- defer t.idleLk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return 0
}
// https, and http proxies (for either http or https with CONNECT).
// Transport can also cache connections for future re-use.
type Transport struct {
- idleLk sync.Mutex
+ idleMu sync.Mutex
idleConn map[string][]*persistConn
- altLk sync.RWMutex
+ reqMu sync.Mutex
+ reqConn map[*Request]*persistConn
+ altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// TODO: tunable on global max cached connections
return nil, errors.New("http: nil Request.Header")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
- t.altLk.RLock()
+ t.altMu.RLock()
var rt RoundTripper
if t.altProto != nil {
rt = t.altProto[req.URL.Scheme]
}
- t.altLk.RUnlock()
+ t.altMu.RUnlock()
if rt == nil {
return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
}
if scheme == "http" || scheme == "https" {
panic("protocol " + scheme + " already registered")
}
- t.altLk.Lock()
- defer t.altLk.Unlock()
+ t.altMu.Lock()
+ defer t.altMu.Unlock()
if t.altProto == nil {
t.altProto = make(map[string]RoundTripper)
}
// a "keep-alive" state. It does not interrupt any connections currently
// in use.
func (t *Transport) CloseIdleConnections() {
- t.idleLk.Lock()
+ t.idleMu.Lock()
m := t.idleConn
t.idleConn = nil
- t.idleLk.Unlock()
+ t.idleMu.Unlock()
if m == nil {
return
}
}
}
+// CancelRequest cancels an in-flight request by closing its
+// connection.
+func (t *Transport) CancelRequest(req *Request) {
+ t.reqMu.Lock()
+ pc := t.reqConn[req]
+ t.reqMu.Unlock()
+ if pc != nil {
+ pc.conn.Close()
+ }
+}
+
//
// Private implementation past this point.
//
if max == 0 {
max = DefaultMaxIdleConnsPerHost
}
- t.idleLk.Lock()
+ t.idleMu.Lock()
if t.idleConn == nil {
t.idleConn = make(map[string][]*persistConn)
}
if len(t.idleConn[key]) >= max {
- t.idleLk.Unlock()
+ t.idleMu.Unlock()
pconn.close()
return false
}
}
}
t.idleConn[key] = append(t.idleConn[key], pconn)
- t.idleLk.Unlock()
+ t.idleMu.Unlock()
return true
}
func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
key := cm.String()
- t.idleLk.Lock()
- defer t.idleLk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return nil
}
panic("unreachable")
}
+func (t *Transport) setReqConn(r *Request, pc *persistConn) {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ if t.reqConn == nil {
+ t.reqConn = make(map[*Request]*persistConn)
+ }
+ if pc != nil {
+ t.reqConn[r] = pc
+ } else {
+ delete(t.reqConn, r)
+ }
+}
+
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
if t.Dial != nil {
return t.Dial(network, addr)
alive = <-waitForBodyRead
}
+ pc.t.setReqConn(rc.req, nil)
+
if !alive {
pc.close()
}
}
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
+ pc.t.setReqConn(req.Request, pc)
pc.lk.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
pc.numExpectedResponses--
pc.lk.Unlock()
+ if re.err != nil {
+ pc.t.setReqConn(req.Request, nil)
+ }
return re.res, re.err
}
if testing.Short() {
t.Skip("skipping timeout test in -short mode")
}
- const debug = false
mux := NewServeMux()
mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
}
}
+func TestTransportCancelRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ }))
+ defer ts.Close()
+ defer close(unblockc)
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ tr.CancelRequest(req)
+ }()
+ t0 := time.Now()
+ body, err := ioutil.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err == nil {
+ t.Error("expected an error reading the body")
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 3; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {