// MultiWriter creates a writer that duplicates its writes to all the
// provided writers, similar to the Unix tee(1) command.
func MultiWriter(writers ...Writer) Writer {
- w := make([]Writer, len(writers))
- copy(w, writers)
- return &multiWriter{w}
+ allWriters := make([]Writer, 0, len(writers))
+ for _, w := range writers {
+ if mw, ok := w.(*multiWriter); ok {
+ allWriters = append(allWriters, mw.writers...)
+ } else {
+ allWriters = append(allWriters, w)
+ }
+ }
+ return &multiWriter{allWriters}
}
}
}
+// writerFunc is an io.Writer implemented by the underlying func.
+type writerFunc func(p []byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) {
+ return f(p)
+}
+
+// Test that MultiWriter properly flattens chained multiWriters,
+func TestMultiWriterSingleChainFlatten(t *testing.T) {
+ pc := make([]uintptr, 1000) // 1000 should fit the full stack
+ n := runtime.Callers(0, pc)
+ var myDepth = callDepth(pc[:n])
+ var writeDepth int // will contain the depth from which writerFunc.Writer was called
+ var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) {
+ n := runtime.Callers(1, pc)
+ writeDepth += callDepth(pc[:n])
+ return 0, nil
+ }))
+
+ mw := w
+ // chain a bunch of multiWriters
+ for i := 0; i < 100; i++ {
+ mw = MultiWriter(w)
+ }
+
+ mw = MultiWriter(w, mw, w, mw)
+ mw.Write(nil) // don't care about errors, just want to check the call-depth for Write
+
+ if writeDepth != 4*(myDepth+2) { // 2 should be multiWriter.Write and writerFunc.Write
+ t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d",
+ 4*(myDepth+2), writeDepth)
+ }
+}
+
// Test that MultiReader copies the input slice and is insulated from future modification.
func TestMultiReaderCopy(t *testing.T) {
slice := []Reader{strings.NewReader("hello world")}