]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Server.Close & Server.Shutdown for forced & graceful shutdown
authorBrad Fitzpatrick <bradfitz@golang.org>
Sun, 30 Oct 2016 03:28:05 +0000 (03:28 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 1 Nov 2016 04:42:33 +0000 (04:42 +0000)
Also updates x/net/http2 to git rev 541150 for:

   http2: add support for graceful shutdown of Server
   https://golang.org/cl/32412

   http2: make http2.Server access http1's Server via an interface check
   https://golang.org/cl/32417

Fixes #4674
Fixes #9478

Change-Id: I8021a18dee0ef2fe3946ac1776d2b10d3d429052
Reviewed-on: https://go-review.googlesource.com/32329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/export_test.go
src/net/http/h2_bundle.go
src/net/http/http_test.go
src/net/http/serve_test.go
src/net/http/server.go

index 00824e754cb36a6a4cd819f650227b0798204120..fbed45070c880e6abb543f59216a62b28340e7e3 100644 (file)
@@ -87,6 +87,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
        return
 }
 
+func (t *Transport) IdleConnKeyCountForTesting() int {
+       t.idleMu.Lock()
+       defer t.idleMu.Unlock()
+       return len(t.idleConn)
+}
+
 func (t *Transport) IdleConnStrsForTesting() []string {
        var ret []string
        t.idleMu.Lock()
index da7c02578cd5a69dac0b3aa4099eaf98375e338d..f8398adb9239bde4a9bf2688b76edfa22d0f41bc 100644 (file)
@@ -2982,10 +2982,6 @@ func (s *http2Server) maxConcurrentStreams() uint32 {
        return http2defaultMaxStreams
 }
 
-// List of funcs for ConfigureServer to run. Both h1 and h2 are guaranteed
-// to be non-nil.
-var http2configServerFuncs []func(h1 *Server, h2 *http2Server) error
-
 // ConfigureServer adds HTTP/2 support to a net/http Server.
 //
 // The configuration conf may be nil.
@@ -3512,6 +3508,11 @@ func (sc *http2serverConn) serve() {
                sc.idleTimerCh = sc.idleTimer.C
        }
 
+       var gracefulShutdownCh <-chan struct{}
+       if sc.hs != nil {
+               gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs)
+       }
+
        go sc.readFrames()
 
        settingsTimer := time.NewTimer(http2firstSettingsTimeout)
@@ -3539,6 +3540,9 @@ func (sc *http2serverConn) serve() {
                case <-settingsTimer.C:
                        sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
                        return
+               case <-gracefulShutdownCh:
+                       gracefulShutdownCh = nil
+                       sc.goAwayIn(http2ErrCodeNo, 0)
                case <-sc.shutdownTimerCh:
                        sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
                        return
@@ -3548,6 +3552,10 @@ func (sc *http2serverConn) serve() {
                case fn := <-sc.testHookCh:
                        fn(loopNum)
                }
+
+               if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
+                       return
+               }
        }
 }
 
@@ -3803,7 +3811,7 @@ func (sc *http2serverConn) scheduleFrameWrite() {
                        sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
                        continue
                }
