]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Request.GetBody func for 307/308 redirects
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 21 Oct 2016 11:03:41 +0000 (12:03 +0100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 22 Oct 2016 01:22:56 +0000 (01:22 +0000)
Updates #10767

Change-Id: I197535f71bc2dc45e783f38d8031aa717d50fd80
Reviewed-on: https://go-review.googlesource.com/31733
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/client.go
src/net/http/request.go
src/net/http/request_test.go

index 39c38bd8dd0ff8ae92968bb37344a7a5b43f445e..9b60f35708c539b56d304c0e84bc6083cb3bba54 100644 (file)
@@ -485,8 +485,15 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo
                                Cancel:   ireq.Cancel,
                                ctx:      ireq.ctx,
                        }
+                       if ireq.GetBody != nil {
+                               req.Body, err = ireq.GetBody()
+                               if err != nil {
+                                       return nil, uerr(err)
+                               }
+                       }
                        if ireq.Method == "POST" || ireq.Method == "PUT" {
                                req.Method = "GET"
+                               req.Body = nil // TODO: fix this when 307/308 support happens
                        }
                        // Copy the initial request's Header values
                        // (at least the safe ones).  Do this before
index 83d6c81de9c4227aeb1c3219d48b907b4c0b9283..551310cab0464f7f612b34c0248e93dddcb867bb 100644 (file)
@@ -151,6 +151,14 @@ type Request struct {
        // Handler does not need to.
        Body io.ReadCloser
 
+       // GetBody defines an optional func to return a new copy of
+       // Body. It used for client requests when a redirect requires
+       // reading the body more than once. Use of GetBody still
+       // requires setting Body.
+       //
+       // For server requests it is unused.
+       GetBody func() (io.ReadCloser, error)
+
        // ContentLength records the length of the associated content.
        // The value -1 indicates that the length is unknown.
        // Values >= 0 indicate that the given number of bytes may
@@ -738,10 +746,25 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
                switch v := body.(type) {
                case *bytes.Buffer:
                        req.ContentLength = int64(v.Len())
+                       buf := v.Bytes()
+                       req.GetBody = func() (io.ReadCloser, error) {
+                               r := bytes.NewReader(buf)
+                               return ioutil.NopCloser(r), nil
+                       }
                case *bytes.Reader:
                        req.ContentLength = int64(v.Len())
+                       snapshot := *v
+                       req.GetBody = func() (io.ReadCloser, error) {
+                               r := snapshot
+                               return ioutil.NopCloser(&r), nil
+                       }
                case *strings.Reader:
                        req.ContentLength = int64(v.Len())
+                       snapshot := *v
+                       req.GetBody = func() (io.ReadCloser, error) {
+                               r := snapshot
+                               return ioutil.NopCloser(&r), nil
+                       }
                default:
                        req.ContentLength = -1 // unknown
                }
@@ -751,6 +774,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
                // to set the Body to nil.
                if req.ContentLength == 0 {
                        req.Body = nil
+                       req.GetBody = nil
                }
        }
 
index f12b41cf1bc10a0529c2cad0c86050b4a6deeba8..e463d79492fb88c5145ffa72c528d7406b08c53c 100644 (file)
@@ -784,6 +784,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) {
        }
 }
 
+// verify that NewRequest sets Request.GetBody and that it works
+func TestNewRequestGetBody(t *testing.T) {
+       tests := []struct {
+               r io.Reader
+       }{
+               {r: strings.NewReader("hello")},
+               {r: bytes.NewReader([]byte("hello"))},
+               {r: bytes.NewBuffer([]byte("hello"))},
+       }
+       for i, tt := range tests {
+               req, err := NewRequest("POST", "http://foo.tld/", tt.r)
+               if err != nil {
+                       t.Errorf("test[%d]: %v", i, err)
+                       continue
+               }
+               if req.Body == nil {
+                       t.Errorf("test[%d]: Body = nil", i)
+                       continue
+               }
+               if req.GetBody == nil {
+                       t.Errorf("test[%d]: GetBody = nil", i)
+                       continue
+               }
+               slurp1, err := ioutil.ReadAll(req.Body)
+               if err != nil {
+                       t.Errorf("test[%d]: ReadAll(Body) = %v", i, err)
+               }
+               newBody, err := req.GetBody()
+               if err != nil {
+                       t.Errorf("test[%d]: GetBody = %v", i, err)
+               }
+               slurp2, err := ioutil.ReadAll(newBody)
+               if err != nil {
+                       t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err)
+               }
+               if string(slurp1) != string(slurp2) {
+                       t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2)
+               }
+       }
+}
+
 func testMissingFile(t *testing.T, req *Request) {
        f, fh, err := req.FormFile("missing")
        if f != nil {