]> Cypherpunks repositories - gostls13.git/commitdiff
http: make Client redirect policy configurable
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 26 Apr 2011 05:41:50 +0000 (22:41 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 26 Apr 2011 05:41:50 +0000 (22:41 -0700)
Work on issue 155

R=rsc, bradfitzwork
CC=golang-dev
https://golang.org/cl/4435071

src/pkg/http/client.go
src/pkg/http/client_test.go

index daba3a89b0c21effce0c6f9e4f85c3d97afb6968..d73cbc8550cab60b1c0cb142e1d8d3bb46300658 100644 (file)
@@ -22,6 +22,16 @@ import (
 // Client is not yet very configurable.
 type Client struct {
        Transport RoundTripper // if nil, DefaultTransport is used
+
+       // If CheckRedirect is not nil, the client calls it before
+       // following an HTTP redirect. The arguments req and via
+       // are the upcoming request and the requests made already,
+       // oldest first. If CheckRedirect returns an error, the client
+       // returns that error instead of issue the Request req.
+       //
+       // If CheckRedirect is nil, the Client uses its default policy,
+       // which is to stop after 10 consecutive requests.
+       CheckRedirect func(req *Request, via []*Request) os.Error
 }
 
 // DefaultClient is the default Client and is used by Get, Head, and Post.
@@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool {
 }
 
 // Get issues a GET to the specified URL.  If the response is one of the following
-// redirect codes, it follows the redirect, up to a maximum of 10 redirects:
+// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
 //
 //    301 (Moved Permanently)
 //    302 (Found)
@@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) {
        return DefaultClient.Get(url)
 }
 
-// Get issues a GET to the specified URL.  If the response is one of the following
-// redirect codes, it follows the redirect, up to a maximum of 10 redirects:
+// Get issues a GET to the specified URL.  If the response is one of the
+// following redirect codes, Get follows the redirect after calling the
+// Client's CheckRedirect function.
 //
 //    301 (Moved Permanently)
 //    302 (Found)
 //    303 (See Other)
 //    307 (Temporary Redirect)
 //
-// finalURL is the URL from which the response was fetched -- identical to the
-// input URL unless redirects were followed.
+// finalURL is the URL from which the response was fetched -- identical
+// to the input URL unless redirects were followed.
 //
 // Caller should close r.Body when done reading from it.
 func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
        // TODO: if/when we add cookie support, the redirected request shouldn't
        // necessarily supply the same cookies as the original.
-       // TODO: set referrer header on redirects.
        var base *URL
-       // TODO: remove this hard-coded 10 and use the Client's policy
-       // (ClientConfig) instead.
-       for redirect := 0; ; redirect++ {
-               if redirect >= 10 {
-                       err = os.ErrorString("stopped after 10 redirects")
-                       break
-               }
+       redirectChecker := c.CheckRedirect
+       if redirectChecker == nil {
+               redirectChecker = defaultCheckRedirect
+       }
+       var via []*Request
 
+       for redirect := 0; ; redirect++ {
                var req Request
                req.Method = "GET"
-               req.ProtoMajor = 1
-               req.ProtoMinor = 1
+               req.Header = make(Header)
                if base == nil {
                        req.URL, err = ParseURL(url)
                } else {
@@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
                if err != nil {
                        break
                }
+               if len(via) > 0 {
+                       // Add the Referer header.
+                       lastReq := via[len(via)-1]
+                       if lastReq.URL.Scheme != "https" {
+                               req.Referer = lastReq.URL.String()
+                       }
+
+                       err = redirectChecker(&req, via)
+                       if err != nil {
+                               break
+                       }
+               }
+
                url = req.URL.String()
                if r, err = send(&req, c.Transport); err != nil {
                        break
@@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
                                break
                        }
                        base = req.URL
+                       via = append(via, &req)
                        continue
                }
                finalURL = url
@@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
        return
 }
 
+func defaultCheckRedirect(req *Request, via []*Request) os.Error {
+       if len(via) >= 10 {
+               return os.ErrorString("stopped after 10 redirects")
+       }
+       return nil
+}
+
 // Post issues a POST to the specified URL.
 //
 // Caller should close r.Body when done reading from it.
index 3a6f834253b3fef59b693a8a176df63ddb939ede..59d62c1c9d4cbd885e5299912a5d9b7e6658ab58 100644 (file)
@@ -12,6 +12,7 @@ import (
        "http/httptest"
        "io/ioutil"
        "os"
+       "strconv"
        "strings"
        "testing"
 )
@@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) {
                t.Errorf("expected non-nil request Header")
        }
 }
+
+func TestRedirects(t *testing.T) {
+       var ts *httptest.Server
+       ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               n, _ := strconv.Atoi(r.FormValue("n"))
+               // Test Referer header. (7 is arbitrary position to test at)
+               if n == 7 {
+                       if g, e := r.Referer, ts.URL+"/?n=6"; e != g {
+                               t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
+                       }
+               }
+               if n < 15 {
+                       Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
+                       return
+               }
+               fmt.Fprintf(w, "n=%d", n)
+       }))
+       defer ts.Close()
+
+       c := &Client{}
+       _, _, err := c.Get(ts.URL)
+       if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+               t.Errorf("with default client, expected error %q, got %q", e, g)
+       }
+
+       var checkErr os.Error
+       var lastVia []*Request
+       c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error {
+               lastVia = via
+               return checkErr
+       }}
+       _, finalUrl, err := c.Get(ts.URL)
+       if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
+               t.Errorf("with custom client, expected error %q, got %q", e, g)
+       }
+       if !strings.HasSuffix(finalUrl, "/?n=15") {
+               t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
+       }
+       if e, g := 15, len(lastVia); e != g {
+               t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
+       }
+
+       checkErr = os.NewError("no redirects allowed")
+       _, finalUrl, err = c.Get(ts.URL)
+       if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g {
+               t.Errorf("with redirects forbidden, expected error %q, got %q", e, g)
+       }
+}