]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: make ReverseProxy close response body if ModifyResponse returns...
authorEdan B <3d4nb3@gmail.com>
Sat, 11 Nov 2017 08:10:14 +0000 (10:10 +0200)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 11 Nov 2017 20:12:59 +0000 (20:12 +0000)
Fixes #22658

Change-Id: I00e2b007d77b6f54798f7755d0b08e4fea824392
Reviewed-on: https://go-review.googlesource.com/77170
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go

index a0f36d1221d5271d511f4896249258d1f218cc31..b96bb21019b0fc1492c4086c9f6c75867b38d1d2 100644 (file)
@@ -207,6 +207,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
                if err := p.ModifyResponse(res); err != nil {
                        p.logf("http: proxy error: %v", err)
                        rw.WriteHeader(http.StatusBadGateway)
+                       res.Body.Close()
                        return
                }
        }
index 37a9992375d4d1e20a374c0652ec146c578c3ea3..2232042d3ed14eec76116bc0b778e5b14e3fdb74 100644 (file)
@@ -769,3 +769,47 @@ type roundTripperFunc func(req *http.Request) (*http.Response, error)
 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
        return fn(req)
 }
+
+func TestModifyResponseClosesBody(t *testing.T) {
+       req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+       req.RemoteAddr = "1.2.3.4:56789"
+       closeCheck := new(checkCloser)
+       logBuf := new(bytes.Buffer)
+       outErr := errors.New("ModifyResponse error")
+       rp := &ReverseProxy{
+               Director: func(req *http.Request) {},
+               Transport: &staticTransport{&http.Response{
+                       StatusCode: 200,
+                       Body:       closeCheck,
+               }},
+               ErrorLog: log.New(logBuf, "", 0),
+               ModifyResponse: func(*http.Response) error {
+                       return outErr
+               },
+       }
+       rec := httptest.NewRecorder()
+       rp.ServeHTTP(rec, req)
+       res := rec.Result()
+       if g, e := res.StatusCode, http.StatusBadGateway; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       if !closeCheck.closed {
+               t.Errorf("body should have been closed")
+       }
+       if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
+               t.Errorf("ErrorLog %q does not contain %q", g, e)
+       }
+}
+
+type checkCloser struct {
+       closed bool
+}
+
+func (cc *checkCloser) Close() error {
+       cc.closed = true
+       return nil
+}
+
+func (cc *checkCloser) Read(b []byte) (int, error) {
+       return len(b), nil
+}