]> Cypherpunks repositories - gostls13.git/commitdiff
http: add Server type supporting timeouts
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 10 Feb 2011 22:36:22 +0000 (14:36 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 10 Feb 2011 22:36:22 +0000 (14:36 -0800)
R=rsc
CC=golang-dev
https://golang.org/cl/4172041

src/pkg/http/serve_test.go
src/pkg/http/server.go

index 7da3fc6f3434c39a3b5bcc1a57f8d01e84c1a208..80ad86290d4982928463a65ffc9377bccb301306 100644 (file)
@@ -9,10 +9,13 @@ package http
 import (
        "bufio"
        "bytes"
+       "fmt"
        "io"
+       "io/ioutil"
        "os"
        "net"
        "testing"
+       "time"
 )
 
 type dummyAddr string
@@ -283,3 +286,66 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
                }
        }
 }
+
+func TestServerTimeouts(t *testing.T) {
+       l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0})
+       if err != nil {
+               t.Fatalf("listen error: %v", err)
+       }
+       addr, _ := l.Addr().(*net.TCPAddr)
+
+       reqNum := 0
+       handler := HandlerFunc(func(res ResponseWriter, req *Request) {
+               reqNum++
+               fmt.Fprintf(res, "req=%d", reqNum)
+       })
+
+       const second = 1000000000 /* nanos */
+       server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second}
+       go server.Serve(l)
+
+       url := fmt.Sprintf("http://localhost:%d/", addr.Port)
+
+       // Hit the HTTP server successfully.
+       r, _, err := Get(url)
+       if err != nil {
+               t.Fatalf("http Get #1: %v", err)
+       }
+       got, _ := ioutil.ReadAll(r.Body)
+       expected := "req=1"
+       if string(got) != expected {
+               t.Errorf("Unexpected response for request #1; got %q; expected %q",
+                       string(got), expected)
+       }
+
+       // Slow client that should timeout.
+       t1 := time.Nanoseconds()
+       conn, err := net.Dial("tcp", "", fmt.Sprintf("localhost:%d", addr.Port))
+       if err != nil {
+               t.Fatalf("Dial: %v", err)
+       }
+       buf := make([]byte, 1)
+       n, err := conn.Read(buf)
+       latency := time.Nanoseconds() - t1
+       if n != 0 || err != os.EOF {
+               t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, os.EOF)
+       }
+       if latency < second*0.20 /* fudge from 0.25 above */ {
+               t.Errorf("got EOF after %d ns, want >= %d", latency, second*0.20)
+       }
+
+       // Hit the HTTP server successfully again, verifying that the
+       // previous slow connection didn't run our handler.  (that we
+       // get "req=2", not "req=3")
+       r, _, err = Get(url)
+       if err != nil {
+               t.Fatalf("http Get #2: %v", err)
+       }
+       got, _ = ioutil.ReadAll(r.Body)
+       expected = "req=2"
+       if string(got) != expected {
+               t.Errorf("Get #2 got %q, want %q", string(got), expected)
+       }
+
+       l.Close()
+}
index 6672c494bf1f95896616807f929c8e16349fcfa2..0be270ad30dd92e8cee228c5a841686191335052 100644 (file)
@@ -670,6 +670,39 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
 // read requests and then call handler to reply to them.
 // Handler is typically nil, in which case the DefaultServeMux is used.
 func Serve(l net.Listener, handler Handler) os.Error {
+       srv := &Server{Handler: handler}
+       return srv.Serve(l)
+}
+
+// A Server defines parameters for running an HTTP server.
+type Server struct {
+       Addr         string  // TCP address to listen on, ":http" if empty
+       Handler      Handler // handler to invoke, http.DefaultServeMux if nil
+       ReadTimeout  int64   // the net.Conn.SetReadTimeout value for new connections
+       WriteTimeout int64   // the net.Conn.SetWriteTimeout value for new connections
+}
+
+// ListenAndServe listens on the TCP network address srv.Addr and then
+// calls Serve to handle requests on incoming connections.  If
+// srv.Addr is blank, ":http" is used.
+func (srv *Server) ListenAndServe() os.Error {
+       addr := srv.Addr
+       if addr == "" {
+               addr = ":http"
+       }
+       l, e := net.Listen("tcp", addr)
+       if e != nil {
+               return e
+       }
+       return srv.Serve(l)
+}
+
+// Serve accepts incoming connections on the Listener l, creating a
+// new service thread for each.  The service threads read requests and
+// then call srv.Handler to reply to them.
+func (srv *Server) Serve(l net.Listener) os.Error {
+       defer l.Close()
+       handler := srv.Handler
        if handler == nil {
                handler = DefaultServeMux
        }
@@ -678,6 +711,12 @@ func Serve(l net.Listener, handler Handler) os.Error {
                if e != nil {
                        return e
                }
+               if srv.ReadTimeout != 0 {
+                       rw.SetReadTimeout(srv.ReadTimeout)
+               }
+               if srv.WriteTimeout != 0 {
+                       rw.SetWriteTimeout(srv.WriteTimeout)
+               }
                c, err := newConn(rw, handler)
                if err != nil {
                        continue
@@ -715,13 +754,8 @@ func Serve(l net.Listener, handler Handler) os.Error {
 //             }
 //     }
 func ListenAndServe(addr string, handler Handler) os.Error {
-       l, e := net.Listen("tcp", addr)
-       if e != nil {
-               return e
-       }
-       e = Serve(l, handler)
-       l.Close()
-       return e
+       server := &Server{Addr: addr, Handler: handler}
+       return server.ListenAndServe()
 }
 
 // ListenAndServeTLS acts identically to ListenAndServe, except that it