]> Cypherpunks repositories - gostls13.git/commitdiff
http: reverse proxy handler
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2011 15:13:52 +0000 (08:13 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2011 15:13:52 +0000 (08:13 -0700)
R=rsc, petar-m
CC=golang-dev
https://golang.org/cl/4428041

src/pkg/http/Makefile
src/pkg/http/reverseproxy.go [new file with mode: 0644]
src/pkg/http/reverseproxy_test.go [new file with mode: 0644]

index 389b0422274bba8e559fee92f61d903ee105bdd2..2a2a2a3beb8a5152c3507e9c0344199bb3bfe629 100644 (file)
@@ -16,6 +16,7 @@ GOFILES=\
        persist.go\
        request.go\
        response.go\
+       reverseproxy.go\
        server.go\
        status.go\
        transfer.go\
diff --git a/src/pkg/http/reverseproxy.go b/src/pkg/http/reverseproxy.go
new file mode 100644 (file)
index 0000000..e4ce1e3
--- /dev/null
@@ -0,0 +1,100 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package http
+
+import (
+       "io"
+       "log"
+       "net"
+       "strings"
+)
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+type ReverseProxy struct {
+       // Director must be a function which modifies
+       // the request into a new request to be sent
+       // using Transport. Its response is then copied
+       // back to the original client unmodified.
+       Director func(*Request)
+
+       // The Transport used to perform proxy requests.
+       // If nil, DefaultTransport is used.
+       Transport RoundTripper
+}
+
+func singleJoiningSlash(a, b string) string {
+       aslash := strings.HasSuffix(a, "/")
+       bslash := strings.HasPrefix(b, "/")
+       switch {
+       case aslash && bslash:
+               return a + b[1:]
+       case !aslash && !bslash:
+               return a + "/" + b
+       }
+       return a + b
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+func NewSingleHostReverseProxy(target *URL) *ReverseProxy {
+       director := func(req *Request) {
+               req.URL.Scheme = target.Scheme
+               req.URL.Host = target.Host
+               req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
+               if q := req.URL.RawQuery; q != "" {
+                       req.URL.RawPath = req.URL.Path + "?" + q
+               } else {
+                       req.URL.RawPath = req.URL.Path
+               }
+               req.URL.RawQuery = target.RawQuery
+       }
+       return &ReverseProxy{Director: director}
+}
+
+func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) {
+       transport := p.Transport
+       if transport == nil {
+               transport = DefaultTransport
+       }
+
+       outreq := new(Request)
+       *outreq = *req // includes shallow copies of maps, but okay
+
+       p.Director(outreq)
+       outreq.Proto = "HTTP/1.1"
+       outreq.ProtoMajor = 1
+       outreq.ProtoMinor = 1
+       outreq.Close = false
+
+       if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+               outreq.Header.Set("X-Forwarded-For", clientIp)
+       }
+
+       res, err := transport.RoundTrip(outreq)
+       if err != nil {
+               log.Printf("http: proxy error: %v", err)
+               rw.WriteHeader(StatusInternalServerError)
+               return
+       }
+
+       hdr := rw.Header()
+       for k, vv := range res.Header {
+               for _, v := range vv {
+                       hdr.Add(k, v)
+               }
+       }
+
+       rw.WriteHeader(res.StatusCode)
+
+       if res.Body != nil {
+               io.Copy(rw, res.Body)
+       }
+}
diff --git a/src/pkg/http/reverseproxy_test.go b/src/pkg/http/reverseproxy_test.go
new file mode 100644 (file)
index 0000000..8cf7705
--- /dev/null
@@ -0,0 +1,50 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package http_test
+
+import (
+       . "http"
+       "http/httptest"
+       "io/ioutil"
+       "testing"
+)
+
+func TestReverseProxy(t *testing.T) {
+       const backendResponse = "I am the backend"
+       const backendStatus = 404
+       backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.Header.Get("X-Forwarded-For") == "" {
+                       t.Errorf("didn't get X-Forwarded-For header")
+               }
+               w.Header().Set("X-Foo", "bar")
+               w.WriteHeader(backendStatus)
+               w.Write([]byte(backendResponse))
+       }))
+       defer backend.Close()
+       backendURL, err := ParseURL(backend.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       proxyHandler := NewSingleHostReverseProxy(backendURL)
+       frontend := httptest.NewServer(proxyHandler)
+       defer frontend.Close()
+
+       res, _, err := Get(frontend.URL)
+       if err != nil {
+               t.Fatalf("Get: %v", err)
+       }
+       if g, e := res.StatusCode, backendStatus; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+               t.Errorf("got X-Foo %q; expected %q", g, e)
+       }
+       bodyBytes, _ := ioutil.ReadAll(res.Body)
+       if g, e := string(bodyBytes), backendResponse; g != e {
+               t.Errorf("got body %q; expected %q", g, e)
+       }
+}