]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: run the ReverseProxy.ModifyResponse hook for upgrades
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 9 Jan 2019 15:06:20 +0000 (15:06 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 9 Jan 2019 15:51:59 +0000 (15:51 +0000)
Fixes #29627

Change-Id: I08a5b45151a11b5a4f3b5a2d984c0322cf904697
Reviewed-on: https://go-review.googlesource.com/c/157098
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go

index 1c9feb7d7d2f540b693436551408403bbe7e0754..4e10bf399711b5f54e9cb985f0049738db46a97b 100644 (file)
@@ -171,6 +171,20 @@ func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request
        return p.defaultErrorHandler
 }
 
+// modifyResponse conditionally runs the optional ModifyResponse hook
+// and reports whether the request should proceed.
+func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
+       if p.ModifyResponse == nil {
+               return true
+       }
+       if err := p.ModifyResponse(res); err != nil {
+               res.Body.Close()
+               p.getErrorHandler()(rw, req, err)
+               return false
+       }
+       return true
+}
+
 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
        transport := p.Transport
        if transport == nil {
@@ -250,6 +264,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 
        // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
        if res.StatusCode == http.StatusSwitchingProtocols {
+               if !p.modifyResponse(rw, res, outreq) {
+                       return
+               }
                p.handleUpgradeResponse(rw, outreq, res)
                return
        }
@@ -260,12 +277,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
                res.Header.Del(h)
        }
 
-       if p.ModifyResponse != nil {
-               if err := p.ModifyResponse(res); err != nil {
-                       res.Body.Close()
-                       p.getErrorHandler()(rw, outreq, err)
-                       return
-               }
+       if !p.modifyResponse(rw, res, outreq) {
+               return
        }
 
        copyHeader(rw.Header(), res.Header)
index bda569acc73bdc5bde1362d03659c0bd56f43d88..5edefa08e55a3c1a17bda4d4957b12dbc4bfeb91 100644 (file)
@@ -1012,6 +1012,10 @@ func TestReverseProxyWebSocket(t *testing.T) {
        backURL, _ := url.Parse(backendServer.URL)
        rproxy := NewSingleHostReverseProxy(backURL)
        rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+       rproxy.ModifyResponse = func(res *http.Response) error {
+               res.Header.Add("X-Modified", "true")
+               return nil
+       }
 
        handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
                rw.Header().Set("X-Header", "X-Value")
@@ -1049,6 +1053,10 @@ func TestReverseProxyWebSocket(t *testing.T) {
        }
        defer rwc.Close()
 
+       if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+               t.Errorf("response X-Modified header = %q; want %q", got, want)
+       }
+
        io.WriteString(rwc, "Hello\n")
        bs := bufio.NewScanner(rwc)
        if !bs.Scan() {