From 5694ebf057889444e8bbe97741004c4ecdcb7785 Mon Sep 17 00:00:00 2001 From: Colby Ranger Date: Wed, 18 Apr 2012 11:33:02 -0700 Subject: [PATCH] net/http/httputil: Clean up ReverseProxy maxLatencyWriter goroutines. When FlushInterval is specified on ReverseProxy, the ResponseWriter is wrapped with a maxLatencyWriter that periodically flushes in a goroutine. That goroutine was not being cleaned up at the end of the request. This resulted in a panic when Flush() was being called on a ResponseWriter that was closed. The code was updated to always send the done message to the flushLoop() goroutine after copying the body. Futhermore, the code was refactored to allow the test to verify the maxLatencyWriter behavior. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/6033043 --- src/pkg/net/http/httputil/reverseproxy.go | 48 +++++++++------ .../net/http/httputil/reverseproxy_test.go | 58 +++++++++++++++++++ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 9c4bd6e09a..2f08a8c0c9 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -17,6 +17,10 @@ import ( "time" ) +// beforeCopyResponse is a callback set by tests to intercept the state of the +// output io.Writer before the data is copied to it. +var beforeCopyResponse func(dst io.Writer) + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. @@ -112,20 +116,32 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) return } + defer res.Body.Close() copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} - if res.Body != nil { - var dst io.Writer = rw - if p.FlushInterval != 0 { - if wf, ok := rw.(writeFlusher); ok { - dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw } - io.Copy(dst, res.Body) } + + if beforeCopyResponse != nil { + beforeCopyResponse(dst) + } + io.Copy(dst, src) } type writeFlusher interface { @@ -137,22 +153,14 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects init of done, as well Write + Flush + lk sync.Mutex // protects Write + Flush done chan bool } -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { +func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.lk.Lock() defer m.lk.Unlock() - if m.done == nil { - m.done = make(chan bool) - go m.flushLoop() - } - n, err = m.dst.Write(p) - if err != nil { - m.done <- true - } - return + return m.dst.Write(p) } func (m *maxLatencyWriter) flushLoop() { @@ -160,13 +168,15 @@ func (m *maxLatencyWriter) flushLoop() { defer t.Stop() for { select { + case <-m.done: + return case <-t.C: m.lk.Lock() m.dst.Flush() m.lk.Unlock() - case <-m.done: - return } } panic("unreached") } + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 28e9c90ad3..3bcb23c077 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -7,11 +7,14 @@ package httputil import ( + "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" + "runtime" "testing" + "time" ) func TestReverseProxy(t *testing.T) { @@ -107,3 +110,58 @@ func TestReverseProxyQuery(t *testing.T) { frontend.Close() } } + +func TestReverseProxyFlushInterval(t *testing.T) { + if testing.Short() { + return + } + + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + dstChan := make(chan io.Writer, 1) + beforeCopyResponse = func(dst io.Writer) { dstChan <- dst } + defer func() { beforeCopyResponse = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + initGoroutines := runtime.NumGoroutine() + for i := 0; i < 100; i++ { + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case dst := <-dstChan: + if _, ok := dst.(*maxLatencyWriter); !ok { + t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{}) + } + default: + t.Error("maxLatencyWriter Write() was never called") + } + + res.Body.Close() + } + // Allow up to 50 additional goroutines over 100 requests. + if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 { + t.Errorf("grew %d goroutines; leak?", delta) + } +} -- 2.50.0