From c40a73d80c38eacc27e8cb9cf25c2bacaee60a3d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 19 Jan 2016 05:10:58 +0000 Subject: [PATCH] net/http: make hidden http2 Transport respect remaining Transport fields Updates x/net/http2 to git rev 72aa00c6 for https://golang.org/cl/18721 (but actually at https://golang.org/cl/18722 now) Fixes #14008 Change-Id: If05d5ad51ec0ba5ba7e4fe16605c0a83f0484bc8 Reviewed-on: https://go-review.googlesource.com/18723 Run-TryBot: Brad Fitzpatrick Reviewed-by: Andrew Gerrand TryBot-Result: Gobot Gobot --- src/net/http/clientserver_test.go | 31 +++++- src/net/http/h2_bundle.go | 163 +++++++++++++++++++++++------- 2 files changed, 155 insertions(+), 39 deletions(-) diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 573ed93c05..9c1aa7920e 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -47,7 +47,7 @@ const ( h2Mode = true ) -func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest { +func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { cst := &clientServerTest{ t: t, h2: h2, @@ -55,6 +55,16 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest { tr: &Transport{}, } cst.c = &Client{Transport: cst.tr} + + for _, opt := range opts { + switch opt := opt.(type) { + case func(*Transport): + opt(cst.tr) + default: + t.Fatalf("unhandled option type %T", opt) + } + } + if !h2 { cst.ts = httptest.NewServer(h) return cst @@ -139,6 +149,7 @@ type h12Compare struct { Handler func(ResponseWriter, *Request) // required ReqFunc reqFunc // optional CheckResponse func(proto string, res *Response) // optional + Opts []interface{} } func (tt h12Compare) reqFunc() reqFunc { @@ -149,9 +160,9 @@ func (tt h12Compare) reqFunc() reqFunc { } func (tt h12Compare) run(t *testing.T) { - cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler)) + cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() - cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler)) + cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) defer cst2.close() res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) @@ -380,6 +391,20 @@ func TestH12_AutoGzip(t *testing.T) { }.run(t) } +func TestH12_AutoGzip_Disabled(t *testing.T) { + h12Compare{ + Opts: []interface{}{ + func(tr *Transport) { tr.DisableCompression = true }, + }, + Handler: func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) + if ae := r.Header.Get("Accept-Encoding"); ae != "" { + t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) + } + }, + }.run(t) +} + // Test304Responses verifies that 304s don't declare that they're // chunking in their response headers and aren't allowed to produce // output. diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 42f0ac1c69..cd530f16cd 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -24,7 +24,6 @@ import ( "encoding/binary" "errors" "fmt" - "golang.org/x/net/http2/hpack" "io" "io/ioutil" "log" @@ -38,6 +37,8 @@ import ( "strings" "sync" "time" + + "golang.org/x/net/http2/hpack" ) // ClientConnPool manages a pool of HTTP/2 client connections. @@ -248,7 +249,11 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ func http2configureTransport(t1 *Transport) (*http2Transport, error) { connPool := new(http2clientConnPool) - t2 := &http2Transport{ConnPool: http2noDialClientConnPool{connPool}} + t2 := &http2Transport{ + ConnPool: http2noDialClientConnPool{connPool}, + t1: t1, + } + connPool.t = t2 if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { return nil, err } @@ -2184,6 +2189,19 @@ func http2bodyAllowedForStatus(status int) bool { return true } +type http2httpError struct { + msg string + timeout bool +} + +func (e *http2httpError) Error() string { return e.msg } + +func (e *http2httpError) Timeout() bool { return e.timeout } + +func (e *http2httpError) Temporary() bool { return true } + +var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -4320,6 +4338,11 @@ type http2Transport struct { // to mean no limit. MaxHeaderListSize uint32 + // t1, if non-nil, is the standard library Transport using + // this transport. Its settings are used (but not its + // RoundTrip method, etc). + t1 *Transport + connPoolOnce sync.Once connPoolOrDef http2ClientConnPool // non-nil version of ConnPool } @@ -4335,11 +4358,7 @@ func (t *http2Transport) maxHeaderListSize() uint32 { } func (t *http2Transport) disableCompression() bool { - if t.DisableCompression { - return true - } - - return false + return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6") @@ -4395,7 +4414,7 @@ type http2ClientConn struct { henc *hpack.Encoder freeBuf [][]byte - wmu sync.Mutex // held while writing; acquire AFTER wmu if holding both + wmu sync.Mutex // held while writing; acquire AFTER mu if holding both werr error // first write error that has occurred } @@ -4413,7 +4432,7 @@ type http2clientStream struct { inflow http2flow // guarded by cc.mu bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read - stopReqBody bool // stop writing req body; guarded by cc.mu + stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu peerReset chan struct{} // closed on peer reset resetErr error // populated before peerReset is closed @@ -4456,10 +4475,13 @@ func (cs *http2clientStream) checkReset() error { } } -func (cs *http2clientStream) abortRequestBodyWrite() { +func (cs *http2clientStream) abortRequestBodyWrite(err error) { + if err == nil { + panic("nil error") + } cc := cs.cc cc.mu.Lock() - cs.stopReqBody = true + cs.stopReqBody = err cc.cond.Broadcast() cc.mu.Unlock() } @@ -4598,6 +4620,12 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) ( return cn, nil } +// disableKeepAlives reports whether connections should be closed as +// soon as possible after handling the first request. +func (t *http2Transport) disableKeepAlives() bool { + return t.t1 != nil && t.t1.DisableKeepAlives +} + func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { if http2VerboseLogs { t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) @@ -4692,7 +4720,7 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool { } func (cc *http2ClientConn) canTakeNewRequestLocked() bool { - return cc.goAway == nil && + return cc.goAway == nil && !cc.closed && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && cc.nextStreamID < 2147483647 } @@ -4772,6 +4800,14 @@ func http2commaSeparatedTrailers(req *Request) (string, error) { return "", nil } +func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { + if cc.t.t1 != nil { + return cc.t.t1.ResponseHeaderTimeout + } + + return 0 +} + func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { trailers, err := http2commaSeparatedTrailers(req) if err != nil { @@ -4832,24 +4868,32 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { return nil, werr } + var respHeaderTimer <-chan time.Time var bodyCopyErrc chan error // result of body copy if hasBody { bodyCopyErrc = make(chan error, 1) go func() { bodyCopyErrc <- cs.writeRequestBody(body, req.Body) }() + } else { + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + } } readLoopResCh := cs.resc requestCanceledCh := http2requestCancel(req) - requestCanceled := false + bodyWritten := false + for { select { case re := <-readLoopResCh: res := re.res if re.err != nil || res.StatusCode > 299 { - cs.abortRequestBodyWrite() + cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } if re.err != nil { cc.forgetStreamID(cs.ID) @@ -4858,32 +4902,35 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { res.Request = req res.TLS = cc.tlsState return res, nil + case <-respHeaderTimer: + cc.forgetStreamID(cs.ID) + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + } + return nil, http2errTimeout case <-requestCanceledCh: cc.forgetStreamID(cs.ID) - cs.abortRequestBodyWrite() - if !hasBody { + if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - return nil, http2errRequestCanceled + } else { + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } - - requestCanceled = true - requestCanceledCh = nil - readLoopResCh = nil + return nil, http2errRequestCanceled case <-cs.peerReset: - if requestCanceled { - - return nil, http2errRequestCanceled - } return nil, cs.resetErr case err := <-bodyCopyErrc: - if requestCanceled { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - return nil, http2errRequestCanceled - } if err != nil { return nil, err } + bodyWritten = true + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + } } } } @@ -4916,9 +4963,14 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs [] return cc.werr } -// errAbortReqBodyWrite is an internal error value. -// It doesn't escape to callers. -var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write") +// internal error values; they don't escape to callers +var ( + // abort request body write; don't send cancel + http2errStopReqBodyWrite = errors.New("http2: aborting request body write") + + // abort request body write, but send stream reset of cancel. + http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") +) func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { cc := cs.cc @@ -4951,7 +5003,13 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos for len(remain) > 0 && err == nil { var allowed int32 allowed, err = cs.awaitFlowControl(len(remain)) - if err != nil { + switch { + case err == http2errStopReqBodyWrite: + return err + case err == http2errStopReqBodyWriteAndCancel: + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + return err + case err != nil: return err } cc.wmu.Lock() @@ -5005,8 +5063,8 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er if cc.closed { return 0, http2errClientConnClosed } - if cs.stopReqBody { - return 0, http2errAbortReqBodyWrite + if cs.stopReqBody != nil { + return 0, cs.stopReqBody } if err := cs.checkReset(); err != nil { return 0, err @@ -5074,7 +5132,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail cc.writeHeader(lowKey, v) } } - if contentLength >= 0 { + if http2shouldSendReqContentLength(req.Method, contentLength) { cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { @@ -5086,6 +5144,27 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail return cc.hbuf.Bytes() } +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func http2shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} + // requires cc.mu be held. func (cc *http2ClientConn) encodeTrailers(req *Request) []byte { cc.hbuf.Reset() @@ -5204,6 +5283,8 @@ func (rl *http2clientConnReadLoop) cleanup() { func (rl *http2clientConnReadLoop) run() error { cc := rl.cc + closeWhenIdle := cc.t.disableKeepAlives() + gotReply := false for { f, err := cc.fr.ReadFrame() if err != nil { @@ -5218,18 +5299,25 @@ func (rl *http2clientConnReadLoop) run() error { if http2VerboseLogs { cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) } + maybeIdle := false switch f := f.(type) { case *http2HeadersFrame: err = rl.processHeaders(f) + maybeIdle = true + gotReply = true case *http2ContinuationFrame: err = rl.processContinuation(f) + maybeIdle = true case *http2DataFrame: err = rl.processData(f) + maybeIdle = true case *http2GoAwayFrame: err = rl.processGoAway(f) + maybeIdle = true case *http2RSTStreamFrame: err = rl.processResetStream(f) + maybeIdle = true case *http2SettingsFrame: err = rl.processSettings(f) case *http2PushPromiseFrame: @@ -5244,6 +5332,9 @@ func (rl *http2clientConnReadLoop) run() error { if err != nil { return err } + if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 { + cc.closeIfIdle() + } } } -- 2.48.1