}
}
+// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
+type cancelableTimeoutContext struct {
+ context.Context
+}
+
+func (c cancelableTimeoutContext) Err() error {
+ if c.Context.Err() != nil {
+ return context.DeadlineExceeded
+ }
+ return nil
+}
+
func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
func testTimeoutHandler(t *testing.T, h2 bool) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h2, h)
defer cst.close()
// Succeed without timing out:
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h1Mode, h)
defer cst.close()
// Succeed without timing out:
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
}
}
+func TestTimeoutHandlerContextCanceled(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour)
+ h := NewTestTimeoutHandler(sayHi, ctx)
+ cancel()
+ cst := newClientServerTest(t, h1Mode, h)
+ defer cst.close()
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), ""; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := <-writeErrors, context.Canceled; g != e {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+}
+
// https://golang.org/issue/15948
func TestTimeoutHandlerEmptyResponse(t *testing.T) {
setParallel(t)
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
- w.WriteHeader(StatusServiceUnavailable)
- io.WriteString(w, h.errorBody())
- tw.timedOut = true
+ switch err := ctx.Err(); err {
+ case context.DeadlineExceeded:
+ w.WriteHeader(StatusServiceUnavailable)
+ io.WriteString(w, h.errorBody())
+ tw.err = ErrHandlerTimeout
+ default:
+ w.WriteHeader(StatusServiceUnavailable)
+ tw.err = err
+ }
}
}
req *Request
mu sync.Mutex
- timedOut bool
+ err error
wroteHeader bool
code int
}
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
- if tw.timedOut {
- return 0, ErrHandlerTimeout
+ if tw.err != nil {
+ return 0, tw.err
}
if !tw.wroteHeader {
tw.writeHeaderLocked(StatusOK)
checkWriteHeaderCode(code)
switch {
- case tw.timedOut:
+ case tw.err != nil:
return
case tw.wroteHeader:
if tw.req != nil {