]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix location of StateHijacked and StateActive
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 4 Mar 2014 02:58:28 +0000 (18:58 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 4 Mar 2014 02:58:28 +0000 (18:58 -0800)
1) Move StateHijacked callback earlier, to make it
panic-proof.  A Hijack followed by a panic didn't previously
result in ConnState getting fired for StateHijacked.  Move it
earlier, to the time of hijack.

2) Don't fire StateActive unless any bytes were read off the
wire while waiting for a request. This means we don't
transition from New or Idle to Active if the client
disconnects or times out. This was documented before, but not
implemented properly.

This CL is required for an pending fix for Issue 7264

LGTM=josharian
R=josharian
CC=golang-codereviews
https://golang.org/cl/69860049

src/pkg/net/http/serve_test.go
src/pkg/net/http/server.go

index 36832140b409e7142a6bd47a8dec5ba0697c45a0..21cd67f9dca2a62093ad316abb6d6997372b23fb 100644 (file)
@@ -2270,6 +2270,12 @@ func TestServerConnState(t *testing.T) {
                        c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
                        c.Close()
                },
+               "/hijack-panic": func(w ResponseWriter, r *Request) {
+                       c, _, _ := w.(Hijacker).Hijack()
+                       c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+                       c.Close()
+                       panic("intentional panic")
+               },
        }
        ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
                handler[r.URL.Path](w, r)
@@ -2280,6 +2286,7 @@ func TestServerConnState(t *testing.T) {
        var stateLog = map[int][]ConnState{}
        var connID = map[net.Conn]int{}
 
+       ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
        ts.Config.ConnState = func(c net.Conn, state ConnState) {
                if c == nil {
                        t.Error("nil conn seen in state %s", state)
@@ -2303,11 +2310,56 @@ func TestServerConnState(t *testing.T) {
        mustGet(t, ts.URL+"/", "Connection", "close")
 
        mustGet(t, ts.URL+"/hijack")
+       mustGet(t, ts.URL+"/hijack-panic")
+
+       // New->Closed
+       {
+               c, err := net.Dial("tcp", ts.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               c.Close()
+       }
+
+       // New->Active->Closed
+       {
+               c, err := net.Dial("tcp", ts.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
+                       t.Fatal(err)
+               }
+               c.Close()
+       }
+
+       // New->Idle->Closed
+       {
+               c, err := net.Dial("tcp", ts.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
+                       t.Fatal(err)
+               }
+               res, err := ReadResponse(bufio.NewReader(c), nil)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
+                       t.Fatal(err)
+               }
+               c.Close()
+       }
 
        want := map[int][]ConnState{
                1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed},
                2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed},
                3: []ConnState{StateNew, StateActive, StateHijacked},
+               4: []ConnState{StateNew, StateActive, StateHijacked},
+               5: []ConnState{StateNew, StateClosed},
+               6: []ConnState{StateNew, StateActive, StateClosed},
+               7: []ConnState{StateNew, StateActive, StateIdle, StateClosed},
        }
        logString := func(m map[int][]ConnState) string {
                var b bytes.Buffer
@@ -2316,6 +2368,7 @@ func TestServerConnState(t *testing.T) {
                        for _, s := range l {
                                fmt.Fprintf(&b, "%s ", s)
                        }
+                       b.WriteString("\n")
                }
                return b.String()
        }
index ffe5838a06377871fccbac55c3aa6d67f698540a..273d5964f14e2c308649f28b662736c8f6a9970f 100644 (file)
@@ -139,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
        buf = c.buf
        c.rwc = nil
        c.buf = nil
+       c.setState(rwc, StateHijacked)
        return
 }
 
@@ -497,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int {
        return DefaultMaxHeaderBytes
 }
 
+func (srv *Server) initialLimitedReaderSize() int64 {
+       return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
+}
+
 // wrapper around io.ReaderCloser which on first read, sends an
 // HTTP/1.1 100 Continue header
 type expectContinueReader struct {
@@ -567,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) {
                }()
        }
 
-       c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
+       c.lr.N = c.server.initialLimitedReaderSize()
        var req *Request
        if req, err = ReadRequest(c.buf.Reader); err != nil {
                if c.lr.N == 0 {
@@ -1127,10 +1132,10 @@ func (c *conn) serve() {
 
        for {
                w, err := c.readRequest()
-               // TODO(bradfitz): could push this StateActive
-               // earlier, but in practice header will be all in one
-               // packet/Read:
-               c.setState(c.rwc, StateActive)
+               if c.lr.N != c.server.initialLimitedReaderSize() {
+                       // If we read any bytes off the wire, we're active.
+                       c.setState(c.rwc, StateActive)
+               }
                if err != nil {
                        if err == errTooLarge {
                                // Their HTTP client may or may not be
@@ -1176,7 +1181,6 @@ func (c *conn) serve() {
                // in parallel even if their responses need to be serialized.
                serverHandler{c.server}.ServeHTTP(w, w.req)
                if c.hijacked() {
-                       c.setState(origConn, StateHijacked)
                        return
                }
                w.finishRequest()