]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: implement CloseNotifier
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 6 Dec 2012 03:25:43 +0000 (19:25 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 6 Dec 2012 03:25:43 +0000 (19:25 -0800)
Fixes #2510

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6867050

src/pkg/net/http/serve_test.go
src/pkg/net/http/server.go

index 355efb2cac9f3b5d9beb59eb3f30718b127c1ec2..8ca227f9dec71cc7e233c1218690638c3f945134 100644 (file)
@@ -1252,6 +1252,42 @@ func TestContentLengthZero(t *testing.T) {
        }
 }
 
+func TestCloseNotifier(t *testing.T) {
+       gotReq := make(chan bool, 1)
+       sawClose := make(chan bool, 1)
+       ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+               gotReq <- true
+               cc := rw.(CloseNotifier).CloseNotify()
+               <-cc
+               sawClose <- true
+       }))
+       conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+       if err != nil {
+               t.Fatalf("error dialing: %v", err)
+       }
+       diec := make(chan bool)
+       go func() {
+               _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               <-diec
+               conn.Close()
+       }()
+For:
+       for {
+               select {
+               case <-gotReq:
+                       diec <- true
+               case <-sawClose:
+                       break For
+               case <-time.After(5 * time.Second):
+                       t.Fatal("timeout")
+               }
+       }
+       ts.Close()
+}
+
 // goTimeout runs f, failing t if f takes more than ns to complete.
 func goTimeout(t *testing.T, d time.Duration, f func()) {
        ch := make(chan bool, 2)
index f786e81b9f460f07056ce2d4d61fbe53a9e54f18..21480458b683afa4583a5b7037b7ad7428dcd19f 100644 (file)
@@ -93,16 +93,104 @@ type Hijacker interface {
        Hijack() (net.Conn, *bufio.ReadWriter, error)
 }
 
+// The CloseNotifier interface is implemented by ResponseWriters which
+// allow detecting when the underlying connection has gone away.
+//
+// This mechanism can be used to cancel long operations on the server
+// if the client has disconnected before the response is ready.
+type CloseNotifier interface {
+       // CloseNotify returns a channel that receives a single value
+       // when the client connection has gone away.
+       CloseNotify() <-chan bool
+}
+
 // A conn represents the server side of an HTTP connection.
 type conn struct {
        remoteAddr string               // network address of remote side
        server     *Server              // the Server on which the connection arrived
        rwc        net.Conn             // i/o connection
-       lr         *io.LimitedReader    // io.LimitReader(rwc)
-       buf        *bufio.ReadWriter    // buffered(lr,rwc), reading from bufio->limitReader->rwc
-       hijacked   bool                 // connection has been hijacked by handler
+       sr         switchReader         // where the LimitReader reads from; usually the rwc
+       lr         *io.LimitedReader    // io.LimitReader(sr)
+       buf        *bufio.ReadWriter    // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
        tlsState   *tls.ConnectionState // or nil when not using TLS
        body       []byte
+
+       mu           sync.Mutex // guards the following
+       clientGone   bool       // if client has disconnected mid-request
+       closeNotifyc chan bool  // made lazily
+       hijackedv    bool       // connection has been hijacked by handler
+}
+
+func (c *conn) hijacked() bool {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       return c.hijackedv
+}
+
+func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       if c.hijackedv {
+               return nil, nil, ErrHijacked
+       }
+       if c.closeNotifyc != nil {
+               return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
+       }
+       c.hijackedv = true
+       rwc = c.rwc
+       buf = c.buf
+       c.rwc = nil
+       c.buf = nil
+       return
+}
+
+func (c *conn) closeNotify() <-chan bool {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       if c.closeNotifyc == nil {
+               c.closeNotifyc = make(chan bool)
+               if c.hijackedv {
+                       // to obey the function signature, even though
+                       // it'll never receive a value.
+                       return c.closeNotifyc
+               }
+               pr, pw := io.Pipe()
+
+               readSource := c.sr.r
+               c.sr.Lock()
+               c.sr.r = pr
+               c.sr.Unlock()
+               go func() {
+                       _, err := io.Copy(pw, readSource)
+                       if err == nil {
+                               err = io.EOF
+                       }
+                       pw.CloseWithError(err)
+                       c.noteClientGone()
+               }()
+       }
+       return c.closeNotifyc
+}
+
+func (c *conn) noteClientGone() {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       if c.closeNotifyc != nil && !c.clientGone {
+               c.closeNotifyc <- true
+       }
+       c.clientGone = true
+}
+
+type switchReader struct {
+       sync.Mutex
+       r io.Reader
+}
+
+func (sr *switchReader) Read(p []byte) (n int, err error) {
+       sr.Lock()
+       r := sr.r
+       sr.Unlock()
+       return r.Read(p)
 }
 
 // A response represents the server side of an HTTP response.
