]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: rewind request body unconditionally
authorAleksandr Razumov <ar@cydev.ru>
Tue, 28 Aug 2018 00:29:01 +0000 (03:29 +0300)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 2 Oct 2018 21:11:23 +0000 (21:11 +0000)
When http2 fails with ErrNoCachedConn the request is retried with body
that has already been read.

Fixes #25009

Change-Id: I51ed5c8cf469dd8b17c73fff6140ab80162bf267
Reviewed-on: https://go-review.googlesource.com/c/131755
Run-TryBot: Iskander Sharipov <iskander.sharipov@intel.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/transport.go
src/net/http/transport_internal_test.go

index 7f8fd505bd9ddd8797e6d22c9558f8162ff4b5a6..e6493036e8cfdca96d359d86a04fa56c8decf23b 100644 (file)
@@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                }
                testHookRoundTripRetried()
 
-               // Rewind the body if we're able to.  (HTTP/2 does this itself so we only
-               // need to do it for HTTP/1.1 connections.)
-               if req.GetBody != nil && pconn.alt == nil {
+               // Rewind the body if we're able to.
+               if req.GetBody != nil {
                        newReq := *req
                        var err error
                        newReq.Body, err = req.GetBody()
index a5f29c97a9087cc2b7716b50ee247e178e5ef666..92729e65b26471f133eb91288f89241b0738bc9a 100644 (file)
@@ -7,8 +7,13 @@
 package http
 
 import (
+       "bytes"
+       "crypto/tls"
        "errors"
+       "io"
+       "io/ioutil"
        "net"
+       "net/http/internal"
        "strings"
        "testing"
 )
@@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) {
                }
        }
 }
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
+       return f(r)
+}
+
+// Issue 25009
+func TestTransportBodyAltRewind(t *testing.T) {
+       cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
+       if err != nil {
+               t.Fatal(err)
+       }
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       go func() {
+               tln := tls.NewListener(ln, &tls.Config{
+                       NextProtos:   []string{"foo"},
+                       Certificates: []tls.Certificate{cert},
+               })
+               for i := 0; i < 2; i++ {
+                       sc, err := tln.Accept()
+                       if err != nil {
+                               t.Error(err)
+                               return
+                       }
+                       if err := sc.(*tls.Conn).Handshake(); err != nil {
+                               t.Error(err)
+                               return
+                       }
+                       sc.Close()
+               }
+       }()
+
+       addr := ln.Addr().String()
+       req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
+       roundTripped := false
+       tr := &Transport{
+               DisableKeepAlives: true,
+               TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+                       "foo": func(authority string, c *tls.Conn) RoundTripper {
+                               return roundTripFunc(func(r *Request) (*Response, error) {
+                                       n, _ := io.Copy(ioutil.Discard, r.Body)
+                                       if n == 0 {
+                                               t.Error("body length is zero")
+                                       }
+                                       if roundTripped {
+                                               return &Response{
+                                                       Body:       NoBody,
+                                                       StatusCode: 200,
+                                               }, nil
+                                       }
+                                       roundTripped = true
+                                       return nil, http2noCachedConnError{}
+                               })
+                       },
+               },
+               DialTLS: func(_, _ string) (net.Conn, error) {
+                       tc, err := tls.Dial("tcp", addr, &tls.Config{
+                               InsecureSkipVerify: true,
+                               NextProtos:         []string{"foo"},
+                       })
+                       if err != nil {
+                               return nil, err
+                       }
+                       if err := tc.Handshake(); err != nil {
+                               return nil, err
+                       }
+                       return tc, nil
+               },
+       }
+       c := &Client{Transport: tr}
+       _, err = c.Do(req)
+       if err != nil {
+               t.Error(err)
+       }
+}