]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: Made reverseproxy test less flaky.
authorColby Ranger <cranger@google.com>
Fri, 20 Apr 2012 16:31:23 +0000 (09:31 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 20 Apr 2012 16:31:23 +0000 (09:31 -0700)
The reverseproxy test depended on the behavior of
runtime.NumGoroutines(), which makes no guarantee when
goroutines are reaped. Instead, modify the flushLoop()
to invoke a callback when it returns, so the exit
from the loop can be tested, instead of the number
of gorountines running.

R=bradfitz
CC=golang-dev
https://golang.org/cl/6068046

src/pkg/net/http/httputil/reverseproxy.go
src/pkg/net/http/httputil/reverseproxy_test.go

index 2f08a8c0c9ca1f5f3d1f5f9e11f3837b5a2bd99a..479945fc211d3b6cfa27e49b03d3c9c740d540b3 100644 (file)
@@ -17,9 +17,9 @@ import (
        "time"
 )
 
-// beforeCopyResponse is a callback set by tests to intercept the state of the
-// output io.Writer before the data is copied to it.
-var beforeCopyResponse func(dst io.Writer)
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
 
 // ReverseProxy is an HTTP Handler that takes an incoming request and
 // sends it to another server, proxying the response back to the
@@ -138,9 +138,6 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
                }
        }
 
-       if beforeCopyResponse != nil {
-               beforeCopyResponse(dst)
-       }
        io.Copy(dst, src)
 }
 
@@ -169,6 +166,9 @@ func (m *maxLatencyWriter) flushLoop() {
        for {
                select {
                case <-m.done:
+                       if onExitFlushLoop != nil {
+                               onExitFlushLoop()
+                       }
                        return
                case <-t.C:
                        m.lk.Lock()
index 3bcb23c077f985ab7d2eb80adf3cdbc1f27eefa8..b42c031caca248765d4950e9c1f260a40140668e 100644 (file)
@@ -7,12 +7,10 @@
 package httputil
 
 import (
-       "io"
        "io/ioutil"
        "net/http"
        "net/http/httptest"
        "net/url"
-       "runtime"
        "testing"
        "time"
 )
@@ -112,10 +110,6 @@ func TestReverseProxyQuery(t *testing.T) {
 }
 
 func TestReverseProxyFlushInterval(t *testing.T) {
-       if testing.Short() {
-               return
-       }
-
        const expected = "hi"
        backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                w.Write([]byte(expected))
@@ -130,38 +124,28 @@ func TestReverseProxyFlushInterval(t *testing.T) {
        proxyHandler := NewSingleHostReverseProxy(backendURL)
        proxyHandler.FlushInterval = time.Microsecond
 
-       dstChan := make(chan io.Writer, 1)
-       beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
-       defer func() { beforeCopyResponse = nil }()
+       done := make(chan bool)
+       onExitFlushLoop = func() { done <- true }
+       defer func() { onExitFlushLoop = nil }()
 
        frontend := httptest.NewServer(proxyHandler)
        defer frontend.Close()
 
-       initGoroutines := runtime.NumGoroutine()
-       for i := 0; i < 100; i++ {
-               req, _ := http.NewRequest("GET", frontend.URL, nil)
-               req.Close = true
-               res, err := http.DefaultClient.Do(req)
-               if err != nil {
-                       t.Fatalf("Get: %v", err)
-               }
-               if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
-                       t.Errorf("got body %q; expected %q", bodyBytes, expected)
-               }
-
-               select {
-               case dst := <-dstChan:
-                       if _, ok := dst.(*maxLatencyWriter); !ok {
-                               t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
-                       }
-               default:
-                       t.Error("maxLatencyWriter Write() was never called")
-               }
-
-               res.Body.Close()
+       req, _ := http.NewRequest("GET", frontend.URL, nil)
+       req.Close = true
+       res, err := http.DefaultClient.Do(req)
+       if err != nil {
+               t.Fatalf("Get: %v", err)
+       }
+       defer res.Body.Close()
+       if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+               t.Errorf("got body %q; expected %q", bodyBytes, expected)
        }
-       // Allow up to 50 additional goroutines over 100 requests.
-       if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
-               t.Errorf("grew %d goroutines; leak?", delta)
+
+       select {
+       case <-done:
+               // OK
+       case <-time.After(5 * time.Second):
+               t.Error("maxLatencyWriter flushLoop() never exited")
        }
 }