// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server
// should consume client request bodies that a handler didn't read.
func TestServerUnreadRequestBodyLittle(t *testing.T) {
+ defer afterTest(t)
conn := new(testConn)
body := strings.Repeat("x", 100<<10)
conn.readBuf.Write([]byte(fmt.Sprintf(
}
}
+// testHandlerBodyConsumer represents a function injected into a test handler to
+// vary work done on a request Body.
+type testHandlerBodyConsumer struct {
+ name string
+ f func(io.ReadCloser)
+}
+
+var testHandlerBodyConsumers = []testHandlerBodyConsumer{
+ {"nil", func(io.ReadCloser) {}},
+ {"close", func(r io.ReadCloser) { r.Close() }},
+ {"discard", func(r io.ReadCloser) { io.Copy(ioutil.Discard, r) }},
+}
+
+func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
+ defer afterTest(t)
+ for _, handler := range testHandlerBodyConsumers {
+ conn := new(testConn)
+ conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n" +
+ "hax\r\n" + // Invalid chunked encoding
+ "GET /secret HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "\r\n")
+
+ conn.closec = make(chan bool, 1)
+ ls := &oneConnListener{conn}
+ var numReqs int
+ go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
+ numReqs++
+ if strings.Contains(req.URL.Path, "secret") {
+ t.Error("Request for /secret encountered, should not have happened.")
+ }
+ handler.f(req.Body)
+ }))
+ <-conn.closec
+ if numReqs != 1 {
+ t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
+ }
+ }
+}
+
+// slowTestConn is a net.Conn that provides a means to simulate parts of a
+// request being received piecemeal. Deadlines can be set and enforced in both
+// Read and Write.
+type slowTestConn struct {
+ // over multiple calls to Read, time.Durations are slept, strings are read.
+ script []interface{}
+ closec chan bool
+ rd, wd time.Time // read, write deadline
+ noopConn
+}
+
+func (c *slowTestConn) SetDeadline(t time.Time) error {
+ c.SetReadDeadline(t)
+ c.SetWriteDeadline(t)
+ return nil
+}
+
+func (c *slowTestConn) SetReadDeadline(t time.Time) error {
+ c.rd = t
+ return nil
+}
+
+func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
+ c.wd = t
+ return nil
+}
+
+func (c *slowTestConn) Read(b []byte) (n int, err error) {
+restart:
+ if !c.rd.IsZero() && time.Now().After(c.rd) {
+ return 0, syscall.ETIMEDOUT
+ }
+ if len(c.script) == 0 {
+ return 0, io.EOF
+ }
+
+ switch cue := c.script[0].(type) {
+ case time.Duration:
+ if !c.rd.IsZero() {
+ // If the deadline falls in the middle of our sleep window, deduct
+ // part of the sleep, then return a timeout.
+ if remaining := c.rd.Sub(time.Now()); remaining < cue {
+ c.script[0] = cue - remaining
+ time.Sleep(remaining)
+ return 0, syscall.ETIMEDOUT
+ }
+ }
+ c.script = c.script[1:]
+ time.Sleep(cue)
+ goto restart
+
+ case string:
+ n = copy(b, cue)
+ // If cue is too big for the buffer, leave the end for the next Read.
+ if len(cue) > n {
+ c.script[0] = cue[n:]
+ } else {
+ c.script = c.script[1:]
+ }
+
+ default:
+ panic("unknown cue in slowTestConn script")
+ }
+
+ return
+}
+
+func (c *slowTestConn) Close() error {
+ select {
+ case c.closec <- true:
+ default:
+ }
+ return nil
+}
+
+func (c *slowTestConn) Write(b []byte) (int, error) {
+ if !c.wd.IsZero() && time.Now().After(c.wd) {
+ return 0, syscall.ETIMEDOUT
+ }
+ return len(b), nil
+}
+
+func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ defer afterTest(t)
+ for _, handler := range testHandlerBodyConsumers {
+ conn := &slowTestConn{
+ script: []interface{}{
+ "POST /public HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Content-Length: 10000\r\n" +
+ "\r\n",
+ "foo bar baz",
+ 600 * time.Millisecond, // Request deadline should hit here
+ "GET /secret HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "\r\n",
+ },
+ closec: make(chan bool, 1),
+ }
+ ls := &oneConnListener{conn}
+
+ var numReqs int
+ s := Server{
+ Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
+ numReqs++
+ if strings.Contains(req.URL.Path, "secret") {
+ t.Error("Request for /secret encountered, should not have happened.")
+ }
+ handler.f(req.Body)
+ }),
+ ReadTimeout: 400 * time.Millisecond,
+ }
+ go s.Serve(ls)
+ <-conn.closec
+
+ if numReqs != 1 {
+ t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
+ }
+ }
+}
+
func TestTimeoutHandler(t *testing.T) {
defer afterTest(t)
sendHi := make(chan bool, 1)
// don't want to do an unbounded amount of reading here for
// DoS reasons, so we only try up to a threshold.
if w.req.ContentLength != 0 && !w.closeAfterReply {
- ecr, isExpecter := w.req.Body.(*expectContinueReader)
- if !isExpecter || ecr.resp.wroteContinue {
- var tooBig bool
- if reqBody, ok := w.req.Body.(*body); ok && reqBody.unreadDataSize() >= maxPostHandlerReadBytes {
+ var discard, tooBig bool
+
+ switch bdy := w.req.Body.(type) {
+ case *expectContinueReader:
+ if bdy.resp.wroteContinue {
+ discard = true
+ }
+ case *body:
+ switch {
+ case bdy.closed:
+ if !bdy.sawEOF {
+ // Body was closed in handler with non-EOF error.
+ w.closeAfterReply = true
+ }
+ case bdy.unreadDataSize() >= maxPostHandlerReadBytes:
tooBig = true
- } else {
- n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
- tooBig = n >= maxPostHandlerReadBytes
+ default:
+ discard = true
}
- if tooBig {
- w.requestTooLarge()
- delHeader("Connection")
- setHeader.connection = "close"
- } else {
- w.req.Body.Close()
+ default:
+ discard = true
+ }
+
+ if discard {
+ _, err := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
+ switch err {
+ case nil:
+ // There must be even more data left over.
+ tooBig = true
+ case ErrBodyReadAfterClose:
+ // Body was already consumed and closed.
+ case io.EOF:
+ // The remaining body was just consumed, close it.
+ err = w.req.Body.Close()
+ if err != nil {
+ w.closeAfterReply = true
+ }
+ default:
+ // Some other kind of error occured, like a read timeout, or
+ // corrupt chunked encoding. In any case, whatever remains
+ // on the wire must not be parsed as another HTTP request.
+ w.closeAfterReply = true
}
}
+
+ if tooBig {
+ w.requestTooLarge()
+ delHeader("Connection")
+ setHeader.connection = "close"
+ }
}
code := w.status