From a14ed2a82a1563ba89e1f22ab517bf3c9abe416f Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 9 Jan 2019 15:06:20 +0000 Subject: [PATCH] net/http/httputil: run the ReverseProxy.ModifyResponse hook for upgrades Fixes #29627 Change-Id: I08a5b45151a11b5a4f3b5a2d984c0322cf904697 Reviewed-on: https://go-review.googlesource.com/c/157098 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Ian Lance Taylor --- src/net/http/httputil/reverseproxy.go | 25 ++++++++++++++++------ src/net/http/httputil/reverseproxy_test.go | 8 +++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 1c9feb7d7d..4e10bf3997 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -171,6 +171,20 @@ func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request return p.defaultErrorHandler } +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + if err := p.ModifyResponse(res); err != nil { + res.Body.Close() + p.getErrorHandler()(rw, req, err) + return false + } + return true +} + func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { @@ -250,6 +264,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, outreq) { + return + } p.handleUpgradeResponse(rw, outreq, res) return } @@ -260,12 +277,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { res.Header.Del(h) } - if p.ModifyResponse != nil { - if err := p.ModifyResponse(res); err != nil { - res.Body.Close() - p.getErrorHandler()(rw, outreq, err) - return - } + if !p.modifyResponse(rw, res, outreq) { + return } copyHeader(rw.Header(), res.Header) diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index bda569acc7..5edefa08e5 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -1012,6 +1012,10 @@ func TestReverseProxyWebSocket(t *testing.T) { backURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(backURL) rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("X-Header", "X-Value") @@ -1049,6 +1053,10 @@ func TestReverseProxyWebSocket(t *testing.T) { } defer rwc.Close() + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + io.WriteString(rwc, "Hello\n") bs := bufio.NewScanner(rwc) if !bs.Scan() { -- 2.50.0