]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: make Server.Close wait for outstanding requests to finish
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 29 Feb 2012 20:18:26 +0000 (12:18 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 29 Feb 2012 20:18:26 +0000 (12:18 -0800)
Might fix issue 3050

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/5708066

src/pkg/net/http/httptest/server.go
src/pkg/net/http/sniff_test.go

index 8d911f7575bfbfa0592bb54e9e4e3912b8a43ed7..57cf0c9417dbae8c822d28d7dc88a23e14a14e41 100644 (file)
@@ -13,6 +13,7 @@ import (
        "net"
        "net/http"
        "os"
+       "sync"
 )
 
 // A Server is an HTTP server listening on a system-chosen port on the
@@ -25,6 +26,10 @@ type Server struct {
        // Config may be changed after calling NewUnstartedServer and
        // before Start or StartTLS.
        Config *http.Server
+
+       // wg counts the number of outstanding HTTP requests on this server.
+       // Close blocks until all requests are finished.
+       wg sync.WaitGroup
 }
 
 // historyListener keeps track of all connections that it's ever
@@ -93,6 +98,7 @@ func (s *Server) Start() {
        }
        s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)}
        s.URL = "http://" + s.Listener.Addr().String()
+       s.wrapHandler()
        go s.Config.Serve(s.Listener)
        if *serve != "" {
                fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
@@ -118,9 +124,21 @@ func (s *Server) StartTLS() {
 
        s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
        s.URL = "https://" + s.Listener.Addr().String()
+       s.wrapHandler()
        go s.Config.Serve(s.Listener)
 }
 
+func (s *Server) wrapHandler() {
+       h := s.Config.Handler
+       if h == nil {
+               h = http.DefaultServeMux
+       }
+       s.Config.Handler = &waitGroupHandler{
+               s: s,
+               h: h,
+       }
+}
+
 // NewTLSServer starts and returns a new Server using TLS.
 // The caller should call Close when finished, to shut it down.
 func NewTLSServer(handler http.Handler) *Server {
@@ -129,9 +147,11 @@ func NewTLSServer(handler http.Handler) *Server {
        return ts
 }
 
-// Close shuts down the server.
+// Close shuts down the server and blocks until all outstanding
+// requests on this server have completed.
 func (s *Server) Close() {
        s.Listener.Close()
+       s.wg.Wait()
 }
 
 // CloseClientConnections closes any currently open HTTP connections
@@ -146,6 +166,20 @@ func (s *Server) CloseClientConnections() {
        }
 }
 
+// waitGroupHandler wraps a handler, incrementing and decrementing a
+// sync.WaitGroup on each request, to enable Server.Close to block
+// until outstanding requests are finished.
+type waitGroupHandler struct {
+       s *Server
+       h http.Handler // non-nil
+}
+
+func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+       h.s.wg.Add(1)
+       defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
+       h.h.ServeHTTP(w, r)
+}
+
 // localhostCert is a PEM-encoded TLS cert with SAN DNS names
 // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
 // of ASN.1 time).
index 6efa8ce1ca228de6cd938eb76dafa44ed11a3472..8ab72ac23f9af139a0f2011e502c5b98d3c19d1c 100644 (file)
@@ -129,9 +129,10 @@ func TestSniffWriteSize(t *testing.T) {
        }))
        defer ts.Close()
        for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} {
-               _, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size))
+               res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size))
                if err != nil {
                        t.Fatalf("size %d: %v", size, err)
                }
+               res.Body.Close()
        }
 }