]> Cypherpunks repositories - gostls13.git/commitdiff
http/http/httputil: add ReverseProxy.ErrorHandler
authorJulien Salleyron <julien.salleyron@gmail.com>
Mon, 13 Nov 2017 22:32:07 +0000 (23:32 +0100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 13 Jul 2018 16:45:34 +0000 (16:45 +0000)
This permits specifying an ErrorHandler to customize the RoundTrip
error handling if the backend fails to return a response.

Fixes #22700
Fixes #21255

Change-Id: I8879f0956e2472a07f584660afa10105ef23bf11
Reviewed-on: https://go-review.googlesource.com/77410
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go

index 6f0a2418b3282553e0d89c95d3091a6c35cbb669..1dddaa95a7707c134bcdd07ec6c385b1c5dfec81 100644 (file)
@@ -55,10 +55,23 @@ type ReverseProxy struct {
        // copying HTTP response bodies.
        BufferPool BufferPool
 
-       // ModifyResponse is an optional function that
-       // modifies the Response from the backend.
-       // If it returns an error, the proxy returns a StatusBadGateway error.
+       // ModifyResponse is an optional function that modifies the
+       // Response from the backend. It is called if the backend
+       // returns a response at all, with any HTTP status code.
+       // If the backend is unreachable, the optional ErrorHandler is
+       // called without any call to ModifyResponse.
+       //
+       // If ModifyResponse returns an error, ErrorHandler is called
+       // with its error value. If ErrorHandler is nil, its default
+       // implementation is used.
        ModifyResponse func(*http.Response) error
+
+       // ErrorHandler is an optional function that handles errors
+       // reaching the backend or errors from ModifyResponse.
+       //
+       // If nil, the default is to log the provided error and return
+       // a 502 Status Bad Gateway response.
+       ErrorHandler func(http.ResponseWriter, *http.Request, error)
 }
 
 // A BufferPool is an interface for getting and returning temporary
@@ -141,6 +154,18 @@ var hopHeaders = []string{
        "Upgrade",
 }
 
+func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
+       p.logf("http: proxy error: %v", err)
+       rw.WriteHeader(http.StatusBadGateway)
+}
+
+func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
+       if p.ErrorHandler != nil {
+               return p.ErrorHandler
+       }
+       return p.defaultErrorHandler
+}
+
 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
        transport := p.Transport
        if transport == nil {
@@ -206,8 +231,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 
        res, err := transport.RoundTrip(outreq)
        if err != nil {
-               p.logf("http: proxy error: %v", err)
-               rw.WriteHeader(http.StatusBadGateway)
+               p.getErrorHandler()(rw, outreq, err)
                return
        }
 
@@ -219,9 +243,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 
        if p.ModifyResponse != nil {
                if err := p.ModifyResponse(res); err != nil {
-                       p.logf("http: proxy error: %v", err)
-                       rw.WriteHeader(http.StatusBadGateway)
                        res.Body.Close()
+                       p.getErrorHandler()(rw, outreq, err)
                        return
                }
        }
index 2a12e753b5c94cfbceaac6ffa54ffff5430889f9..2f75b4e34ec9917356c0a11806446fd0f99962f2 100644 (file)
@@ -637,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) {
        }
 }
 
+type failingRoundTripper struct{}
+
+func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+       return nil, errors.New("some error")
+}
+
+type staticResponseRoundTripper struct{ res *http.Response }
+
+func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+       return rt.res, nil
+}
+
+func TestReverseProxyErrorHandler(t *testing.T) {
+       tests := []struct {
+               name           string
+               wantCode       int
+               errorHandler   func(http.ResponseWriter, *http.Request, error)
+               transport      http.RoundTripper // defaults to failingRoundTripper
+               modifyResponse func(*http.Response) error
+       }{
+               {
+                       name:     "default",
+                       wantCode: http.StatusBadGateway,
+               },
+               {
+                       name:         "errorhandler",
+                       wantCode:     http.StatusTeapot,
+                       errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+               },
+               {
+                       name: "modifyresponse_noerr",
+                       transport: staticResponseRoundTripper{
+                               &http.Response{StatusCode: 345, Body: http.NoBody},
+                       },
+                       modifyResponse: func(res *http.Response) error {
+                               res.StatusCode++
+                               return nil
+                       },
+                       errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+                       wantCode:     346,
+               },
+               {
+                       name: "modifyresponse_err",
+                       transport: staticResponseRoundTripper{
+                               &http.Response{StatusCode: 345, Body: http.NoBody},
+                       },
+                       modifyResponse: func(res *http.Response) error {
+                               res.StatusCode++
+                               return errors.New("some error to trigger errorHandler")
+                       },
+                       errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+                       wantCode:     http.StatusTeapot,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       target := &url.URL{
+                               Scheme: "http",
+                               Host:   "dummy.tld",
+                               Path:   "/",
+                       }
+                       rproxy := NewSingleHostReverseProxy(target)
+                       rproxy.Transport = tt.transport
+                       rproxy.ModifyResponse = tt.modifyResponse
+                       if rproxy.Transport == nil {
+                               rproxy.Transport = failingRoundTripper{}
+                       }
+                       rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+                       if tt.errorHandler != nil {
+                               rproxy.ErrorHandler = tt.errorHandler
+                       }
+                       frontendProxy := httptest.NewServer(rproxy)
+                       defer frontendProxy.Close()
+
+                       resp, err := http.Get(frontendProxy.URL + "/test")
+                       if err != nil {
+                               t.Fatalf("failed to reach proxy: %v", err)
+                       }
+                       if g, e := resp.StatusCode, tt.wantCode; g != e {
+                               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+                       }
+                       resp.Body.Close()
+               })
+       }
+}
+
 // Issue 16659: log errors from short read
 func TestReverseProxy_CopyBuffer(t *testing.T) {
        backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {