]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: distinguish between timeouts and client hangups in TimeoutHandler
authorCharlie Getzen <charlie@bolt.com>
Fri, 5 Nov 2021 17:27:35 +0000 (17:27 +0000)
committerDamien Neil <dneil@google.com>
Fri, 5 Nov 2021 21:18:28 +0000 (21:18 +0000)
Fixes #48948

Change-Id: I411e3be99c7979ae289fd937388aae63d81adb59
GitHub-Last-Rev: 14abd7e4d774ed5ef63aa0a69e80fbc8b5a5af26
GitHub-Pull-Request: golang/go#48993
Reviewed-on: https://go-review.googlesource.com/c/go/+/356009
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Trust: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>

src/net/http/export_test.go
src/net/http/serve_test.go
src/net/http/server.go

index 096a6d382a8afab3ea3e8ce7cb231524994ce02c..a849327f4528b430cac53046168f4a4a9440aaf2 100644 (file)
@@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) {
 
 func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
 
-func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
-       ctx, cancel := context.WithCancel(context.Background())
-       go func() {
-               <-ch
-               cancel()
-       }()
+func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
        return &timeoutHandler{
                handler:     handler,
                testContext: ctx,
index a98d6c313f86cba5f0e86c4339cca146cb96cf22..e8fb77446c3d45668dcfda36026971abe21c49ac 100644 (file)
@@ -2274,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
        }
 }
 
+// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
+type cancelableTimeoutContext struct {
+       context.Context
+}
+
+func (c cancelableTimeoutContext) Err() error {
+       if c.Context.Err() != nil {
+               return context.DeadlineExceeded
+       }
+       return nil
+}
+
 func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
 func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
 func testTimeoutHandler(t *testing.T, h2 bool) {
@@ -2286,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
                _, werr := w.Write([]byte("hi"))
                writeErrors <- werr
        })
-       timeout := make(chan time.Time, 1) // write to this to force timeouts
-       cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
+       ctx, cancel := context.WithCancel(context.Background())
+       h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+       cst := newClientServerTest(t, h2, h)
        defer cst.close()
 
        // Succeed without timing out:
@@ -2308,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
        }
 
        // Times out:
-       timeout <- time.Time{}
+       cancel()
+
        res, err = cst.c.Get(cst.ts.URL)
        if err != nil {
                t.Error(err)
@@ -2429,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
                _, werr := w.Write([]byte("hi"))
                writeErrors <- werr
        })
-       timeout := make(chan time.Time, 1) // write to this to force timeouts
-       cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+       ctx, cancel := context.WithCancel(context.Background())
+       h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+       cst := newClientServerTest(t, h1Mode, h)
        defer cst.close()
 
        // Succeed without timing out:
@@ -2451,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
        }
 
        // Times out:
-       timeout <- time.Time{}
+       cancel()
+
        res, err = cst.c.Get(cst.ts.URL)
        if err != nil {
                t.Error(err)
@@ -2501,6 +2517,41 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
        }
 }
 
+func TestTimeoutHandlerContextCanceled(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+       sendHi := make(chan bool, 1)
+       writeErrors := make(chan error, 1)
+       sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("Content-Type", "text/plain")
+               <-sendHi
+               _, werr := w.Write([]byte("hi"))
+               writeErrors <- werr
+       })
+       ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour)
+       h := NewTestTimeoutHandler(sayHi, ctx)
+       cancel()
+       cst := newClientServerTest(t, h1Mode, h)
+       defer cst.close()
+
+       // Succeed without timing out:
+       sendHi <- true
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       body, _ := io.ReadAll(res.Body)
+       if g, e := string(body), ""; g != e {
+               t.Errorf("got body %q; expected %q", g, e)
+       }
+       if g, e := <-writeErrors, context.Canceled; g != e {
+               t.Errorf("got unexpected Write error on first request: %v", g)
+       }
+}
+
 // https://golang.org/issue/15948
 func TestTimeoutHandlerEmptyResponse(t *testing.T) {
        setParallel(t)
index 91fad68694b51659b579ee8eef7e1360ce44a0f6..08fd478ed9e48d69fa7170bce998fe4d5ad6fcec 100644 (file)
@@ -3391,9 +3391,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
        case <-ctx.Done():
                tw.mu.Lock()
                defer tw.mu.Unlock()
-               w.WriteHeader(StatusServiceUnavailable)
-               io.WriteString(w, h.errorBody())
-               tw.timedOut = true
+               switch err := ctx.Err(); err {
+               case context.DeadlineExceeded:
+                       w.WriteHeader(StatusServiceUnavailable)
+                       io.WriteString(w, h.errorBody())
+                       tw.err = ErrHandlerTimeout
+               default:
+                       w.WriteHeader(StatusServiceUnavailable)
+                       tw.err = err
+               }
        }
 }
 
@@ -3404,7 +3410,7 @@ type timeoutWriter struct {
        req  *Request
 
        mu          sync.Mutex
-       timedOut    bool
+       err         error
        wroteHeader bool
        code        int
 }
@@ -3424,8 +3430,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h }
 func (tw *timeoutWriter) Write(p []byte) (int, error) {
        tw.mu.Lock()
        defer tw.mu.Unlock()
-       if tw.timedOut {
-               return 0, ErrHandlerTimeout
+       if tw.err != nil {
+               return 0, tw.err
        }
        if !tw.wroteHeader {
                tw.writeHeaderLocked(StatusOK)
@@ -3437,7 +3443,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) {
        checkWriteHeaderCode(code)
 
        switch {
-       case tw.timedOut:
+       case tw.err != nil:
                return
        case tw.wroteHeader:
                if tw.req != nil {