From 5b588e668276b74b69be41fccbbc6135606d7d69 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 18 Jan 2016 14:50:52 -0800 Subject: [PATCH] net/http: make http2 Transport send Content Length Updates x/net/http2 to git rev 5c0dae8 for https://golang.org/cl/18709 Fixes #14003 Change-Id: I8bc205d6d089107b017e3458bbc7e05f6d0cae60 Reviewed-on: https://go-review.googlesource.com/18730 Reviewed-by: Andrew Gerrand Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/http/clientserver_test.go | 29 ++++++++++++++++++++++ src/net/http/h2_bundle.go | 40 +++++++++++++++++++++++++------ 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 3a601d304b..573ed93c05 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -421,6 +421,35 @@ func TestH12_ServerEmptyContentLength(t *testing.T) { }.run(t) } +func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) +} + +func TestH12_RequestContentLength_Known_Zero(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0) +} + +func TestH12_RequestContentLength_Unknown(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) +} + +func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) + fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) + }, + ReqFunc: func(c *Client, url string) (*Response, error) { + return c.Post(url, "text/plain", bodyfn()) + }, + CheckResponse: func(proto string, res *Response) { + if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { + t.Errorf("Proto %q got length %q; want %q", proto, got, want) + } + }, + }.run(t) +} + // Tests that closing the Request.Cancel channel also while still // reading the response body. Issue 13159. func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index d40fabd021..42f0ac1c69 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -4779,6 +4779,25 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } hasTrailers := trailers != "" + var body io.Reader = req.Body + contentLen := req.ContentLength + if req.Body != nil && contentLen == 0 { + // Test to see if it's actually zero or just unset. + var buf [1]byte + n, rerr := io.ReadFull(body, buf[:]) + if rerr != nil && rerr != io.EOF { + contentLen = -1 + body = http2errorReader{rerr} + } else if n == 1 { + + contentLen = -1 + body = io.MultiReader(bytes.NewReader(buf[:]), body) + } else { + + body = nil + } + } + cc.mu.Lock() if cc.closed || !cc.canTakeNewRequestLocked() { cc.mu.Unlock() @@ -4787,7 +4806,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs := cc.newStream() cs.req = req - hasBody := req.Body != nil + hasBody := body != nil if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && @@ -4797,7 +4816,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs.requestedGzip = true } - hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers) + hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) cc.wmu.Lock() endStream := !hasBody && !hasTrailers werr := cc.writeHeaders(cs.ID, endStream, hdrs) @@ -4817,7 +4836,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if hasBody { bodyCopyErrc = make(chan error, 1) go func() { - bodyCopyErrc <- cs.writeRequestBody(req.Body) + bodyCopyErrc <- cs.writeRequestBody(body, req.Body) }() } @@ -4901,7 +4920,7 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs [] // It doesn't escape to callers. var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write") -func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { +func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { cc := cs.cc sentEnd := false buf := cc.frameScratchBuffer() @@ -4909,7 +4928,7 @@ func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { defer func() { - cerr := body.Close() + cerr := bodyCloser.Close() if err == nil { err = cerr } @@ -5016,7 +5035,7 @@ type http2badStringError struct { func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } // requires cc.mu be held. -func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string) []byte { +func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) []byte { cc.hbuf.Reset() host := req.Host @@ -5037,7 +5056,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail var didUA bool for k, vv := range req.Header { lowKey := strings.ToLower(k) - if lowKey == "host" { + if lowKey == "host" || lowKey == "content-length" { continue } if lowKey == "user-agent" { @@ -5055,6 +5074,9 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail cc.writeHeader(lowKey, v) } } + if contentLength >= 0 { + cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) + } if addGzipHeader { cc.writeHeader("accept-encoding", "gzip") } @@ -5745,6 +5767,10 @@ func (gz *http2gzipReader) Close() error { return gz.body.Close() } +type http2errorReader struct{ err error } + +func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error -- 2.48.1