From: Brad Fitzpatrick Date: Mon, 4 Feb 2013 21:52:45 +0000 (-0800) Subject: net/http: fix when server deadlines get extended X-Git-Tag: go1.1rc2~1145 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=022504b3ab62a4d35aad13c58382bd0a7168805b;p=gostls13.git net/http: fix when server deadlines get extended Deadlines should be extended at the beginning of a request, not at the beginning of a connection. Fixes #4676 R=golang-dev, fullung, patrick.allen.higgins, adg CC=golang-dev https://golang.org/cl/7220076 --- diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index 886ed4e8f7..6c97eaf637 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -256,28 +256,20 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - // TODO(bradfitz): convert this to use httptest.Server - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen error: %v", err) - } - addr, _ := l.Addr().(*net.TCPAddr) - reqNum := 0 - handler := HandlerFunc(func(res ResponseWriter, req *Request) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - }) - - server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond} - go server.Serve(l) - - url := fmt.Sprintf("http://%s/", addr) + })) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test c := &Client{Transport: tr} - r, err := c.Get(url) + r, err := c.Get(ts.URL) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -290,13 +282,13 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Now() - conn, err := net.Dial("tcp", addr.String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } buf := make([]byte, 1) n, err := conn.Read(buf) - latency := time.Now().Sub(t1) + latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) } @@ -307,7 +299,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(url) + r, err = Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } @@ -317,7 +309,21 @@ func TestServerTimeouts(t *testing.T) { t.Errorf("Get #2 got %q, want %q", string(got), expected) } - l.Close() + if !testing.Short() { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + go io.Copy(ioutil.Discard, conn) + for i := 0; i < 5; i++ { + _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("on write %d: %v", i, err) + } + time.Sleep(ts.Config.ReadTimeout / 2) + } + } } // TestIdentityResponse verifies that a handler can unset diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 434943d49a..e24b0dd931 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -416,6 +416,16 @@ func (c *conn) readRequest() (w *response, err error) { if c.hijacked() { return nil, ErrHijacked } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { @@ -779,6 +789,12 @@ func (c *conn) serve() { }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } if err := tlsConn.Handshake(); err != nil { return } @@ -1274,12 +1290,6 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } c, err := srv.newConn(rw) if err != nil { continue