var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
- t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
- t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+ t.putOrCloseIdleConn(pconn)
+ t.setReqCanceler(req, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
if err == nil {
return resp, nil
}
- if !pconn.shouldRetryRequest(req, err) {
+ if http2isNoCachedConnError(err) {
+ t.removeIdleConn(pconn)
+ t.decHostConnCount(cm.key()) // clean up the persistent connection
+ } else if !pconn.shouldRetryRequest(req, err) {
// Issue 16465: return underlying net.Conn.Read error from peek,
// as we've historically done.
if e, ok := err.(transportReadFromServerError); ok {
if pconn.isBroken() {
return errConnBroken
}
- if pconn.alt != nil {
- return errNotCachingH2Conn
- }
pconn.markReused()
key := pconn.cacheKey
if pconn.idleTimer != nil {
pconn.idleTimer.Reset(t.IdleConnTimeout)
} else {
- pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+ // idleTimer does not apply to HTTP/2
+ if pconn.alt == nil {
+ pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+ }
}
}
pconn.idleAt = time.Now()
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
- return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
+ return &persistConn{cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
}
}
<-reqComplete
}
+func TestTransportMaxConnsPerHost(t *testing.T) {
+ defer afterTest(t)
+ if runtime.GOOS == "js" {
+ t.Skipf("skipping test on js/wasm")
+ }
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ testMaxConns := func(scheme string, ts *httptest.Server) {
+ defer ts.Close()
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
+ }
+
+ connCh := make(chan net.Conn, 1)
+ var dialCnt, gotConnCnt, tlsHandshakeCnt int32
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ atomic.AddInt32(&dialCnt, 1)
+ c, err := net.Dial(network, addr)
+ connCh <- c
+ return c, err
+ }
+
+ doReq := func() {
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ if !connInfo.Reused {
+ atomic.AddInt32(&gotConnCnt, 1)
+ }
+ },
+ TLSHandshakeStart: func() {
+ atomic.AddInt32(&tlsHandshakeCnt, 1)
+ },
+ }
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("request failed: %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("read body failed: %v", err)
+ }
+ }
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doReq()
+ }()
+ }
+ wg.Wait()
+
+ expected := int32(tr.MaxConnsPerHost)
+ if dialCnt != expected {
+ t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
+ }
+
+ (<-connCh).Close()
+
+ doReq()
+ expected++
+ if dialCnt != expected {
+ t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
+ }
+ }
+
+ testMaxConns("http", httptest.NewServer(h))
+ testMaxConns("https", httptest.NewTLSServer(h))
+
+ ts := httptest.NewUnstartedServer(h)
+ ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
+ ts.StartTLS()
+ testMaxConns("http2", ts)
+}
+
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)