"time"
)
-// beforeCopyResponse is a callback set by tests to intercept the state of the
-// output io.Writer before the data is copied to it.
-var beforeCopyResponse func(dst io.Writer)
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
}
}
- if beforeCopyResponse != nil {
- beforeCopyResponse(dst)
- }
io.Copy(dst, src)
}
for {
select {
case <-m.done:
+ if onExitFlushLoop != nil {
+ onExitFlushLoop()
+ }
return
case <-t.C:
m.lk.Lock()
package httputil
import (
- "io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
- "runtime"
"testing"
"time"
)
}
func TestReverseProxyFlushInterval(t *testing.T) {
- if testing.Short() {
- return
- }
-
const expected = "hi"
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(expected))
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.FlushInterval = time.Microsecond
- dstChan := make(chan io.Writer, 1)
- beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
- defer func() { beforeCopyResponse = nil }()
+ done := make(chan bool)
+ onExitFlushLoop = func() { done <- true }
+ defer func() { onExitFlushLoop = nil }()
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
- initGoroutines := runtime.NumGoroutine()
- for i := 0; i < 100; i++ {
- req, _ := http.NewRequest("GET", frontend.URL, nil)
- req.Close = true
- res, err := http.DefaultClient.Do(req)
- if err != nil {
- t.Fatalf("Get: %v", err)
- }
- if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
- t.Errorf("got body %q; expected %q", bodyBytes, expected)
- }
-
- select {
- case dst := <-dstChan:
- if _, ok := dst.(*maxLatencyWriter); !ok {
- t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
- }
- default:
- t.Error("maxLatencyWriter Write() was never called")
- }
-
- res.Body.Close()
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
}
- // Allow up to 50 additional goroutines over 100 requests.
- if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
- t.Errorf("grew %d goroutines; leak?", delta)
+
+ select {
+ case <-done:
+ // OK
+ case <-time.After(5 * time.Second):
+ t.Error("maxLatencyWriter flushLoop() never exited")
}
}