"time"
)
+// ClientConnPool manages a pool of HTTP/2 client connections.
+type http2ClientConnPool interface {
+ GetClientConn(req *Request, addr string) (*http2ClientConn, error)
+ MarkDead(*http2ClientConn)
+}
+
+type http2clientConnPool struct {
+ t *http2Transport
+ mu sync.Mutex // TODO: switch to RWMutex
+ // TODO: add support for sharing conns based on cert names
+ // (e.g. share conn for googleapis.com and appspot.com)
+ conns map[string][]*http2ClientConn // key is host:port
+ keys map[*http2ClientConn][]string
+}
+
+func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+ return p.getClientConn(req, addr, true)
+}
+
+func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+ p.mu.Lock()
+ for _, cc := range p.conns[addr] {
+ if cc.CanTakeNewRequest() {
+ p.mu.Unlock()
+ return cc, nil
+ }
+ }
+ p.mu.Unlock()
+ if !dialOnMiss {
+ return nil, http2ErrNoCachedConn
+ }
+
+ cc, err := p.t.dialClientConn(addr)
+ if err != nil {
+ return nil, err
+ }
+ p.addConn(addr, cc)
+ return cc, nil
+}
+
+func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.conns == nil {
+ p.conns = make(map[string][]*http2ClientConn)
+ }
+ if p.keys == nil {
+ p.keys = make(map[*http2ClientConn][]string)
+ }
+ p.conns[key] = append(p.conns[key], cc)
+ p.keys[cc] = append(p.keys[cc], key)
+}
+
+func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for _, key := range p.keys[cc] {
+ vv, ok := p.conns[key]
+ if !ok {
+ continue
+ }
+ newList := http2filterOutClientConn(vv, cc)
+ if len(newList) > 0 {
+ p.conns[key] = newList
+ } else {
+ delete(p.conns, key)
+ }
+ }
+ delete(p.keys, cc)
+}
+
+func (p *http2clientConnPool) closeIdleConnections() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ for _, vv := range p.conns {
+ for _, cc := range vv {
+ cc.closeIfIdle()
+ }
+ }
+}
+
+func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn {
+ out := in[:0]
+ for _, v := range in {
+ if v != exclude {
+ out = append(out, v)
+ }
+ }
+
+ if len(in) != len(out) {
+ in[len(in)-1] = nil
+ }
+ return out
+}
+
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type http2ErrCode uint32
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
- // TODO: switch to RWMutex
- // TODO: add support for sharing conns based on cert names
- // (e.g. share conn for googleapis.com and appspot.com)
- connMu sync.Mutex
- conns map[string][]*http2clientConn // key is host:port
+ // ConnPool optionally specifies an alternate connection pool to use.
+ // If nil, the default is used.
+ ConnPool http2ClientConnPool
+
+ connPoolOnce sync.Once
+ connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
+}
+
+func (t *http2Transport) connPool() http2ClientConnPool {
+ t.connPoolOnce.Do(t.initConnPool)
+ return t.connPoolOrDef
+}
+
+func (t *http2Transport) initConnPool() {
+ if t.ConnPool != nil {
+ t.connPoolOrDef = t.ConnPool
+ } else {
+ t.connPoolOrDef = &http2clientConnPool{t: t}
+ }
}
-// clientConn is the state of a single HTTP/2 client connection to an
+// ClientConn is the state of a single HTTP/2 client connection to an
// HTTP/2 server.
-type http2clientConn struct {
+type http2ClientConn struct {
t *http2Transport
- tconn net.Conn
- tlsState *tls.ConnectionState
- connKey []string // key(s) this connection is cached in, in t.conns
+ tconn net.Conn // usually *tls.Conn, except specialized impls
+ tlsState *tls.ConnectionState // nil only for specialized impls
// readLoop goroutine fields:
readerDone chan struct{} // closed on error
// clientStream is the state for a single HTTP/2 stream. One of these
// is created for each Transport.RoundTrip call.
type http2clientStream struct {
- cc *http2clientConn
+ cc *http2ClientConn
ID uint32
resc chan http2resAndError
bufPipe http2pipe // buffered pipe with the flow-controlled response payload
return t.RoundTripOpt(req, http2RoundTripOpt{})
}
+// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
+// and returns a host:port. The port 443 is added if needed.
+func http2authorityAddr(authority string) (addr string) {
+ if _, _, err := net.SplitHostPort(authority); err == nil {
+ return authority
+ }
+ return net.JoinHostPort(authority, "443")
+}
+
// RoundTripOpt is like RoundTrip, but takes options.
func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
if req.URL.Scheme != "https" {
return nil, errors.New("http2: unsupported scheme")
}
- host, port, err := net.SplitHostPort(req.URL.Host)
- if err != nil {
- host = req.URL.Host
- port = "443"
- }
-
+ addr := http2authorityAddr(req.URL.Host)
for {
- cc, err := t.getClientConn(host, port, opt.OnlyCachedConn)
+ cc, err := t.connPool().GetClientConn(req, addr)
if err != nil {
return nil, err
}
- res, err := cc.roundTrip(req)
+ res, err := cc.RoundTrip(req)
if http2shouldRetryRequest(err) {
continue
}
// connected from previous requests but are now sitting idle.
// It does not interrupt any connections currently in use.
func (t *http2Transport) CloseIdleConnections() {
- t.connMu.Lock()
- defer t.connMu.Unlock()
- for _, vv := range t.conns {
- for _, cc := range vv {
- cc.closeIfIdle()
- }
+ if cp, ok := t.connPool().(*http2clientConnPool); ok {
+ cp.closeIdleConnections()
}
}
return err == http2errClientConnClosed
}
-func (t *http2Transport) removeClientConn(cc *http2clientConn) {
- t.connMu.Lock()
- defer t.connMu.Unlock()
- for _, key := range cc.connKey {
- vv, ok := t.conns[key]
- if !ok {
- continue
- }
- newList := http2filterOutClientConn(vv, cc)
- if len(newList) > 0 {
- t.conns[key] = newList
- } else {
- delete(t.conns, key)
- }
- }
-}
-
-func http2filterOutClientConn(in []*http2clientConn, exclude *http2clientConn) []*http2clientConn {
- out := in[:0]
- for _, v := range in {
- if v != exclude {
- out = append(out, v)
- }
- }
-
- if len(in) != len(out) {
- in[len(in)-1] = nil
- }
- return out
-}
-
-// AddIdleConn adds c as an idle conn for Transport.
-// It assumes that c has not yet exchanged SETTINGS frames.
-// The addr maybe be either "host" or "host:port".
-func (t *http2Transport) AddIdleConn(addr string, c *tls.Conn) error {
- var key string
- _, _, err := net.SplitHostPort(addr)
- if err == nil {
- key = addr
- } else {
- key = addr + ":443"
- }
- cc, err := t.newClientConn(key, c)
- if err != nil {
- return err
- }
-
- t.addConn(key, cc)
- return nil
-}
-
-func (t *http2Transport) addConn(key string, cc *http2clientConn) {
- t.connMu.Lock()
- defer t.connMu.Unlock()
- if t.conns == nil {
- t.conns = make(map[string][]*http2clientConn)
- }
- t.conns[key] = append(t.conns[key], cc)
-}
-
-func (t *http2Transport) getClientConn(host, port string, onlyCached bool) (*http2clientConn, error) {
- key := net.JoinHostPort(host, port)
-
- t.connMu.Lock()
- for _, cc := range t.conns[key] {
- if cc.canTakeNewRequest() {
- t.connMu.Unlock()
- return cc, nil
- }
- }
- t.connMu.Unlock()
- if onlyCached {
- return nil, http2ErrNoCachedConn
- }
-
- cc, err := t.dialClientConn(host, port, key)
+func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) {
+ host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
- t.addConn(key, cc)
- return cc, nil
-}
-
-func (t *http2Transport) dialClientConn(host, port, key string) (*http2clientConn, error) {
- tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
+ tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
if err != nil {
return nil, err
}
- return t.newClientConn(key, tconn)
+ return t.NewClientConn(tconn)
}
func (t *http2Transport) newTLSConfig(host string) *tls.Config {
return cn, nil
}
-func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2clientConn, error) {
- if _, err := tconn.Write(http2clientPreface); err != nil {
+func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
+ if _, err := c.Write(http2clientPreface); err != nil {
return nil, err
}
- cc := &http2clientConn{
+ cc := &http2ClientConn{
t: t,
- tconn: tconn,
- connKey: []string{key},
+ tconn: c,
readerDone: make(chan struct{}),
nextStreamID: 1,
maxFrameSize: 16 << 10,
cc.cond = sync.NewCond(&cc.mu)
cc.flow.add(int32(http2initialWindowSize))
- cc.bw = bufio.NewWriter(http2stickyErrWriter{tconn, &cc.werr})
- cc.br = bufio.NewReader(tconn)
+ cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr})
+ cc.br = bufio.NewReader(c)
cc.fr = http2NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf)
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
- if cs, ok := tconn.(connectionStater); ok {
+ if cs, ok := c.(connectionStater); ok {
state := cs.ConnectionState()
cc.tlsState = &state
}
return cc, nil
}
-func (cc *http2clientConn) setGoAway(f *http2GoAwayFrame) {
+func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.goAway = f
}
-func (cc *http2clientConn) canTakeNewRequest() bool {
+func (cc *http2ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.goAway == nil &&
cc.nextStreamID < 2147483647
}
-func (cc *http2clientConn) closeIfIdle() {
+func (cc *http2ClientConn) closeIfIdle() {
cc.mu.Lock()
if len(cc.streams) > 0 {
cc.mu.Unlock()
// They're capped at the min of the peer's max frame size or 512KB
// (kinda arbitrarily), but definitely capped so we don't allocate 4GB
// bufers.
-func (cc *http2clientConn) frameScratchBuffer() []byte {
+func (cc *http2ClientConn) frameScratchBuffer() []byte {
cc.mu.Lock()
size := cc.maxFrameSize
if size > http2maxAllocFrameSize {
return make([]byte, size)
}
-func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) {
+func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) {
cc.mu.Lock()
defer cc.mu.Unlock()
const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
}
-func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) {
+func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
cc.mu.Lock()
if cc.closed {
}
// requires cc.mu be held.
-func (cc *http2clientConn) encodeHeaders(req *Request) []byte {
+func (cc *http2ClientConn) encodeHeaders(req *Request) []byte {
cc.hbuf.Reset()
host := req.Host
return cc.hbuf.Bytes()
}
-func (cc *http2clientConn) writeHeader(name, value string) {
+func (cc *http2ClientConn) writeHeader(name, value string) {
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
}
// requires cc.mu be held.
-func (cc *http2clientConn) newStream() *http2clientStream {
+func (cc *http2ClientConn) newStream() *http2clientStream {
cs := &http2clientStream{
cc: cc,
ID: cc.nextStreamID,
return cs
}
-func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
+func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
cc.mu.Lock()
defer cc.mu.Unlock()
cs := cc.streams[id]
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type http2clientConnReadLoop struct {
- cc *http2clientConn
+ cc *http2ClientConn
activeRes map[uint32]*http2clientStream // keyed by streamID
// continueStreamID is the stream ID we're waiting for
}
// readLoop runs in its own goroutine and reads and dispatches frames.
-func (cc *http2clientConn) readLoop() {
+func (cc *http2ClientConn) readLoop() {
rl := &http2clientConnReadLoop{
cc: cc,
activeRes: make(map[uint32]*http2clientStream),
func (rl *http2clientConnReadLoop) cleanup() {
cc := rl.cc
defer cc.tconn.Close()
- defer cc.t.removeClientConn(cc)
+ defer cc.t.connPool().MarkDead(cc)
defer close(cc.readerDone)
err := cc.readerErr
func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
cc := rl.cc
- cc.t.removeClientConn(cc)
+ cc.t.connPool().MarkDead(cc)
if f.ErrCode != 0 {
cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
return http2ConnectionError(http2ErrCodeProtocol)
}
-func (cc *http2clientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
+func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
cc.wmu.Lock()
cc.fr.WriteRSTStream(streamID, code)
}
}
-func (cc *http2clientConn) logf(format string, args ...interface{}) {
+func (cc *http2ClientConn) logf(format string, args ...interface{}) {
cc.t.logf(format, args...)
}
-func (cc *http2clientConn) vlogf(format string, args ...interface{}) {
+func (cc *http2ClientConn) vlogf(format string, args ...interface{}) {
cc.t.vlogf(format, args...)
}
"time"
)
-// h2DefaultTransport is the HTTP/2 version of DefaultTransport.
-// DefaultTransport and h2DefaultTransport are wired up together.
-var h2DefaultTransport = &http2Transport{}
+// HTTP/2 transport, integrated with the DefaultTransport.
+var (
+ // h2ConnPool is the connection pool for HTTP/2 connections.
+ h2ConnPool = &http2clientConnPool{}
+ // h2Transport is the HTTP/2 version of DefaultTransport.
+ h2Transport = &http2Transport{ConnPool: noDialClientConnPool{h2ConnPool}}
+)
+
+func init() {
+ h2ConnPool.t = h2Transport // avoid decalaration loop
+}
+
+// noDialClientConnPool is an implementation of http2.ClientConnPool
+// which never dials. We let the HTTP/1.1 client dial and use its TLS
+// connection instead.
+type noDialClientConnPool struct{ *http2clientConnPool }
+
+func (p noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+ const doDial = false
+ return p.getClientConn(req, addr, doDial)
+}
// DefaultTransport is the default implementation of Transport and is
// used by DefaultClient. It establishes network connections as needed
return
}
t := DefaultTransport.(*Transport)
- t.RegisterProtocol("https", noDialH2Transport{h2DefaultTransport})
+ t.RegisterProtocol("https", noDialH2RoundTripper{})
t.TLSClientConfig = &tls.Config{
NextProtos: []string{"h2"},
}
t.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{
"h2": func(authority string, c *tls.Conn) RoundTripper {
- h2DefaultTransport.AddIdleConn(authority, c)
- return h2DefaultTransport
+ cc, err := h2Transport.NewClientConn(c)
+ if err != nil {
+ c.Close()
+ return erringRoundTripper{err}
+ }
+ h2ConnPool.addConn(http2authorityAddr(authority), cc)
+ return h2Transport
},
}
}
-// noDialH2Transport is a RoundTripper which only tries to complete the request if
-// the wrapped *http2Transport already has a cached connection to the host.
-type noDialH2Transport struct{ rt *http2Transport }
+type erringRoundTripper struct{ err error }
+
+func (rt erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
+
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
+// if there's already has a cached connection to the host.
+type noDialH2RoundTripper struct{}
-func (t noDialH2Transport) RoundTrip(req *Request) (*Response, error) {
- res, err := t.rt.RoundTripOpt(req, http2RoundTripOpt{OnlyCachedConn: true})
+func (noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
+ res, err := h2Transport.RoundTrip(req)
if err == http2ErrNoCachedConn {
return nil, ErrSkipAltProtocol
}