]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: update http2 bundle to rev d62542
authorBrad Fitzpatrick <bradfitz@golang.org>
Sat, 7 Nov 2015 15:34:03 +0000 (16:34 +0100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 7 Nov 2015 15:51:59 +0000 (15:51 +0000)
Updates to use new client pool abstraction.

Change-Id: I3552018038ee8394d313d3253af337b07be211f6
Reviewed-on: https://go-review.googlesource.com/16730
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/h2_bundle.go
src/net/http/transport.go

index d6e749f8cdafe587e9250de23efd577138bbbf87..c129c98aa92afac71d2b9eb0b27e6c9df07f17a1 100644 (file)
@@ -34,6 +34,102 @@ import (
        "time"
 )
 
+// ClientConnPool manages a pool of HTTP/2 client connections.
+type http2ClientConnPool interface {
+       GetClientConn(req *Request, addr string) (*http2ClientConn, error)
+       MarkDead(*http2ClientConn)
+}
+
+type http2clientConnPool struct {
+       t  *http2Transport
+       mu sync.Mutex // TODO: switch to RWMutex
+       // TODO: add support for sharing conns based on cert names
+       // (e.g. share conn for googleapis.com and appspot.com)
+       conns map[string][]*http2ClientConn // key is host:port
+       keys  map[*http2ClientConn][]string
+}
+
+func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+       return p.getClientConn(req, addr, true)
+}
+
+func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+       p.mu.Lock()
+       for _, cc := range p.conns[addr] {
+               if cc.CanTakeNewRequest() {
+                       p.mu.Unlock()
+                       return cc, nil
+               }
+       }
+       p.mu.Unlock()
+       if !dialOnMiss {
+               return nil, http2ErrNoCachedConn
+       }
+
+       cc, err := p.t.dialClientConn(addr)
+       if err != nil {
+               return nil, err
+       }
+       p.addConn(addr, cc)
+       return cc, nil
+}
+
+func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) {
+       p.mu.Lock()
+       defer p.mu.Unlock()
+       if p.conns == nil {
+               p.conns = make(map[string][]*http2ClientConn)
+       }
+       if p.keys == nil {
+               p.keys = make(map[*http2ClientConn][]string)
+       }
+       p.conns[key] = append(p.conns[key], cc)
+       p.keys[cc] = append(p.keys[cc], key)
+}
+
+func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) {
+       p.mu.Lock()
+       defer p.mu.Unlock()
+       for _, key := range p.keys[cc] {
+               vv, ok := p.conns[key]
+               if !ok {
+                       continue
+               }
+               newList := http2filterOutClientConn(vv, cc)
+               if len(newList) > 0 {
+                       p.conns[key] = newList
+               } else {
+                       delete(p.conns, key)
+               }
+       }
+       delete(p.keys, cc)
+}
+
+func (p *http2clientConnPool) closeIdleConnections() {
+       p.mu.Lock()
+       defer p.mu.Unlock()
+
+       for _, vv := range p.conns {
+               for _, cc := range vv {
+                       cc.closeIfIdle()
+               }
+       }
+}
+
+func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn {
+       out := in[:0]
+       for _, v := range in {
+               if v != exclude {
+                       out = append(out, v)
+               }
+       }
+
+       if len(in) != len(out) {
+               in[len(in)-1] = nil
+       }
+       return out
+}
+
 // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
 type http2ErrCode uint32
 
@@ -3503,20 +3599,33 @@ type http2Transport struct {
        // tls.Client. If nil, the default configuration is used.
        TLSClientConfig *tls.Config
 
-       // TODO: switch to RWMutex
-       // TODO: add support for sharing conns based on cert names
-       // (e.g. share conn for googleapis.com and appspot.com)
-       connMu sync.Mutex
-       conns  map[string][]*http2clientConn // key is host:port
+       // ConnPool optionally specifies an alternate connection pool to use.
+       // If nil, the default is used.
+       ConnPool http2ClientConnPool
+
+       connPoolOnce  sync.Once
+       connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
+}
+
+func (t *http2Transport) connPool() http2ClientConnPool {
+       t.connPoolOnce.Do(t.initConnPool)
+       return t.connPoolOrDef
+}
+
+func (t *http2Transport) initConnPool() {
+       if t.ConnPool != nil {
+               t.connPoolOrDef = t.ConnPool
+       } else {
+               t.connPoolOrDef = &http2clientConnPool{t: t}
+       }
 }
 
