From fd0c0db4a411eae0483d1cb141e801af401e43d3 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 29 Jul 2022 09:27:16 -0700 Subject: [PATCH] net/http: add ResponseController and per-handler timeouts The ResponseController type provides a discoverable interface to optional methods implemented by ResponseWriters. c := http.NewResponseController(w) c.Flush() vs. if f, ok := w.(http.Flusher); ok { f.Flush() } Add the ability to control per-request read and write deadlines via the ResponseController SetReadDeadline and SetWriteDeadline methods. For #54136 Change-Id: I3f97de60d4c9ff150cda559ef86c6620eee665d2 Reviewed-on: https://go-review.googlesource.com/c/go/+/436890 TryBot-Result: Gopher Robot Reviewed-by: Brad Fitzpatrick Reviewed-by: Bryan Mills Run-TryBot: Damien Neil --- api/next/54136.txt | 6 + src/net/http/request.go | 9 +- src/net/http/responsecontroller.go | 122 +++++++++++ src/net/http/responsecontroller_test.go | 277 ++++++++++++++++++++++++ src/net/http/server.go | 29 ++- 5 files changed, 434 insertions(+), 9 deletions(-) create mode 100644 api/next/54136.txt create mode 100644 src/net/http/responsecontroller.go create mode 100644 src/net/http/responsecontroller_test.go diff --git a/api/next/54136.txt b/api/next/54136.txt new file mode 100644 index 0000000000..feeba39d8d --- /dev/null +++ b/api/next/54136.txt @@ -0,0 +1,6 @@ +pkg net/http, func NewResponseController(ResponseWriter) *ResponseController #54136 +pkg net/http, method (*ResponseController) Flush() error #54136 +pkg net/http, method (*ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) #54136 +pkg net/http, method (*ResponseController) SetReadDeadline(time.Time) error #54136 +pkg net/http, method (*ResponseController) SetWriteDeadline(time.Time) error #54136 +pkg net/http, type ResponseController struct #54136 diff --git a/src/net/http/request.go b/src/net/http/request.go index 88d3d75af5..a45c9e3d18 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -49,9 +49,12 @@ type ProtocolError struct { func (pe *ProtocolError) Error() string { return pe.ErrorString } var ( - // ErrNotSupported is returned by the Push method of Pusher - // implementations to indicate that HTTP/2 Push support is not - // available. + // ErrNotSupported indicates that a feature is not supported. + // + // It is returned by ResponseController methods to indicate that + // the handler does not support the method, and by the Push method + // of Pusher implementations to indicate that HTTP/2 Push support + // is not available. ErrNotSupported = &ProtocolError{"feature not supported"} // Deprecated: ErrUnexpectedTrailer is no longer returned by diff --git a/src/net/http/responsecontroller.go b/src/net/http/responsecontroller.go new file mode 100644 index 0000000000..018bdc00eb --- /dev/null +++ b/src/net/http/responsecontroller.go @@ -0,0 +1,122 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "fmt" + "net" + "time" +) + +// A ResponseController is used by an HTTP handler to control the response. +// +// A ResponseController may not be used after the Handler.ServeHTTP method has returned. +type ResponseController struct { + rw ResponseWriter +} + +// NewResponseController creates a ResponseController for a request. +// +// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method, +// or have an Unwrap method returning the original ResponseWriter. +// +// If the ResponseWriter implements any of the following methods, the ResponseController +// will call them as appropriate: +// +// Flush() +// FlushError() error // alternative Flush returning an error +// Hijack() (net.Conn, *bufio.ReadWriter, error) +// SetReadDeadline(deadline time.Time) error +// SetWriteDeadline(deadline time.Time) error +// +// If the ResponseWriter does not support a method, ResponseController returns +// an error matching ErrNotSupported. +func NewResponseController(rw ResponseWriter) *ResponseController { + return &ResponseController{rw} +} + +type rwUnwrapper interface { + Unwrap() ResponseWriter +} + +// Flush flushes buffered data to the client. +func (c *ResponseController) Flush() error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ FlushError() error }: + return t.FlushError() + case Flusher: + t.Flush() + return nil + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + +// Hijack lets the caller take over the connection. +// See the Hijacker interface for details. +func (c *ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := c.rw + for { + switch t := rw.(type) { + case Hijacker: + return t.Hijack() + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, nil, errNotSupported() + } + } +} + +// SetReadDeadline sets the deadline for reading the entire request, including the body. +// Reads from the request body after the deadline has been exceeded will return an error. +// A zero value means no deadline. +// +// Setting the read deadline after it has been exceeded will not extend it. +func (c *ResponseController) SetReadDeadline(deadline time.Time) error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ SetReadDeadline(time.Time) error }: + return t.SetReadDeadline(deadline) + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + +// SetWriteDeadline sets the deadline for writing the response. +// Writes to the response body after the deadline has been exceeded will not block, +// but may succeed if the data has been buffered. +// A zero value means no deadline. +// +// Setting the write deadline after it has been exceeded will not extend it. +func (c *ResponseController) SetWriteDeadline(deadline time.Time) error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ SetWriteDeadline(time.Time) error }: + return t.SetWriteDeadline(deadline) + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + +// errNotSupported returns an error that Is ErrNotSupported, +// but is not == to it. +func errNotSupported() error { + return fmt.Errorf("%w", ErrNotSupported) +} diff --git a/src/net/http/responsecontroller_test.go b/src/net/http/responsecontroller_test.go new file mode 100644 index 0000000000..af036837a4 --- /dev/null +++ b/src/net/http/responsecontroller_test.go @@ -0,0 +1,277 @@ +package http_test + +import ( + "errors" + "fmt" + "io" + . "net/http" + "os" + "sync" + "testing" + "time" +) + +func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) } +func testResponseControllerFlush(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + continuec := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + w.Write([]byte("one")) + if err := ctl.Flush(); err != nil { + t.Errorf("ctl.Flush() = %v, want nil", err) + return + } + <-continuec + w.Write([]byte("two")) + })) + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() + + buf := make([]byte, 16) + n, err := res.Body.Read(buf) + close(continuec) + if err != nil || string(buf[:n]) != "one" { + t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one") + } + + got, err := io.ReadAll(res.Body) + if err != nil || string(got) != "two" { + t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two") + } +} + +func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) } +func testResponseControllerHijack(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + const header = "X-Header" + const value = "set" + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + c, _, err := ctl.Hijack() + if mode == http2Mode { + if err == nil { + t.Errorf("ctl.Hijack = nil, want error") + } + w.Header().Set(header, value) + return + } + if err != nil { + t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err) + return + } + fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value) + })) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if got, want := res.Header.Get(header), value; got != want { + t.Errorf("response header %q = %q, want %q", header, got, want) + } +} + +func TestResponseControllerSetPastWriteDeadline(t *testing.T) { + run(t, testResponseControllerSetPastWriteDeadline) +} +func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + w.Write([]byte("one")) + if err := ctl.Flush(); err != nil { + t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err) + } + if err := ctl.SetWriteDeadline(time.Now()); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + + w.Write([]byte("two")) + if err := ctl.Flush(); err == nil { + t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil") + } + // Connection errors are sticky, so resetting the deadline does not permit + // making more progress. We might want to change this in the future, but verify + // the current behavior for now. If we do change this, we'll want to make sure + // to do so only for writing the response body, not headers. + if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + w.Write([]byte("three")) + if err := ctl.Flush(); err == nil { + t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil") + } + })) + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() + b, _ := io.ReadAll(res.Body) + if string(b) != "one" { + t.Errorf("unexpected body: %q", string(b)) + } +} + +func TestResponseControllerSetFutureWriteDeadline(t *testing.T) { + run(t, testResponseControllerSetFutureWriteDeadline) +} +func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + errc := make(chan error, 1) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + _, err := io.Copy(w, neverEnding('a')) + errc <- err + })) + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() + _, err = io.Copy(io.Discard, res.Body) + if err == nil { + t.Errorf("client reading from truncated request body: got nil error, want non-nil") + } + err = <-errc // io.Copy error + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err) + } +} + +func TestResponseControllerSetPastReadDeadline(t *testing.T) { + run(t, testResponseControllerSetPastReadDeadline) +} +func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + readc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + b := make([]byte, 3) + n, err := io.ReadFull(r.Body, b) + b = b[:n] + if err != nil || string(b) != "one" { + t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one") + return + } + if err := ctl.SetReadDeadline(time.Now()); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + return + } + b, err = io.ReadAll(r.Body) + if err == nil || string(b) != "" { + t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b)) + } + close(readc) + // Connection errors are sticky, so resetting the deadline does not permit + // making more progress. We might want to change this in the future, but verify + // the current behavior for now. + if err := ctl.SetReadDeadline(time.Time{}); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + return + } + b, err = io.ReadAll(r.Body) + if err == nil { + t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b)) + } + })) + + pr, pw := io.Pipe() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + pw.Write([]byte("one")) + <-readc + pw.Write([]byte("two")) + pw.Close() + }() + defer wg.Wait() + res, err := cst.c.Post(cst.ts.URL, "text/foo", pr) + if err == nil { + defer res.Body.Close() + } +} + +func TestResponseControllerSetFutureReadDeadline(t *testing.T) { + run(t, testResponseControllerSetFutureReadDeadline) +} +func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + respBody := "response body" + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { + ctl := NewResponseController(w) + if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + } + _, err := io.Copy(io.Discard, req.Body) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err) + } + w.Write([]byte(respBody)) + })) + pr, pw := io.Pipe() + res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if string(got) != respBody || err != nil { + t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody) + } + pw.Close() +} + +type wrapWriter struct { + ResponseWriter +} + +func (w wrapWriter) Unwrap() ResponseWriter { + return w.ResponseWriter +} + +func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) } +func testWrappedResponseController(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until h2_bundle.go is updated") + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ctl := NewResponseController(w) + if err := ctl.Flush(); err != nil { + t.Errorf("ctl.Flush() = %v, want nil", err) + } + if err := ctl.SetReadDeadline(time.Time{}); err != nil { + t.Errorf("ctl.SetReadDeadline() = %v, want nil", err) + } + if err := ctl.SetWriteDeadline(time.Time{}); err != nil { + t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err) + } + })) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer res.Body.Close() +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 8d3b0f3ad1..698d0636fa 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -395,11 +395,11 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { return } -func (cw *chunkWriter) flush() { +func (cw *chunkWriter) flush() error { if !cw.wroteHeader { cw.writeHeader(nil) } - cw.res.conn.bufw.Flush() + return cw.res.conn.bufw.Flush() } func (cw *chunkWriter) close() { @@ -486,7 +486,15 @@ type response struct { // TODO(bradfitz): this is currently (for Go 1.8) always // non-nil. Make this lazily-created again as it used to be? closeNotifyCh chan bool - didCloseNotify int32 // atomic (only 0->1 winner should send) + didCloseNotify atomic.Bool // atomic (only false->true winner should send) +} + +func (c *response) SetReadDeadline(deadline time.Time) error { + return c.conn.rwc.SetReadDeadline(deadline) +} + +func (c *response) SetWriteDeadline(deadline time.Time) error { + return c.conn.rwc.SetWriteDeadline(deadline) } // TrailerPrefix is a magic prefix for ResponseWriter.Header map keys @@ -738,7 +746,7 @@ func (cr *connReader) handleReadError(_ error) { // may be called from multiple goroutines. func (cr *connReader) closeNotify() { res := cr.conn.curReq.Load() - if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + if res != nil && !res.didCloseNotify.Swap(true) { res.closeNotifyCh <- true } } @@ -1688,11 +1696,19 @@ func (w *response) closedRequestBodyEarly() bool { } func (w *response) Flush() { + w.FlushError() +} + +func (w *response) FlushError() error { if !w.wroteHeader { w.WriteHeader(StatusOK) } - w.w.Flush() - w.cw.flush() + err := w.w.Flush() + e2 := w.cw.flush() + if err == nil { + err = e2 + } + return err } func (c *conn) finalFlush() { @@ -1983,6 +1999,7 @@ func (c *conn) serve(ctx context.Context) { return } w.finishRequest() + c.rwc.SetWriteDeadline(time.Time{}) if !w.shouldReuseConnection() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { c.closeWriteAndWait() -- 2.50.0