@@ -183,8 +271,9 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
        if debugServerConnections {
                c.rwc = newLoggingConn("server", c.rwc)
        }
+       c.sr = switchReader{r: c.rwc}
        c.body = make([]byte, sniffLen)
-       c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader)
+       c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
        br := bufio.NewReader(c.lr)
        bw := bufio.NewWriter(c.rwc)
        c.buf = bufio.NewReadWriter(br, bw)
@@ -215,7 +304,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
        if ecr.closed {
                return 0, errors.New("http: Read after Close on request Body")
        }
-       if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
+       if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
                ecr.resp.wroteContinue = true
                io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
                ecr.resp.conn.buf.Flush()
@@ -238,7 +327,7 @@ var errTooLarge = errors.New("http: request too large")
 
 // Read next request from connection.
 func (c *conn) readRequest() (w *response, err error) {
-       if c.hijacked {
+       if c.hijacked() {
                return nil, ErrHijacked
        }
        c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
@@ -279,7 +368,7 @@ func (w *response) Header() Header {
 const maxPostHandlerReadBytes = 256 << 10
 
 func (w *response) WriteHeader(code int) {
-       if w.conn.hijacked {
+       if w.conn.hijacked() {
                log.Print("http: response.WriteHeader on hijacked connection")
                return
        }
@@ -455,7 +544,7 @@ func (w *response) bodyAllowed() bool {
 }
 
 func (w *response) Write(data []byte) (n int, err error) {
-       if w.conn.hijacked {
+       if w.conn.hijacked() {
                log.Print("http: response.Write on hijacked connection")
                return 0, ErrHijacked
        }
@@ -673,21 +762,7 @@ func (c *conn) serve() {
                        }
                        req.Header.Del("Expect")
                } else if req.Header.get("Expect") != "" {
-                       // TODO(bradfitz): let ServeHTTP handlers handle
-                       // requests with non-standard expectation[s]? Seems
-                       // theoretical at best, and doesn't fit into the
-                       // current ServeHTTP model anyway.  We'd need to
-                       // make the ResponseWriter an optional
-                       // "ExpectReplier" interface or something.
-                       //
-                       // For now we'll just obey RFC 2616 14.20 which says
-                       // "If a server receives a request containing an
-                       // Expect field that includes an expectation-
-                       // extension that it does not support, it MUST
-                       // respond with a 417 (Expectation Failed) status."
-                       w.Header().Set("Connection", "close")
-                       w.WriteHeader(StatusExpectationFailed)
-                       w.finishRequest()
+                       w.sendExpectationFailed()
                        break
                }
 
@@ -702,7 +777,7 @@ func (c *conn) serve() {
                // [*] Not strictly true: HTTP pipelining.  We could let them all process
                // in parallel even if their responses need to be serialized.
                handler.ServeHTTP(w, w.req)
-               if c.hijacked {
+               if c.hijacked() {
                        return
                }
                w.finishRequest()
@@ -716,18 +791,32 @@ func (c *conn) serve() {
        c.close()
 }
 
+func (w *response) sendExpectationFailed() {
+       // TODO(bradfitz): let ServeHTTP handlers handle
+       // requests with non-standard expectation[s]? Seems
+       // theoretical at best, and doesn't fit into the
+       // current ServeHTTP model anyway.  We'd need to
+       // make the ResponseWriter an optional
+       // "ExpectReplier" interface or something.
+       //
+       // For now we'll just obey RFC 2616 14.20 which says
+       // "If a server receives a request containing an
+       // Expect field that includes an expectation-
+       // extension that it does not support, it MUST
+       // respond with a 417 (Expectation Failed) status."
+       w.Header().Set("Connection", "close")
+       w.WriteHeader(StatusExpectationFailed)
+       w.finishRequest()
+}
+
 // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
 // and a Hijacker.
 func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
-       if w.conn.hijacked {
-               return nil, nil, ErrHijacked
-       }
-       w.conn.hijacked = true
-       rwc = w.conn.rwc
-       buf = w.conn.buf
-       w.conn.rwc = nil
-       w.conn.buf = nil
-       return
+       return w.conn.hijack()
+}
+
+func (w *response) CloseNotify() <-chan bool {
+       return w.conn.closeNotify()
 }
 
 // The HandlerFunc type is an adapter to allow the use of