]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make hidden http2 Transport respect remaining Transport fields
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 19 Jan 2016 05:10:58 +0000 (05:10 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 19 Jan 2016 05:31:38 +0000 (05:31 +0000)
Updates x/net/http2 to git rev 72aa00c6 for https://golang.org/cl/18721
(but actually at https://golang.org/cl/18722 now)

Fixes #14008

Change-Id: If05d5ad51ec0ba5ba7e4fe16605c0a83f0484bc8
Reviewed-on: https://go-review.googlesource.com/18723
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/clientserver_test.go
src/net/http/h2_bundle.go

index 573ed93c0584d03f4542b8692d49b2e7216ca834..9c1aa7920e347c5b0815a8fea5f631ca6dd1e19a 100644 (file)
@@ -47,7 +47,7 @@ const (
        h2Mode = true
 )
 
-func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest {
+func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
        cst := &clientServerTest{
                t:  t,
                h2: h2,
@@ -55,6 +55,16 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest {
                tr: &Transport{},
        }
        cst.c = &Client{Transport: cst.tr}
+
+       for _, opt := range opts {
+               switch opt := opt.(type) {
+               case func(*Transport):
+                       opt(cst.tr)
+               default:
+                       t.Fatalf("unhandled option type %T", opt)
+               }
+       }
+
        if !h2 {
                cst.ts = httptest.NewServer(h)
                return cst
@@ -139,6 +149,7 @@ type h12Compare struct {
        Handler       func(ResponseWriter, *Request)    // required
        ReqFunc       reqFunc                           // optional
        CheckResponse func(proto string, res *Response) // optional
+       Opts          []interface{}
 }
 
 func (tt h12Compare) reqFunc() reqFunc {
@@ -149,9 +160,9 @@ func (tt h12Compare) reqFunc() reqFunc {
 }
 
 func (tt h12Compare) run(t *testing.T) {
-       cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler))
+       cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
        defer cst1.close()
-       cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler))
+       cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
        defer cst2.close()
 
        res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
@@ -380,6 +391,20 @@ func TestH12_AutoGzip(t *testing.T) {
        }.run(t)
 }
 
+func TestH12_AutoGzip_Disabled(t *testing.T) {
+       h12Compare{
+               Opts: []interface{}{
+                       func(tr *Transport) { tr.DisableCompression = true },
+               },
+               Handler: func(w ResponseWriter, r *Request) {
+                       fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
+                       if ae := r.Header.Get("Accept-Encoding"); ae != "" {
+                               t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
+                       }
+               },
+       }.run(t)
+}
+
 // Test304Responses verifies that 304s don't declare that they're
 // chunking in their response headers and aren't allowed to produce
 // output.
index 42f0ac1c69dfc51dc81539f9c9b90cde616a056f..cd530f16cd0c6e85f812af6e842e5fed03ec7de0 100644 (file)
@@ -24,7 +24,6 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
-       "golang.org/x/net/http2/hpack"
        "io"
        "io/ioutil"
        "log"
@@ -38,6 +37,8 @@ import (
        "strings"
        "sync"
        "time"
+
+       "golang.org/x/net/http2/hpack"
 )
 
 // ClientConnPool manages a pool of HTTP/2 client connections.
@@ -248,7 +249,11 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [
 
 func http2configureTransport(t1 *Transport) (*http2Transport, error) {
        connPool := new(http2clientConnPool)
-       t2 := &http2Transport{ConnPool: http2noDialClientConnPool{connPool}}
+       t2 := &http2Transport{
+               ConnPool: http2noDialClientConnPool{connPool},
+               t1:       t1,
+       }
+       connPool.t = t2
        if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil {
                return nil, err
        }
@@ -2184,6 +2189,19 @@ func http2bodyAllowedForStatus(status int) bool {
        return true
 }
 
+type http2httpError struct {
+       msg     string
+       timeout bool
+}
+
+func (e *http2httpError) Error() string { return e.msg }
+
+func (e *http2httpError) Timeout() bool { return e.timeout }
+
+func (e *http2httpError) Temporary() bool { return true }
+
+var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true}
+
 // pipe is a goroutine-safe io.Reader/io.Writer pair.  It's like
 // io.Pipe except there are no PipeReader/PipeWriter halves, and the
 // underlying buffer is an interface. (io.Pipe is always unbuffered)
@@ -4320,6 +4338,11 @@ type http2Transport struct {
        // to mean no limit.
        MaxHeaderListSize uint32
 
+       // t1, if non-nil, is the standard library Transport using
+       // this transport. Its settings are used (but not its
+       // RoundTrip method, etc).
+       t1 *Transport
+
        connPoolOnce  sync.Once
        connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
 }
@@ -4335,11 +4358,7 @@ func (t *http2Transport) maxHeaderListSize() uint32 {
 }
 
 func (t *http2Transport) disableCompression() bool {
-       if t.DisableCompression {
-               return true
-       }
-
-       return false
+       return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
 }
 
 var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
@@ -4395,7 +4414,7 @@ type http2ClientConn struct {
        henc                 *hpack.Encoder
        freeBuf              [][]byte
 
-       wmu  sync.Mutex // held while writing; acquire AFTER wmu if holding both
+       wmu  sync.Mutex // held while writing; acquire AFTER mu if holding both
        werr error      // first write error that has occurred
 }
 
@@ -4413,7 +4432,7 @@ type http2clientStream struct {
        inflow      http2flow // guarded by cc.mu
        bytesRemain int64     // -1 means unknown; owned by transportResponseBody.Read
        readErr     error     // sticky read error; owned by transportResponseBody.Read
-       stopReqBody bool      // stop writing req body; guarded by cc.mu
+       stopReqBody error     // if non-nil, stop writing req body; guarded by cc.mu
 
        peerReset chan struct{} // closed on peer reset
        resetErr  error         // populated before peerReset is closed
@@ -4456,10 +4475,13 @@ func (cs *http2clientStream) checkReset() error {
        }
 }
 
-func (cs *http2clientStream) abortRequestBodyWrite() {
+func (cs *http2clientStream) abortRequestBodyWrite(err error) {
+       if err == nil {
+               panic("nil error")
+       }
        cc := cs.cc
        cc.mu.Lock()
-       cs.stopReqBody = true
+       cs.stopReqBody = err
        cc.cond.Broadcast()
        cc.mu.Unlock()
 }
@@ -4598,6 +4620,12 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (
        return cn, nil
 }
 
+// disableKeepAlives reports whether connections should be closed as
+// soon as possible after handling the first request.
+func (t *http2Transport) disableKeepAlives() bool {
+       return t.t1 != nil && t.t1.DisableKeepAlives
+}
+
 func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
        if http2VerboseLogs {
                t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
@@ -4692,7 +4720,7 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool {
 }
 
 func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
-       return cc.goAway == nil &&
+       return cc.goAway == nil && !cc.closed &&
                int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
                cc.nextStreamID < 2147483647
 }
@@ -4772,6 +4800,14 @@ func http2commaSeparatedTrailers(req *Request) (string, error) {
        return "", nil
 }
 
+func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
+       if cc.t.t1 != nil {
+               return cc.t.t1.ResponseHeaderTimeout
+       }
+
+       return 0
+}
+
 func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
        trailers, err := http2commaSeparatedTrailers(req)
        if err != nil {
@@ -4832,24 +4868,32 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
                return nil, werr
        }
 
+       var respHeaderTimer <-chan time.Time
        var bodyCopyErrc chan error // result of body copy
        if hasBody {
                bodyCopyErrc = make(chan error, 1)
                go func() {
                        bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
                }()
+       } else {
+               if d := cc.responseHeaderTimeout(); d != 0 {
+                       timer := time.NewTimer(d)
+                       defer timer.Stop()
+                       respHeaderTimer = timer.C
+               }
        }
 
        readLoopResCh := cs.resc
        requestCanceledCh := http2requestCancel(req)
-       requestCanceled := false
+       bodyWritten := false
+
        for {
                select {
                case re := <-readLoopResCh:
                        res := re.res
                        if re.err != nil || res.StatusCode > 299 {
 
-                               cs.abortRequestBodyWrite()
+                               cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
                        }
                        if re.err != nil {
                                cc.forgetStreamID(cs.ID)
@@ -4858,32 +4902,35 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
                        res.Request = req
                        res.TLS = cc.tlsState
                        return res, nil
+               case <-respHeaderTimer:
+                       cc.forgetStreamID(cs.ID)
+                       if !hasBody || bodyWritten {
+                               cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+                       } else {
+                               cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
+                       }
+                       return nil, http2errTimeout
                case <-requestCanceledCh:
                        cc.forgetStreamID(cs.ID)
-                       cs.abortRequestBodyWrite()
-                       if !hasBody {
+                       if !hasBody || bodyWritten {
                                cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
-                               return nil, http2errRequestCanceled
+                       } else {
+                               cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
                        }
-
-                       requestCanceled = true
-                       requestCanceledCh = nil
-                       readLoopResCh = nil
+                       return nil, http2errRequestCanceled
                case <-cs.peerReset:
-                       if requestCanceled {
-
-                               return nil, http2errRequestCanceled
-                       }
 
                        return nil, cs.resetErr
                case err := <-bodyCopyErrc:
-                       if requestCanceled {
-                               cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
-                               return nil, http2errRequestCanceled
-                       }
                        if err != nil {
                                return nil, err
                        }
+                       bodyWritten = true
+                       if d := cc.responseHeaderTimeout(); d != 0 {
+                               timer := time.NewTimer(d)
+                               defer timer.Stop()
+                               respHeaderTimer = timer.C
+                       }
                }
        }
 }
@@ -4916,9 +4963,14 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []
        return cc.werr
 }
 
-// errAbortReqBodyWrite is an internal error value.
-// It doesn't escape to callers.
-var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write")
+// internal error values; they don't escape to callers
+var (
+       // abort request body write; don't send cancel
+       http2errStopReqBodyWrite = errors.New("http2: aborting request body write")
+
+       // abort request body write, but send stream reset of cancel.
+       http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
+)
 
 func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
        cc := cs.cc
@@ -4951,7 +5003,13 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
                for len(remain) > 0 && err == nil {
                        var allowed int32
                        allowed, err = cs.awaitFlowControl(len(remain))
-                       if err != nil {
+                       switch {
+                       case err == http2errStopReqBodyWrite:
+                               return err
+                       case err == http2errStopReqBodyWriteAndCancel:
+                               cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+                               return err
+                       case err != nil:
                                return err
                        }
                        cc.wmu.Lock()
@@ -5005,8 +5063,8 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er
                if cc.closed {
                        return 0, http2errClientConnClosed
                }
-               if cs.stopReqBody {
-                       return 0, http2errAbortReqBodyWrite
+               if cs.stopReqBody != nil {
+                       return 0, cs.stopReqBody
                }
                if err := cs.checkReset(); err != nil {
                        return 0, err
@@ -5074,7 +5132,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
                        cc.writeHeader(lowKey, v)
                }
        }
-       if contentLength >= 0 {
+       if http2shouldSendReqContentLength(req.Method, contentLength) {
                cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
        }
        if addGzipHeader {
@@ -5086,6 +5144,27 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
        return cc.hbuf.Bytes()
 }
 
+// shouldSendReqContentLength reports whether the http2.Transport should send
+// a "content-length" request header. This logic is basically a copy of the net/http
+// transferWriter.shouldSendContentLength.
+// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
+// -1 means unknown.
+func http2shouldSendReqContentLength(method string, contentLength int64) bool {
+       if contentLength > 0 {
+               return true
+       }
+       if contentLength < 0 {
+               return false
+       }
+
+       switch method {
+       case "POST", "PUT", "PATCH":
+               return true
+       default:
+               return false
+       }
+}
+
 // requires cc.mu be held.
 func (cc *http2ClientConn) encodeTrailers(req *Request) []byte {
        cc.hbuf.Reset()
@@ -5204,6 +5283,8 @@ func (rl *http2clientConnReadLoop) cleanup() {
 
 func (rl *http2clientConnReadLoop) run() error {
        cc := rl.cc
+       closeWhenIdle := cc.t.disableKeepAlives()
+       gotReply := false
        for {
                f, err := cc.fr.ReadFrame()
                if err != nil {
@@ -5218,18 +5299,25 @@ func (rl *http2clientConnReadLoop) run() error {
                if http2VerboseLogs {
                        cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
                }
+               maybeIdle := false
 
                switch f := f.(type) {
                case *http2HeadersFrame:
                        err = rl.processHeaders(f)
+                       maybeIdle = true
+                       gotReply = true
                case *http2ContinuationFrame:
                        err = rl.processContinuation(f)
+                       maybeIdle = true
                case *http2DataFrame:
                        err = rl.processData(f)
+                       maybeIdle = true
                case *http2GoAwayFrame:
                        err = rl.processGoAway(f)
+                       maybeIdle = true
                case *http2RSTStreamFrame:
                        err = rl.processResetStream(f)
+                       maybeIdle = true
                case *http2SettingsFrame:
                        err = rl.processSettings(f)
                case *http2PushPromiseFrame:
@@ -5244,6 +5332,9 @@ func (rl *http2clientConnReadLoop) run() error {
                if err != nil {
                        return err
                }
+               if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
+                       cc.closeIfIdle()
+               }
        }
 }