From 4c7cafdd03426bc2b9fb1275d13d0abc755dde16 Mon Sep 17 00:00:00 2001 From: Charlie Getzen Date: Fri, 5 Nov 2021 17:27:35 +0000 Subject: [PATCH] net/http: distinguish between timeouts and client hangups in TimeoutHandler Fixes #48948 Change-Id: I411e3be99c7979ae289fd937388aae63d81adb59 GitHub-Last-Rev: 14abd7e4d774ed5ef63aa0a69e80fbc8b5a5af26 GitHub-Pull-Request: golang/go#48993 Reviewed-on: https://go-review.googlesource.com/c/go/+/356009 Reviewed-by: Damien Neil Trust: Damien Neil Trust: Ian Lance Taylor Run-TryBot: Damien Neil TryBot-Result: Go Bot --- src/net/http/export_test.go | 7 +---- src/net/http/serve_test.go | 63 +++++++++++++++++++++++++++++++++---- src/net/http/server.go | 20 +++++++----- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 096a6d382a..a849327f45 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) { func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } -func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-ch - cancel() - }() +func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler { return &timeoutHandler{ handler: handler, testContext: ctx, diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index a98d6c313f..e8fb77446c 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2274,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { } } +// cancelableTimeoutContext overwrites the error message to DeadlineExceeded +type cancelableTimeoutContext struct { + context.Context +} + +func (c cancelableTimeoutContext) Err() error { + if c.Context.Err() != nil { + return context.DeadlineExceeded + } + return nil +} + func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { @@ -2286,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h2, h) defer cst.close() // Succeed without timing out: @@ -2308,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2429,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h1Mode, h) defer cst.close() // Succeed without timing out: @@ -2451,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2501,6 +2517,41 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } +func TestTimeoutHandlerContextCanceled(t *testing.T) { + setParallel(t) + defer afterTest(t) + sendHi := make(chan bool, 1) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/plain") + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour) + h := NewTestTimeoutHandler(sayHi, ctx) + cancel() + cst := newClientServerTest(t, h1Mode, h) + defer cst.close() + + // Succeed without timing out: + sendHi <- true + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := io.ReadAll(res.Body) + if g, e := string(body), ""; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := <-writeErrors, context.Canceled; g != e { + t.Errorf("got unexpected Write error on first request: %v", g) + } +} + // https://golang.org/issue/15948 func TestTimeoutHandlerEmptyResponse(t *testing.T) { setParallel(t) diff --git a/src/net/http/server.go b/src/net/http/server.go index 91fad68694..08fd478ed9 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -3391,9 +3391,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { case <-ctx.Done(): tw.mu.Lock() defer tw.mu.Unlock() - w.WriteHeader(StatusServiceUnavailable) - io.WriteString(w, h.errorBody()) - tw.timedOut = true + switch err := ctx.Err(); err { + case context.DeadlineExceeded: + w.WriteHeader(StatusServiceUnavailable) + io.WriteString(w, h.errorBody()) + tw.err = ErrHandlerTimeout + default: + w.WriteHeader(StatusServiceUnavailable) + tw.err = err + } } } @@ -3404,7 +3410,7 @@ type timeoutWriter struct { req *Request mu sync.Mutex - timedOut bool + err error wroteHeader bool code int } @@ -3424,8 +3430,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h } func (tw *timeoutWriter) Write(p []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() - if tw.timedOut { - return 0, ErrHandlerTimeout + if tw.err != nil { + return 0, tw.err } if !tw.wroteHeader { tw.writeHeaderLocked(StatusOK) @@ -3437,7 +3443,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) { checkWriteHeaderCode(code) switch { - case tw.timedOut: + case tw.err != nil: return case tw.wroteHeader: if tw.req != nil { -- 2.50.0