From 3aa3c052e302add1d30b0481b0347c47f190bef9 Mon Sep 17 00:00:00 2001 From: Aleksandr Razumov Date: Tue, 28 Aug 2018 03:29:01 +0300 Subject: [PATCH] net/http: rewind request body unconditionally When http2 fails with ErrNoCachedConn the request is retried with body that has already been read. Fixes #25009 Change-Id: I51ed5c8cf469dd8b17c73fff6140ab80162bf267 Reviewed-on: https://go-review.googlesource.com/c/131755 Run-TryBot: Iskander Sharipov TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/net/http/transport.go | 5 +- src/net/http/transport_internal_test.go | 83 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 7f8fd505bd..e6493036e8 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } testHookRoundTripRetried() - // Rewind the body if we're able to. (HTTP/2 does this itself so we only - // need to do it for HTTP/1.1 connections.) - if req.GetBody != nil && pconn.alt == nil { + // Rewind the body if we're able to. + if req.GetBody != nil { newReq := *req var err error newReq.Body, err = req.GetBody() diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go index a5f29c97a9..92729e65b2 100644 --- a/src/net/http/transport_internal_test.go +++ b/src/net/http/transport_internal_test.go @@ -7,8 +7,13 @@ package http import ( + "bytes" + "crypto/tls" "errors" + "io" + "io/ioutil" "net" + "net/http/internal" "strings" "testing" ) @@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) { } } } + +type roundTripFunc func(r *Request) (*Response, error) + +func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { + return f(r) +} + +// Issue 25009 +func TestTransportBodyAltRewind(t *testing.T) { + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + t.Fatal(err) + } + ln := newLocalListener(t) + defer ln.Close() + + go func() { + tln := tls.NewListener(ln, &tls.Config{ + NextProtos: []string{"foo"}, + Certificates: []tls.Certificate{cert}, + }) + for i := 0; i < 2; i++ { + sc, err := tln.Accept() + if err != nil { + t.Error(err) + return + } + if err := sc.(*tls.Conn).Handshake(); err != nil { + t.Error(err) + return + } + sc.Close() + } + }() + + addr := ln.Addr().String() + req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) + roundTripped := false + tr := &Transport{ + DisableKeepAlives: true, + TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { + return roundTripFunc(func(r *Request) (*Response, error) { + n, _ := io.Copy(ioutil.Discard, r.Body) + if n == 0 { + t.Error("body length is zero") + } + if roundTripped { + return &Response{ + Body: NoBody, + StatusCode: 200, + }, nil + } + roundTripped = true + return nil, http2noCachedConnError{} + }) + }, + }, + DialTLS: func(_, _ string) (net.Conn, error) { + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"foo"}, + }) + if err != nil { + return nil, err + } + if err := tc.Handshake(); err != nil { + return nil, err + } + return tc, nil + }, + } + c := &Client{Transport: tr} + _, err = c.Do(req) + if err != nil { + t.Error(err) + } +} -- 2.50.0