]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.MaxConnsPerHost knob
authorMark Fischer <meirfischer@gmail.com>
Sun, 22 Apr 2018 05:16:46 +0000 (01:16 -0400)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 9 Jul 2018 18:32:16 +0000 (18:32 +0000)
Add field to http.Transport which limits connections per host,
including dial-in-progress, in-use and idle (keep-alive) connections.

For HTTP/2, this field only controls the number of dials in progress.

Fixes #13957

Change-Id: I7a5e045b4d4793c6b5b1a7191e1342cd7df78e6c
Reviewed-on: https://go-review.googlesource.com/71272
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/export_test.go
src/net/http/transport.go
src/net/http/transport_test.go

index 7c7b5d566757346c2e4d9b6f7550cde950210d80..7cdb51b05b0eb7c633682aab238694b055306336 100644 (file)
@@ -133,9 +133,11 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string {
        return ret
 }
 
-func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
+func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
        t.idleMu.Lock()
        defer t.idleMu.Unlock()
+       key := connectMethodKey{"", scheme, addr}
+       cacheKey := key.String()
        for k, conns := range t.idleConn {
                if k.String() == cacheKey {
                        return len(conns)
@@ -160,13 +162,19 @@ func (t *Transport) RequestIdleConnChForTesting() {
        t.getIdleConnCh(connectMethod{nil, "http", "example.com"})
 }
 
-func (t *Transport) PutIdleTestConn() bool {
+func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
        c, _ := net.Pipe()
+       key := connectMethodKey{"", scheme, addr}
+       select {
+       case <-t.incHostConnCount(key):
+       default:
+               return false
+       }
        return t.tryPutIdleConn(&persistConn{
                t:        t,
                conn:     c,                   // dummy
                closech:  make(chan struct{}), // so it can be closed
-               cacheKey: connectMethodKey{"", "http", "example.com"},
+               cacheKey: key,
        }) == nil
 }
 
index 59bffd0ae89139e6a50070681daa019e96329a16..182390cf017bf0b536cfb307e5bf2449f5f6982a 100644 (file)
@@ -55,6 +55,15 @@ var DefaultTransport RoundTripper = &Transport{
 // MaxIdleConnsPerHost.
 const DefaultMaxIdleConnsPerHost = 2
 
+// connsPerHostClosedCh is a closed channel used by MaxConnsPerHost
+// for the property that receives from a closed channel return the
+// zero value.
+var connsPerHostClosedCh = make(chan struct{})
+
+func init() {
+       close(connsPerHostClosedCh)
+}
+
 // Transport is an implementation of RoundTripper that supports HTTP,
 // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
 //
@@ -103,6 +112,10 @@ type Transport struct {
        altMu    sync.Mutex   // guards changing altProto only
        altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
 
+       connCountMu          sync.Mutex
+       connPerHostCount     map[connectMethodKey]int
+       connPerHostAvailable map[connectMethodKey]chan struct{}
+
        // Proxy specifies a function to return a proxy for a given
        // Request. If the function returns a non-nil error, the
        // request is aborted with the provided error.
@@ -183,6 +196,18 @@ type Transport struct {
        // DefaultMaxIdleConnsPerHost is used.
        MaxIdleConnsPerHost int
 
+       // MaxConnsPerHost optionally limits the total number of
+       // connections per host, including connections in the dialing,
+       // active, and idle states. On limit violation, dials will block.
+       //
+       // Zero means no limit.
+       //
+       // For HTTP/2, this currently only controls the number of new
+       // connections being created at a time, instead of the total
+       // number. In practice, hosts using HTTP/2 only have about one
+       // idle connection, though.
+       MaxConnsPerHost int
+
        // IdleConnTimeout is the maximum amount of time an idle
        // (keep-alive) connection will remain idle before closing
        // itself.
@@ -231,8 +256,6 @@ type Transport struct {
        // h2transport (via onceSetNextProtoDefaults)
        nextProtoOnce sync.Once
        h2transport   *http2Transport // non-nil if http2 wired up
-
-       // TODO: tunable on max per-host TCP dials in flight (Issue 13957)
 }
 
 // onceSetNextProtoDefaults initializes TLSNextProto.
@@ -409,7 +432,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                var resp *Response
                if pconn.alt != nil {
                        // HTTP/2 path.
-                       t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+                       t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
+                       t.setReqCanceler(req, nil)   // not cancelable with CancelRequest
                        resp, err = pconn.alt.RoundTrip(req)
                } else {
                        resp, err = pconn.roundTrip(treq)
@@ -908,6 +932,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                err error
        }
        dialc := make(chan dialRes)
+       cmKey := cm.key()
 
        // Copy these hooks so we don't race on the postPendingDial in
        // the goroutine we launch. Issue 11136.
@@ -919,6 +944,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                go func() {
                        if v := <-dialc; v.err == nil {
                                t.putOrCloseIdleConn(v.pc)
+                       } else {
+                               t.decHostConnCount(cmKey)
                        }
                        testHookPostPendingDial()
                }()
@@ -927,6 +954,27 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
        cancelc := make(chan error, 1)
        t.setReqCanceler(req, func(err error) { cancelc <- err })
 
+       if t.MaxConnsPerHost > 0 {
+               select {
+               case <-t.incHostConnCount(cmKey):
+                       // count below conn per host limit; proceed
+               case pc := <-t.getIdleConnCh(cm):
+                       if trace != nil && trace.GotConn != nil {
+                               trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()})
+                       }
+                       return pc, nil
+               case <-req.Cancel:
+                       return nil, errRequestCanceledConn
+               case <-req.Context().Done():
+                       return nil, req.Context().Err()
+               case err := <-cancelc:
+                       if err == errRequestCanceled {
+                               err = errRequestCanceledConn
+                       }
+                       return nil, err
+               }
+       }
+
        go func() {
                pc, err := t.dialConn(ctx, cm)
                dialc <- dialRes{pc, err}
@@ -944,6 +992,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
                }
                // Our dial failed. See why to return a nicer error
                // value.
+               t.decHostConnCount(cmKey)
                select {
                case <-req.Cancel:
                        // It was an error due to cancelation, so prioritize that
@@ -987,6 +1036,83 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
        }
 }
 
+// incHostConnCount increments the count of connections for a
+// given host. It returns an already-closed channel if the count
+// is not at its limit; otherwise it returns a channel which is
+// notified when the count is below the limit.
+func (t *Transport) incHostConnCount(cmKey connectMethodKey) <-chan struct{} {
+       if t.MaxConnsPerHost <= 0 {
+               return connsPerHostClosedCh
+       }
+       t.connCountMu.Lock()
+       defer t.connCountMu.Unlock()
+       if t.connPerHostCount[cmKey] == t.MaxConnsPerHost {
+               if t.connPerHostAvailable == nil {
+                       t.connPerHostAvailable = make(map[connectMethodKey]chan struct{})
+               }
+               ch, ok := t.connPerHostAvailable[cmKey]
+               if !ok {
+                       ch = make(chan struct{})
+                       t.connPerHostAvailable[cmKey] = ch
+               }
+               return ch
+       }
+       if t.connPerHostCount == nil {
+               t.connPerHostCount = make(map[connectMethodKey]int)
+       }
+       t.connPerHostCount[cmKey]++
+       // return a closed channel to avoid race: if decHostConnCount is called
+       // after incHostConnCount and during the nil check, decHostConnCount
+       // will delete the channel since it's not being listened on yet.
+       return connsPerHostClosedCh
+}
+
+// decHostConnCount decrements the count of connections
+// for a given host.
+// See Transport.MaxConnsPerHost.
+func (t *Transport) decHostConnCount(cmKey connectMethodKey) {
+       if t.MaxConnsPerHost <= 0 {
+               return
+       }
+       t.connCountMu.Lock()
+       defer t.connCountMu.Unlock()
+       t.connPerHostCount[cmKey]--
+       select {
+       case t.connPerHostAvailable[cmKey] <- struct{}{}:
+       default:
+               // close channel before deleting avoids getConn waiting forever in
+               // case getConn has reference to channel but hasn't started waiting.
+               // This could lead to more than MaxConnsPerHost in the unlikely case
+               // that > 1 go routine has fetched the channel but none started waiting.
+               if t.connPerHostAvailable[cmKey] != nil {
+                       close(t.connPerHostAvailable[cmKey])
+               }
+               delete(t.connPerHostAvailable, cmKey)
+       }
+       if t.connPerHostCount[cmKey] == 0 {
+               delete(t.connPerHostCount, cmKey)
+       }
+}
+
+// connCloseListener wraps a connection, the transport that dialed it
+// and the connected-to host key so the host connection count can be
+// transparently decremented by whatever closes the embedded connection.
+type connCloseListener struct {
+       net.Conn
+       t        *Transport
+       cmKey    connectMethodKey
+       didClose int32
+}
+
+func (c *connCloseListener) Close() error {
+       if atomic.AddInt32(&c.didClose, 1) != 1 {
+               return nil
+       }
+       err := c.Conn.Close()
+       c.t.decHostConnCount(c.cmKey)
+       return err
+}
+
 // The connect method and the transport can both specify a TLS
 // Host name.  The transport's name takes precedence if present.
 func chooseTLSHost(cm connectMethod, t *Transport) string {
@@ -1184,6 +1310,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                }
        }
 
+       if t.MaxConnsPerHost > 0 {
+               pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey}
+       }
        pconn.br = bufio.NewReader(pconn)
        pconn.bw = bufio.NewWriter(persistConnWriter{pconn})
        go pconn.readLoop()
