]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: change Server to use http.Server.ConnState for accounting
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 29 Sep 2015 21:26:48 +0000 (14:26 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Oct 2015 00:07:19 +0000 (00:07 +0000)
With this CL, httptest.Server now uses connection-level accounting of
outstanding requests instead of ServeHTTP-level accounting. This is
more robust and results in a non-racy shutdown.

This is much easier now that net/http.Server has the ConnState hook.

Fixes #12789
Fixes #12781

Change-Id: I098cf334a6494316acb66cd07df90766df41764b
Reviewed-on: https://go-review.googlesource.com/15151
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/httptest/server.go
src/net/http/httptest/server_test.go

index 96eb0ef6d2f147b74e86e7c9b0a09ef7afec3d57..e4f680fe92e361dbe40498e8afee694e51b7afe2 100644 (file)
@@ -7,13 +7,17 @@
 package httptest
 
 import (
+       "bytes"
        "crypto/tls"
        "flag"
        "fmt"
+       "log"
        "net"
        "net/http"
        "os"
+       "runtime"
        "sync"
+       "time"
 )
 
 // A Server is an HTTP server listening on a system-chosen port on the
@@ -34,24 +38,10 @@ type Server struct {
        // wg counts the number of outstanding HTTP requests on this server.
        // Close blocks until all requests are finished.
        wg sync.WaitGroup
-}
-
-// historyListener keeps track of all connections that it's ever
-// accepted.
-type historyListener struct {
-       net.Listener
-       sync.Mutex // protects history
-       history    []net.Conn
-}
 
-func (hs *historyListener) Accept() (c net.Conn, err error) {
-       c, err = hs.Listener.Accept()
-       if err == nil {
-               hs.Lock()
-               hs.history = append(hs.history, c)
-               hs.Unlock()
-       }
-       return
+       mu     sync.Mutex // guards closed and conns
+       closed bool
+       conns  map[net.Conn]http.ConnState // except terminal states
 }
 
 func newLocalListener() net.Listener {
@@ -103,10 +93,9 @@ func (s *Server) Start() {
        if s.URL != "" {
                panic("Server already started")
        }
-       s.Listener = &historyListener{Listener: s.Listener}
        s.URL = "http://" + s.Listener.Addr().String()
-       s.wrapHandler()
-       go s.Config.Serve(s.Listener)
+       s.wrap()
+       s.goServe()
        if *serve != "" {
                fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
                select {}
@@ -134,23 +123,10 @@ func (s *Server) StartTLS() {
        if len(s.TLS.Certificates) == 0 {
                s.TLS.Certificates = []tls.Certificate{cert}
        }
-       tlsListener := tls.NewListener(s.Listener, s.TLS)
-
-       s.Listener = &historyListener{Listener: tlsListener}
+       s.Listener = tls.NewListener(s.Listener, s.TLS)
        s.URL = "https://" + s.Listener.Addr().String()
-       s.wrapHandler()
-       go s.Config.Serve(s.Listener)
-}
-
-func (s *Server) wrapHandler() {
-       h := s.Config.Handler
-       if h == nil {
-               h = http.DefaultServeMux
-       }
-       s.Config.Handler = &waitGroupHandler{
-               s: s,
-               h: h,
-       }
+       s.wrap()
+       s.goServe()
 }
 
 // NewTLSServer starts and returns a new Server using TLS.
@@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server {
        return ts
 }
 
+type closeIdleTransport interface {
+       CloseIdleConnections()
+}
+
 // Close shuts down the server and blocks until all outstanding
 // requests on this server have completed.
 func (s *Server) Close() {
-       s.Listener.Close()
-       s.wg.Wait()
-       s.CloseClientConnections()
-       if t, ok := http.DefaultTransport.(*http.Transport); ok {
+       s.mu.Lock()
+       if !s.closed {
+               s.closed = true
+               s.Listener.Close()
+               s.Config.SetKeepAlivesEnabled(false)
+               for c, st := range s.conns {
+                       if st == http.StateIdle {
+                               s.closeConn(c)
+                       }
+               }
+               // If this server doesn't shut down in 5 seconds, tell the user why.
+               t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
+               defer t.Stop()
+       }
+       s.mu.Unlock()
+
+       // Not part of httptest.Server's correctness, but assume most
+       // users of httptest.Server will be using the standard
+       // transport, so help them out and close any idle connections for them.
+       if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
                t.CloseIdleConnections()
        }
+
+       s.wg.Wait()
 }
 
-// CloseClientConnections closes any currently open HTTP connections
-// to the test Server.
+func (s *Server) logCloseHangDebugInfo() {
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       var buf bytes.Buffer
+       buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
+       for c, st := range s.conns {
+               fmt.Fprintf(&buf, "  %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
+       }
+       log.Print(buf.String())
+}
+
+// CloseClientConnections closes any open HTTP connections to the test Server.
 func (s *Server) CloseClientConnections() {
-       hl, ok := s.Listener.(*historyListener)
-       if !ok {
-               return
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       for c := range s.conns {
+               s.closeConn(c)
        }
-       hl.Lock()
-       for _, conn := range hl.history {
-               conn.Close()
+}
+
+func (s *Server) goServe() {
+       s.wg.Add(1)
+       go func() {
+               defer s.wg.Done()
+               s.Config.Serve(s.Listener)
+       }()
+}
+
+// wrap installs the connection state-tracking hook to know which
+// connections are idle.
+func (s *Server) wrap() {
+       oldHook := s.Config.ConnState
+       s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
+               s.mu.Lock()
+               defer s.mu.Unlock()
+               switch cs {
+               case http.StateNew:
+                       s.wg.Add(1)
+                       if _, exists := s.conns[c]; exists {
+                               panic("invalid state transition")
+                       }
+                       if s.conns == nil {
+                               s.conns = make(map[net.Conn]http.ConnState)
+                       }
+                       s.conns[c] = cs
+                       if s.closed {
+                               // Probably just a socket-late-binding dial from
+                               // the default transport that lost the race (and
+                               // thus this connection is now idle and will
+                               // never be used).
+                               s.closeConn(c)
+                       }
+               case http.StateActive:
+                       if oldState, ok := s.conns[c]; ok {
+                               if oldState != http.StateNew && oldState != http.StateIdle {
+                                       panic("invalid state transition")
+                               }
+                               s.conns[c] = cs
+                       }
+               case http.StateIdle:
+                       if oldState, ok := s.conns[c]; ok {
+                               if oldState != http.StateActive {
+                                       panic("invalid state transition")
+                               }
+                               s.conns[c] = cs
+                       }
+                       if s.closed {
+                               s.closeConn(c)
+                       }
+               case http.StateHijacked, http.StateClosed:
+                       s.forgetConn(c)
+               }
+               if oldHook != nil {
+                       oldHook(c, cs)
+               }
        }
-       hl.Unlock()
 }
 
-// waitGroupHandler wraps a handler, incrementing and decrementing a
-// sync.WaitGroup on each request, to enable Server.Close to block
-// until outstanding requests are finished.
-type waitGroupHandler struct {
-       s *Server
-       h http.Handler // non-nil
+// closeConn closes c. Except on plan9, which is special. See comment below.
+// s.mu must be held.
+func (s *Server) closeConn(c net.Conn) {
+       if runtime.GOOS == "plan9" {
+               // Go's Plan 9 net package isn't great at unblocking reads when
+               // their underlying TCP connections are closed.  Don't trust
+               // that that the ConnState state machine will get to
+               // StateClosed. Instead, just go there directly. Plan 9 may leak
+               // resources if the syscall doesn't end up returning. Oh well.
+               s.forgetConn(c)
+       }
+       go c.Close()
 }
 
-func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-       h.s.wg.Add(1)
-       defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
-       h.h.ServeHTTP(w, r)
+// forgetConn removes c from the set of tracked conns and decrements it from the
+// waitgroup, unless it was previously removed.
+// s.mu must be held.
+func (s *Server) forgetConn(c net.Conn) {
+       if _, ok := s.conns[c]; ok {
+               delete(s.conns, c)
+               s.wg.Done()
+       }
 }
 
 // localhostCert is a PEM-encoded TLS cert with SAN IPs
index 500a9f0b80000b93a1c85f9f6fd1689e487c7239..90901ceb76170dc95a25222190ffc38340aa6846 100644 (file)
@@ -27,3 +27,30 @@ func TestServer(t *testing.T) {
                t.Errorf("got %q, want hello", string(got))
        }
 }
+
+// Issue 12781
+func TestGetAfterClose(t *testing.T) {
+       ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               w.Write([]byte("hello"))
+       }))
+
+       res, err := http.Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       got, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if string(got) != "hello" {
+               t.Fatalf("got %q, want hello", string(got))
+       }
+
+       ts.Close()
+
+       res, err = http.Get(ts.URL)
+       if err == nil {
+               body, _ := ioutil.ReadAll(res.Body)
+               t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
+       }
+}