-// clientConn is the state of a single HTTP/2 client connection to an
+// ClientConn is the state of a single HTTP/2 client connection to an
 // HTTP/2 server.
-type http2clientConn struct {
+type http2ClientConn struct {
        t        *http2Transport
-       tconn    net.Conn
-       tlsState *tls.ConnectionState
-       connKey  []string // key(s) this connection is cached in, in t.conns
+       tconn    net.Conn             // usually *tls.Conn, except specialized impls
+       tlsState *tls.ConnectionState // nil only for specialized impls
 
        // readLoop goroutine fields:
        readerDone chan struct{} // closed on error
@@ -3548,7 +3657,7 @@ type http2clientConn struct {
 // clientStream is the state for a single HTTP/2 stream. One of these
 // is created for each Transport.RoundTrip call.
 type http2clientStream struct {
-       cc      *http2clientConn
+       cc      *http2ClientConn
        ID      uint32
        resc    chan http2resAndError
        bufPipe http2pipe // buffered pipe with the flow-controlled response payload
@@ -3600,24 +3709,28 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) {
        return t.RoundTripOpt(req, http2RoundTripOpt{})
 }
 
+// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
+// and returns a host:port. The port 443 is added if needed.
+func http2authorityAddr(authority string) (addr string) {
+       if _, _, err := net.SplitHostPort(authority); err == nil {
+               return authority
+       }
+       return net.JoinHostPort(authority, "443")
+}
+
 // RoundTripOpt is like RoundTrip, but takes options.
 func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
        if req.URL.Scheme != "https" {
                return nil, errors.New("http2: unsupported scheme")
        }
 
-       host, port, err := net.SplitHostPort(req.URL.Host)
-       if err != nil {
-               host = req.URL.Host
-               port = "443"
-       }
-
+       addr := http2authorityAddr(req.URL.Host)
        for {
-               cc, err := t.getClientConn(host, port, opt.OnlyCachedConn)
+               cc, err := t.connPool().GetClientConn(req, addr)
                if err != nil {
                        return nil, err
                }
-               res, err := cc.roundTrip(req)
+               res, err := cc.RoundTrip(req)
                if http2shouldRetryRequest(err) {
                        continue
                }
@@ -3632,12 +3745,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
 // connected from previous requests but are now sitting idle.
 // It does not interrupt any connections currently in use.
 func (t *http2Transport) CloseIdleConnections() {
-       t.connMu.Lock()
-       defer t.connMu.Unlock()
-       for _, vv := range t.conns {
-               for _, cc := range vv {
-                       cc.closeIfIdle()
-               }
+       if cp, ok := t.connPool().(*http2clientConnPool); ok {
+               cp.closeIdleConnections()
        }
 }
 
@@ -3648,95 +3757,16 @@ func http2shouldRetryRequest(err error) bool {
        return err == http2errClientConnClosed
 }
 
-func (t *http2Transport) removeClientConn(cc *http2clientConn) {
-       t.connMu.Lock()
-       defer t.connMu.Unlock()
-       for _, key := range cc.connKey {
-               vv, ok := t.conns[key]
-               if !ok {
-                       continue
-               }
-               newList := http2filterOutClientConn(vv, cc)
-               if len(newList) > 0 {
-                       t.conns[key] = newList
-               } else {
-                       delete(t.conns, key)
-               }
-       }
-}
-
-func http2filterOutClientConn(in []*http2clientConn, exclude *http2clientConn) []*http2clientConn {
-       out := in[:0]
-       for _, v := range in {
-               if v != exclude {
-                       out = append(out, v)
-               }
-       }
-
-       if len(in) != len(out) {
-               in[len(in)-1] = nil
-       }
-       return out
-}
-
-// AddIdleConn adds c as an idle conn for Transport.
-// It assumes that c has not yet exchanged SETTINGS frames.
-// The addr maybe be either "host" or "host:port".
-func (t *http2Transport) AddIdleConn(addr string, c *tls.Conn) error {
-       var key string
-       _, _, err := net.SplitHostPort(addr)
-       if err == nil {
-               key = addr
-       } else {
-               key = addr + ":443"
-       }
-       cc, err := t.newClientConn(key, c)
-       if err != nil {
-               return err
-       }
-
-       t.addConn(key, cc)
-       return nil
-}
-
-func (t *http2Transport) addConn(key string, cc *http2clientConn) {
-       t.connMu.Lock()
-       defer t.connMu.Unlock()
-       if t.conns == nil {
-               t.conns = make(map[string][]*http2clientConn)
-       }
-       t.conns[key] = append(t.conns[key], cc)
-}
-
-func (t *http2Transport) getClientConn(host, port string, onlyCached bool) (*http2clientConn, error) {
-       key := net.JoinHostPort(host, port)
-
-       t.connMu.Lock()
-       for _, cc := range t.conns[key] {
-               if cc.canTakeNewRequest() {
-                       t.connMu.Unlock()
-                       return cc, nil
-               }
-       }
-       t.connMu.Unlock()
-       if onlyCached {
-               return nil, http2ErrNoCachedConn
-       }
-
-       cc, err := t.dialClientConn(host, port, key)
+func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) {
+       host, _, err := net.SplitHostPort(addr)
        if err != nil {
                return nil, err
        }
-       t.addConn(key, cc)
-       return cc, nil
-}
-
-func (t *http2Transport) dialClientConn(host, port, key string) (*http2clientConn, error) {
-       tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
+       tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
        if err != nil {
                return nil, err
        }
-       return t.newClientConn(key, tconn)
+       return t.NewClientConn(tconn)
 }
 
 func (t *http2Transport) newTLSConfig(host string) *tls.Config {
@@ -3779,15 +3809,14 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (
        return cn, nil
 }
 
-func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2clientConn, error) {
-       if _, err := tconn.Write(http2clientPreface); err != nil {
+func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
+       if _, err := c.Write(http2clientPreface); err != nil {
                return nil, err
        }
 
-       cc := &http2clientConn{
+       cc := &http2ClientConn{
                t:                    t,
-               tconn:                tconn,
-               connKey:              []string{key},
+               tconn:                c,
                readerDone:           make(chan struct{}),
                nextStreamID:         1,
                maxFrameSize:         16 << 10,
@@ -3798,15 +3827,15 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client
        cc.cond = sync.NewCond(&cc.mu)
        cc.flow.add(int32(http2initialWindowSize))
 
-       cc.bw = bufio.NewWriter(http2stickyErrWriter{tconn, &cc.werr})
-       cc.br = bufio.NewReader(tconn)
+       cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr})
+       cc.br = bufio.NewReader(c)
        cc.fr = http2NewFramer(cc.bw, cc.br)
        cc.henc = hpack.NewEncoder(&cc.hbuf)
 
        type connectionStater interface {
                ConnectionState() tls.ConnectionState
        }
-       if cs, ok := tconn.(connectionStater); ok {
+       if cs, ok := c.(connectionStater); ok {
                state := cs.ConnectionState()
                cc.tlsState = &state
        }
@@ -3852,13 +3881,13 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client
        return cc, nil
 }
 
-func (cc *http2clientConn) setGoAway(f *http2GoAwayFrame) {
+func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
        cc.mu.Lock()
        defer cc.mu.Unlock()
        cc.goAway = f
 }
 
-func (cc *http2clientConn) canTakeNewRequest() bool {
+func (cc *http2ClientConn) CanTakeNewRequest() bool {
        cc.mu.Lock()
        defer cc.mu.Unlock()
        return cc.goAway == nil &&
@@ -3866,7 +3895,7 @@ func (cc *http2clientConn) canTakeNewRequest() bool {
                cc.nextStreamID < 2147483647
 }
 
-func (cc *http2clientConn) closeIfIdle() {
+func (cc *http2ClientConn) closeIfIdle() {
        cc.mu.Lock()
        if len(cc.streams) > 0 {
                cc.mu.Unlock()
@@ -3885,7 +3914,7 @@ const http2maxAllocFrameSize = 512 << 10
 // They're capped at the min of the peer's max frame size or 512KB
 // (kinda arbitrarily), but definitely capped so we don't allocate 4GB
 // bufers.
-func (cc *http2clientConn) frameScratchBuffer() []byte {
+func (cc *http2ClientConn) frameScratchBuffer() []byte {
        cc.mu.Lock()
        size := cc.maxFrameSize
        if size > http2maxAllocFrameSize {
@@ -3902,7 +3931,7 @@ func (cc *http2clientConn) frameScratchBuffer() []byte {
        return make([]byte, size)
 }
 
-func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) {
+func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) {
        cc.mu.Lock()
        defer cc.mu.Unlock()
        const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
@@ -3919,7 +3948,7 @@ func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) {
 
 }
 
-func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) {
+func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
        cc.mu.Lock()
 
        if cc.closed {
@@ -4086,7 +4115,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int32) (taken int32, err
 }
 
 // requires cc.mu be held.
-func (cc *http2clientConn) encodeHeaders(req *Request) []byte {
+func (cc *http2ClientConn) encodeHeaders(req *Request) []byte {
        cc.hbuf.Reset()
 
        host := req.Host
@@ -4111,7 +4140,7 @@ func (cc *http2clientConn) encodeHeaders(req *Request) []byte {
        return cc.hbuf.Bytes()
 }
 
-func (cc *http2clientConn) writeHeader(name, value string) {
+func (cc *http2ClientConn) writeHeader(name, value string) {
        cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
 }
 
@@ -4121,7 +4150,7 @@ type http2resAndError struct {
 }
 
 // requires cc.mu be held.
-func (cc *http2clientConn) newStream() *http2clientStream {
+func (cc *http2ClientConn) newStream() *http2clientStream {
        cs := &http2clientStream{
                cc:        cc,
                ID:        cc.nextStreamID,
@@ -4137,7 +4166,7 @@ func (cc *http2clientConn) newStream() *http2clientStream {
        return cs
 }
 
-func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
+func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
        cc.mu.Lock()
        defer cc.mu.Unlock()
        cs := cc.streams[id]
@@ -4149,7 +4178,7 @@ func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStr
 
 // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
 type http2clientConnReadLoop struct {
-       cc        *http2clientConn
+       cc        *http2ClientConn
        activeRes map[uint32]*http2clientStream // keyed by streamID
 
        // continueStreamID is the stream ID we're waiting for
@@ -4165,7 +4194,7 @@ type http2clientConnReadLoop struct {
 }
 
 // readLoop runs in its own goroutine and reads and dispatches frames.
-func (cc *http2clientConn) readLoop() {
+func (cc *http2ClientConn) readLoop() {
        rl := &http2clientConnReadLoop{
                cc:        cc,
                activeRes: make(map[uint32]*http2clientStream),
@@ -4185,7 +4214,7 @@ func (cc *http2clientConn) readLoop() {
 func (rl *http2clientConnReadLoop) cleanup() {
        cc := rl.cc
        defer cc.tconn.Close()
-       defer cc.t.removeClientConn(cc)
+       defer cc.t.connPool().MarkDead(cc)
        defer close(cc.readerDone)
 
        err := cc.readerErr
@@ -4397,7 +4426,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
 
 func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
        cc := rl.cc
-       cc.t.removeClientConn(cc)
+       cc.t.connPool().MarkDead(cc)
        if f.ErrCode != 0 {
 
                cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
@@ -4472,7 +4501,7 @@ func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame)
        return http2ConnectionError(http2ErrCodeProtocol)
 }
 
-func (cc *http2clientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
+func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
 
        cc.wmu.Lock()
        cc.fr.WriteRSTStream(streamID, code)
@@ -4511,11 +4540,11 @@ func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
        }
 }
 
-func (cc *http2clientConn) logf(format string, args ...interface{}) {
+func (cc *http2ClientConn) logf(format string, args ...interface{}) {
        cc.t.logf(format, args...)
 }
 
-func (cc *http2clientConn) vlogf(format string, args ...interface{}) {
+func (cc *http2ClientConn) vlogf(format string, args ...interface{}) {
        cc.t.vlogf(format, args...)
 }
 
index 809f6de2892bf10e991f6bafdc19a698db693393..cb35be20ce422ab150e7bea28f38d319b19bff6b 100644 (file)
@@ -25,9 +25,27 @@ import (
        "time"
 )
 
-// h2DefaultTransport is the HTTP/2 version of DefaultTransport.
-// DefaultTransport and h2DefaultTransport are wired up together.
-var h2DefaultTransport = &http2Transport{}
+// HTTP/2 transport, integrated with the DefaultTransport.
+var (
+       // h2ConnPool is the connection pool for HTTP/2 connections.
+       h2ConnPool = &http2clientConnPool{}
+       // h2Transport is the HTTP/2 version of DefaultTransport.
+       h2Transport = &http2Transport{ConnPool: noDialClientConnPool{h2ConnPool}}
+)
+
+func init() {
+       h2ConnPool.t = h2Transport // avoid decalaration loop
+}
+
+// noDialClientConnPool is an implementation of http2.ClientConnPool
+// which never dials.  We let the HTTP/1.1 client dial and use its TLS
+// connection instead.
+type noDialClientConnPool struct{ *http2clientConnPool }
+
+func (p noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+       const doDial = false
+       return p.getClientConn(req, addr, doDial)
+}
 
 // DefaultTransport is the default implementation of Transport and is
 // used by DefaultClient. It establishes network connections as needed
@@ -50,24 +68,33 @@ func init() {
                return
        }
        t := DefaultTransport.(*Transport)
-       t.RegisterProtocol("https", noDialH2Transport{h2DefaultTransport})
+       t.RegisterProtocol("https", noDialH2RoundTripper{})
        t.TLSClientConfig = &tls.Config{
                NextProtos: []string{"h2"},
        }
        t.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{
                "h2": func(authority string, c *tls.Conn) RoundTripper {
-                       h2DefaultTransport.AddIdleConn(authority, c)
-                       return h2DefaultTransport
+                       cc, err := h2Transport.NewClientConn(c)
+                       if err != nil {
+                               c.Close()
+                               return erringRoundTripper{err}
+                       }
+                       h2ConnPool.addConn(http2authorityAddr(authority), cc)
+                       return h2Transport
                },
        }
 }
 
-// noDialH2Transport is a RoundTripper which only tries to complete the request if
-// the wrapped *http2Transport already has a cached connection to the host.
-type noDialH2Transport struct{ rt *http2Transport }
+type erringRoundTripper struct{ err error }
+
+func (rt erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
+
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
+// if there's already has a cached connection to the host.
+type noDialH2RoundTripper struct{}
 
-func (t noDialH2Transport) RoundTrip(req *Request) (*Response, error) {
-       res, err := t.rt.RoundTripOpt(req, http2RoundTripOpt{OnlyCachedConn: true})
+func (noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
+       res, err := h2Transport.RoundTrip(req)
        if err == http2ErrNoCachedConn {
                return nil, ErrSkipAltProtocol
        }