]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix when server deadlines get extended
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Feb 2013 21:52:45 +0000 (13:52 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Feb 2013 21:52:45 +0000 (13:52 -0800)
Deadlines should be extended at the beginning of
a request, not at the beginning of a connection.

Fixes #4676

R=golang-dev, fullung, patrick.allen.higgins, adg
CC=golang-dev
https://golang.org/cl/7220076

src/pkg/net/http/serve_test.go
src/pkg/net/http/server.go

index 886ed4e8f7413e2d2504550ea5f5e999a29f9344..6c97eaf63707b1056ee167362b7f3b38cb84d26e 100644 (file)
@@ -256,28 +256,20 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
 }
 
 func TestServerTimeouts(t *testing.T) {
-       // TODO(bradfitz): convert this to use httptest.Server
-       l, err := net.Listen("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatalf("listen error: %v", err)
-       }
-       addr, _ := l.Addr().(*net.TCPAddr)
-
        reqNum := 0
-       handler := HandlerFunc(func(res ResponseWriter, req *Request) {
+       ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
                reqNum++
                fmt.Fprintf(res, "req=%d", reqNum)
-       })
-
-       server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond}
-       go server.Serve(l)
-
-       url := fmt.Sprintf("http://%s/", addr)
+       }))
+       ts.Config.ReadTimeout = 250 * time.Millisecond
+       ts.Config.WriteTimeout = 250 * time.Millisecond
+       ts.Start()
+       defer ts.Close()
 
        // Hit the HTTP server successfully.
        tr := &Transport{DisableKeepAlives: true} // they interfere with this test
        c := &Client{Transport: tr}
-       r, err := c.Get(url)
+       r, err := c.Get(ts.URL)
        if err != nil {
                t.Fatalf("http Get #1: %v", err)
        }
@@ -290,13 +282,13 @@ func TestServerTimeouts(t *testing.T) {
 
        // Slow client that should timeout.
        t1 := time.Now()
-       conn, err := net.Dial("tcp", addr.String())
+       conn, err := net.Dial("tcp", ts.Listener.Addr().String())
        if err != nil {
                t.Fatalf("Dial: %v", err)
        }
        buf := make([]byte, 1)
        n, err := conn.Read(buf)
-       latency := time.Now().Sub(t1)
+       latency := time.Since(t1)
        if n != 0 || err != io.EOF {
                t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
        }
@@ -307,7 +299,7 @@ func TestServerTimeouts(t *testing.T) {
        // Hit the HTTP server successfully again, verifying that the
        // previous slow connection didn't run our handler.  (that we
        // get "req=2", not "req=3")
-       r, err = Get(url)
+       r, err = Get(ts.URL)
        if err != nil {
                t.Fatalf("http Get #2: %v", err)
        }
@@ -317,7 +309,21 @@ func TestServerTimeouts(t *testing.T) {
                t.Errorf("Get #2 got %q, want %q", string(got), expected)
        }
 
-       l.Close()
+       if !testing.Short() {
+               conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+               if err != nil {
+                       t.Fatalf("Dial: %v", err)
+               }
+               defer conn.Close()
+               go io.Copy(ioutil.Discard, conn)
+               for i := 0; i < 5; i++ {
+                       _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+                       if err != nil {
+                               t.Fatalf("on write %d: %v", i, err)
+                       }
+                       time.Sleep(ts.Config.ReadTimeout / 2)
+               }
+       }
 }
 
 // TestIdentityResponse verifies that a handler can unset
index 434943d49a389087a2368fb1266294ee880d0f22..e24b0dd9312ac7c09f19d94ae20060cb547a2b6e 100644 (file)
@@ -416,6 +416,16 @@ func (c *conn) readRequest() (w *response, err error) {
        if c.hijacked() {
                return nil, ErrHijacked
        }
+
+       if d := c.server.ReadTimeout; d != 0 {
+               c.rwc.SetReadDeadline(time.Now().Add(d))
+       }
+       if d := c.server.WriteTimeout; d != 0 {
+               defer func() {
+                       c.rwc.SetWriteDeadline(time.Now().Add(d))
+               }()
+       }
+
        c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
        var req *Request
        if req, err = ReadRequest(c.buf.Reader); err != nil {
@@ -779,6 +789,12 @@ func (c *conn) serve() {
        }()
 
        if tlsConn, ok := c.rwc.(*tls.Conn); ok {
+               if d := c.server.ReadTimeout; d != 0 {
+                       c.rwc.SetReadDeadline(time.Now().Add(d))
+               }
+               if d := c.server.WriteTimeout; d != 0 {
+                       c.rwc.SetWriteDeadline(time.Now().Add(d))
+               }
                if err := tlsConn.Handshake(); err != nil {
                        return
                }
@@ -1274,12 +1290,6 @@ func (srv *Server) Serve(l net.Listener) error {
                        return e
                }
                tempDelay = 0
-               if srv.ReadTimeout != 0 {
-                       rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
-               }
-               if srv.WriteTimeout != 0 {
-                       rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
-               }
                c, err := srv.newConn(rw)
                if err != nil {
                        continue