"time"
)
+// beforeCopyResponse is a callback set by tests to intercept the state of the
+// output io.Writer before the data is copied to it.
+var beforeCopyResponse func(dst io.Writer)
+
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
rw.WriteHeader(http.StatusInternalServerError)
return
}
+ defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
+ p.copyResponse(rw, res.Body)
+}
- if res.Body != nil {
- var dst io.Writer = rw
- if p.FlushInterval != 0 {
- if wf, ok := rw.(writeFlusher); ok {
- dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+ if p.FlushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: p.FlushInterval,
+ done: make(chan bool),
}
+ go mlw.flushLoop()
+ defer mlw.stop()
+ dst = mlw
}
- io.Copy(dst, res.Body)
}
+
+ if beforeCopyResponse != nil {
+ beforeCopyResponse(dst)
+ }
+ io.Copy(dst, src)
}
type writeFlusher interface {
dst writeFlusher
latency time.Duration
- lk sync.Mutex // protects init of done, as well Write + Flush
+ lk sync.Mutex // protects Write + Flush
done chan bool
}
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
- if m.done == nil {
- m.done = make(chan bool)
- go m.flushLoop()
- }
- n, err = m.dst.Write(p)
- if err != nil {
- m.done <- true
- }
- return
+ return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
+ case <-m.done:
+ return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
- case <-m.done:
- return
}
}
panic("unreached")
}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }
package httputil
import (
+ "io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
+ "runtime"
"testing"
+ "time"
)
func TestReverseProxy(t *testing.T) {
frontend.Close()
}
}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ if testing.Short() {
+ return
+ }
+
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ dstChan := make(chan io.Writer, 1)
+ beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
+ defer func() { beforeCopyResponse = nil }()
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ initGoroutines := runtime.NumGoroutine()
+ for i := 0; i < 100; i++ {
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+
+ select {
+ case dst := <-dstChan:
+ if _, ok := dst.(*maxLatencyWriter); !ok {
+ t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
+ }
+ default:
+ t.Error("maxLatencyWriter Write() was never called")
+ }
+
+ res.Body.Close()
+ }
+ // Allow up to 50 additional goroutines over 100 requests.
+ if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
+ t.Errorf("grew %d goroutines; leak?", delta)
+ }
+}