]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add ResponseController and per-handler timeouts
authorDamien Neil <dneil@google.com>
Fri, 29 Jul 2022 16:27:16 +0000 (09:27 -0700)
committerDamien Neil <dneil@google.com>
Thu, 10 Nov 2022 18:18:03 +0000 (18:18 +0000)
The ResponseController type provides a discoverable interface
to optional methods implemented by ResponseWriters.

c := http.NewResponseController(w)
c.Flush()

vs.

if f, ok := w.(http.Flusher); ok {
f.Flush()
}

Add the ability to control per-request read and write deadlines
via the ResponseController SetReadDeadline and SetWriteDeadline
methods.

For #54136

Change-Id: I3f97de60d4c9ff150cda559ef86c6620eee665d2
Reviewed-on: https://go-review.googlesource.com/c/go/+/436890
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Damien Neil <dneil@google.com>

api/next/54136.txt [new file with mode: 0644]
src/net/http/request.go
src/net/http/responsecontroller.go [new file with mode: 0644]
src/net/http/responsecontroller_test.go [new file with mode: 0644]
src/net/http/server.go

diff --git a/api/next/54136.txt b/api/next/54136.txt
new file mode 100644 (file)
index 0000000..feeba39
--- /dev/null
@@ -0,0 +1,6 @@
+pkg net/http, func NewResponseController(ResponseWriter) *ResponseController #54136
+pkg net/http, method (*ResponseController) Flush() error #54136
+pkg net/http, method (*ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) #54136
+pkg net/http, method (*ResponseController) SetReadDeadline(time.Time) error #54136
+pkg net/http, method (*ResponseController) SetWriteDeadline(time.Time) error #54136
+pkg net/http, type ResponseController struct #54136
index 88d3d75af5aec75f272583e499637b4fff7aaf01..a45c9e3d18e16a74a5bbb8b78de077d71042aa43 100644 (file)
@@ -49,9 +49,12 @@ type ProtocolError struct {
 func (pe *ProtocolError) Error() string { return pe.ErrorString }
 
 var (
-       // ErrNotSupported is returned by the Push method of Pusher
-       // implementations to indicate that HTTP/2 Push support is not
-       // available.
+       // ErrNotSupported indicates that a feature is not supported.
+       //
+       // It is returned by ResponseController methods to indicate that
+       // the handler does not support the method, and by the Push method
+       // of Pusher implementations to indicate that HTTP/2 Push support
+       // is not available.
        ErrNotSupported = &ProtocolError{"feature not supported"}
 
        // Deprecated: ErrUnexpectedTrailer is no longer returned by
diff --git a/src/net/http/responsecontroller.go b/src/net/http/responsecontroller.go
new file mode 100644 (file)
index 0000000..018bdc0
--- /dev/null
@@ -0,0 +1,122 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "bufio"
+       "fmt"
+       "net"
+       "time"
+)
+
+// A ResponseController is used by an HTTP handler to control the response.
+//
+// A ResponseController may not be used after the Handler.ServeHTTP method has returned.
+type ResponseController struct {
+       rw ResponseWriter
+}
+
+// NewResponseController creates a ResponseController for a request.
+//
+// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method,
+// or have an Unwrap method returning the original ResponseWriter.
+//
+// If the ResponseWriter implements any of the following methods, the ResponseController
+// will call them as appropriate:
+//
+//     Flush()
+//     FlushError() error // alternative Flush returning an error
+//     Hijack() (net.Conn, *bufio.ReadWriter, error)
+//     SetReadDeadline(deadline time.Time) error
+//     SetWriteDeadline(deadline time.Time) error
+//
+// If the ResponseWriter does not support a method, ResponseController returns
+// an error matching ErrNotSupported.
+func NewResponseController(rw ResponseWriter) *ResponseController {
+       return &ResponseController{rw}
+}
+
+type rwUnwrapper interface {
+       Unwrap() ResponseWriter
+}
+
+// Flush flushes buffered data to the client.
+func (c *ResponseController) Flush() error {
+       rw := c.rw
+       for {
+               switch t := rw.(type) {
+               case interface{ FlushError() error }:
+                       return t.FlushError()
+               case Flusher:
+                       t.Flush()
+                       return nil
+               case rwUnwrapper:
+                       rw = t.Unwrap()
+               default:
+                       return errNotSupported()
+               }
+       }
+}
+
+// Hijack lets the caller take over the connection.
+// See the Hijacker interface for details.
+func (c *ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       rw := c.rw
+       for {
+               switch t := rw.(type) {
+               case Hijacker:
+                       return t.Hijack()
+               case rwUnwrapper:
+                       rw = t.Unwrap()
+               default:
+                       return nil, nil, errNotSupported()
+               }
+       }
+}
+
+// SetReadDeadline sets the deadline for reading the entire request, including the body.
+// Reads from the request body after the deadline has been exceeded will return an error.
+// A zero value means no deadline.
+//
+// Setting the read deadline after it has been exceeded will not extend it.
+func (c *ResponseController) SetReadDeadline(deadline time.Time) error {
+       rw := c.rw
+       for {
+               switch t := rw.(type) {
+               case interface{ SetReadDeadline(time.Time) error }:
+                       return t.SetReadDeadline(deadline)
+               case rwUnwrapper:
+                       rw = t.Unwrap()
+               default:
+                       return errNotSupported()
+               }
+       }
+}
+
+// SetWriteDeadline sets the deadline for writing the response.
+// Writes to the response body after the deadline has been exceeded will not block,
+// but may succeed if the data has been buffered.
+// A zero value means no deadline.
+//
+// Setting the write deadline after it has been exceeded will not extend it.
+func (c *ResponseController) SetWriteDeadline(deadline time.Time) error {
+       rw := c.rw
+       for {
+               switch t := rw.(type) {
+               case interface{ SetWriteDeadline(time.Time) error }:
+                       return t.SetWriteDeadline(deadline)
+               case rwUnwrapper:
+                       rw = t.Unwrap()
+               default:
+                       return errNotSupported()
+               }
+       }
+}
+
+// errNotSupported returns an error that Is ErrNotSupported,
+// but is not == to it.
+func errNotSupported() error {
+       return fmt.Errorf("%w", ErrNotSupported)
+}
diff --git a/src/net/http/responsecontroller_test.go b/src/net/http/responsecontroller_test.go
new file mode 100644 (file)
index 0000000..af03683
--- /dev/null
@@ -0,0 +1,277 @@
+package http_test
+
+import (
+       "errors"
+       "fmt"
+       "io"
+       . "net/http"
+       "os"
+       "sync"
+       "testing"
+       "time"
+)
+
+func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) }
+func testResponseControllerFlush(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       continuec := make(chan struct{})
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               ctl := NewResponseController(w)
+               w.Write([]byte("one"))
+               if err := ctl.Flush(); err != nil {
+                       t.Errorf("ctl.Flush() = %v, want nil", err)
+                       return
+               }
+               <-continuec
+               w.Write([]byte("two"))
+       }))
+
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Fatalf("unexpected connection error: %v", err)
+       }
+       defer res.Body.Close()
+
+       buf := make([]byte, 16)
+       n, err := res.Body.Read(buf)
+       close(continuec)
+       if err != nil || string(buf[:n]) != "one" {
+               t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one")
+       }
+
+       got, err := io.ReadAll(res.Body)
+       if err != nil || string(got) != "two" {
+               t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two")
+       }
+}
+
+func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) }
+func testResponseControllerHijack(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       const header = "X-Header"
+       const value = "set"
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               ctl := NewResponseController(w)
+               c, _, err := ctl.Hijack()
+               if mode == http2Mode {
+                       if err == nil {
+                               t.Errorf("ctl.Hijack = nil, want error")
+                       }
+                       w.Header().Set(header, value)
+                       return
+               }
+               if err != nil {
+                       t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err)
+                       return
+               }
+               fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value)
+       }))
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if got, want := res.Header.Get(header), value; got != want {
+               t.Errorf("response header %q = %q, want %q", header, got, want)
+       }
+}
+
+func TestResponseControllerSetPastWriteDeadline(t *testing.T) {
+       run(t, testResponseControllerSetPastWriteDeadline)
+}
+func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               ctl := NewResponseController(w)
+               w.Write([]byte("one"))
+               if err := ctl.Flush(); err != nil {
+                       t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err)
+               }
+               if err := ctl.SetWriteDeadline(time.Now()); err != nil {
+                       t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+               }
+
+               w.Write([]byte("two"))
+               if err := ctl.Flush(); err == nil {
+                       t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil")
+               }
+               // Connection errors are sticky, so resetting the deadline does not permit
+               // making more progress. We might want to change this in the future, but verify
+               // the current behavior for now. If we do change this, we'll want to make sure
+               // to do so only for writing the response body, not headers.
+               if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil {
+                       t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+               }
+               w.Write([]byte("three"))
+               if err := ctl.Flush(); err == nil {
+                       t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil")
+               }
+       }))
+
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Fatalf("unexpected connection error: %v", err)
+       }
+       defer res.Body.Close()
+       b, _ := io.ReadAll(res.Body)
+       if string(b) != "one" {
+               t.Errorf("unexpected body: %q", string(b))
+       }
+}
+
+func TestResponseControllerSetFutureWriteDeadline(t *testing.T) {
+       run(t, testResponseControllerSetFutureWriteDeadline)
+}
+func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       errc := make(chan error, 1)
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               ctl := NewResponseController(w)
+               if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
+                       t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+               }
+               _, err := io.Copy(w, neverEnding('a'))
+               errc <- err
+       }))
+
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Fatalf("unexpected connection error: %v", err)
+       }
+       defer res.Body.Close()
+       _, err = io.Copy(io.Discard, res.Body)
+       if err == nil {
+               t.Errorf("client reading from truncated request body: got nil error, want non-nil")
+       }
+       err = <-errc // io.Copy error
+       if !errors.Is(err, os.ErrDeadlineExceeded) {
+               t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
+       }
+}
+
+func TestResponseControllerSetPastReadDeadline(t *testing.T) {
+       run(t, testResponseControllerSetPastReadDeadline)
+}
+func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       readc := make(chan struct{})
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               ctl := NewResponseController(w)
+               b := make([]byte, 3)
+               n, err := io.ReadFull(r.Body, b)
+               b = b[:n]
+               if err != nil || string(b) != "one" {
+                       t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one")
+                       return
+               }
+               if err := ctl.SetReadDeadline(time.Now()); err != nil {
+                       t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+                       return
+               }
+               b, err = io.ReadAll(r.Body)
+               if err == nil || string(b) != "" {
+                       t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b))
+               }
+               close(readc)
+               // Connection errors are sticky, so resetting the deadline does not permit
+               // making more progress. We might want to change this in the future, but verify
+               // the current behavior for now.
+               if err := ctl.SetReadDeadline(time.Time{}); err != nil {
+                       t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+                       return
+               }
+               b, err = io.ReadAll(r.Body)
+               if err == nil {
+                       t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b))
+               }
+       }))
+
+       pr, pw := io.Pipe()
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               pw.Write([]byte("one"))
+               <-readc
+               pw.Write([]byte("two"))
+               pw.Close()
+       }()
+       defer wg.Wait()
+       res, err := cst.c.Post(cst.ts.URL, "text/foo", pr)
+       if err == nil {
+               defer res.Body.Close()
+       }
+}
+
+func TestResponseControllerSetFutureReadDeadline(t *testing.T) {
+       run(t, testResponseControllerSetFutureReadDeadline)
+}
+func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       respBody := "response body"
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
+               ctl := NewResponseController(w)
+               if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
+                       t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+               }
+               _, err := io.Copy(io.Discard, req.Body)
+               if !errors.Is(err, os.ErrDeadlineExceeded) {
+                       t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
+               }
+               w.Write([]byte(respBody))
+       }))
+       pr, pw := io.Pipe()
+       res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer res.Body.Close()
+       got, err := io.ReadAll(res.Body)
+       if string(got) != respBody || err != nil {
+               t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
+       }
+       pw.Close()
+}
+
+type wrapWriter struct {
+       ResponseWriter
+}
+
+func (w wrapWriter) Unwrap() ResponseWriter {
+       return w.ResponseWriter
+}
+
+func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) }
+func testWrappedResponseController(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until h2_bundle.go is updated")
+       }
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               ctl := NewResponseController(w)
+               if err := ctl.Flush(); err != nil {
+                       t.Errorf("ctl.Flush() = %v, want nil", err)
+               }
+               if err := ctl.SetReadDeadline(time.Time{}); err != nil {
+                       t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+               }
+               if err := ctl.SetWriteDeadline(time.Time{}); err != nil {
+                       t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+               }
+       }))
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Fatalf("unexpected connection error: %v", err)
+       }
+       defer res.Body.Close()
+}
index 8d3b0f3ad118f7d2017e48aea30493c3b1134529..698d0636fa8f12674299393b05df5ac55622134c 100644 (file)
@@ -395,11 +395,11 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) {
        return
 }
 
