]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: Clean up ReverseProxy maxLatencyWriter goroutines.
authorColby Ranger <cranger@google.com>
Wed, 18 Apr 2012 18:33:02 +0000 (11:33 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 18 Apr 2012 18:33:02 +0000 (11:33 -0700)
When FlushInterval is specified on ReverseProxy, the ResponseWriter is
wrapped with a maxLatencyWriter that periodically flushes in a
goroutine. That goroutine was not being cleaned up at the end of the
request. This resulted in a panic when Flush() was being called on a
ResponseWriter that was closed.

The code was updated to always send the done message to the flushLoop()
goroutine after copying the body. Futhermore, the code was refactored to
allow the test to verify the maxLatencyWriter behavior.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/6033043

src/pkg/net/http/httputil/reverseproxy.go
src/pkg/net/http/httputil/reverseproxy_test.go

index 9c4bd6e09a5c81aa09c9ea264b597b19ab21a2d0..2f08a8c0c9ca1f5f3d1f5f9e11f3837b5a2bd99a 100644 (file)
@@ -17,6 +17,10 @@ import (
        "time"
 )
 
+// beforeCopyResponse is a callback set by tests to intercept the state of the
+// output io.Writer before the data is copied to it.
+var beforeCopyResponse func(dst io.Writer)
+
 // ReverseProxy is an HTTP Handler that takes an incoming request and
 // sends it to another server, proxying the response back to the
 // client.
@@ -112,20 +116,32 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
                rw.WriteHeader(http.StatusInternalServerError)
                return
        }
+       defer res.Body.Close()
 
        copyHeader(rw.Header(), res.Header)
 
        rw.WriteHeader(res.StatusCode)
+       p.copyResponse(rw, res.Body)
+}
 
-       if res.Body != nil {
-               var dst io.Writer = rw
-               if p.FlushInterval != 0 {
-                       if wf, ok := rw.(writeFlusher); ok {
-                               dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+       if p.FlushInterval != 0 {
+               if wf, ok := dst.(writeFlusher); ok {
+                       mlw := &maxLatencyWriter{
+                               dst:     wf,
+                               latency: p.FlushInterval,
+                               done:    make(chan bool),
                        }
+                       go mlw.flushLoop()
+                       defer mlw.stop()
+                       dst = mlw
                }
-               io.Copy(dst, res.Body)
        }
+
+       if beforeCopyResponse != nil {
+               beforeCopyResponse(dst)
+       }
+       io.Copy(dst, src)
 }
 
 type writeFlusher interface {
@@ -137,22 +153,14 @@ type maxLatencyWriter struct {
        dst     writeFlusher
        latency time.Duration
 
-       lk   sync.Mutex // protects init of done, as well Write + Flush
+       lk   sync.Mutex // protects Write + Flush
        done chan bool
 }
 
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
        m.lk.Lock()
        defer m.lk.Unlock()
-       if m.done == nil {
-               m.done = make(chan bool)
-               go m.flushLoop()
-       }
-       n, err = m.dst.Write(p)
-       if err != nil {
-               m.done <- true
-       }
-       return
+       return m.dst.Write(p)
 }
 
 func (m *maxLatencyWriter) flushLoop() {
@@ -160,13 +168,15 @@ func (m *maxLatencyWriter) flushLoop() {
        defer t.Stop()
        for {
                select {
+               case <-m.done:
+                       return
                case <-t.C:
                        m.lk.Lock()
                        m.dst.Flush()
                        m.lk.Unlock()
-               case <-m.done:
-                       return
                }
        }
        panic("unreached")
 }
+
+func (m *maxLatencyWriter) stop() { m.done <- true }
index 28e9c90ad36cdac93a4550a4384b1a481edbf1b6..3bcb23c077f985ab7d2eb80adf3cdbc1f27eefa8 100644 (file)
@@ -7,11 +7,14 @@
 package httputil
 
 import (
+       "io"
        "io/ioutil"
        "net/http"
        "net/http/httptest"
        "net/url"
+       "runtime"
        "testing"
+       "time"
 )
 
 func TestReverseProxy(t *testing.T) {
@@ -107,3 +110,58 @@ func TestReverseProxyQuery(t *testing.T) {
                frontend.Close()
        }
 }
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+       if testing.Short() {
+               return
+       }
+
+       const expected = "hi"
+       backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               w.Write([]byte(expected))
+       }))
+       defer backend.Close()
+
+       backendURL, err := url.Parse(backend.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       proxyHandler := NewSingleHostReverseProxy(backendURL)
+       proxyHandler.FlushInterval = time.Microsecond
+
+       dstChan := make(chan io.Writer, 1)
+       beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
+       defer func() { beforeCopyResponse = nil }()
+
+       frontend := httptest.NewServer(proxyHandler)
+       defer frontend.Close()
+
+       initGoroutines := runtime.NumGoroutine()
+       for i := 0; i < 100; i++ {
+               req, _ := http.NewRequest("GET", frontend.URL, nil)
+               req.Close = true
+               res, err := http.DefaultClient.Do(req)
+               if err != nil {
+                       t.Fatalf("Get: %v", err)
+               }
+               if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+                       t.Errorf("got body %q; expected %q", bodyBytes, expected)
+               }
+
+               select {
+               case dst := <-dstChan:
+                       if _, ok := dst.(*maxLatencyWriter); !ok {
+                               t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
+                       }
+               default:
+                       t.Error("maxLatencyWriter Write() was never called")
+               }
+
+               res.Body.Close()
+       }
+       // Allow up to 50 additional goroutines over 100 requests.
+       if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
+               t.Errorf("grew %d goroutines; leak?", delta)
+       }
+}