}
}
-func TestServerConsumesRequestBody(t *testing.T) {
+// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should consume client request bodies that a handler didn't read.
+func TestServerUnreadRequestBodyLittle(t *testing.T) {
conn := new(testConn)
- body := strings.Repeat("x", 1<<20)
+ body := strings.Repeat("x", 100<<10)
conn.readBuf.Write([]byte(fmt.Sprintf(
"POST / HTTP/1.1\r\n"+
"Host: test\r\n"+
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
if conn.readBuf.Len() < len(body)/2 {
- t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len())
}
rw.WriteHeader(200)
if g, e := conn.readBuf.Len(), 0; g != e {
t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
}
- done <- true
+ if c := rw.Header().Get("Connection"); c != "" {
+ t.Errorf(`Connection header = %q; want ""`, c)
+ }
+ }))
+ <-done
+}
+
+// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should ignore client request bodies that a handler didn't read
+// and close the connection.
+func TestServerUnreadRequestBodyLarge(t *testing.T) {
+ conn := new(testConn)
+ body := strings.Repeat("x", 1<<20)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+
+ done := make(chan bool)
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ rw.WriteHeader(200)
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ if c := rw.Header().Get("Connection"); c != "close" {
+ t.Errorf(`Connection header = %q; want "close"`, c)
+ }
}))
<-done
}
"crypto/tls"
"fmt"
"io"
+ "io/ioutil"
"log"
"net"
"os"
// requestTooLarge is called by maxBytesReader when too much input has
// been read from the client.
-func (r *response) requestTooLarge() {
- r.closeAfterReply = true
- r.requestBodyLimitHit = true
- if !r.wroteHeader {
- r.Header().Set("Connection", "close")
+func (w *response) requestTooLarge() {
+ w.closeAfterReply = true
+ w.requestBodyLimitHit = true
+ if !w.wroteHeader {
+ w.Header().Set("Connection", "close")
}
}
io.Writer
}
-func (r *response) ReadFrom(src io.Reader) (n int64, err os.Error) {
- // Flush before checking r.chunking, as Flush will call
+func (w *response) ReadFrom(src io.Reader) (n int64, err os.Error) {
+ // Flush before checking w.chunking, as Flush will call
// WriteHeader if it hasn't been called yet, and WriteHeader
- // is what sets r.chunking.
- r.Flush()
- if !r.chunking && r.bodyAllowed() && !r.needSniff {
- if rf, ok := r.conn.rwc.(io.ReaderFrom); ok {
+ // is what sets w.chunking.
+ w.Flush()
+ if !w.chunking && w.bodyAllowed() && !w.needSniff {
+ if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(src)
- r.written += n
+ w.written += n
return
}
}
// Fall back to default io.Copy implementation.
- // Use wrapper to hide r.ReadFrom from io.Copy.
- return io.Copy(writerOnly{r}, src)
+ // Use wrapper to hide w.ReadFrom from io.Copy.
+ return io.Copy(writerOnly{w}, src)
}
// noLimit is an effective infinite upper bound for io.LimitedReader
return w.header
}
+// maxPostHandlerReadBytes is the max number of Request.Body bytes not
+// consumed by a handler that the server will read from the a client
+// in order to keep a connection alive. If there are more bytes than
+// this then the server to be paranoid instead sends a "Connection:
+// close" response.
+//
+// This number is approximately what a typical machine's TCP buffer
+// size is anyway. (if we have the bytes on the machine, we might as
+// well read them)
+const maxPostHandlerReadBytes = 256 << 10
+
func (w *response) WriteHeader(code int) {
if w.conn.hijacked {
log.Print("http: response.WriteHeader on hijacked connection")
log.Print("http: multiple response.WriteHeader calls")
return
}
+ w.wroteHeader = true
+ w.status = code
+
+ // Check for a explicit (and valid) Content-Length header.
+ var hasCL bool
+ var contentLength int64
+ if clenStr := w.header.Get("Content-Length"); clenStr != "" {
+ var err os.Error
+ contentLength, err = strconv.Atoi64(clenStr)
+ if err == nil {
+ hasCL = true
+ } else {
+ log.Printf("http: invalid Content-Length of %q sent", clenStr)
+ w.header.Del("Content-Length")
+ }
+ }
+
+ if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) {
+ _, connectionHeaderSet := w.header["Connection"]
+ if !connectionHeaderSet {
+ w.header.Set("Connection", "keep-alive")
+ }
+ } else if !w.req.ProtoAtLeast(1, 1) {
+ // Client did not ask to keep connection alive.
+ w.closeAfterReply = true
+ }
+
+ if w.header.Get("Connection") == "close" {
+ w.closeAfterReply = true
+ }
// Per RFC 2616, we should consume the request body before
- // replying, if the handler hasn't already done so.
- if w.req.ContentLength != 0 && !w.requestBodyLimitHit {
+ // replying, if the handler hasn't already done so. But we
+ // don't want to do an unbounded amount of reading here for
+ // DoS reasons, so we only try up to a threshold.
+ if w.req.ContentLength != 0 && !w.closeAfterReply {
ecr, isExpecter := w.req.Body.(*expectContinueReader)
if !isExpecter || ecr.resp.wroteContinue {
- w.req.Body.Close()
+ n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
+ if n >= maxPostHandlerReadBytes {
+ w.requestTooLarge()
+ w.header.Set("Connection", "close")
+ } else {
+ w.req.Body.Close()
+ }
}
}
- w.wroteHeader = true
- w.status = code
if code == StatusNotModified {
// Must not have body.
for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} {
w.Header().Set("Date", time.UTC().Format(TimeFormat))
}
- // Check for a explicit (and valid) Content-Length header.
- var hasCL bool
- var contentLength int64
- if clenStr := w.header.Get("Content-Length"); clenStr != "" {
- var err os.Error
- contentLength, err = strconv.Atoi64(clenStr)
- if err == nil {
- hasCL = true
- } else {
- log.Printf("http: invalid Content-Length of %q sent", clenStr)
- w.header.Del("Content-Length")
- }
- }
-
te := w.header.Get("Transfer-Encoding")
hasTE := te != ""
if hasCL && hasTE && te != "identity" {
w.header.Del("Transfer-Encoding") // in case already set
}
- if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) {
- _, connectionHeaderSet := w.header["Connection"]
- if !connectionHeaderSet {
- w.header.Set("Connection", "keep-alive")
- }
- } else if !w.req.ProtoAtLeast(1, 1) {
- // Client did not ask to keep connection alive.
- w.closeAfterReply = true
- }
-
- if w.header.Get("Connection") == "close" {
- w.closeAfterReply = true
- }
-
// Cannot use Content-Length with non-identity Transfer-Encoding.
if w.chunking {
w.header.Del("Content-Length")