index 87361e81ca50bb17246b441ed8bc0d9ad4f85040..5145da0ae07ee04609c92c9121c52cb8c45b5b1e 100644 (file)
@@ -446,27 +446,95 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
        if e, g := 1, len(keys); e != g {
                t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
        }
-       cacheKey := "|http|" + ts.Listener.Addr().String()
+       addr := ts.Listener.Addr().String()
+       cacheKey := "|http|" + addr
        if keys[0] != cacheKey {
                t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
        }
-       if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g {
+       if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
                t.Errorf("after first response, expected %d idle conns; got %d", e, g)
        }
 
        resch <- "res2"
        <-donech
-       if g, w := tr.IdleConnCountForTesting(cacheKey), 2; g != w {
+       if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
                t.Errorf("after second response, idle conns = %d; want %d", g, w)
        }
 
        resch <- "res3"
        <-donech
-       if g, w := tr.IdleConnCountForTesting(cacheKey), maxIdleConnsPerHost; g != w {
+       if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
                t.Errorf("after third response, idle conns = %d; want %d", g, w)
        }
 }
 
+func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
+       defer afterTest(t)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               _, err := w.Write([]byte("foo"))
+               if err != nil {
+                       t.Fatalf("Write: %v", err)
+               }
+       }))
+       defer ts.Close()
+       c := ts.Client()
+       tr := c.Transport.(*Transport)
+       dialStarted := make(chan struct{})
+       stallDial := make(chan struct{})
+       tr.Dial = func(network, addr string) (net.Conn, error) {
+               dialStarted <- struct{}{}
+               <-stallDial
+               return net.Dial(network, addr)
+       }
+
+       tr.DisableKeepAlives = true
+       tr.MaxConnsPerHost = 1
+
+       preDial := make(chan struct{})
+       reqComplete := make(chan struct{})
+       doReq := func(reqId string) {
+               req, _ := NewRequest("GET", ts.URL, nil)
+               trace := &httptrace.ClientTrace{
+                       GetConn: func(hostPort string) {
+                               preDial <- struct{}{}
+                       },
+               }
+               req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+               resp, err := tr.RoundTrip(req)
+               if err != nil {
+                       t.Errorf("unexpected error for request %s: %v", reqId, err)
+               }
+               _, err = ioutil.ReadAll(resp.Body)
+               if err != nil {
+                       t.Errorf("unexpected error for request %s: %v", reqId, err)
+               }
+               reqComplete <- struct{}{}
+       }
+       // get req1 to dial-in-progress
+       go doReq("req1")
+       <-preDial
+       <-dialStarted
+
+       // get req2 to waiting on conns per host to go down below max
+       go doReq("req2")
+       <-preDial
+       select {
+       case <-dialStarted:
+               t.Error("req2 dial started while req1 dial in progress")
+               return
+       default:
+       }
+
+       // let req1 complete
+       stallDial <- struct{}{}
+       <-reqComplete
+
+       // let req2 complete
+       <-dialStarted
+       stallDial <- struct{}{}
+       <-reqComplete
+}
+
 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
        setParallel(t)
        defer afterTest(t)
