// MaxIdleConnsPerHost.
const DefaultMaxIdleConnsPerHost = 2
+// connsPerHostClosedCh is a closed channel used by MaxConnsPerHost
+// for the property that receives from a closed channel return the
+// zero value.
+var connsPerHostClosedCh = make(chan struct{})
+
+func init() {
+ close(connsPerHostClosedCh)
+}
+
// Transport is an implementation of RoundTripper that supports HTTP,
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
//
altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
+ connCountMu sync.Mutex
+ connPerHostCount map[connectMethodKey]int
+ connPerHostAvailable map[connectMethodKey]chan struct{}
+
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
+ // MaxConnsPerHost optionally limits the total number of
+ // connections per host, including connections in the dialing,
+ // active, and idle states. On limit violation, dials will block.
+ //
+ // Zero means no limit.
+ //
+ // For HTTP/2, this currently only controls the number of new
+ // connections being created at a time, instead of the total
+ // number. In practice, hosts using HTTP/2 only have about one
+ // idle connection, though.
+ MaxConnsPerHost int
+
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
// h2transport (via onceSetNextProtoDefaults)
nextProtoOnce sync.Once
h2transport *http2Transport // non-nil if http2 wired up
-
- // TODO: tunable on max per-host TCP dials in flight (Issue 13957)
}
// onceSetNextProtoDefaults initializes TLSNextProto.
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
- t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+ t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
+ t.setReqCanceler(req, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
err error
}
dialc := make(chan dialRes)
+ cmKey := cm.key()
// Copy these hooks so we don't race on the postPendingDial in
// the goroutine we launch. Issue 11136.
go func() {
if v := <-dialc; v.err == nil {
t.putOrCloseIdleConn(v.pc)
+ } else {
+ t.decHostConnCount(cmKey)
}
testHookPostPendingDial()
}()
cancelc := make(chan error, 1)
t.setReqCanceler(req, func(err error) { cancelc <- err })
+ if t.MaxConnsPerHost > 0 {
+ select {
+ case <-t.incHostConnCount(cmKey):
+ // count below conn per host limit; proceed
+ case pc := <-t.getIdleConnCh(cm):
+ if trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()})
+ }
+ return pc, nil
+ case <-req.Cancel:
+ return nil, errRequestCanceledConn
+ case <-req.Context().Done():
+ return nil, req.Context().Err()
+ case err := <-cancelc:
+ if err == errRequestCanceled {
+ err = errRequestCanceledConn
+ }
+ return nil, err
+ }
+ }
+
go func() {
pc, err := t.dialConn(ctx, cm)
dialc <- dialRes{pc, err}
}
// Our dial failed. See why to return a nicer error
// value.
+ t.decHostConnCount(cmKey)
select {
case <-req.Cancel:
// It was an error due to cancelation, so prioritize that
}
}
+// incHostConnCount increments the count of connections for a
+// given host. It returns an already-closed channel if the count
+// is not at its limit; otherwise it returns a channel which is
+// notified when the count is below the limit.
+func (t *Transport) incHostConnCount(cmKey connectMethodKey) <-chan struct{} {
+ if t.MaxConnsPerHost <= 0 {
+ return connsPerHostClosedCh
+ }
+ t.connCountMu.Lock()
+ defer t.connCountMu.Unlock()
+ if t.connPerHostCount[cmKey] == t.MaxConnsPerHost {
+ if t.connPerHostAvailable == nil {
+ t.connPerHostAvailable = make(map[connectMethodKey]chan struct{})
+ }
+ ch, ok := t.connPerHostAvailable[cmKey]
+ if !ok {
+ ch = make(chan struct{})
+ t.connPerHostAvailable[cmKey] = ch
+ }
+ return ch
+ }
+ if t.connPerHostCount == nil {
+ t.connPerHostCount = make(map[connectMethodKey]int)
+ }
+ t.connPerHostCount[cmKey]++
+ // return a closed channel to avoid race: if decHostConnCount is called
+ // after incHostConnCount and during the nil check, decHostConnCount
+ // will delete the channel since it's not being listened on yet.
+ return connsPerHostClosedCh
+}
+
+// decHostConnCount decrements the count of connections
+// for a given host.
+// See Transport.MaxConnsPerHost.
+func (t *Transport) decHostConnCount(cmKey connectMethodKey) {
+ if t.MaxConnsPerHost <= 0 {
+ return
+ }
+ t.connCountMu.Lock()
+ defer t.connCountMu.Unlock()
+ t.connPerHostCount[cmKey]--
+ select {
+ case t.connPerHostAvailable[cmKey] <- struct{}{}:
+ default:
+ // close channel before deleting avoids getConn waiting forever in
+ // case getConn has reference to channel but hasn't started waiting.
+ // This could lead to more than MaxConnsPerHost in the unlikely case
+ // that > 1 go routine has fetched the channel but none started waiting.
+ if t.connPerHostAvailable[cmKey] != nil {
+ close(t.connPerHostAvailable[cmKey])
+ }
+ delete(t.connPerHostAvailable, cmKey)
+ }
+ if t.connPerHostCount[cmKey] == 0 {
+ delete(t.connPerHostCount, cmKey)
+ }
+}
+
+// connCloseListener wraps a connection, the transport that dialed it
+// and the connected-to host key so the host connection count can be
+// transparently decremented by whatever closes the embedded connection.
+type connCloseListener struct {
+ net.Conn
+ t *Transport
+ cmKey connectMethodKey
+ didClose int32
+}
+
+func (c *connCloseListener) Close() error {
+ if atomic.AddInt32(&c.didClose, 1) != 1 {
+ return nil
+ }
+ err := c.Conn.Close()
+ c.t.decHostConnCount(c.cmKey)
+ return err
+}
+
// The connect method and the transport can both specify a TLS
// Host name. The transport's name takes precedence if present.
func chooseTLSHost(cm connectMethod, t *Transport) string {
}
}
+ if t.MaxConnsPerHost > 0 {
+ pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey}
+ }
pconn.br = bufio.NewReader(pconn)
pconn.bw = bufio.NewWriter(persistConnWriter{pconn})
go pconn.readLoop()
if e, g := 1, len(keys); e != g {
t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
}
- cacheKey := "|http|" + ts.Listener.Addr().String()
+ addr := ts.Listener.Addr().String()
+ cacheKey := "|http|" + addr
if keys[0] != cacheKey {
t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
}
- if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g {
+ if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
t.Errorf("after first response, expected %d idle conns; got %d", e, g)
}
resch <- "res2"
<-donech
- if g, w := tr.IdleConnCountForTesting(cacheKey), 2; g != w {
+ if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
t.Errorf("after second response, idle conns = %d; want %d", g, w)
}
resch <- "res3"
<-donech
- if g, w := tr.IdleConnCountForTesting(cacheKey), maxIdleConnsPerHost; g != w {
+ if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
t.Errorf("after third response, idle conns = %d; want %d", g, w)
}
}
+func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ }))
+ defer ts.Close()
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ dialStarted := make(chan struct{})
+ stallDial := make(chan struct{})
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ dialStarted <- struct{}{}
+ <-stallDial
+ return net.Dial(network, addr)
+ }
+
+ tr.DisableKeepAlives = true
+ tr.MaxConnsPerHost = 1
+
+ preDial := make(chan struct{})
+ reqComplete := make(chan struct{})
+ doReq := func(reqId string) {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ preDial <- struct{}{}
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ resp, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ _, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ reqComplete <- struct{}{}
+ }
+ // get req1 to dial-in-progress
+ go doReq("req1")
+ <-preDial
+ <-dialStarted
+
+ // get req2 to waiting on conns per host to go down below max
+ go doReq("req2")
+ <-preDial
+ select {
+ case <-dialStarted:
+ t.Error("req2 dial started while req1 dial in progress")
+ return
+ default:
+ }
+
+ // let req1 complete
+ stallDial <- struct{}{}
+ <-reqComplete
+
+ // let req2 complete
+ <-dialStarted
+ stallDial <- struct{}{}
+ <-reqComplete
+}
+
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)
func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
tr := &Transport{}
wantIdle := func(when string, n int) bool {
- got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn
+ got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
if got == n {
return true
}
return false
}
wantIdle("start", 0)
- if !tr.PutIdleTestConn() {
+ if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("put failed")
}
- if !tr.PutIdleTestConn() {
+ if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("second put failed")
}
wantIdle("after put", 2)
t.Error("should be idle after CloseIdleConnections")
}
wantIdle("after close idle", 0)
- if tr.PutIdleTestConn() {
+ if tr.PutIdleTestConn("http", "example.com") {
t.Fatal("put didn't fail")
}
wantIdle("after second put", 0)
if tr.IsIdleForTesting() {
t.Error("shouldn't be idle after RequestIdleConnChForTesting")
}
- if !tr.PutIdleTestConn() {
+ if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("after re-activation")
}
wantIdle("after final put", 1)