From 5e03c84a3d48d84229a9c24727f1af99583ac348 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 7 Nov 2015 16:34:03 +0100 Subject: [PATCH] net/http: update http2 bundle to rev d62542 Updates to use new client pool abstraction. Change-Id: I3552018038ee8394d313d3253af337b07be211f6 Reviewed-on: https://go-review.googlesource.com/16730 Reviewed-by: Blake Mizerany Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/http/h2_bundle.go | 295 +++++++++++++++++++++----------------- src/net/http/transport.go | 49 +++++-- 2 files changed, 200 insertions(+), 144 deletions(-) diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index d6e749f8cd..c129c98aa9 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -34,6 +34,102 @@ import ( "time" ) +// ClientConnPool manages a pool of HTTP/2 client connections. +type http2ClientConnPool interface { + GetClientConn(req *Request, addr string) (*http2ClientConn, error) + MarkDead(*http2ClientConn) +} + +type http2clientConnPool struct { + t *http2Transport + mu sync.Mutex // TODO: switch to RWMutex + // TODO: add support for sharing conns based on cert names + // (e.g. share conn for googleapis.com and appspot.com) + conns map[string][]*http2ClientConn // key is host:port + keys map[*http2ClientConn][]string +} + +func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, true) +} + +func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + p.mu.Lock() + for _, cc := range p.conns[addr] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return cc, nil + } + } + p.mu.Unlock() + if !dialOnMiss { + return nil, http2ErrNoCachedConn + } + + cc, err := p.t.dialClientConn(addr) + if err != nil { + return nil, err + } + p.addConn(addr, cc) + return cc, nil +} + +func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + if p.conns == nil { + p.conns = make(map[string][]*http2ClientConn) + } + if p.keys == nil { + p.keys = make(map[*http2ClientConn][]string) + } + p.conns[key] = append(p.conns[key], cc) + p.keys[cc] = append(p.keys[cc], key) +} + +func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + for _, key := range p.keys[cc] { + vv, ok := p.conns[key] + if !ok { + continue + } + newList := http2filterOutClientConn(vv, cc) + if len(newList) > 0 { + p.conns[key] = newList + } else { + delete(p.conns, key) + } + } + delete(p.keys, cc) +} + +func (p *http2clientConnPool) closeIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + + for _, vv := range p.conns { + for _, cc := range vv { + cc.closeIfIdle() + } + } +} + +func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { + out := in[:0] + for _, v := range in { + if v != exclude { + out = append(out, v) + } + } + + if len(in) != len(out) { + in[len(in)-1] = nil + } + return out +} + // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. type http2ErrCode uint32 @@ -3503,20 +3599,33 @@ type http2Transport struct { // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config - // TODO: switch to RWMutex - // TODO: add support for sharing conns based on cert names - // (e.g. share conn for googleapis.com and appspot.com) - connMu sync.Mutex - conns map[string][]*http2clientConn // key is host:port + // ConnPool optionally specifies an alternate connection pool to use. + // If nil, the default is used. + ConnPool http2ClientConnPool + + connPoolOnce sync.Once + connPoolOrDef http2ClientConnPool // non-nil version of ConnPool +} + +func (t *http2Transport) connPool() http2ClientConnPool { + t.connPoolOnce.Do(t.initConnPool) + return t.connPoolOrDef +} + +func (t *http2Transport) initConnPool() { + if t.ConnPool != nil { + t.connPoolOrDef = t.ConnPool + } else { + t.connPoolOrDef = &http2clientConnPool{t: t} + } } -// clientConn is the state of a single HTTP/2 client connection to an +// ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. -type http2clientConn struct { +type http2ClientConn struct { t *http2Transport - tconn net.Conn - tlsState *tls.ConnectionState - connKey []string // key(s) this connection is cached in, in t.conns + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls // readLoop goroutine fields: readerDone chan struct{} // closed on error @@ -3548,7 +3657,7 @@ type http2clientConn struct { // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. type http2clientStream struct { - cc *http2clientConn + cc *http2ClientConn ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload @@ -3600,24 +3709,28 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { return t.RoundTripOpt(req, http2RoundTripOpt{}) } +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func http2authorityAddr(authority string) (addr string) { + if _, _, err := net.SplitHostPort(authority); err == nil { + return authority + } + return net.JoinHostPort(authority, "443") +} + // RoundTripOpt is like RoundTrip, but takes options. func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { if req.URL.Scheme != "https" { return nil, errors.New("http2: unsupported scheme") } - host, port, err := net.SplitHostPort(req.URL.Host) - if err != nil { - host = req.URL.Host - port = "443" - } - + addr := http2authorityAddr(req.URL.Host) for { - cc, err := t.getClientConn(host, port, opt.OnlyCachedConn) + cc, err := t.connPool().GetClientConn(req, addr) if err != nil { return nil, err } - res, err := cc.roundTrip(req) + res, err := cc.RoundTrip(req) if http2shouldRetryRequest(err) { continue } @@ -3632,12 +3745,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. func (t *http2Transport) CloseIdleConnections() { - t.connMu.Lock() - defer t.connMu.Unlock() - for _, vv := range t.conns { - for _, cc := range vv { - cc.closeIfIdle() - } + if cp, ok := t.connPool().(*http2clientConnPool); ok { + cp.closeIdleConnections() } } @@ -3648,95 +3757,16 @@ func http2shouldRetryRequest(err error) bool { return err == http2errClientConnClosed } -func (t *http2Transport) removeClientConn(cc *http2clientConn) { - t.connMu.Lock() - defer t.connMu.Unlock() - for _, key := range cc.connKey { - vv, ok := t.conns[key] - if !ok { - continue - } - newList := http2filterOutClientConn(vv, cc) - if len(newList) > 0 { - t.conns[key] = newList - } else { - delete(t.conns, key) - } - } -} - -func http2filterOutClientConn(in []*http2clientConn, exclude *http2clientConn) []*http2clientConn { - out := in[:0] - for _, v := range in { - if v != exclude { - out = append(out, v) - } - } - - if len(in) != len(out) { - in[len(in)-1] = nil - } - return out -} - -// AddIdleConn adds c as an idle conn for Transport. -// It assumes that c has not yet exchanged SETTINGS frames. -// The addr maybe be either "host" or "host:port". -func (t *http2Transport) AddIdleConn(addr string, c *tls.Conn) error { - var key string - _, _, err := net.SplitHostPort(addr) - if err == nil { - key = addr - } else { - key = addr + ":443" - } - cc, err := t.newClientConn(key, c) - if err != nil { - return err - } - - t.addConn(key, cc) - return nil -} - -func (t *http2Transport) addConn(key string, cc *http2clientConn) { - t.connMu.Lock() - defer t.connMu.Unlock() - if t.conns == nil { - t.conns = make(map[string][]*http2clientConn) - } - t.conns[key] = append(t.conns[key], cc) -} - -func (t *http2Transport) getClientConn(host, port string, onlyCached bool) (*http2clientConn, error) { - key := net.JoinHostPort(host, port) - - t.connMu.Lock() - for _, cc := range t.conns[key] { - if cc.canTakeNewRequest() { - t.connMu.Unlock() - return cc, nil - } - } - t.connMu.Unlock() - if onlyCached { - return nil, http2ErrNoCachedConn - } - - cc, err := t.dialClientConn(host, port, key) +func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) { + host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } - t.addConn(key, cc) - return cc, nil -} - -func (t *http2Transport) dialClientConn(host, port, key string) (*http2clientConn, error) { - tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host)) + tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host)) if err != nil { return nil, err } - return t.newClientConn(key, tconn) + return t.NewClientConn(tconn) } func (t *http2Transport) newTLSConfig(host string) *tls.Config { @@ -3779,15 +3809,14 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) ( return cn, nil } -func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2clientConn, error) { - if _, err := tconn.Write(http2clientPreface); err != nil { +func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { + if _, err := c.Write(http2clientPreface); err != nil { return nil, err } - cc := &http2clientConn{ + cc := &http2ClientConn{ t: t, - tconn: tconn, - connKey: []string{key}, + tconn: c, readerDone: make(chan struct{}), nextStreamID: 1, maxFrameSize: 16 << 10, @@ -3798,15 +3827,15 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client cc.cond = sync.NewCond(&cc.mu) cc.flow.add(int32(http2initialWindowSize)) - cc.bw = bufio.NewWriter(http2stickyErrWriter{tconn, &cc.werr}) - cc.br = bufio.NewReader(tconn) + cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) + cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) cc.henc = hpack.NewEncoder(&cc.hbuf) type connectionStater interface { ConnectionState() tls.ConnectionState } - if cs, ok := tconn.(connectionStater); ok { + if cs, ok := c.(connectionStater); ok { state := cs.ConnectionState() cc.tlsState = &state } @@ -3852,13 +3881,13 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client return cc, nil } -func (cc *http2clientConn) setGoAway(f *http2GoAwayFrame) { +func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() cc.goAway = f } -func (cc *http2clientConn) canTakeNewRequest() bool { +func (cc *http2ClientConn) CanTakeNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() return cc.goAway == nil && @@ -3866,7 +3895,7 @@ func (cc *http2clientConn) canTakeNewRequest() bool { cc.nextStreamID < 2147483647 } -func (cc *http2clientConn) closeIfIdle() { +func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 { cc.mu.Unlock() @@ -3885,7 +3914,7 @@ const http2maxAllocFrameSize = 512 << 10 // They're capped at the min of the peer's max frame size or 512KB // (kinda arbitrarily), but definitely capped so we don't allocate 4GB // bufers. -func (cc *http2clientConn) frameScratchBuffer() []byte { +func (cc *http2ClientConn) frameScratchBuffer() []byte { cc.mu.Lock() size := cc.maxFrameSize if size > http2maxAllocFrameSize { @@ -3902,7 +3931,7 @@ func (cc *http2clientConn) frameScratchBuffer() []byte { return make([]byte, size) } -func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) { +func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { cc.mu.Lock() defer cc.mu.Unlock() const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. @@ -3919,7 +3948,7 @@ func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) { } -func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) { +func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cc.mu.Lock() if cc.closed { @@ -4086,7 +4115,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int32) (taken int32, err } // requires cc.mu be held. -func (cc *http2clientConn) encodeHeaders(req *Request) []byte { +func (cc *http2ClientConn) encodeHeaders(req *Request) []byte { cc.hbuf.Reset() host := req.Host @@ -4111,7 +4140,7 @@ func (cc *http2clientConn) encodeHeaders(req *Request) []byte { return cc.hbuf.Bytes() } -func (cc *http2clientConn) writeHeader(name, value string) { +func (cc *http2ClientConn) writeHeader(name, value string) { cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) } @@ -4121,7 +4150,7 @@ type http2resAndError struct { } // requires cc.mu be held. -func (cc *http2clientConn) newStream() *http2clientStream { +func (cc *http2ClientConn) newStream() *http2clientStream { cs := &http2clientStream{ cc: cc, ID: cc.nextStreamID, @@ -4137,7 +4166,7 @@ func (cc *http2clientConn) newStream() *http2clientStream { return cs } -func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStream { +func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream { cc.mu.Lock() defer cc.mu.Unlock() cs := cc.streams[id] @@ -4149,7 +4178,7 @@ func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStr // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. type http2clientConnReadLoop struct { - cc *http2clientConn + cc *http2ClientConn activeRes map[uint32]*http2clientStream // keyed by streamID // continueStreamID is the stream ID we're waiting for @@ -4165,7 +4194,7 @@ type http2clientConnReadLoop struct { } // readLoop runs in its own goroutine and reads and dispatches frames. -func (cc *http2clientConn) readLoop() { +func (cc *http2ClientConn) readLoop() { rl := &http2clientConnReadLoop{ cc: cc, activeRes: make(map[uint32]*http2clientStream), @@ -4185,7 +4214,7 @@ func (cc *http2clientConn) readLoop() { func (rl *http2clientConnReadLoop) cleanup() { cc := rl.cc defer cc.tconn.Close() - defer cc.t.removeClientConn(cc) + defer cc.t.connPool().MarkDead(cc) defer close(cc.readerDone) err := cc.readerErr @@ -4397,7 +4426,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { cc := rl.cc - cc.t.removeClientConn(cc) + cc.t.connPool().MarkDead(cc) if f.ErrCode != 0 { cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) @@ -4472,7 +4501,7 @@ func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) return http2ConnectionError(http2ErrCodeProtocol) } -func (cc *http2clientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { +func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { cc.wmu.Lock() cc.fr.WriteRSTStream(streamID, code) @@ -4511,11 +4540,11 @@ func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) { } } -func (cc *http2clientConn) logf(format string, args ...interface{}) { +func (cc *http2ClientConn) logf(format string, args ...interface{}) { cc.t.logf(format, args...) } -func (cc *http2clientConn) vlogf(format string, args ...interface{}) { +func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { cc.t.vlogf(format, args...) } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 809f6de289..cb35be20ce 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -25,9 +25,27 @@ import ( "time" ) -// h2DefaultTransport is the HTTP/2 version of DefaultTransport. -// DefaultTransport and h2DefaultTransport are wired up together. -var h2DefaultTransport = &http2Transport{} +// HTTP/2 transport, integrated with the DefaultTransport. +var ( + // h2ConnPool is the connection pool for HTTP/2 connections. + h2ConnPool = &http2clientConnPool{} + // h2Transport is the HTTP/2 version of DefaultTransport. + h2Transport = &http2Transport{ConnPool: noDialClientConnPool{h2ConnPool}} +) + +func init() { + h2ConnPool.t = h2Transport // avoid decalaration loop +} + +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type noDialClientConnPool struct{ *http2clientConnPool } + +func (p noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + const doDial = false + return p.getClientConn(req, addr, doDial) +} // DefaultTransport is the default implementation of Transport and is // used by DefaultClient. It establishes network connections as needed @@ -50,24 +68,33 @@ func init() { return } t := DefaultTransport.(*Transport) - t.RegisterProtocol("https", noDialH2Transport{h2DefaultTransport}) + t.RegisterProtocol("https", noDialH2RoundTripper{}) t.TLSClientConfig = &tls.Config{ NextProtos: []string{"h2"}, } t.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{ "h2": func(authority string, c *tls.Conn) RoundTripper { - h2DefaultTransport.AddIdleConn(authority, c) - return h2DefaultTransport + cc, err := h2Transport.NewClientConn(c) + if err != nil { + c.Close() + return erringRoundTripper{err} + } + h2ConnPool.addConn(http2authorityAddr(authority), cc) + return h2Transport }, } } -// noDialH2Transport is a RoundTripper which only tries to complete the request if -// the wrapped *http2Transport already has a cached connection to the host. -type noDialH2Transport struct{ rt *http2Transport } +type erringRoundTripper struct{ err error } + +func (rt erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } + +// noDialH2RoundTripper is a RoundTripper which only tries to complete the request +// if there's already has a cached connection to the host. +type noDialH2RoundTripper struct{} -func (t noDialH2Transport) RoundTrip(req *Request) (*Response, error) { - res, err := t.rt.RoundTripOpt(req, http2RoundTripOpt{OnlyCachedConn: true}) +func (noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) { + res, err := h2Transport.RoundTrip(req) if err == http2ErrNoCachedConn { return nil, ErrSkipAltProtocol } -- 2.48.1