]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: follow certain redirects after POST requests
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 12 Dec 2012 19:09:55 +0000 (11:09 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 12 Dec 2012 19:09:55 +0000 (11:09 -0800)
Fixes #4145

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6923055

src/pkg/net/http/client.go
src/pkg/net/http/client_test.go

index 2f957d23dbe7b27e94d5da191dbd67afcdc69767..2b28b77d1bebae38bb81e090b853c0b66fdf394e 100644 (file)
@@ -120,7 +120,10 @@ func (c *Client) send(req *Request) (*Response, error) {
 // Generally Get, Post, or PostForm will be used instead of Do.
 func (c *Client) Do(req *Request) (resp *Response, err error) {
        if req.Method == "GET" || req.Method == "HEAD" {
-               return c.doFollowingRedirects(req)
+               return c.doFollowingRedirects(req, shouldRedirectGet)
+       }
+       if req.Method == "POST" || req.Method == "PUT" {
+               return c.doFollowingRedirects(req, shouldRedirectPost)
        }
        return c.send(req)
 }
@@ -166,7 +169,7 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
 
 // True if the specified HTTP status code is one for which the Get utility should
 // automatically redirect.
-func shouldRedirect(statusCode int) bool {
+func shouldRedirectGet(statusCode int) bool {
        switch statusCode {
        case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
                return true
@@ -174,6 +177,16 @@ func shouldRedirect(statusCode int) bool {
        return false
 }
 
+// True if the specified HTTP status code is one for which the Post utility should
+// automatically redirect.
+func shouldRedirectPost(statusCode int) bool {
+       switch statusCode {
+       case StatusFound, StatusSeeOther:
+               return true
+       }
+       return false
+}
+
 // Get issues a GET to the specified URL.  If the response is one of the following
 // redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
 //
@@ -214,10 +227,10 @@ func (c *Client) Get(url string) (resp *Response, err error) {
        if err != nil {
                return nil, err
        }
-       return c.doFollowingRedirects(req)
+       return c.doFollowingRedirects(req, shouldRedirectGet)
 }
 
-func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) {
+func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
        // TODO: if/when we add cookie support, the redirected request shouldn't
        // necessarily supply the same cookies as the original.
        var base *url.URL
@@ -238,6 +251,9 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error)
                if redirect != 0 {
                        req = new(Request)
                        req.Method = ireq.Method
+                       if ireq.Method == "POST" || ireq.Method == "PUT" {
+                               req.Method = "GET"
+                       }
                        req.Header = make(Header)
                        req.URL, err = base.Parse(urlStr)
                        if err != nil {
@@ -321,7 +337,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon
                return nil, err
        }
        req.Header.Set("Content-Type", bodyType)
-       return c.send(req)
+       return c.doFollowingRedirects(req, shouldRedirectPost)
 }
 
 // PostForm issues a POST to the specified URL, with data's keys and
@@ -371,5 +387,5 @@ func (c *Client) Head(url string) (resp *Response, err error) {
        if err != nil {
                return nil, err
        }
-       return c.doFollowingRedirects(req)
+       return c.doFollowingRedirects(req, shouldRedirectGet)
 }
index f4ba6a9e652ac7d2d611c4c16b8c15f2a811930d..4bb336f1a9b9a2d4f04ea1312495fe29c6038288 100644 (file)
@@ -7,6 +7,7 @@
 package http_test
 
 import (
+       "bytes"
        "crypto/tls"
        "crypto/x509"
        "errors"
@@ -246,6 +247,52 @@ func TestRedirects(t *testing.T) {
        }
 }
 
+func TestPostRedirects(t *testing.T) {
+       var log struct {
+               sync.Mutex
+               bytes.Buffer
+       }
+       var ts *httptest.Server
+       ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               log.Lock()
+               fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
+               log.Unlock()
+               if v := r.URL.Query().Get("code"); v != "" {
+                       code, _ := strconv.Atoi(v)
+                       if code/100 == 3 {
+                               w.Header().Set("Location", ts.URL)
+                       }
+                       w.WriteHeader(code)
+               }
+       }))
+       tests := []struct {
+               suffix string
+               want   int // response code
+       }{
+               {"/", 200},
+               {"/?code=301", 301},
+               {"/?code=302", 200},
+               {"/?code=303", 200},
+               {"/?code=404", 404},
+       }
+       for _, tt := range tests {
+               res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if res.StatusCode != tt.want {
+                       t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
+               }
+       }
+       log.Lock()
+       got := log.String()
+       log.Unlock()
+       want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
+       if got != want {
+               t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
+       }
+}
+
 var expectedCookies = []*Cookie{
        {Name: "ChocolateChip", Value: "tasty"},
        {Name: "First", Value: "Hit"},