-func (cw *chunkWriter) flush() {
+func (cw *chunkWriter) flush() error {
        if !cw.wroteHeader {
                cw.writeHeader(nil)
        }
-       cw.res.conn.bufw.Flush()
+       return cw.res.conn.bufw.Flush()
 }
 
 func (cw *chunkWriter) close() {
@@ -486,7 +486,15 @@ type response struct {
        // TODO(bradfitz): this is currently (for Go 1.8) always
        // non-nil. Make this lazily-created again as it used to be?
        closeNotifyCh  chan bool
-       didCloseNotify int32 // atomic (only 0->1 winner should send)
+       didCloseNotify atomic.Bool // atomic (only false->true winner should send)
+}
+
+func (c *response) SetReadDeadline(deadline time.Time) error {
+       return c.conn.rwc.SetReadDeadline(deadline)
+}
+
+func (c *response) SetWriteDeadline(deadline time.Time) error {
+       return c.conn.rwc.SetWriteDeadline(deadline)
 }
 
 // TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
@@ -738,7 +746,7 @@ func (cr *connReader) handleReadError(_ error) {
 // may be called from multiple goroutines.
 func (cr *connReader) closeNotify() {
        res := cr.conn.curReq.Load()
-       if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) {
+       if res != nil && !res.didCloseNotify.Swap(true) {
                res.closeNotifyCh <- true
        }
 }
@@ -1688,11 +1696,19 @@ func (w *response) closedRequestBodyEarly() bool {
 }
 
 func (w *response) Flush() {
+       w.FlushError()
+}
+
+func (w *response) FlushError() error {
        if !w.wroteHeader {
                w.WriteHeader(StatusOK)
        }
-       w.w.Flush()
-       w.cw.flush()
+       err := w.w.Flush()
+       e2 := w.cw.flush()
+       if err == nil {
+               err = e2
+       }
+       return err
 }
 
 func (c *conn) finalFlush() {
@@ -1983,6 +1999,7 @@ func (c *conn) serve(ctx context.Context) {
                        return
                }
                w.finishRequest()
+               c.rwc.SetWriteDeadline(time.Time{})
                if !w.shouldReuseConnection() {
                        if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
                                c.closeWriteAndWait()