]> Cypherpunks repositories - gostls13.git/commitdiff
http: handler timeout support
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2011 19:53:32 +0000 (12:53 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2011 19:53:32 +0000 (12:53 -0700)
Fixes #213

R=r, rsc
CC=golang-dev
https://golang.org/cl/4432043

src/pkg/http/export_test.go
src/pkg/http/serve_test.go
src/pkg/http/server.go

index 47c6877602d0ba167b67106f62c21c1bca32085b..3fe658641f8b4e546a4b7b95f5bbb21f1a5e8dd2 100644 (file)
@@ -32,3 +32,10 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
        }
        return len(conns)
 }
+
+func NewTestTimeoutHandler(handler Handler, ch <-chan int64) Handler {
+       f := func() <-chan int64 {
+               return ch
+       }
+       return &timeoutHandler{handler, f, ""}
+}
index eb1ecfdd32c3322b173233b9a2fec4f69539b4ba..4dce3781de4e6e89032a1d1bfc227153b6409539 100644 (file)
@@ -662,3 +662,54 @@ func TestServerConsumesRequestBody(t *testing.T) {
                }
        }
 }
+
+func TestTimeoutHandler(t *testing.T) {
+       sendHi := make(chan bool, 1)
+       writeErrors := make(chan os.Error, 1)
+       sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+               <-sendHi
+               _, werr := w.Write([]byte("hi"))
+               writeErrors <- werr
+       })
+       timeout := make(chan int64, 1) // write to this to force timeouts
+       ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout))
+       defer ts.Close()
+
+       // Succeed without timing out:
+       sendHi <- true
+       res, _, err := Get(ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if g, e := res.StatusCode, StatusOK; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       body, _ := ioutil.ReadAll(res.Body)
+       if g, e := string(body), "hi"; g != e {
+               t.Errorf("got body %q; expected %q", g, e)
+       }
+       if g := <-writeErrors; g != nil {
+               t.Errorf("got unexpected Write error on first request: %v", g)
+       }
+
+       // Times out:
+       timeout <- 1
+       res, _, err = Get(ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+               t.Errorf("got res.StatusCode %d; expected %d", g, e)
+       }
+       body, _ = ioutil.ReadAll(res.Body)
+       if !strings.Contains(string(body), "<title>Timeout</title>") {
+               t.Errorf("expected timeout body; got %q", string(body))
+       }
+
+       // Now make the previously-timed out handler speak again,
+       // which verifies the panic is handled:
+       sendHi <- true
+       if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+               t.Errorf("expected Write error of %v; got %v", e, g)
+       }
+}
index aa4dc294224c04436ee0eab47c59a06f715dbf6f..db8b23ca2374c6e65041f2fb5e8888cfc7e6af7a 100644 (file)
@@ -22,6 +22,7 @@ import (
        "path"
        "strconv"
        "strings"
+       "sync"
        "time"
 )
 
@@ -898,3 +899,89 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han
        tlsListener := tls.NewListener(conn, config)
        return Serve(tlsListener, handler)
 }
+
+// TimeoutHandler returns a Handler that runs h with the given time limit.
+//
+// The new Handler calls h.ServeHTTP to handle each request, but if a
+// call runs for more than ns nanoseconds, the handler responds with
+// a 503 Service Unavailable error and the given message in its body.
+// (If msg is empty, a suitable default message will be sent.)
+// After such a timeout, writes by h to its ResponseWriter will return
+// ErrHandlerTimeout.
+func TimeoutHandler(h Handler, ns int64, msg string) Handler {
+       f := func() <-chan int64 {
+               return time.After(ns)
+       }
+       return &timeoutHandler{h, f, msg}
+}
+
+// ErrHandlerTimeout is returned on ResponseWriter Write calls
+// in handlers which have timed out.
+var ErrHandlerTimeout = os.NewError("http: Handler timeout")
+
+type timeoutHandler struct {
+       handler Handler
+       timeout func() <-chan int64 // returns channel producing a timeout
+       body    string
+}
+
+func (h *timeoutHandler) errorBody() string {
+       if h.body != "" {
+               return h.body
+       }
+       return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>"
+}
+
+func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
+       done := make(chan bool)
+       tw := &timeoutWriter{w: w}
+       go func() {
+               h.handler.ServeHTTP(tw, r)
+               done <- true
+       }()
+       select {
+       case <-done:
+               return
+       case <-h.timeout():
+               tw.mu.Lock()
+               defer tw.mu.Unlock()
+               if !tw.wroteHeader {
+                       tw.w.WriteHeader(StatusServiceUnavailable)
+                       tw.w.Write([]byte(h.errorBody()))
+               }
+               tw.timedOut = true
+       }
+}
+
+type timeoutWriter struct {
+       w ResponseWriter
+
+       mu          sync.Mutex
+       timedOut    bool
+       wroteHeader bool
+}
+
+func (tw *timeoutWriter) Header() Header {
+       return tw.w.Header()
+}
+
+func (tw *timeoutWriter) Write(p []byte) (int, os.Error) {
+       tw.mu.Lock()
+       timedOut := tw.timedOut
+       tw.mu.Unlock()
+       if timedOut {
+               return 0, ErrHandlerTimeout
+       }
+       return tw.w.Write(p)
+}
+
+func (tw *timeoutWriter) WriteHeader(code int) {
+       tw.mu.Lock()
+       if tw.timedOut || tw.wroteHeader {
+               tw.mu.Unlock()
+               return
+       }
+       tw.wroteHeader = true
+       tw.mu.Unlock()
+       tw.w.WriteHeader(code)
+}