From: Brad Fitzpatrick Date: Wed, 29 Feb 2012 20:18:26 +0000 (-0800) Subject: net/http/httptest: make Server.Close wait for outstanding requests to finish X-Git-Tag: weekly.2012-03-04~84 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=8f0bfc5a29ab942af5b8dd1caf143383a90c2170;p=gostls13.git net/http/httptest: make Server.Close wait for outstanding requests to finish Might fix issue 3050 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/5708066 --- diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go index 8d911f7575..57cf0c9417 100644 --- a/src/pkg/net/http/httptest/server.go +++ b/src/pkg/net/http/httptest/server.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "os" + "sync" ) // A Server is an HTTP server listening on a system-chosen port on the @@ -25,6 +26,10 @@ type Server struct { // Config may be changed after calling NewUnstartedServer and // before Start or StartTLS. Config *http.Server + + // wg counts the number of outstanding HTTP requests on this server. + // Close blocks until all requests are finished. + wg sync.WaitGroup } // historyListener keeps track of all connections that it's ever @@ -93,6 +98,7 @@ func (s *Server) Start() { } s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)} s.URL = "http://" + s.Listener.Addr().String() + s.wrapHandler() go s.Config.Serve(s.Listener) if *serve != "" { fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) @@ -118,9 +124,21 @@ func (s *Server) StartTLS() { s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} s.URL = "https://" + s.Listener.Addr().String() + s.wrapHandler() go s.Config.Serve(s.Listener) } +func (s *Server) wrapHandler() { + h := s.Config.Handler + if h == nil { + h = http.DefaultServeMux + } + s.Config.Handler = &waitGroupHandler{ + s: s, + h: h, + } +} + // NewTLSServer starts and returns a new Server using TLS. // The caller should call Close when finished, to shut it down. func NewTLSServer(handler http.Handler) *Server { @@ -129,9 +147,11 @@ func NewTLSServer(handler http.Handler) *Server { return ts } -// Close shuts down the server. +// Close shuts down the server and blocks until all outstanding +// requests on this server have completed. func (s *Server) Close() { s.Listener.Close() + s.wg.Wait() } // CloseClientConnections closes any currently open HTTP connections @@ -146,6 +166,20 @@ func (s *Server) CloseClientConnections() { } } +// waitGroupHandler wraps a handler, incrementing and decrementing a +// sync.WaitGroup on each request, to enable Server.Close to block +// until outstanding requests are finished. +type waitGroupHandler struct { + s *Server + h http.Handler // non-nil +} + +func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.s.wg.Add(1) + defer h.s.wg.Done() // a defer, in case ServeHTTP below panics + h.h.ServeHTTP(w, r) +} + // localhostCert is a PEM-encoded TLS cert with SAN DNS names // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). diff --git a/src/pkg/net/http/sniff_test.go b/src/pkg/net/http/sniff_test.go index 6efa8ce1ca..8ab72ac23f 100644 --- a/src/pkg/net/http/sniff_test.go +++ b/src/pkg/net/http/sniff_test.go @@ -129,9 +129,10 @@ func TestSniffWriteSize(t *testing.T) { })) defer ts.Close() for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { - _, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size)) + res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size)) if err != nil { t.Fatalf("size %d: %v", size, err) } + res.Body.Close() } }