]> Cypherpunks repositories - gostls13.git/commitdiff
http: add Transport.ProxySelector
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 18 May 2011 16:23:29 +0000 (09:23 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 18 May 2011 16:23:29 +0000 (09:23 -0700)
R=mattn.jp, rsc
CC=golang-dev
https://golang.org/cl/4528077

src/pkg/http/proxy_test.go
src/pkg/http/transport.go
src/pkg/http/transport_test.go

index 308bf44b48aec78f161592c0c050976cd97e76a7..9b320b3aa5b9904f26cc5aba7983f60ba5ddd2b7 100644 (file)
@@ -40,10 +40,8 @@ func TestUseProxy(t *testing.T) {
        no_proxy := "foobar.com, .barbaz.net"
        os.Setenv("NO_PROXY", no_proxy)
 
-       tr := &Transport{}
-
        for _, test := range UseProxyTests {
-               if tr.useProxy(test.host+":80") != test.match {
+               if useProxy(test.host+":80") != test.match {
                        t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
                }
        }
index fa912b1e18d1541f77149b1bfcb648b9abcae693..34bfbdd34a41844c4b436f92ebe9402c331ec1ee 100644 (file)
@@ -24,7 +24,7 @@ import (
 // each call to Do and uses HTTP proxies as directed by the
 // $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
 // environment variables.
-var DefaultTransport RoundTripper = &Transport{}
+var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
 
 // DefaultMaxIdleConnsPerHost is the default value of Transport's
 // MaxIdleConnsPerHost.
@@ -41,7 +41,12 @@ type Transport struct {
        // TODO: tunable on timeout on cached connections
        // TODO: optional pipelining
 
-       IgnoreEnvironment  bool // don't look at environment variables for proxy configuration
+       // Proxy optionally specifies a function to return a proxy for
+       // a given Request. If the function returns a non-nil error,
+       // the request is aborted with the provided error. If Proxy is
+       // nil or returns a nil *URL, no proxy is used.
+       Proxy func(*Request) (*URL, os.Error)
+
        DisableKeepAlives  bool
        DisableCompression bool
 
@@ -51,6 +56,39 @@ type Transport struct {
        MaxIdleConnsPerHost int
 }
 
+// ProxyFromEnvironment returns the URL of the proxy to use for a
+// given request, as indicated by the environment variables
+// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy).
+// Either URL or an error is returned.
+func ProxyFromEnvironment(req *Request) (*URL, os.Error) {
+       proxy := getenvEitherCase("HTTP_PROXY")
+       if proxy == "" {
+               return nil, nil
+       }
+       if !useProxy(canonicalAddr(req.URL)) {
+               return nil, nil
+       }
+       proxyURL, err := ParseRequestURL(proxy)
+       if err != nil {
+               return nil, os.ErrorString("invalid proxy address")
+       }
+       if proxyURL.Host == "" {
+               proxyURL, err = ParseRequestURL("http://" + proxy)
+               if err != nil {
+                       return nil, os.ErrorString("invalid proxy address")
+               }
+       }
+       return proxyURL, nil
+}
+
+// ProxyURL returns a proxy function (for use in a Transport)
+// that always returns the same URL.
+func ProxyURL(url *URL) func(*Request) (*URL, os.Error) {
+       return func(*Request) (*URL, os.Error) {
+               return url, nil
+       }
+}
+
 // RoundTrip implements the RoundTripper interface.
 func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
        if req.URL == nil {
@@ -101,21 +139,11 @@ func (t *Transport) CloseIdleConnections() {
 // Private implementation past this point.
 //
 
-func (t *Transport) getenvEitherCase(k string) string {
-       if t.IgnoreEnvironment {
-               return ""
-       }
-       if v := t.getenv(strings.ToUpper(k)); v != "" {
+func getenvEitherCase(k string) string {
+       if v := os.Getenv(strings.ToUpper(k)); v != "" {
                return v
        }
-       return t.getenv(strings.ToLower(k))
-}
-
-func (t *Transport) getenv(k string) string {
-       if t.IgnoreEnvironment {
-               return ""
-       }
-       return os.Getenv(k)
+       return os.Getenv(strings.ToLower(k))
 }
 
 func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
@@ -123,20 +151,12 @@ func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Er
                targetScheme: req.URL.Scheme,
                targetAddr:   canonicalAddr(req.URL),
        }
-
-       proxy := t.getenvEitherCase("HTTP_PROXY")
-       if proxy != "" && t.useProxy(cm.targetAddr) {
-               proxyURL, err := ParseRequestURL(proxy)
+       if t.Proxy != nil {
+               var err os.Error
+               cm.proxyURL, err = t.Proxy(req)
                if err != nil {
-                       return nil, os.ErrorString("invalid proxy address")
-               }
-               if proxyURL.Host == "" {
-                       proxyURL, err = ParseRequestURL("http://" + proxy)
-                       if err != nil {
-                               return nil, os.ErrorString("invalid proxy address")
-                       }
+                       return nil, err
                }
-               cm.proxyURL = proxyURL
        }
        return cm, nil
 }
@@ -296,7 +316,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
 // useProxy returns true if requests to addr should use a proxy,
 // according to the NO_PROXY or no_proxy environment variable.
 // addr is always a canonicalAddr with a host and port.
-func (t *Transport) useProxy(addr string) bool {
+func useProxy(addr string) bool {
        if len(addr) == 0 {
                return true
        }
@@ -313,7 +333,7 @@ func (t *Transport) useProxy(addr string) bool {
                }
        }
 
-       no_proxy := t.getenvEitherCase("NO_PROXY")
+       no_proxy := getenvEitherCase("NO_PROXY")
        if no_proxy == "*" {
                return false
        }
index 13865505efc545cd406670b4db11b5dc241d0d84..9cd18ffecf42b91e1a28fb136d80504aa53653b6 100644 (file)
@@ -478,6 +478,30 @@ func TestTransportGzip(t *testing.T) {
        }
 }
 
+func TestTransportProxy(t *testing.T) {
+       ch := make(chan string, 1)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               ch <- "real server"
+       }))
+       defer ts.Close()
+       proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               ch <- "proxy for " + r.URL.String()
+       }))
+       defer proxy.Close()
+
+       pu, err := ParseURL(proxy.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}}
+       c.Head(ts.URL)
+       got := <-ch
+       want := "proxy for " + ts.URL + "/"
+       if got != want {
+               t.Errorf("want %q, got %q", want, got)
+       }
+}
+
 // TestTransportGzipRecursive sends a gzip quine and checks that the
 // client gets the same value back. This is more cute than anything,
 // but checks that we don't recurse forever, and checks that