]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix race in TimeoutHandler
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 14 Dec 2015 01:04:07 +0000 (01:04 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 14 Dec 2015 14:55:37 +0000 (14:55 +0000)
New implementation of TimeoutHandler: buffer everything to memory.

All or nothing: either the handler finishes completely within the
timeout (in which case the wrapper writes it all), or it misses the
timeout and none of it gets written, in which case handler wrapper can
reliably print the error response without fear that some of the
wrapped Handler's code already wrote to the output.

Now the goroutine running the wrapped Handler has its own write buffer
and Header copy.

Document the limitations.

Fixes #9162

Change-Id: Ia058c1d62cefd11843e7a2fc1ae1609d75de2441
Reviewed-on: https://go-review.googlesource.com/17752
Reviewed-by: David Symonds <dsymonds@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>

src/net/http/export_test.go
src/net/http/serve_test.go
src/net/http/server.go

index e0ae49afa710bf70bc0c3d9ef498645be7898888..4ccce08b43e118a2f5d4abe5786e36897c78b9a9 100644 (file)
@@ -58,10 +58,11 @@ func SetPendingDialHooks(before, after func()) {
 func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
 
 func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
-       f := func() <-chan time.Time {
-               return ch
+       return &timeoutHandler{
+               handler: handler,
+               timeout: func() <-chan time.Time { return ch },
+               // (no body and nil cancelTimer)
        }
-       return &timeoutHandler{handler, f, ""}
 }
 
 func ResetCachedEnvironment() {
index 28cc12a3608b5da0371aeab1a89b0b8acf059015..5a0706e06e744a0a3b5cada2de359534b37b044e 100644 (file)
@@ -1785,6 +1785,60 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
        wg.Wait()
 }
 
+// Issue 9162
+func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
+       defer afterTest(t)
+       sendHi := make(chan bool, 1)
+       writeErrors := make(chan error, 1)
+       sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("Content-Type", "text/plain")
+               <-sendHi
+               _, werr := w.Write([]byte("hi"))
+               writeErrors <- werr
+       })
+       timeout := make(chan time.Time, 1) // write to this to force timeouts
+       cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+       defer cst.close()
+
+       // Succeed without timing out:
+       sendHi <- true
+       res, err := cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if g, e := res.StatusCode, StatusOK; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       body, _ := ioutil.ReadAll(res.Body)
+       if g, e := string(body), "hi"; g != e {
+               t.Errorf("got body %q; expected %q", g, e)
+       }
+       if g := <-writeErrors; g != nil {
+               t.Errorf("got unexpected Write error on first request: %v", g)
+       }
+
+       // Times out:
+       timeout <- time.Time{}
+       res, err = cst.c.Get(cst.ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       body, _ = ioutil.ReadAll(res.Body)
+       if !strings.Contains(string(body), "<title>Timeout</title>") {
+               t.Errorf("expected timeout body; got %q", string(body))
+       }
+
+       // Now make the previously-timed out handler speak again,
+       // which verifies the panic is handled:
+       sendHi <- true
+       if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+               t.Errorf("expected Write error of %v; got %v", e, g)
+       }
+}
+
 // Verifies we don't path.Clean() on the wrong parts in redirects.
 func TestRedirectMunging(t *testing.T) {
        req, _ := NewRequest("GET", "http://example.com/", nil)
index 9b3313b7e7b648917e95c1b785538da207e810a8..35f41e734e0cb7a374e0e5ff46cde5300a37de44 100644 (file)
@@ -8,6 +8,7 @@ package http
 
 import (
        "bufio"
+       "bytes"
        "crypto/tls"
        "errors"
        "fmt"
@@ -2101,11 +2102,20 @@ func (srv *Server) onceSetNextProtoDefaults() {
 // (If msg is empty, a suitable default message will be sent.)
 // After such a timeout, writes by h to its ResponseWriter will return
 // ErrHandlerTimeout.
+//
+// TimeoutHandler buffers all Handler writes to memory and does not
+// support the Hijacker or Flusher interfaces.
 func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
-       f := func() <-chan time.Time {
-               return time.After(dt)
+       t := time.NewTimer(dt)
+       return &timeoutHandler{
+               handler: h,
+               body:    msg,
+
+               // Effectively storing a *time.Timer, but decomposed
+               // for testing:
+               timeout:     func() <-chan time.Time { return t.C },
+               cancelTimer: t.Stop,
        }
-       return &timeoutHandler{h, f, msg}
 }
 
 // ErrHandlerTimeout is returned on ResponseWriter Write calls
@@ -2114,8 +2124,13 @@ var ErrHandlerTimeout = errors.New("http: Handler timeout")
 
 type timeoutHandler struct {
        handler Handler
-       timeout func() <-chan time.Time // returns channel producing a timeout
        body    string
+
+       // timeout returns the channel of a *time.Timer and
+       // cancelTimer cancels it.  They're stored separately for
+       // testing purposes.
+       timeout     func() <-chan time.Time // returns channel producing a timeout
+       cancelTimer func() bool             // optional
 }
 
 func (h *timeoutHandler) errorBody() string {
@@ -2126,46 +2141,61 @@ func (h *timeoutHandler) errorBody() string {
 }
 
 func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
-       done := make(chan bool, 1)
-       tw := &timeoutWriter{w: w}
+       done := make(chan struct{})
+       tw := &timeoutWriter{
+               w: w,
+               h: make(Header),
+       }
        go func() {
                h.handler.ServeHTTP(tw, r)
-               done <- true
+               close(done)
        }()
        select {
        case <-done:
-               return
-       case <-h.timeout():
                tw.mu.Lock()
                defer tw.mu.Unlock()
-               if !tw.wroteHeader {
-                       tw.w.WriteHeader(StatusServiceUnavailable)
-                       tw.w.Write([]byte(h.errorBody()))
+               dst := w.Header()
+               for k, vv := range tw.h {
+                       dst[k] = vv
                }
+               w.WriteHeader(tw.code)
+               w.Write(tw.wbuf.Bytes())
+               if h.cancelTimer != nil {
+                       h.cancelTimer()
+               }
+       case <-h.timeout():
+               tw.mu.Lock()
+               defer tw.mu.Unlock()
+               w.WriteHeader(StatusServiceUnavailable)
+               io.WriteString(w, h.errorBody())
                tw.timedOut = true
+               return
        }
 }
 
 type timeoutWriter struct {
-       w ResponseWriter
+       w    ResponseWriter
+       h    Header
+       wbuf bytes.Buffer
 
        mu          sync.Mutex
        timedOut    bool
        wroteHeader bool
+       code        int
 }
 
-func (tw *timeoutWriter) Header() Header {
-       return tw.w.Header()
-}
+func (tw *timeoutWriter) Header() Header { return tw.h }
 
 func (tw *timeoutWriter) Write(p []byte) (int, error) {
        tw.mu.Lock()
        defer tw.mu.Unlock()
-       tw.wroteHeader = true // implicitly at least
        if tw.timedOut {
                return 0, ErrHandlerTimeout
        }
-       return tw.w.Write(p)
+       if !tw.wroteHeader {
+               tw.writeHeader(StatusOK)
+       }
+       return tw.wbuf.Write(p)
 }
 
 func (tw *timeoutWriter) WriteHeader(code int) {
@@ -2174,8 +2204,12 @@ func (tw *timeoutWriter) WriteHeader(code int) {
        if tw.timedOut || tw.wroteHeader {
                return
        }
+       tw.writeHeader(code)
+}
+
+func (tw *timeoutWriter) writeHeader(code int) {
        tw.wroteHeader = true
-       tw.w.WriteHeader(code)
+       tw.code = code
 }
 
 // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted