]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: prevent Server reuse after a Shutdown
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Dec 2017 19:34:52 +0000 (19:34 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 3 Jul 2018 01:06:34 +0000 (01:06 +0000)
Fixes #20239

Change-Id: Icb021daad82e6905f536e4ef09ab219500b08167
Reviewed-on: https://go-review.googlesource.com/81778
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/http/serve_test.go
src/net/http/server.go

index 4e62cee054fa2b43fdc70688ecbab22e8305f874..e597ac35a48580fd08da01e7775f58bfc0df860f 100644 (file)
@@ -5980,6 +5980,21 @@ func TestServerCloseListenerOnce(t *testing.T) {
        }
 }
 
+// Issue 20239: don't block in Serve if Shutdown is called first.
+func TestServerShutdownThenServe(t *testing.T) {
+       var srv Server
+       cl := &countCloseListener{Listener: nil}
+       srv.Shutdown(context.Background())
+       got := srv.Serve(cl)
+       if got != ErrServerClosed {
+               t.Errorf("Serve err = %v; want ErrServerClosed", got)
+       }
+       nclose := atomic.LoadInt32(&cl.closes)
+       if nclose != 1 {
+               t.Errorf("Close calls = %v; want 1", nclose)
+       }
+}
+
 // Issue 23351: document and test behavior of ServeMux with ports
 func TestStripPortFromHost(t *testing.T) {
        mux := NewServeMux()
index 5349c39c61ac7378b404f0ad5255ca3534418f94..de77485bd697e5da48f8484674e37ae6c1224746 100644 (file)
@@ -2541,6 +2541,7 @@ func (s *Server) closeDoneChanLocked() {
 // Close returns any error returned from closing the Server's
 // underlying Listener(s).
 func (srv *Server) Close() error {
+       atomic.StoreInt32(&srv.inShutdown, 1)
        srv.mu.Lock()
        defer srv.mu.Unlock()
        srv.closeDoneChanLocked()
@@ -2578,9 +2579,11 @@ var shutdownPollInterval = 500 * time.Millisecond
 // separately notify such long-lived connections of shutdown and wait
 // for them to close, if desired. See RegisterOnShutdown for a way to
 // register shutdown notification functions.
+//
+// Once Shutdown has been called on a server, it may not be reused;
+// future calls to methods such as Serve will return ErrServerClosed.
 func (srv *Server) Shutdown(ctx context.Context) error {
-       atomic.AddInt32(&srv.inShutdown, 1)
-       defer atomic.AddInt32(&srv.inShutdown, -1)
+       atomic.StoreInt32(&srv.inShutdown, 1)
 
        srv.mu.Lock()
        lnerr := srv.closeListenersLocked()
@@ -2727,6 +2730,9 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
 // If srv.Addr is blank, ":http" is used.
 // ListenAndServe always returns a non-nil error.
 func (srv *Server) ListenAndServe() error {
+       if srv.shuttingDown() {
+               return ErrServerClosed
+       }
        addr := srv.Addr
        if addr == "" {
                addr = ":http"
@@ -2775,8 +2781,8 @@ var ErrServerClosed = errors.New("http: Server closed")
 // srv.TLSConfig is non-nil and doesn't include the string "h2" in
 // Config.NextProtos, HTTP/2 support is not enabled.
 //
-// Serve always returns a non-nil error. After Shutdown or Close, the
-// returned error is ErrServerClosed.
+// Serve always returns a non-nil error and closes l.
+// After Shutdown or Close, the returned error is ErrServerClosed.
 func (srv *Server) Serve(l net.Listener) error {
        if fn := testHookServerServe; fn != nil {
                fn(srv, l) // call hook with unwrapped listener
@@ -2785,15 +2791,19 @@ func (srv *Server) Serve(l net.Listener) error {
        l = &onceCloseListener{Listener: l}
        defer l.Close()
 
-       var tempDelay time.Duration // how long to sleep on accept failure
-
        if err := srv.setupHTTP2_Serve(); err != nil {
                return err
        }
 
-       srv.trackListener(&l, true)
+       serveDone := make(chan struct{})
+       defer close(serveDone)
+
+       if !srv.trackListener(&l, true) {
+               return ErrServerClosed
+       }
        defer srv.trackListener(&l, false)
 
+       var tempDelay time.Duration     // how long to sleep on accept failure
        baseCtx := context.Background() // base is always background, per Issue 16220
        ctx := context.WithValue(baseCtx, ServerContextKey, srv)
        for {
@@ -2877,13 +2887,18 @@ func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
 // trackListener via Serve and can track+defer untrack the same
 // pointer to local variable there. We never need to compare a
 // Listener from another caller.
-func (s *Server) trackListener(ln *net.Listener, add bool) {
+//
+// It reports whether the server is still up (not Shutdown or Closed).
+func (s *Server) trackListener(ln *net.Listener, add bool) bool {
        s.mu.Lock()
        defer s.mu.Unlock()
        if s.listeners == nil {
                s.listeners = make(map[*net.Listener]struct{})
        }
        if add {
+               if s.shuttingDown() {
+                       return false
+               }
                // If the *Server is being reused after a previous
                // Close or Shutdown, reset its doneChan:
                if len(s.listeners) == 0 && len(s.activeConn) == 0 {
@@ -2893,6 +2908,7 @@ func (s *Server) trackListener(ln *net.Listener, add bool) {
        } else {
                delete(s.listeners, ln)
        }
+       return true
 }
 
 func (s *Server) trackConn(c *conn, add bool) {
@@ -2927,6 +2943,8 @@ func (s *Server) doKeepAlives() bool {
 }
 
 func (s *Server) shuttingDown() bool {
+       // TODO: replace inShutdown with the existing atomicBool type;
+       // see https://github.com/golang/go/issues/20239#issuecomment-381434582
        return atomic.LoadInt32(&s.inShutdown) != 0
 }
 
@@ -3055,6 +3073,9 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
 //
 // ListenAndServeTLS always returns a non-nil error.
 func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
+       if srv.shuttingDown() {
+               return ErrServerClosed
+       }
        addr := srv.Addr
        if addr == "" {
                addr = ":https"