"time"
)
-// onExitFlushLoop is a callback set by tests to detect the state of the
-// flushLoop() goroutine.
-var onExitFlushLoop func()
-
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
+ // A negative value means to flush immediately
+ // after each write to the client.
+ // The FlushInterval is ignored when ReverseProxy
+ // recognizes a response as a streaming response;
+ // for such reponses, writes are flushed to the client
+ // immediately.
FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors
fl.Flush()
}
}
- err = p.copyResponse(rw, res.Body)
+ err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
if err != nil {
defer res.Body.Close()
// Since we're streaming the response, if we run into an error all we can do
}
}
-func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) error {
- if p.FlushInterval != 0 {
+// flushInterval returns the p.FlushInterval value, conditionally
+// overriding its value for a specific request/response.
+func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
+ resCT := res.Header.Get("Content-Type")
+
+ // For Server-Sent Events responses, flush immediately.
+ // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
+ if resCT == "text/event-stream" {
+ return -1 // negative means immediately
+ }
+
+ // TODO: more specific cases? e.g. res.ContentLength == -1?
+ return p.FlushInterval
+}
+
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+ if flushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
- latency: p.FlushInterval,
- done: make(chan bool),
+ latency: flushInterval,
}
- go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
type maxLatencyWriter struct {
dst writeFlusher
- latency time.Duration
+ latency time.Duration // non-zero; negative means to flush immediately
- mu sync.Mutex // protects Write + Flush
- done chan bool
+ mu sync.Mutex // protects t, flushPending, and dst.Flush
+ t *time.Timer
+ flushPending bool
}
-func (m *maxLatencyWriter) Write(p []byte) (int, error) {
+func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
- return m.dst.Write(p)
+ n, err = m.dst.Write(p)
+ if m.latency < 0 {
+ m.dst.Flush()
+ return
+ }
+ if m.flushPending {
+ return
+ }
+ if m.t == nil {
+ m.t = time.AfterFunc(m.latency, m.delayedFlush)
+ } else {
+ m.t.Reset(m.latency)
+ }
+ m.flushPending = true
+ return
}
-func (m *maxLatencyWriter) flushLoop() {
- t := time.NewTicker(m.latency)
- defer t.Stop()
- for {
- select {
- case <-m.done:
- if onExitFlushLoop != nil {
- onExitFlushLoop()
- }
- return
- case <-t.C:
- m.mu.Lock()
- m.dst.Flush()
- m.mu.Unlock()
- }
- }
+func (m *maxLatencyWriter) delayedFlush() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.dst.Flush()
+ m.flushPending = false
}
-func (m *maxLatencyWriter) stop() { m.done <- true }
+func (m *maxLatencyWriter) stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.t != nil {
+ m.t.Stop()
+ }
+}
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.FlushInterval = time.Microsecond
- done := make(chan bool)
- onExitFlushLoop = func() { done <- true }
- defer func() { onExitFlushLoop = nil }()
-
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
t.Errorf("got body %q; expected %q", bodyBytes, expected)
}
-
- select {
- case <-done:
- // OK
- case <-time.After(5 * time.Second):
- t.Error("maxLatencyWriter flushLoop() never exited")
- }
}
func TestReverseProxyCancelation(t *testing.T) {
req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
rproxy.ServeHTTP(httptest.NewRecorder(), req)
}
+
+func TestSelectFlushInterval(t *testing.T) {
+ tests := []struct {
+ name string
+ p *ReverseProxy
+ req *http.Request
+ res *http.Response
+ want time.Duration
+ }{
+ {
+ name: "default",
+ res: &http.Response{},
+ p: &ReverseProxy{FlushInterval: 123},
+ want: 123,
+ },
+ {
+ name: "server-sent events overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.p.flushInterval(tt.req, tt.res)
+ if got != tt.want {
+ t.Errorf("flushLatency = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}