]> Cypherpunks repositories - gostls13.git/commitdiff
io: flatten MultiWriter writers
authorMichael Fraenkel <michael.fraenkel@gmail.com>
Wed, 25 Oct 2017 16:42:14 +0000 (12:42 -0400)
committerIan Lance Taylor <iant@golang.org>
Wed, 25 Oct 2017 21:48:50 +0000 (21:48 +0000)
Replace any nested Writer that is a MultiWriter with its associated
Writers.

Fixes #22431

Change-Id: Ida7c4c83926363c1780689e216cf0c5241a5b8eb
Reviewed-on: https://go-review.googlesource.com/73470
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/io/multi.go
src/io/multi_test.go

index d784846862c455cd8ce9deef25ddb62db6425a1e..c662765a3b50a60d32e4bb549b73f74a4605e5a1 100644 (file)
@@ -96,7 +96,13 @@ func (t *multiWriter) WriteString(s string) (n int, err error) {
 // MultiWriter creates a writer that duplicates its writes to all the
 // provided writers, similar to the Unix tee(1) command.
 func MultiWriter(writers ...Writer) Writer {
-       w := make([]Writer, len(writers))
-       copy(w, writers)
-       return &multiWriter{w}
+       allWriters := make([]Writer, 0, len(writers))
+       for _, w := range writers {
+               if mw, ok := w.(*multiWriter); ok {
+                       allWriters = append(allWriters, mw.writers...)
+               } else {
+                       allWriters = append(allWriters, w)
+               }
+       }
+       return &multiWriter{allWriters}
 }
index 0a7eb43032b99f823e1a931dbdae9c860dae22b2..83eef756fd7d7126464576cd373cf742a4f13d95 100644 (file)
@@ -142,6 +142,40 @@ func testMultiWriter(t *testing.T, sink interface {
        }
 }
 
+// writerFunc is an io.Writer implemented by the underlying func.
+type writerFunc func(p []byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) {
+       return f(p)
+}
+
+// Test that MultiWriter properly flattens chained multiWriters,
+func TestMultiWriterSingleChainFlatten(t *testing.T) {
+       pc := make([]uintptr, 1000) // 1000 should fit the full stack
+       n := runtime.Callers(0, pc)
+       var myDepth = callDepth(pc[:n])
+       var writeDepth int // will contain the depth from which writerFunc.Writer was called
+       var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) {
+               n := runtime.Callers(1, pc)
+               writeDepth += callDepth(pc[:n])
+               return 0, nil
+       }))
+
+       mw := w
+       // chain a bunch of multiWriters
+       for i := 0; i < 100; i++ {
+               mw = MultiWriter(w)
+       }
+
+       mw = MultiWriter(w, mw, w, mw)
+       mw.Write(nil) // don't care about errors, just want to check the call-depth for Write
+
+       if writeDepth != 4*(myDepth+2) { // 2 should be multiWriter.Write and writerFunc.Write
+               t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d",
+                       4*(myDepth+2), writeDepth)
+       }
+}
+
 // Test that MultiReader copies the input slice and is insulated from future modification.
 func TestMultiReaderCopy(t *testing.T) {
        slice := []Reader{strings.NewReader("hello world")}