]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: avoid panic when writing 100-continue after handler done
authorDamien Neil <dneil@google.com>
Tue, 14 May 2024 16:55:11 +0000 (09:55 -0700)
committerDamien Neil <dneil@google.com>
Tue, 14 May 2024 22:34:15 +0000 (22:34 +0000)
When a request contains an "Expect: 100-continue" header,
the first read from the request body causes the server to
write a 100-continue status.

This write caused a panic when performed after the server handler
has exited. Disable the write when cleaning up after a handler
exits.

This also fixes a bug where an implicit 100-continue could be
sent after a call to WriteHeader has sent a non-1xx header.

This change drops tracking of whether we've written a
100-continue or not in response.wroteContinue. This tracking
was used to determine whether we should consume the remaining
request body in chunkWriter.writeHeader, but the discard-the-body
path was only taken when the body was already consumed.
(If the body is not consumed, we set closeAfterReply, and we
don't consume the remaining body when closeAfterReply is set.
If the body is consumed, then we may attempt to discard the
remaining body, but there is obviously no body remaining.)

Fixes #53808

Change-Id: I3542df26ad6cdfe93b50a45ae2d6e7ef031e46fa
Reviewed-on: https://go-review.googlesource.com/c/go/+/585395
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
src/net/http/serve_test.go
src/net/http/server.go

index e21af8b1595612b2f42164eb38db7669e70e0e3e..f454dcdbed3525fba279ec22c784996b97bb756f 100644 (file)
@@ -7166,3 +7166,79 @@ func TestError(t *testing.T) {
                t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
        }
 }
