"io"
"io/ioutil"
"log"
+ "net"
. "net/http"
"net/http/httptest"
"net/url"
"sort"
"strings"
"sync"
+ "sync/atomic"
"testing"
+ "time"
)
type clientServerTest struct {
t.Errorf("RequestURI = %q; want *", req.RequestURI)
}
}
+
+// Issue 13957
+func TestTransportDiscardsUnneededConns(t *testing.T) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
+ }))
+ defer cst.close()
+
+ var numOpen, numClose int32 // atomic
+
+ tlsConfig := &tls.Config{InsecureSkipVerify: true}
+ tr := &Transport{
+ TLSClientConfig: tlsConfig,
+ DialTLS: func(_, addr string) (net.Conn, error) {
+ time.Sleep(10 * time.Millisecond)
+ rc, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ atomic.AddInt32(&numOpen, 1)
+ c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
+ return tls.Client(c, tlsConfig), nil
+ },
+ }
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatal(err)
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+
+ const N = 10
+ gotBody := make(chan string, N)
+ var wg sync.WaitGroup
+ for i := 0; i < N; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ resp, err := c.Get(cst.ts.URL)
+ if err != nil {
+ t.Errorf("Get: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+ slurp, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ gotBody <- string(slurp)
+ }()
+ }
+ wg.Wait()
+ close(gotBody)
+
+ var last string
+ for got := range gotBody {
+ if last == "" {
+ last = got
+ continue
+ }
+ if got != last {
+ t.Errorf("Response body changed: %q -> %q", last, got)
+ }
+ }
+
+ var open, close int32
+ for i := 0; i < 150; i++ {
+ open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
+ if open < 1 {
+ t.Fatalf("open = %d; want at least", open)
+ }
+ if close == open-1 {
+ // Success
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
+}
+
+type noteCloseConn struct {
+ net.Conn
+ closeFunc func()
+}
+
+func (x noteCloseConn) Close() error {
+ x.closeFunc()
+ return x.Conn.Close()
+}
"encoding/binary"
"errors"
"fmt"
+ "golang.org/x/net/http2/hpack"
"io"
"io/ioutil"
"log"
"strings"
"sync"
"time"
-
- "golang.org/x/net/http2/hpack"
)
// ClientConnPool manages a pool of HTTP/2 client connections.
MarkDead(*http2ClientConn)
}
+// TODO: use singleflight for dialing and addConnCalls?
type http2clientConnPool struct {
- t *http2Transport
+ t *http2Transport
+
mu sync.Mutex // TODO: maybe switch to RWMutex
// TODO: add support for sharing conns based on cert names
// (e.g. share conn for googleapis.com and appspot.com)
- conns map[string][]*http2ClientConn // key is host:port
- dialing map[string]*http2dialCall // currently in-flight dials
- keys map[*http2ClientConn][]string
+ conns map[string][]*http2ClientConn // key is host:port
+ dialing map[string]*http2dialCall // currently in-flight dials
+ keys map[*http2ClientConn][]string
+ addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls
}
func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
- return p.getClientConn(req, addr, true)
+ return p.getClientConn(req, addr, http2dialOnMiss)
}
-func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+const (
+ http2dialOnMiss = true
+ http2noDialOnMiss = false
+)
+
+func (p *http2clientConnPool) getClientConn(_ *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if cc.CanTakeNewRequest() {
c.p.mu.Unlock()
}
+// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
+// already exist. It coalesces concurrent calls with the same key.
+// This is used by the http1 Transport code when it creates a new connection. Because
+// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
+// the protocol), it can get into a situation where it has multiple TLS connections.
+// This code decides which ones live or die.
+// The return value used is whether c was used.
+// c is never closed.
+func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) {
+ p.mu.Lock()
+ for _, cc := range p.conns[key] {
+ if cc.CanTakeNewRequest() {
+ p.mu.Unlock()
+ return false, nil
+ }
+ }
+ call, dup := p.addConnCalls[key]
+ if !dup {
+ if p.addConnCalls == nil {
+ p.addConnCalls = make(map[string]*http2addConnCall)
+ }
+ call = &http2addConnCall{
+ p: p,
+ done: make(chan struct{}),
+ }
+ p.addConnCalls[key] = call
+ go call.run(t, key, c)
+ }
+ p.mu.Unlock()
+
+ <-call.done
+ if call.err != nil {
+ return false, call.err
+ }
+ return !dup, nil
+}
+
+type http2addConnCall struct {
+ p *http2clientConnPool
+ done chan struct{} // closed when done
+ err error
+}
+
+func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) {
+ cc, err := t.NewClientConn(tc)
+
+ p := c.p
+ p.mu.Lock()
+ if err != nil {
+ c.err = err
+ } else {
+ p.addConnLocked(key, cc)
+ }
+ delete(p.addConnCalls, key)
+ p.mu.Unlock()
+ close(c.done)
+}
+
func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) {
p.mu.Lock()
p.addConnLocked(key, cc)
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
}
upgradeFn := func(authority string, c *tls.Conn) RoundTripper {
- cc, err := t2.NewClientConn(c)
- if err != nil {
- c.Close()
+ addr := http2authorityAddr(authority)
+ if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
+ go c.Close()
return http2erringRoundTripper{err}
+ } else if !used {
+
+ go c.Close()
}
- connPool.addConn(http2authorityAddr(authority), cc)
return t2
}
if m := t1.TLSNextProto; len(m) == 0 {
type http2noDialClientConnPool struct{ *http2clientConnPool }
func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
- const doDial = false
- return p.getClientConn(req, addr, doDial)
+ return p.getClientConn(req, addr, http2noDialOnMiss)
}
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request