"runtime"
"strconv"
"strings"
+ "sync"
"sync/atomic"
"syscall"
"testing"
}
}
+func TestServerConnState(t *testing.T) {
+ defer afterTest(t)
+ handler := map[string]func(w ResponseWriter, r *Request){
+ "/": func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/close": func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/hijack": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ },
+ }
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ handler[r.URL.Path](w, r)
+ }))
+ defer ts.Close()
+
+ type connIDAndState struct {
+ connID int
+ state ConnState
+ }
+ var mu sync.Mutex // guard stateLog and connID
+ var stateLog []connIDAndState
+ var connID = map[net.Conn]int{}
+
+ ts.Config.ConnState = func(c net.Conn, state ConnState) {
+ if c == nil {
+ t.Error("nil conn seen in state %s", state)
+ return
+ }
+ mu.Lock()
+ defer mu.Unlock()
+ id, ok := connID[c]
+ if !ok {
+ id = len(connID) + 1
+ connID[c] = id
+ }
+ stateLog = append(stateLog, connIDAndState{id, state})
+ }
+ ts.Start()
+
+ mustGet(t, ts.URL+"/")
+ mustGet(t, ts.URL+"/close")
+
+ mustGet(t, ts.URL+"/")
+ mustGet(t, ts.URL+"/", "Connection", "close")
+
+ mustGet(t, ts.URL+"/hijack")
+
+ want := []connIDAndState{
+ {1, StateNew},
+ {1, StateActive},
+ {1, StateIdle},
+ {1, StateActive},
+ {1, StateClosed},
+
+ {2, StateNew},
+ {2, StateActive},
+ {2, StateIdle},
+ {2, StateActive},
+ {2, StateClosed},
+
+ {3, StateNew},
+ {3, StateActive},
+ {3, StateHijacked},
+ }
+ logString := func(l []connIDAndState) string {
+ var b bytes.Buffer
+ for _, cs := range l {
+ fmt.Fprintf(&b, "[%d %s] ", cs.connID, cs.state)
+ }
+ return b.String()
+ }
+
+ for i := 0; i < 5; i++ {
+ time.Sleep(time.Duration(i) * 50 * time.Millisecond)
+ mu.Lock()
+ match := reflect.DeepEqual(stateLog, want)
+ mu.Unlock()
+ if match {
+ return
+ }
+ }
+
+ mu.Lock()
+ t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want))
+ mu.Unlock()
+}
+
+func mustGet(t *testing.T, url string, headers ...string) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for len(headers) > 0 {
+ req.Header.Add(headers[0], headers[1])
+ headers = headers[2:]
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Errorf("Error fetching %s: %v", url, err)
+ return
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Errorf("Error reading %s: %v", url, err)
+ }
+}
+
func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs()
b.StopTimer()
return true
}
+func (c *conn) setState(nc net.Conn, state ConnState) {
+ if hook := c.server.ConnState; hook != nil {
+ hook(nc, state)
+ }
+}
+
// Serve a new connection.
func (c *conn) serve() {
+ origConn := c.rwc // copy it before it's set nil on Close or Hijack
+ c.setState(origConn, StateNew)
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
}
if !c.hijacked() {
c.close()
+ c.setState(origConn, StateClosed)
}
}()
for {
w, err := c.readRequest()
+ // TODO(bradfitz): could push this StateActive
+ // earlier, but in practice header will be all in one
+ // packet/Read:
+ c.setState(c.rwc, StateActive)
if err != nil {
if err == errTooLarge {
// Their HTTP client may or may not be
// in parallel even if their responses need to be serialized.
serverHandler{c.server}.ServeHTTP(w, w.req)
if c.hijacked() {
+ c.setState(origConn, StateHijacked)
return
}
w.finishRequest()
}
break
}
+ c.setState(c.rwc, StateIdle)
}
}
// and RemoteAddr if not already set. The connection is
// automatically closed when the function returns.
TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+
+ // ConnState specifies an optional callback function that is
+ // called when a client connection changes state. See the
+ // ConnState type and associated constants for details.
+ ConnState func(net.Conn, ConnState)
+}
+
+// A ConnState represents the state of a client connection to a server.
+// It's used by the optional Server.ConnState hook.
+type ConnState int
+
+const (
+ // StateNew represents a new connection that is expected to
+ // send a request immediately. Connections begin at this
+ // state and then transition to either StateActive or
+ // StateClosed.
+ StateNew ConnState = iota
+
+ // StateActive represents a connection that has read 1 or more
+ // bytes of a request. The Server.ConnState hook for
+ // StateActive fires before the request has entered a handler
+ // and doesn't fire again until the request has been
+ // handled. After the request is handled, the state
+ // transitions to StateClosed, StateHijacked, or StateIdle.
+ StateActive
+
+ // StateIdle represents a connection that has finished
+ // handling a request and is in the keep-alive state, waiting
+ // for a new request. Connections transition from StateIdle
+ // to either StateActive or StateClosed.
+ StateIdle
+
+ // StateHijacked represents a hijacked connection.
+ // This is a terminal state. It does not transition to StateClosed.
+ StateHijacked
+
+ // StateClosed represents a closed connection.
+ // This is a terminal state. Hijacked connections do not
+ // transition to StateClosed.
+ StateClosed
+)
+
+var stateName = map[ConnState]string{
+ StateNew: "new",
+ StateActive: "active",
+ StateIdle: "idle",
+ StateHijacked: "hijacked",
+ StateClosed: "closed",
+}
+
+func (c ConnState) String() string {
+ return stateName[c]
}
// serverHandler delegates to either the server's Handler or