func (mr *multiReader) Read(p []byte) (n int, err error) {
for len(mr.readers) > 0 {
+ // Optimization to flatten nested multiReaders (Issue 13558)
+ if len(mr.readers) == 1 {
+ if r, ok := mr.readers[0].(*multiReader); ok {
+ mr.readers = r.readers
+ continue
+ }
+ }
n, err = mr.readers[0].Read(p)
if n > 0 || err != EOF {
if err == EOF {
import (
"bytes"
"crypto/sha1"
+ "errors"
"fmt"
. "io"
"io/ioutil"
+ "runtime"
"strings"
"testing"
)
t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world")
}
}
+
+// readerFunc is an io.Reader implemented by the underlying func.
+type readerFunc func(p []byte) (int, error)
+
+func (f readerFunc) Read(p []byte) (int, error) {
+ return f(p)
+}
+
+// Test that MultiReader properly flattens chained multiReaders when Read is called
+func TestMultiReaderFlatten(t *testing.T) {
+ pc := make([]uintptr, 1000) // 1000 should fit the full stack
+ var myDepth = runtime.Callers(0, pc)
+ var readDepth int // will contain the depth from which fakeReader.Read was called
+ var r Reader = MultiReader(readerFunc(func(p []byte) (int, error) {
+ readDepth = runtime.Callers(1, pc)
+ return 0, errors.New("irrelevant")
+ }))
+
+ // chain a bunch of multiReaders
+ for i := 0; i < 100; i++ {
+ r = MultiReader(r)
+ }
+
+ r.Read(nil) // don't care about errors, just want to check the call-depth for Read
+
+ if readDepth != myDepth+2 { // 2 should be multiReader.Read and fakeReader.Read
+ t.Errorf("multiReader did not flatten chained multiReaders: expected readDepth %d, got %d",
+ myDepth+2, readDepth)
+ }
+}