]> Cypherpunks repositories - gostls13.git/commitdiff
http: configurable and default request header size limit
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 9 Aug 2011 17:55:14 +0000 (10:55 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 9 Aug 2011 17:55:14 +0000 (10:55 -0700)
This addresses the biggest DoS in issue 2093

R=golang-dev, dsymonds
CC=golang-dev
https://golang.org/cl/4841050

src/pkg/http/serve_test.go
src/pkg/http/server.go

index 9c8a122ff0413a88b8adaa7051d916ad74050746..2725c3b428babab120aa45e276068b770c3e68d3 100644 (file)
@@ -873,6 +873,28 @@ func TestStripPrefix(t *testing.T) {
        }
 }
 
+func TestRequestLimit(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               t.Fatalf("didn't expect to get request in Handler")
+       }))
+       defer ts.Close()
+       req, _ := NewRequest("GET", ts.URL, nil)
+       var bytesPerHeader = len("header12345: val12345\r\n")
+       for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
+               req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
+       }
+       res, err := DefaultClient.Do(req)
+       if err != nil {
+               // Some HTTP clients may fail on this undefined behavior (server replying and
+               // closing the connection while the request is still being written), but
+               // we do support it (at least currently), so we expect a response below.
+               t.Fatalf("Do: %v", err)
+       }
+       if res.StatusCode != 400 {
+               t.Fatalf("expected 400 response status; got: %d %s", res.StatusCode, res.Status)
+       }
+}
+
 type errorListener struct {
        errs []os.Error
 }
index 96547c4eff0d2addb077c730cd38674c7ab0dd13..1955b67e65957b03b0e3f3cf7335c3e2cb6a753e 100644 (file)
@@ -94,9 +94,10 @@ type Hijacker interface {
 // A conn represents the server side of an HTTP connection.
 type conn struct {
        remoteAddr string               // network address of remote side
-       handler    Handler              // request handler
+       server     *Server              // the Server on which the connection arrived
        rwc        net.Conn             // i/o connection
-       buf        *bufio.ReadWriter    // buffered rwc
+       lr         *io.LimitedReader    // io.LimitReader(rwc)
+       buf        *bufio.ReadWriter    // buffered(lr,rwc), reading from bufio->limitReader->rwc
        hijacked   bool                 // connection has been hijacked by handler
        tlsState   *tls.ConnectionState // or nil when not using TLS
        body       []byte
@@ -143,14 +144,18 @@ func (r *response) ReadFrom(src io.Reader) (n int64, err os.Error) {
        return io.Copy(writerOnly{r}, src)
 }
 
+// noLimit is an effective infinite upper bound for io.LimitedReader
+const noLimit int64 = (1 << 63) - 1
+
 // Create new connection from rwc.
-func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) {
+func (srv *Server) newConn(rwc net.Conn) (c *conn, err os.Error) {
        c = new(conn)
        c.remoteAddr = rwc.RemoteAddr().String()
-       c.handler = handler
+       c.server = srv
        c.rwc = rwc
        c.body = make([]byte, sniffLen)
-       br := bufio.NewReader(rwc)
+       c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
+       br := bufio.NewReader(c.lr)
        bw := bufio.NewWriter(rwc)
        c.buf = bufio.NewReadWriter(br, bw)
 
@@ -163,6 +168,18 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) {
        return c, nil
 }
 
+// DefaultMaxHeaderBytes is the maximum permitted size of the headers
+// in an HTTP request.
+// This can be overridden by setting Server.MaxHeaderBytes.
+const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
+
+func (srv *Server) maxHeaderBytes() int {
+       if srv.MaxHeaderBytes > 0 {
+               return srv.MaxHeaderBytes
+       }
+       return DefaultMaxHeaderBytes
+}
+
 // wrapper around io.ReaderCloser which on first read, sends an
 // HTTP/1.1 100 Continue header
 type expectContinueReader struct {
@@ -194,15 +211,22 @@ func (ecr *expectContinueReader) Close() os.Error {
 // It is like time.RFC1123 but hard codes GMT as the time zone.
 const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
 
+var errTooLarge = os.NewError("http: request too large")
+
 // Read next request from connection.
 func (c *conn) readRequest() (w *response, err os.Error) {
        if c.hijacked {
                return nil, ErrHijacked
        }
+       c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
        var req *Request
        if req, err = ReadRequest(c.buf.Reader); err != nil {
+               if c.lr.N == 0 {
+                       return nil, errTooLarge
+               }
                return nil, err
        }
+       c.lr.N = noLimit
 
        req.RemoteAddr = c.remoteAddr
        req.TLS = c.tlsState
@@ -567,6 +591,14 @@ func (c *conn) serve() {
        for {
                w, err := c.readRequest()
                if err != nil {
+                       if err == errTooLarge {
+                               // Their HTTP client may or may not be
+                               // able to read this if we're
+                               // responding to them and hanging up
+                               // while they're still writing their
+                               // request.  Undefined behavior.
+                               fmt.Fprintf(c.rwc, "HTTP/1.1 400 Request Too Large\r\n\r\n")
+                       }
                        break
                }
 
@@ -603,12 +635,17 @@ func (c *conn) serve() {
                        break
                }
 
+               handler := c.server.Handler
+               if handler == nil {
+                       handler = DefaultServeMux
+               }
+
                // HTTP cannot have multiple simultaneous active requests.[*]
                // Until the server replies to this request, it can't read another,
                // so we might as well run the handler in this goroutine.
                // [*] Not strictly true: HTTP pipelining.  We could let them all process
                // in parallel even if their responses need to be serialized.
-               c.handler.ServeHTTP(w, w.req)
+               handler.ServeHTTP(w, w.req)
                if c.hijacked {
                        return
                }
@@ -906,10 +943,11 @@ func Serve(l net.Listener, handler Handler) os.Error {
 
 // A Server defines parameters for running an HTTP server.
 type Server struct {
-       Addr         string  // TCP address to listen on, ":http" if empty
-       Handler      Handler // handler to invoke, http.DefaultServeMux if nil
-       ReadTimeout  int64   // the net.Conn.SetReadTimeout value for new connections
-       WriteTimeout int64   // the net.Conn.SetWriteTimeout value for new connections
+       Addr           string  // TCP address to listen on, ":http" if empty
+       Handler        Handler // handler to invoke, http.DefaultServeMux if nil
+       ReadTimeout    int64   // the net.Conn.SetReadTimeout value for new connections
+       WriteTimeout   int64   // the net.Conn.SetWriteTimeout value for new connections
+       MaxHeaderBytes int     // maximum size of request headers, DefaultMaxHeaderBytes if 0
 }
 
 // ListenAndServe listens on the TCP network address srv.Addr and then
@@ -932,10 +970,6 @@ func (srv *Server) ListenAndServe() os.Error {
 // then call srv.Handler to reply to them.
 func (srv *Server) Serve(l net.Listener) os.Error {
        defer l.Close()
-       handler := srv.Handler
-       if handler == nil {
-               handler = DefaultServeMux
-       }
        for {
                rw, e := l.Accept()
                if e != nil {
@@ -951,7 +985,7 @@ func (srv *Server) Serve(l net.Listener) os.Error {
                if srv.WriteTimeout != 0 {
                        rw.SetWriteTimeout(srv.WriteTimeout)
                }
-               c, err := newConn(rw, handler)
+               c, err := srv.newConn(rw)
                if err != nil {
                        continue
                }