-               if !sc.inGoAway {
+               if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
                        if wr, ok := sc.writeSched.Pop(); ok {
                                sc.startFrameWrite(wr)
                                continue
@@ -3821,14 +3829,23 @@ func (sc *http2serverConn) scheduleFrameWrite() {
 
 func (sc *http2serverConn) goAway(code http2ErrCode) {
        sc.serveG.check()
-       if sc.inGoAway {
-               return
-       }
+       var forceCloseIn time.Duration
        if code != http2ErrCodeNo {
-               sc.shutDownIn(250 * time.Millisecond)
+               forceCloseIn = 250 * time.Millisecond
        } else {
 
-               sc.shutDownIn(1 * time.Second)
+               forceCloseIn = 1 * time.Second
+       }
+       sc.goAwayIn(code, forceCloseIn)
+}
+
+func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) {
+       sc.serveG.check()
+       if sc.inGoAway {
+               return
+       }
+       if forceCloseIn != 0 {
+               sc.shutDownIn(forceCloseIn)
        }
        sc.inGoAway = true
        sc.needToSendGoAway = true
@@ -5264,6 +5281,31 @@ var http2badTrailer = map[string]bool{
        "Www-Authenticate":    true,
 }
 
+// h1ServerShutdownChan returns a channel that will be closed when the
+// provided *http.Server wants to shut down.
+//
+// This is a somewhat hacky way to get at http1 innards. It works
+// when the http2 code is bundled into the net/http package in the
+// standard library. The alternatives ended up making the cmd/go tool
+// depend on http Servers. This is the lightest option for now.
+// This is tested via the TestServeShutdown* tests in net/http.
+func http2h1ServerShutdownChan(hs *Server) <-chan struct{} {
+       if fn := http2testh1ServerShutdownChan; fn != nil {
+               return fn(hs)
+       }
+       var x interface{} = hs
+       type I interface {
+               getDoneChan() <-chan struct{}
+       }
+       if hs, ok := x.(I); ok {
+               return hs.getDoneChan()
+       }
+       return nil
+}
+
+// optional test hook for h1ServerShutdownChan.
+var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{}
+
 const (
        // transportDefaultConnFlow is how many connection-level flow control
        // tokens we give the server at start-up, past the default 64k.
index c6c38ffcaedaeade9fef9e14adfd1b2470db3a05..aaae67cf2918e8c5c51ae072ddea2f396afe4cbf 100644 (file)
@@ -12,8 +12,13 @@ import (
        "os/exec"
        "reflect"
        "testing"
+       "time"
 )
 
+func init() {
+       shutdownPollInterval = 5 * time.Millisecond
+}
+
 func TestForeachHeaderElement(t *testing.T) {
        tests := []struct {
                in   string
index 7834478352cd62769484227d9f5122e0375651a3..f855c358221dadcb2962d9bc6a28ab156a562cbf 100644 (file)
@@ -4832,3 +4832,96 @@ func TestServerIdleTimeout(t *testing.T) {
                t.Fatal("copy byte succeeded; want err")
        }
 }
+
+func get(t *testing.T, c *Client, url string) string {
+       res, err := c.Get(url)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer res.Body.Close()
+       slurp, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       return string(slurp)
+}
+
+// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
+// currently-open connections.
+func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
+       if runtime.GOOS == "nacl" {
+               t.Skip("skipping on nacl; see golang.org/issue/17695")
+       }
+       defer afterTest(t)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               io.WriteString(w, r.RemoteAddr)
+       }))
+       defer ts.Close()
+
+       tr := &Transport{}
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+
+       get := func() string { return get(t, c, ts.URL) }
+
+       a1, a2 := get(), get()
+       if a1 != a2 {
+               t.Fatal("expected first two requests on same connection")
+       }
+       var idle0 int
+       if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+               idle0 = tr.IdleConnKeyCountForTesting()
+               return idle0 == 1
+       }) {
+               t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0)
+       }
+
+       ts.Config.SetKeepAlivesEnabled(false)
+
+       var idle1 int
+       if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+               idle1 = tr.IdleConnKeyCountForTesting()
+               return idle1 == 0
+       }) {
+               t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
+       }
+
+       a3 := get()
+       if a3 == a2 {
+               t.Fatal("expected third request on new connection")
+       }
+}
+
+func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) }
+func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) }
+
+func testServerShutdown(t *testing.T, h2 bool) {
+       defer afterTest(t)
+       var doShutdown func() // set later
+       var shutdownRes = make(chan error, 1)
+       cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+               go doShutdown()
+               // Shutdown is graceful, so it should not interrupt
+               // this in-flight response. Add a tiny sleep here to
+               // increase the odds of a failure if shutdown has
+               // bugs.
+               time.Sleep(20 * time.Millisecond)
+               io.WriteString(w, r.RemoteAddr)
+       }))
+       defer cst.close()
+
+       doShutdown = func() {
+               shutdownRes <- cst.ts.Config.Shutdown(context.Background())
+       }
+       get(t, cst.c, cst.ts.URL) // calls t.Fail on failure
+
+       if err := <-shutdownRes; err != nil {
+               t.Fatalf("Shutdown: %v", err)
+       }
+
+       res, err := cst.c.Get(cst.ts.URL)
+       if err == nil {
+               res.Body.Close()
+               t.Fatal("second request should fail. server should be shut down")
+       }
+}
index 51a66c37d53f0c1af3e5f73e2bcdb1945ed43ef8..eae065f6730f40a8e7d54d38e5cd07b0ba35262b 100644 (file)
@@ -248,6 +248,8 @@ type conn struct {
 
        curReq atomic.Value // of *response (which has a Request in it)
 
+       curState atomic.Value // of ConnectionState
+
        // mu guards hijackedv
        mu sync.Mutex
 
@@ -1586,11 +1588,30 @@ func validNPN(proto string) bool {
 }
 
 func (c *conn) setState(nc net.Conn, state ConnState) {
-       if hook := c.server.ConnState; hook != nil {
+       srv := c.server
+       switch state {
+       case StateNew:
+               srv.trackConn(c, true)
+       case StateHijacked, StateClosed:
+               srv.trackConn(c, false)
+       }
+       c.curState.Store(connStateInterface[state])
+       if hook := srv.ConnState; hook != nil {
                hook(nc, state)
        }
 }
 
+// connStateInterface is an array of the interface{} versions of
+// ConnState values, so we can use them in atomic.Values later without
+// paying the cost of shoving their integers in an interface{}.
+var connStateInterface = [...]interface{}{
+       StateNew:      StateNew,
+       StateActive:   StateActive,
+       StateIdle:     StateIdle,
+       StateHijacked: StateHijacked,
+       StateClosed:   StateClosed,
+}
+
 // badRequestError is a literal string (used by in the server in HTML,
 // unescaped) to tell the user why their request was bad. It should
 // be plain text without user info or other embedded errors.
@@ -2247,8 +2268,120 @@ type Server struct {
        ErrorLog *log.Logger
 
        disableKeepAlives int32     // accessed atomically.
+       inShutdown        int32     // accessed atomically (non-zero means we're in Shutdown)
        nextProtoOnce     sync.Once // guards setupHTTP2_* init
        nextProtoErr      error     // result of http2.ConfigureServer if used
+
+       mu         sync.Mutex
+       listeners  map[net.Listener]struct{}
+       activeConn map[*conn]struct{}
+       doneChan   chan struct{}
+}
+
+func (s *Server) getDoneChan() <-chan struct{} {
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       return s.getDoneChanLocked()
+}
+
+func (s *Server) getDoneChanLocked() chan struct{} {
+       if s.doneChan == nil {
+               s.doneChan = make(chan struct{})
+       }
+       return s.doneChan
+}
+
+func (s *Server) closeDoneChanLocked() {
+       ch := s.getDoneChanLocked()
+       select {
+       case <-ch:
+               // Already closed. Don't close again.
+       default:
+               // Safe to close here. We're the only closer, guarded
+               // by s.mu.
+               close(ch)
+       }
+}
+
+// Close immediately closes all active net.Listeners and connections,
+// regardless of their state. For a graceful shutdown, use Shutdown.
+func (s *Server) Close() error {
+       s.mu.Lock()
+       defer s.mu.Lock()
+       s.closeDoneChanLocked()
+       err := s.closeListenersLocked()
+       for c := range s.activeConn {
+               c.rwc.Close()
+               delete(s.activeConn, c)
+       }
+       return err
+}
+
+// shutdownPollInterval is how often we poll for quiescence
+// during Server.Shutdown. This is lower during tests, to
+// speed up tests.
+// Ideally we could find a solution that doesn't involve polling,
+// but which also doesn't have a high runtime cost (and doesn't
+// involve any contentious mutexes), but that is left as an
+// exercise for the reader.
+var shutdownPollInterval = 500 * time.Millisecond
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners, then closing all idle connections, and then waiting
+// indefinitely for connections to return to idle and then shut down.
+// If the provided context expires before the shutdown is complete,
+// then the context's error is returned.
+func (s *Server) Shutdown(ctx context.Context) error {
+       atomic.AddInt32(&s.inShutdown, 1)
+       defer atomic.AddInt32(&s.inShutdown, -1)
+
+       s.mu.Lock()
+       lnerr := s.closeListenersLocked()
+       s.closeDoneChanLocked()
+       s.mu.Unlock()
+
+       ticker := time.NewTicker(shutdownPollInterval)
+       defer ticker.Stop()
+       for {
+               if s.closeIdleConns() {
+                       return lnerr
+               }
+               select {
+               case <-ctx.Done():
+                       return ctx.Err()
+               case <-ticker.C:
+               }
+       }
+}
+
+// closeIdleConns closes all idle connections and reports whether the
+// server is quiescent.
+func (s *Server) closeIdleConns() bool {
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       quiescent := true
+       for c := range s.activeConn {
+               st, ok := c.curState.Load().(ConnState)
+               if !ok || st != StateIdle {
+                       quiescent = false
+                       continue
+               }
+               c.rwc.Close()
+               delete(s.activeConn, c)
+       }
+       return quiescent
+}
+
+func (s *Server) closeListenersLocked() error {
+       var err error
+       for ln := range s.listeners {
+               if cerr := ln.Close(); cerr != nil && err == nil {
+                       err = cerr
+               }
+               delete(s.listeners, ln)
+       }
+       return err
 }
 
 // A ConnState represents the state of a client connection to a server.
@@ -2361,6 +2494,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
        return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
 }
 
+var ErrServerClosed = errors.New("http: Server closed")
+
 // Serve accepts incoming connections on the Listener l, creating a
 // new service goroutine for each. The service goroutines read requests and
 // then call srv.Handler to reply to them.
@@ -2370,7 +2505,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
 // srv.TLSConfig is non-nil and doesn't include the string "h2" in
 // Config.NextProtos, HTTP/2 support is not enabled.
 //
-// Serve always returns a non-nil error.
+// Serve always returns a non-nil error. After Shutdown or Close, the
+// returned error is ErrServerClosed.
 func (srv *Server) Serve(l net.Listener) error {
        defer l.Close()
        if fn := testHookServerServe; fn != nil {
@@ -2382,12 +2518,20 @@ func (srv *Server) Serve(l net.Listener) error {
                return err
        }
 
+       srv.trackListener(l, true)
+       defer srv.trackListener(l, false)
+
        baseCtx := context.Background() // base is always background, per Issue 16220
        ctx := context.WithValue(baseCtx, ServerContextKey, srv)
        ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr())
        for {
                rw, e := l.Accept()
                if e != nil {
+                       select {
+                       case <-srv.getDoneChan():
+                               return ErrServerClosed
+                       default:
+                       }
                        if ne, ok := e.(net.Error); ok && ne.Temporary() {
                                if tempDelay == 0 {
                                        tempDelay = 5 * time.Millisecond
@@ -2410,6 +2554,37 @@ func (srv *Server) Serve(l net.Listener) error {
        }
 }
 
+func (s *Server) trackListener(ln net.Listener, add bool) {
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       if s.listeners == nil {
+               s.listeners = make(map[net.Listener]struct{})
+       }
+       if add {
+               // If the *Server is being reused after a previous
+               // Close or Shutdown, reset its doneChan:
+               if len(s.listeners) == 0 && len(s.activeConn) == 0 {
+                       s.doneChan = nil
+               }
+               s.listeners[ln] = struct{}{}
+       } else {
+               delete(s.listeners, ln)
+       }
+}
+
+func (s *Server) trackConn(c *conn, add bool) {
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       if s.activeConn == nil {
+               s.activeConn = make(map[*conn]struct{})
+       }
+       if add {
+               s.activeConn[c] = struct{}{}
+       } else {
+               delete(s.activeConn, c)
+       }
+}
+
 func (s *Server) idleTimeout() time.Duration {
        if s.IdleTimeout != 0 {
                return s.IdleTimeout
@@ -2425,7 +2600,11 @@ func (s *Server) readHeaderTimeout() time.Duration {
 }
 
 func (s *Server) doKeepAlives() bool {
-       return atomic.LoadInt32(&s.disableKeepAlives) == 0
+       return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
+}
+
+func (s *Server) shuttingDown() bool {
+       return atomic.LoadInt32(&s.inShutdown) != 0
 }
 
 // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
@@ -2435,9 +2614,21 @@ func (s *Server) doKeepAlives() bool {
 func (srv *Server) SetKeepAlivesEnabled(v bool) {
        if v {
                atomic.StoreInt32(&srv.disableKeepAlives, 0)
-       } else {
-               atomic.StoreInt32(&srv.disableKeepAlives, 1)
+               return
        }
+       atomic.StoreInt32(&srv.disableKeepAlives, 1)
+
+       // Close idle HTTP/1 conns:
+       srv.closeIdleConns()
+
+       // Close HTTP/2 conns, as soon as they become idle, but reset
+       // the chan so future conns (if the listener is still active)
+       // still work and don't get a GOAWAY immediately, before their
+       // first request:
+       srv.mu.Lock()
+       defer srv.mu.Unlock()
+       srv.closeDoneChanLocked() // closes http2 conns
+       srv.doneChan = nil
 }
 
 func (s *Server) logf(format string, args ...interface{}) {