]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make Transport.MaxConnsPerHost work for HTTP/2
authorMichael Fraenkel <michael.fraenkel@gmail.com>
Sat, 6 Oct 2018 17:32:19 +0000 (13:32 -0400)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 30 Apr 2019 23:25:55 +0000 (23:25 +0000)
Treat HTTP/2 connections as an ongoing persistent connection. When we
are told there is no cached connections, cleanup the associated
connection and host connection count.

Fixes #27753

Change-Id: I6b7bd915fc7819617cb5d3b35e46e225c75eda29
Reviewed-on: https://go-review.googlesource.com/c/go/+/140357
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/transport.go
src/net/http/transport_test.go

index c94d2b50bddb07b74111749ecaa0384916722b55..88761909fd503b8443f22d3360c4ea41c98eb174 100644 (file)
@@ -506,8 +506,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                var resp *Response
                if pconn.alt != nil {
                        // HTTP/2 path.
-                       t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
-                       t.setReqCanceler(req, nil)   // not cancelable with CancelRequest
+                       t.putOrCloseIdleConn(pconn)
+                       t.setReqCanceler(req, nil) // not cancelable with CancelRequest
                        resp, err = pconn.alt.RoundTrip(req)
                } else {
                        resp, err = pconn.roundTrip(treq)
@@ -515,7 +515,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                if err == nil {
                        return resp, nil
                }
-               if !pconn.shouldRetryRequest(req, err) {
+               if http2isNoCachedConnError(err) {
+                       t.removeIdleConn(pconn)
+                       t.decHostConnCount(cm.key()) // clean up the persistent connection
+               } else if !pconn.shouldRetryRequest(req, err) {
                        // Issue 16465: return underlying net.Conn.Read error from peek,
                        // as we've historically done.
                        if e, ok := err.(transportReadFromServerError); ok {
@@ -778,9 +781,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
        if pconn.isBroken() {
                return errConnBroken
        }
-       if pconn.alt != nil {
-               return errNotCachingH2Conn
-       }
        pconn.markReused()
        key := pconn.cacheKey
 
@@ -829,7 +829,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
                if pconn.idleTimer != nil {
                        pconn.idleTimer.Reset(t.IdleConnTimeout)
                } else {
-                       pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+                       // idleTimer does not apply to HTTP/2
+                       if pconn.alt == nil {
+                               pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+                       }
                }
        }
        pconn.idleAt = time.Now()
@@ -1377,7 +1380,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
 
        if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
                if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
-                       return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
+                       return &persistConn{cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
                }
        }
 
index 857f0d59288f58da460cd4a6af33575ddbb63fa8..44a935960ee6a6e8a5764bf1c1756c17053b41bb 100644 (file)
@@ -588,6 +588,107 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
        <-reqComplete
 }
 
+func TestTransportMaxConnsPerHost(t *testing.T) {
+       defer afterTest(t)
+       if runtime.GOOS == "js" {
+               t.Skipf("skipping test on js/wasm")
+       }
+       h := HandlerFunc(func(w ResponseWriter, r *Request) {
+               _, err := w.Write([]byte("foo"))
+               if err != nil {
+                       t.Fatalf("Write: %v", err)
+               }
+       })
+
+       testMaxConns := func(scheme string, ts *httptest.Server) {
+               defer ts.Close()
+
+               c := ts.Client()
+               tr := c.Transport.(*Transport)
+               tr.MaxConnsPerHost = 1
+               if err := ExportHttp2ConfigureTransport(tr); err != nil {
+                       t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
+               }
+
+               connCh := make(chan net.Conn, 1)
+               var dialCnt, gotConnCnt, tlsHandshakeCnt int32
+               tr.Dial = func(network, addr string) (net.Conn, error) {
+                       atomic.AddInt32(&dialCnt, 1)
+                       c, err := net.Dial(network, addr)
+                       connCh <- c
+                       return c, err
+               }
+
+               doReq := func() {
+                       trace := &httptrace.ClientTrace{
+                               GotConn: func(connInfo httptrace.GotConnInfo) {
+                                       if !connInfo.Reused {
+                                               atomic.AddInt32(&gotConnCnt, 1)
+                                       }
+                               },
+                               TLSHandshakeStart: func() {
+                                       atomic.AddInt32(&tlsHandshakeCnt, 1)
+                               },
+                       }
+                       req, _ := NewRequest("GET", ts.URL, nil)
+                       req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+                       resp, err := c.Do(req)
+                       if err != nil {
+                               t.Fatalf("request failed: %v", err)
+                       }
+                       defer resp.Body.Close()
+                       _, err = ioutil.ReadAll(resp.Body)
+                       if err != nil {
+                               t.Fatalf("read body failed: %v", err)
+                       }
+               }
+
+               wg := sync.WaitGroup{}
+               for i := 0; i < 10; i++ {
+                       wg.Add(1)
+                       go func() {
+                               defer wg.Done()
+                               doReq()
+                       }()
+               }
+               wg.Wait()
+
+               expected := int32(tr.MaxConnsPerHost)
+               if dialCnt != expected {
+                       t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
+               }
+               if gotConnCnt != expected {
+                       t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
+               }
+               if ts.TLS != nil && tlsHandshakeCnt != expected {
+                       t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
+               }
+
+               (<-connCh).Close()
+
+               doReq()
+               expected++
+               if dialCnt != expected {
+                       t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
+               }
+               if gotConnCnt != expected {
+                       t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
+               }
+               if ts.TLS != nil && tlsHandshakeCnt != expected {
+                       t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
+               }
+       }
+
+       testMaxConns("http", httptest.NewServer(h))
+       testMaxConns("https", httptest.NewTLSServer(h))
+
+       ts := httptest.NewUnstartedServer(h)
+       ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
+       ts.StartTLS()
+       testMaxConns("http2", ts)
+}
+
 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
        setParallel(t)
        defer afterTest(t)