+
+func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
+       run(t, testServerReadAfterWriteHeader100Continue)
+}
+func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
+       body := []byte("body")
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.WriteHeader(200)
+               NewResponseController(w).Flush()
+               io.ReadAll(r.Body)
+               w.Write(body)
+       }))
+
+       req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
+       req.Header.Set("Expect", "100-continue")
+       res, err := cst.c.Do(req)
+       if err != nil {
+               t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
+       }
+       defer res.Body.Close()
+       got, err := io.ReadAll(res.Body)
+       if err != nil {
+               t.Fatalf("io.ReadAll(res.Body) = %v", err)
+       }
+       if !bytes.Equal(got, body) {
+               t.Fatalf("response body = %q, want %q", got, body)
+       }
+}
+
+func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
+       run(t, testServerReadAfterHandlerDone100Continue)
+}
+func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
+       readyc := make(chan struct{})
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               go func() {
+                       <-readyc
+                       io.ReadAll(r.Body)
+                       <-readyc
+               }()
+       }))
+
+       req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
+       req.Header.Set("Expect", "100-continue")
+       res, err := cst.c.Do(req)
+       if err != nil {
+               t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
+       }
+       res.Body.Close()
+       readyc <- struct{}{} // server starts reading from the request body
+       readyc <- struct{}{} // server finishes reading from the request body
+}
+
+func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
+       run(t, testServerReadAfterHandlerAbort100Continue)
+}
+func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
+       readyc := make(chan struct{})
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               go func() {
+                       <-readyc
+                       io.ReadAll(r.Body)
+                       <-readyc
+               }()
+               panic(ErrAbortHandler)
+       }))
+
+       req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
+       req.Header.Set("Expect", "100-continue")
+       res, err := cst.c.Do(req)
+       if err == nil {
+               res.Body.Close()
+       }
+       readyc <- struct{}{} // server starts reading from the request body
+       readyc <- struct{}{} // server finishes reading from the request body
+}
index a50b20b7da905b2ef9dc497000de382ee4be3492..b76c869567185b77652816b34f963816bb6a43dc 100644 (file)
@@ -425,7 +425,6 @@ type response struct {
        reqBody          io.ReadCloser
        cancelCtx        context.CancelFunc // when ServeHTTP exits
        wroteHeader      bool               // a non-1xx header has been (logically) written
-       wroteContinue    bool               // 100 Continue response was written
        wants10KeepAlive bool               // HTTP/1.0 w/ Connection "keep-alive"
        wantsClose       bool               // HTTP request has Connection "close"
 
@@ -436,8 +435,8 @@ type response struct {
        // These two fields together synchronize the body reader (the
        // expectContinueReader, which wants to write 100 Continue)
        // against the main writer.
-       canWriteContinue atomic.Bool
        writeContinueMu  sync.Mutex
+       canWriteContinue atomic.Bool
 
        w  *bufio.Writer // buffers output in chunks to chunkWriter
        cw chunkWriter
@@ -565,6 +564,14 @@ func (w *response) requestTooLarge() {
        }
 }
 
+// disableWriteContinue stops Request.Body.Read from sending an automatic 100-Continue.
+// If a 100-Continue is being written, it waits for it to complete before continuing.
+func (w *response) disableWriteContinue() {
+       w.writeContinueMu.Lock()
+       w.canWriteContinue.Store(false)
+       w.writeContinueMu.Unlock()
+}
+
 // writerOnly hides an io.Writer value's optional ReadFrom method
 // from io.Copy.
 type writerOnly struct {
@@ -917,8 +924,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
                return 0, ErrBodyReadAfterClose
        }
        w := ecr.resp
-       if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() {
-               w.wroteContinue = true
+       if w.canWriteContinue.Load() {
                w.writeContinueMu.Lock()
                if w.canWriteContinue.Load() {
                        w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
@@ -1159,18 +1165,17 @@ func (w *response) WriteHeader(code int) {
        }
        checkWriteHeaderCode(code)
 
+       if code < 101 || code > 199 {
+               // Sending a 100 Continue or any non-1xx header disables the
+               // automatically-sent 100 Continue from Request.Body.Read.
+               w.disableWriteContinue()
+       }
+
        // Handle informational headers.
        //
        // We shouldn't send any further headers after 101 Switching Protocols,
        // so it takes the non-informational path.
        if code >= 100 && code <= 199 && code != StatusSwitchingProtocols {
-               // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
-               if code == 100 && w.canWriteContinue.Load() {
-                       w.writeContinueMu.Lock()
-                       w.canWriteContinue.Store(false)
-                       w.writeContinueMu.Unlock()
-               }
-
                writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
 
                // Per RFC 8297 we must not clear the current header map
@@ -1378,14 +1383,20 @@ func (cw *chunkWriter) writeHeader(p []byte) {
        //
        // If full duplex mode has been enabled with ResponseController.EnableFullDuplex,
        // then leave the request body alone.
+       //
+       // We don't take this path when w.closeAfterReply is set.
+       // We may not need to consume the request to get ready for the next one
+       // (since we're closing the conn), but a client which sends a full request
+       // before reading a response may deadlock in this case.
+       // This behavior has been present since CL 5268043 (2011), however,
+       // so it doesn't seem to be causing problems.
        if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex {
                var discard, tooBig bool
 
                switch bdy := w.req.Body.(type) {
                case *expectContinueReader:
-                       if bdy.resp.wroteContinue {
-                               discard = true
-                       }
+                       // We only get here if we have already fully consumed the request body
+                       // (see above).
                case *body:
                        bdy.mu.Lock()
                        switch {
@@ -1626,13 +1637,8 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er
        }
 
        if w.canWriteContinue.Load() {
-               // Body reader wants to write 100 Continue but hasn't yet.
-               // Tell it not to. The store must be done while holding the lock
-               // because the lock makes sure that there is not an active write
-               // this very moment.
-               w.writeContinueMu.Lock()
-               w.canWriteContinue.Store(false)
-               w.writeContinueMu.Unlock()
+               // Body reader wants to write 100 Continue but hasn't yet. Tell it not to.
+               w.disableWriteContinue()
        }
 
        if !w.wroteHeader {
@@ -1900,6 +1906,7 @@ func (c *conn) serve(ctx context.Context) {
                }
                if inFlightResponse != nil {
                        inFlightResponse.cancelCtx()
+                       inFlightResponse.disableWriteContinue()
                }
                if !c.hijacked() {
                        if inFlightResponse != nil {
@@ -2106,6 +2113,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
        if w.handlerDone.Load() {
                panic("net/http: Hijack called after ServeHTTP finished")
        }
+       w.disableWriteContinue()
        if w.wroteHeader {
                w.cw.flush()
        }