From: Brad Fitzpatrick Date: Thu, 6 Dec 2012 03:25:43 +0000 (-0800) Subject: net/http: implement CloseNotifier X-Git-Tag: go1.1rc2~1711 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=4fb78c3a16580329c1c465fbc67c12456b8297dd;p=gostls13.git net/http: implement CloseNotifier Fixes #2510 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/6867050 --- diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index 355efb2cac..8ca227f9de 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -1252,6 +1252,42 @@ func TestContentLengthZero(t *testing.T) { } } +func TestCloseNotifier(t *testing.T) { + gotReq := make(chan bool, 1) + sawClose := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool) + go func() { + _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + } + ts.Close() +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index f786e81b9f..21480458b6 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -93,16 +93,104 @@ type Hijacker interface { Hijack() (net.Conn, *bufio.ReadWriter, error) } +// The CloseNotifier interface is implemented by ResponseWriters which +// allow detecting when the underlying connection has gone away. +// +// This mechanism can be used to cancel long operations on the server +// if the client has disconnected before the response is ready. +type CloseNotifier interface { + // CloseNotify returns a channel that receives a single value + // when the client connection has gone away. + CloseNotify() <-chan bool +} + // A conn represents the server side of an HTTP connection. type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - lr *io.LimitedReader // io.LimitReader(rwc) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc - hijacked bool // connection has been hijacked by handler + sr switchReader // where the LimitReader reads from; usually the rwc + lr *io.LimitedReader // io.LimitReader(sr) + buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc tlsState *tls.ConnectionState // or nil when not using TLS body []byte + + mu sync.Mutex // guards the following + clientGone bool // if client has disconnected mid-request + closeNotifyc chan bool // made lazily + hijackedv bool // connection has been hijacked by handler +} + +func (c *conn) hijacked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.hijackedv +} + +func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.hijackedv { + return nil, nil, ErrHijacked + } + if c.closeNotifyc != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") + } + c.hijackedv = true + rwc = c.rwc + buf = c.buf + c.rwc = nil + c.buf = nil + return +} + +func (c *conn) closeNotify() <-chan bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc == nil { + c.closeNotifyc = make(chan bool) + if c.hijackedv { + // to obey the function signature, even though + // it'll never receive a value. + return c.closeNotifyc + } + pr, pw := io.Pipe() + + readSource := c.sr.r + c.sr.Lock() + c.sr.r = pr + c.sr.Unlock() + go func() { + _, err := io.Copy(pw, readSource) + if err == nil { + err = io.EOF + } + pw.CloseWithError(err) + c.noteClientGone() + }() + } + return c.closeNotifyc +} + +func (c *conn) noteClientGone() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc != nil && !c.clientGone { + c.closeNotifyc <- true + } + c.clientGone = true +} + +type switchReader struct { + sync.Mutex + r io.Reader +} + +func (sr *switchReader) Read(p []byte) (n int, err error) { + sr.Lock() + r := sr.r + sr.Unlock() + return r.Read(p) } // A response represents the server side of an HTTP response. @@ -183,8 +271,9 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } + c.sr = switchReader{r: c.rwc} c.body = make([]byte, sniffLen) - c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader) + c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) br := bufio.NewReader(c.lr) bw := bufio.NewWriter(c.rwc) c.buf = bufio.NewReadWriter(br, bw) @@ -215,7 +304,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { if ecr.closed { return 0, errors.New("http: Read after Close on request Body") } - if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { + if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") ecr.resp.conn.buf.Flush() @@ -238,7 +327,7 @@ var errTooLarge = errors.New("http: request too large") // Read next request from connection. func (c *conn) readRequest() (w *response, err error) { - if c.hijacked { + if c.hijacked() { return nil, ErrHijacked } c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ @@ -279,7 +368,7 @@ func (w *response) Header() Header { const maxPostHandlerReadBytes = 256 << 10 func (w *response) WriteHeader(code int) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.WriteHeader on hijacked connection") return } @@ -455,7 +544,7 @@ func (w *response) bodyAllowed() bool { } func (w *response) Write(data []byte) (n int, err error) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.Write on hijacked connection") return 0, ErrHijacked } @@ -673,21 +762,7 @@ func (c *conn) serve() { } req.Header.Del("Expect") } else if req.Header.get("Expect") != "" { - // TODO(bradfitz): let ServeHTTP handlers handle - // requests with non-standard expectation[s]? Seems - // theoretical at best, and doesn't fit into the - // current ServeHTTP model anyway. We'd need to - // make the ResponseWriter an optional - // "ExpectReplier" interface or something. - // - // For now we'll just obey RFC 2616 14.20 which says - // "If a server receives a request containing an - // Expect field that includes an expectation- - // extension that it does not support, it MUST - // respond with a 417 (Expectation Failed) status." - w.Header().Set("Connection", "close") - w.WriteHeader(StatusExpectationFailed) - w.finishRequest() + w.sendExpectationFailed() break } @@ -702,7 +777,7 @@ func (c *conn) serve() { // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. handler.ServeHTTP(w, w.req) - if c.hijacked { + if c.hijacked() { return } w.finishRequest() @@ -716,18 +791,32 @@ func (c *conn) serve() { c.close() } +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() +} + // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter // and a Hijacker. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - if w.conn.hijacked { - return nil, nil, ErrHijacked - } - w.conn.hijacked = true - rwc = w.conn.rwc - buf = w.conn.buf - w.conn.rwc = nil - w.conn.buf = nil - return + return w.conn.hijack() +} + +func (w *response) CloseNotify() <-chan bool { + return w.conn.closeNotify() } // The HandlerFunc type is an adapter to allow the use of