]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.16] net/http: fix ResponseWriter.ReadFrom with short reads
authorDamien Neil <dneil@google.com>
Fri, 12 Mar 2021 21:53:11 +0000 (13:53 -0800)
committerDmitri Shuralyov <dmitshur@golang.org>
Tue, 29 Jun 2021 17:42:45 +0000 (17:42 +0000)
CL 249238 changes ResponseWriter.ReadFrom to probe the source with
a single read of sniffLen bytes before writing the response header.
If the source returns less than sniffLen bytes without reaching
EOF, this can cause Content-Type and Content-Length detection to
fail.

Fix ResponseWrite.ReadFrom to copy a full sniffLen bytes from
the source as a probe.

Drop the explicit call to w.WriteHeader; writing the probe will
trigger a WriteHeader call.

Consistently use io.CopyBuffer; ReadFrom has already acquired a
copy buffer, so it may as well use it.

Fixes #44984.
Updates #44953.

Change-Id: Ic49305fb827a2bd7da4764b68d64b797b5157dc0
Reviewed-on: https://go-review.googlesource.com/c/go/+/301449
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
(cherry picked from commit 831f9376d8d730b16fb33dfd775618dffe13ce7a)
Reviewed-on: https://go-review.googlesource.com/c/go/+/324971
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
src/net/http/server.go
src/net/http/sniff_test.go

index ad99741177d50d20ad12a3ad8f5615eb8371b42e..198102ffd21ea1d79042871528fbd2b56fb59c27 100644 (file)
@@ -577,37 +577,17 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
                return io.CopyBuffer(writerOnly{w}, src, buf)
        }
 
-       // sendfile path:
-
-       // Do not start actually writing response until src is readable.
-       // If body length is <= sniffLen, sendfile/splice path will do
-       // little anyway. This small read also satisfies sniffing the
-       // body in case Content-Type is missing.
-       nr, er := src.Read(buf[:sniffLen])
-       atEOF := errors.Is(er, io.EOF)
-       n += int64(nr)
-
-       if nr > 0 {
-               // Write the small amount read normally.
-               nw, ew := w.Write(buf[:nr])
-               if ew != nil {
-                       err = ew
-               } else if nr != nw {
-                       err = io.ErrShortWrite
+       // Copy the first sniffLen bytes before switching to ReadFrom.
+       // This ensures we don't start writing the response before the
+       // source is available (see golang.org/issue/5660) and provides
+       // enough bytes to perform Content-Type sniffing when required.
+       if !w.cw.wroteHeader {
+               n0, err := io.CopyBuffer(writerOnly{w}, io.LimitReader(src, sniffLen), buf)
+               n += n0
+               if err != nil || n0 < sniffLen {
+                       return n, err
                }
        }
-       if err == nil && er != nil && !atEOF {
-               err = er
-       }
-
-       // Do not send StatusOK in the error case where nothing has been written.
-       if err == nil && !w.wroteHeader {
-               w.WriteHeader(StatusOK) // nr == 0, no error (or EOF)
-       }
-
-       if err != nil || atEOF {
-               return n, err
-       }
 
        w.w.Flush()  // get rid of any previous writes
        w.cw.flush() // make sure Header is written; flush data to rwc
@@ -620,7 +600,7 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
                return n, err
        }
 
-       n0, err := io.Copy(writerOnly{w}, src)
+       n0, err := io.CopyBuffer(writerOnly{w}, src, buf)
        n += n0
        return n, err
 }
index 8d5350374ddb61040af7c0d6b2a00e88436dd64c..e91335729af16171cfe4df274eb3298bfcc91903 100644 (file)
@@ -157,9 +157,25 @@ func testServerIssue5953(t *testing.T, h2 bool) {
        resp.Body.Close()
 }
 
