]> Cypherpunks repositories - gostls13.git/commitdiff
bufio: use underlying ReadFrom even when data is buffered
authorDamien Neil <dneil@google.com>
Fri, 6 Aug 2021 20:23:13 +0000 (13:23 -0700)
committerDamien Neil <dneil@google.com>
Mon, 18 Oct 2021 21:52:05 +0000 (21:52 +0000)
When (*bufio.Writer).ReadFrom is called with a partially filled buffer,
fill out and flush the buffer and then call the underlying writer's
ReadFrom method if present.

Fixes #44815.

Change-Id: I15b3ef0746d0d60fd62041189a9b9df11254dd29
Reviewed-on: https://go-review.googlesource.com/c/go/+/340530
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/bufio/bufio.go
src/bufio/bufio_test.go

index a58df254941166ae72c71888c15d3d14eebf84d9..063a7785f3637aa59c6ef8b8f62ed8a49e3e41a1 100644 (file)
@@ -745,19 +745,14 @@ func (b *Writer) WriteString(s string) (int, error) {
 }
 
 // ReadFrom implements io.ReaderFrom. If the underlying writer
-// supports the ReadFrom method, and b has no buffered data yet,
-// this calls the underlying ReadFrom without buffering.
+// supports the ReadFrom method, this calls the underlying ReadFrom.
+// If there is buffered data and an underlying ReadFrom, this fills
+// the buffer and writes it before calling ReadFrom.
 func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
        if b.err != nil {
                return 0, b.err
        }
-       if b.Buffered() == 0 {
-               if w, ok := b.wr.(io.ReaderFrom); ok {
-                       n, err = w.ReadFrom(r)
-                       b.err = err
-                       return n, err
-               }
-       }
+       readerFrom, readerFromOK := b.wr.(io.ReaderFrom)
        var m int
        for {
                if b.Available() == 0 {
@@ -765,6 +760,12 @@ func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
                                return n, err1
                        }
                }
+               if readerFromOK && b.Buffered() == 0 {
+                       nn, err := readerFrom.ReadFrom(r)
+                       b.err = err
+                       n += nn
+                       return n, err
+               }
                nr := 0
                for nr < maxConsecutiveEmptyReads {
                        m, err = r.Read(b.buf[b.n:])
index 8e8a8a1778a690ade5bed0a4d8bef18ead5324e8..66b3e700531a98a8c876013ba951039ea434c721 100644 (file)
@@ -1351,6 +1351,54 @@ func TestWriterReadFromErrNoProgress(t *testing.T) {
        }
 }
 
+type readFromWriter struct {
+       buf           []byte
+       writeBytes    int
+       readFromBytes int
+}
+
+func (w *readFromWriter) Write(p []byte) (int, error) {
+       w.buf = append(w.buf, p...)
+       w.writeBytes += len(p)
+       return len(p), nil
+}
+
+func (w *readFromWriter) ReadFrom(r io.Reader) (int64, error) {
+       b, err := io.ReadAll(r)
+       w.buf = append(w.buf, b...)
+       w.readFromBytes += len(b)
+       return int64(len(b)), err
+}
+
+// Test that calling (*Writer).ReadFrom with a partially-filled buffer
+// fills the buffer before switching over to ReadFrom.
+func TestWriterReadFromWithBufferedData(t *testing.T) {
+       const bufsize = 16
+
+       input := createTestInput(64)
+       rfw := &readFromWriter{}
+       w := NewWriterSize(rfw, bufsize)
+
+       const writeSize = 8
+       if n, err := w.Write(input[:writeSize]); n != writeSize || err != nil {
+               t.Errorf("w.Write(%v bytes) = %v, %v; want %v, nil", writeSize, n, err, writeSize)
+       }
+       n, err := w.ReadFrom(bytes.NewReader(input[writeSize:]))
+       if wantn := len(input[writeSize:]); int(n) != wantn || err != nil {
+               t.Errorf("io.Copy(w, %v bytes) = %v, %v; want %v, nil", wantn, n, err, wantn)
+       }
+       if err := w.Flush(); err != nil {
+               t.Errorf("w.Flush() = %v, want nil", err)
+       }
+
+       if got, want := rfw.writeBytes, bufsize; got != want {
+               t.Errorf("wrote %v bytes with Write, want %v", got, want)
+       }
+       if got, want := rfw.readFromBytes, len(input)-bufsize; got != want {
+               t.Errorf("wrote %v bytes with ReadFrom, want %v", got, want)
+       }
+}
+
 func TestReadZero(t *testing.T) {
        for _, size := range []int{100, 2} {
                t.Run(fmt.Sprintf("bufsize=%d", size), func(t *testing.T) {