From: Brad Fitzpatrick Date: Wed, 12 Dec 2012 19:09:55 +0000 (-0800) Subject: net/http: follow certain redirects after POST requests X-Git-Tag: go1.1rc2~1651 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=08ce7f1d5c309cf4187fbb1e442ee4388c29212b;p=gostls13.git net/http: follow certain redirects after POST requests Fixes #4145 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/6923055 --- diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 2f957d23db..2b28b77d1b 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -120,7 +120,10 @@ func (c *Client) send(req *Request) (*Response, error) { // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { if req.Method == "GET" || req.Method == "HEAD" { - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) + } + if req.Method == "POST" || req.Method == "PUT" { + return c.doFollowingRedirects(req, shouldRedirectPost) } return c.send(req) } @@ -166,7 +169,7 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { // True if the specified HTTP status code is one for which the Get utility should // automatically redirect. -func shouldRedirect(statusCode int) bool { +func shouldRedirectGet(statusCode int) bool { switch statusCode { case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: return true @@ -174,6 +177,16 @@ func shouldRedirect(statusCode int) bool { return false } +// True if the specified HTTP status code is one for which the Post utility should +// automatically redirect. +func shouldRedirectPost(statusCode int) bool { + switch statusCode { + case StatusFound, StatusSeeOther: + return true + } + return false +} + // Get issues a GET to the specified URL. If the response is one of the following // redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // @@ -214,10 +227,10 @@ func (c *Client) Get(url string) (resp *Response, err error) { if err != nil { return nil, err } - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } -func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) { +func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. var base *url.URL @@ -238,6 +251,9 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) if redirect != 0 { req = new(Request) req.Method = ireq.Method + if ireq.Method == "POST" || ireq.Method == "PUT" { + req.Method = "GET" + } req.Header = make(Header) req.URL, err = base.Parse(urlStr) if err != nil { @@ -321,7 +337,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon return nil, err } req.Header.Set("Content-Type", bodyType) - return c.send(req) + return c.doFollowingRedirects(req, shouldRedirectPost) } // PostForm issues a POST to the specified URL, with data's keys and @@ -371,5 +387,5 @@ func (c *Client) Head(url string) (resp *Response, err error) { if err != nil { return nil, err } - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index f4ba6a9e65..4bb336f1a9 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -7,6 +7,7 @@ package http_test import ( + "bytes" "crypto/tls" "crypto/x509" "errors" @@ -246,6 +247,52 @@ func TestRedirects(t *testing.T) { } } +func TestPostRedirects(t *testing.T) { + var log struct { + sync.Mutex + bytes.Buffer + } + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log.Lock() + fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + log.Unlock() + if v := r.URL.Query().Get("code"); v != "" { + code, _ := strconv.Atoi(v) + if code/100 == 3 { + w.Header().Set("Location", ts.URL) + } + w.WriteHeader(code) + } + })) + tests := []struct { + suffix string + want int // response code + }{ + {"/", 200}, + {"/?code=301", 301}, + {"/?code=302", 200}, + {"/?code=303", 200}, + {"/?code=404", 404}, + } + for _, tt := range tests { + res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.want { + t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want) + } + } + log.Lock() + got := log.String() + log.Unlock() + want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + if got != want { + t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) + } +} + var expectedCookies = []*Cookie{ {Name: "ChocolateChip", Value: "tasty"}, {Name: "First", Value: "Hit"},