]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add optional Server.ConnState callback
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 28 Feb 2014 02:29:00 +0000 (18:29 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 28 Feb 2014 02:29:00 +0000 (18:29 -0800)
Update #4674

This allows for all sorts of graceful shutdown policies,
without picking a policy (e.g. lameduck period) and without
adding lots of locking to the server core. That policy and
locking can be implemented outside of net/http now.

LGTM=adg
R=golang-codereviews, josharian, r, adg, dvyukov
CC=golang-codereviews
https://golang.org/cl/69260044

src/pkg/net/http/serve_test.go
src/pkg/net/http/server.go

index 7e306bb02193dbcf4e9f57ae4b706e0e0ce24b10..77d2b97a7db5928829fb37f9c17e139bd66f9d61 100644 (file)
@@ -26,6 +26,7 @@ import (
        "runtime"
        "strconv"
        "strings"
+       "sync"
        "sync/atomic"
        "syscall"
        "testing"
@@ -2243,6 +2244,120 @@ func TestAppendTime(t *testing.T) {
        }
 }
 
+func TestServerConnState(t *testing.T) {
+       defer afterTest(t)
+       handler := map[string]func(w ResponseWriter, r *Request){
+               "/": func(w ResponseWriter, r *Request) {
+                       fmt.Fprintf(w, "Hello.")
+               },
+               "/close": func(w ResponseWriter, r *Request) {
+                       w.Header().Set("Connection", "close")
+                       fmt.Fprintf(w, "Hello.")
+               },
+               "/hijack": func(w ResponseWriter, r *Request) {
+                       c, _, _ := w.(Hijacker).Hijack()
+                       c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+                       c.Close()
+               },
+       }
+       ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               handler[r.URL.Path](w, r)
+       }))
+       defer ts.Close()
+
+       type connIDAndState struct {
+               connID int
+               state  ConnState
+       }
+       var mu sync.Mutex // guard stateLog and connID
+       var stateLog []connIDAndState
+       var connID = map[net.Conn]int{}
+
+       ts.Config.ConnState = func(c net.Conn, state ConnState) {
+               if c == nil {
+                       t.Error("nil conn seen in state %s", state)
+                       return
+               }
+               mu.Lock()
+               defer mu.Unlock()
+               id, ok := connID[c]
+               if !ok {
+                       id = len(connID) + 1
+                       connID[c] = id
+               }
+               stateLog = append(stateLog, connIDAndState{id, state})
+       }
+       ts.Start()
+
+       mustGet(t, ts.URL+"/")
+       mustGet(t, ts.URL+"/close")
+
+       mustGet(t, ts.URL+"/")
+       mustGet(t, ts.URL+"/", "Connection", "close")
+
+       mustGet(t, ts.URL+"/hijack")
+
+       want := []connIDAndState{
+               {1, StateNew},
+               {1, StateActive},
+               {1, StateIdle},
+               {1, StateActive},
+               {1, StateClosed},
+
+               {2, StateNew},
+               {2, StateActive},
+               {2, StateIdle},
+               {2, StateActive},
+               {2, StateClosed},
+
+               {3, StateNew},
+               {3, StateActive},
+               {3, StateHijacked},
+       }
+       logString := func(l []connIDAndState) string {
+               var b bytes.Buffer
+               for _, cs := range l {
+                       fmt.Fprintf(&b, "[%d %s] ", cs.connID, cs.state)
+               }
+               return b.String()
+       }
+
+       for i := 0; i < 5; i++ {
+               time.Sleep(time.Duration(i) * 50 * time.Millisecond)
+               mu.Lock()
+               match := reflect.DeepEqual(stateLog, want)
+               mu.Unlock()
+               if match {
+                       return
+               }
+       }
+
+       mu.Lock()
+       t.Errorf("Unexpected events.\nGot log: %s\n   Want: %s\n", logString(stateLog), logString(want))
+       mu.Unlock()
+}
+
+func mustGet(t *testing.T, url string, headers ...string) {
+       req, err := NewRequest("GET", url, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       for len(headers) > 0 {
+               req.Header.Add(headers[0], headers[1])
+               headers = headers[2:]
+       }
+       res, err := DefaultClient.Do(req)
+       if err != nil {
+               t.Errorf("Error fetching %s: %v", url, err)
+               return
+       }
+       _, err = ioutil.ReadAll(res.Body)
+       defer res.Body.Close()
+       if err != nil {
+               t.Errorf("Error reading %s: %v", url, err)
+       }
+}
+
 func BenchmarkClientServer(b *testing.B) {
        b.ReportAllocs()
        b.StopTimer()
index fea1898fd7e794c3e98f973808c931a9f2acb22e..8ca48ab3ce0a4c4d790d5785bbd911916ca781e2 100644 (file)
@@ -1079,8 +1079,16 @@ func validNPN(proto string) bool {
        return true
 }
 
+func (c *conn) setState(nc net.Conn, state ConnState) {
+       if hook := c.server.ConnState; hook != nil {
+               hook(nc, state)
+       }
+}
+
 // Serve a new connection.
 func (c *conn) serve() {
+       origConn := c.rwc // copy it before it's set nil on Close or Hijack
+       c.setState(origConn, StateNew)
        defer func() {
                if err := recover(); err != nil {
                        const size = 64 << 10
@@ -1090,6 +1098,7 @@ func (c *conn) serve() {
                }
                if !c.hijacked() {
                        c.close()
+                       c.setState(origConn, StateClosed)
                }
        }()
 
@@ -1116,6 +1125,10 @@ func (c *conn) serve() {
 
        for {
                w, err := c.readRequest()
+               // TODO(bradfitz): could push this StateActive
+               // earlier, but in practice header will be all in one
+               // packet/Read:
+               c.setState(c.rwc, StateActive)
                if err != nil {
                        if err == errTooLarge {
                                // Their HTTP client may or may not be
@@ -1161,6 +1174,7 @@ func (c *conn) serve() {
                // in parallel even if their responses need to be serialized.
                serverHandler{c.server}.ServeHTTP(w, w.req)
                if c.hijacked() {
+                       c.setState(origConn, StateHijacked)
                        return
                }
                w.finishRequest()
@@ -1170,6 +1184,7 @@ func (c *conn) serve() {
                        }
                        break
                }
+               c.setState(c.rwc, StateIdle)
        }
 }
 
@@ -1580,6 +1595,58 @@ type Server struct {
        // and RemoteAddr if not already set.  The connection is
        // automatically closed when the function returns.
        TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+
+       // ConnState specifies an optional callback function that is
+       // called when a client connection changes state. See the
+       // ConnState type and associated constants for details.
+       ConnState func(net.Conn, ConnState)
+}
+
+// A ConnState represents the state of a client connection to a server.
+// It's used by the optional Server.ConnState hook.
+type ConnState int
+
+const (
+       // StateNew represents a new connection that is expected to
+       // send a request immediately. Connections begin at this
+       // state and then transition to either StateActive or
+       // StateClosed.
+       StateNew ConnState = iota
+
+       // StateActive represents a connection that has read 1 or more
+       // bytes of a request. The Server.ConnState hook for
+       // StateActive fires before the request has entered a handler
+       // and doesn't fire again until the request has been
+       // handled. After the request is handled, the state
+       // transitions to StateClosed, StateHijacked, or StateIdle.
+       StateActive
+
+       // StateIdle represents a connection that has finished
+       // handling a request and is in the keep-alive state, waiting
+       // for a new request. Connections transition from StateIdle
+       // to either StateActive or StateClosed.
+       StateIdle
+
+       // StateHijacked represents a hijacked connection.
+       // This is a terminal state. It does not transition to StateClosed.
+       StateHijacked
+
+       // StateClosed represents a closed connection.
+       // This is a terminal state. Hijacked connections do not
+       // transition to StateClosed.
+       StateClosed
+)
+
+var stateName = map[ConnState]string{
+       StateNew:      "new",
+       StateActive:   "active",
+       StateIdle:     "idle",
+       StateHijacked: "hijacked",
+       StateClosed:   "closed",
+}
+
+func (c ConnState) String() string {
+       return stateName[c]
 }
 
 // serverHandler delegates to either the server's Handler or