@@ -3118,7 +3186,7 @@ func TestRoundTripReturnsProxyError(t *testing.T) {
 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
        tr := &Transport{}
        wantIdle := func(when string, n int) bool {
-               got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn
+               got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
                if got == n {
                        return true
                }
@@ -3126,10 +3194,10 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
                return false
        }
        wantIdle("start", 0)
-       if !tr.PutIdleTestConn() {
+       if !tr.PutIdleTestConn("http", "example.com") {
                t.Fatal("put failed")
        }
-       if !tr.PutIdleTestConn() {
+       if !tr.PutIdleTestConn("http", "example.com") {
                t.Fatal("second put failed")
        }
        wantIdle("after put", 2)
@@ -3138,7 +3206,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
                t.Error("should be idle after CloseIdleConnections")
        }
        wantIdle("after close idle", 0)
-       if tr.PutIdleTestConn() {
+       if tr.PutIdleTestConn("http", "example.com") {
                t.Fatal("put didn't fail")
        }
        wantIdle("after second put", 0)
@@ -3147,7 +3215,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
        if tr.IsIdleForTesting() {
                t.Error("shouldn't be idle after RequestIdleConnChForTesting")
        }
-       if !tr.PutIdleTestConn() {
+       if !tr.PutIdleTestConn("http", "example.com") {
                t.Fatal("after re-activation")
        }
        wantIdle("after final put", 1)