-func TestContentTypeWithCopy_h1(t *testing.T) { testContentTypeWithCopy(t, h1Mode) }
-func TestContentTypeWithCopy_h2(t *testing.T) { testContentTypeWithCopy(t, h2Mode) }
-func testContentTypeWithCopy(t *testing.T, h2 bool) {
+type byteAtATimeReader struct {
+       buf []byte
+}
+
+func (b *byteAtATimeReader) Read(p []byte) (n int, err error) {
+       if len(p) < 1 {
+               return 0, nil
+       }
+       if len(b.buf) == 0 {
+               return 0, io.EOF
+       }
+       p[0] = b.buf[0]
+       b.buf = b.buf[1:]
+       return 1, nil
+}
+
+func TestContentTypeWithVariousSources_h1(t *testing.T) { testContentTypeWithVariousSources(t, h1Mode) }
+func TestContentTypeWithVariousSources_h2(t *testing.T) { testContentTypeWithVariousSources(t, h2Mode) }
+func testContentTypeWithVariousSources(t *testing.T, h2 bool) {
        defer afterTest(t)
 
        const (
@@ -167,30 +183,86 @@ func testContentTypeWithCopy(t *testing.T, h2 bool) {
                expected = "text/html; charset=utf-8"
        )
 
-       cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
-               // Use io.Copy from a bytes.Buffer to trigger ReadFrom.
-               buf := bytes.NewBuffer([]byte(input))
-               n, err := io.Copy(w, buf)
-               if int(n) != len(input) || err != nil {
-                       t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
-               }
-       }))
-       defer cst.close()
+       for _, test := range []struct {
+               name    string
+               handler func(ResponseWriter, *Request)
+       }{{
+               name: "write",
+               handler: func(w ResponseWriter, r *Request) {
+                       // Write the whole input at once.
+                       n, err := w.Write([]byte(input))
+                       if int(n) != len(input) || err != nil {
+                               t.Errorf("w.Write(%q) = %v, %v want %d, nil", input, n, err, len(input))
+                       }
+               },
+       }, {
+               name: "write one byte at a time",
+               handler: func(w ResponseWriter, r *Request) {
+                       // Write the input one byte at a time.
+                       buf := []byte(input)
+                       for i := range buf {
+                               n, err := w.Write(buf[i : i+1])
+                               if n != 1 || err != nil {
+                                       t.Errorf("w.Write(%q) = %v, %v want 1, nil", input, n, err)
+                               }
+                       }
+               },
+       }, {
+               name: "copy from Reader",
+               handler: func(w ResponseWriter, r *Request) {
+                       // Use io.Copy from a plain Reader.
+                       type readerOnly struct{ io.Reader }
+                       buf := bytes.NewBuffer([]byte(input))
+                       n, err := io.Copy(w, readerOnly{buf})
+                       if int(n) != len(input) || err != nil {
+                               t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+                       }
+               },
+       }, {
+               name: "copy from bytes.Buffer",
+               handler: func(w ResponseWriter, r *Request) {
+                       // Use io.Copy from a bytes.Buffer to trigger ReadFrom.
+                       buf := bytes.NewBuffer([]byte(input))
+                       n, err := io.Copy(w, buf)
+                       if int(n) != len(input) || err != nil {
+                               t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+                       }
+               },
+       }, {
+               name: "copy one byte at a time",
+               handler: func(w ResponseWriter, r *Request) {
+                       // Use io.Copy from a Reader that returns one byte at a time.
+                       n, err := io.Copy(w, &byteAtATimeReader{[]byte(input)})
+                       if int(n) != len(input) || err != nil {
+                               t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+                       }
+               },
+       }} {
+               t.Run(test.name, func(t *testing.T) {
+                       cst := newClientServerTest(t, h2, HandlerFunc(test.handler))
+                       defer cst.close()
+
+                       resp, err := cst.c.Get(cst.ts.URL)
+                       if err != nil {
+                               t.Fatalf("Get: %v", err)
+                       }
+                       if ct := resp.Header.Get("Content-Type"); ct != expected {
+                               t.Errorf("Content-Type = %q, want %q", ct, expected)
+                       }
+                       if want, got := resp.Header.Get("Content-Length"), fmt.Sprint(len(input)); want != got {
+                               t.Errorf("Content-Length = %q, want %q", want, got)
+                       }
+                       data, err := io.ReadAll(resp.Body)
+                       if err != nil {
+                               t.Errorf("reading body: %v", err)
+                       } else if !bytes.Equal(data, []byte(input)) {
+                               t.Errorf("data is %q, want %q", data, input)
+                       }
+                       resp.Body.Close()
+
+               })
 
-       resp, err := cst.c.Get(cst.ts.URL)
-       if err != nil {
-               t.Fatalf("Get: %v", err)
-       }
-       if ct := resp.Header.Get("Content-Type"); ct != expected {
-               t.Errorf("Content-Type = %q, want %q", ct, expected)
-       }
-       data, err := io.ReadAll(resp.Body)
-       if err != nil {
-               t.Errorf("reading body: %v", err)
-       } else if !bytes.Equal(data, []byte(input)) {
-               t.Errorf("data is %q, want %q", data, input)
        }
-       resp.Body.Close()
 }
 
 func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) }