From: Brad Fitzpatrick Date: Wed, 6 Apr 2016 19:31:55 +0000 (-0700) Subject: net/http: set the Request context for incoming server requests X-Git-Tag: go1.7beta1~784 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=1faa8869c6c72f055cdaa2b547964830909c96c6;p=gostls13.git net/http: set the Request context for incoming server requests Updates #13021 Updates #15224 Change-Id: Ia3cd608bb887fcfd8d81b035fa57bd5eb8edf09b Reviewed-on: https://go-review.googlesource.com/21810 Reviewed-by: Andrew Gerrand Run-TryBot: Brad Fitzpatrick Reviewed-by: Emmanuel Odeke TryBot-Result: Gobot Gobot --- diff --git a/src/net/http/request.go b/src/net/http/request.go index 5510691912..5bca888845 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -266,9 +266,13 @@ type Request struct { // // The returned context is always non-nil; it defaults to the // background context. +// +// For outgoing client requests, the context controls cancelation. +// +// For incoming server requests, the context is canceled when either +// the client's connection closes, or when the ServeHTTP method +// returns. func (r *Request) Context() context.Context { - // TODO(bradfitz): document above what Context means for server and client - // requests, once implemented. if r.ctx != nil { return r.ctx } diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 638ba5f48f..4cd6ed077f 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -9,6 +9,7 @@ package http_test import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -3989,6 +3990,72 @@ func TestServerValidatesHeaders(t *testing.T) { } } +func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { + defer afterTest(t) + ctxc := make(chan context.Context, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ctx := r.Context() + select { + case <-ctx.Done(): + t.Error("should not be Done in ServeHTTP") + default: + } + ctxc <- ctx + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + ctx := <-ctxc + select { + case <-ctx.Done(): + default: + t.Error("context should be done after ServeHTTP completes") + } +} + +func TestServerRequestContextCancel_ConnClose(t *testing.T) { + // Currently the context is not canceled when the connection + // is closed because we're not reading from the connection + // until after ServeHTTP for the previous handler is done. + // Until the server code is modified to always be in a read + // (Issue 15224), this test doesn't work yet. + t.Skip("TODO(bradfitz): this test doesn't yet work; golang.org/issue/15224") + defer afterTest(t) + inHandler := make(chan struct{}) + handlerDone := make(chan struct{}) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + close(inHandler) + select { + case <-r.Context().Done(): + case <-time.After(3 * time.Second): + t.Errorf("timeout waiting for context to be done") + } + close(handlerDone) + })) + defer ts.Close() + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n") + select { + case <-inHandler: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting to see ServeHTTP get called") + } + c.Close() // this should trigger the context being done + + select { + case <-handlerDone: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting to see ServeHTTP exit") + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() diff --git a/src/net/http/server.go b/src/net/http/server.go index f4e697169d..e37df99deb 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -9,6 +9,7 @@ package http import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -312,10 +313,11 @@ type response struct { conn *conn req *Request // request for this response reqBody io.ReadCloser - wroteHeader bool // reply header has been (logically) written - wroteContinue bool // 100 Continue response was written - wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" - wantsClose bool // HTTP request has Connection "close" + cancelCtx context.CancelFunc // when ServeHTTP exits + wroteHeader bool // reply header has been (logically) written + wroteContinue bool // 100 Continue response was written + wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" + wantsClose bool // HTTP request has Connection "close" w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -686,7 +688,7 @@ func appendTime(b []byte, t time.Time) []byte { var errTooLarge = errors.New("http: request too large") // Read next request from connection. -func (c *conn) readRequest() (w *response, err error) { +func (c *conn) readRequest(ctx context.Context) (w *response, err error) { if c.hijacked() { return nil, ErrHijacked } @@ -715,6 +717,10 @@ func (c *conn) readRequest() (w *response, err error) { } return nil, err } + + ctx, cancelCtx := context.WithCancel(ctx) + req.ctx = ctx + c.lastMethod = req.Method c.r.setInfiniteReadLimit() @@ -749,6 +755,7 @@ func (c *conn) readRequest() (w *response, err error) { w = &response{ conn: c, + cancelCtx: cancelCtx, req: req, reqBody: req.Body, handlerHeader: make(Header), @@ -1432,12 +1439,20 @@ func (c *conn) serve() { } } + // HTTP/1.x from here on. + c.r = &connReader{r: c.rwc} c.bufr = newBufioReader(c.r) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + // TODO: allow changing base context? can't imagine concrete + // use cases yet. + baseCtx := context.Background() + ctx, cancelCtx := context.WithCancel(baseCtx) + defer cancelCtx() + for { - w, err := c.readRequest() + w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. c.setState(c.rwc, StateActive) @@ -1485,6 +1500,7 @@ func (c *conn) serve() { // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. serverHandler{c.server}.ServeHTTP(w, w.req) + w.cancelCtx() if c.hijacked() { return }