]> Cypherpunks repositories - gostls13.git/commitdiff
io: MultiReader and MultiWriter
authorBrad Fitzpatrick <brad@danga.com>
Wed, 28 Jul 2010 18:30:00 +0000 (11:30 -0700)
committerRuss Cox <rsc@golang.org>
Wed, 28 Jul 2010 18:30:00 +0000 (11:30 -0700)
Little helpers I've found useful.

R=adg, rsc, r, gri
CC=golang-dev
https://golang.org/cl/1764043

src/pkg/io/Makefile
src/pkg/io/multi_reader.go [new file with mode: 0644]
src/pkg/io/multi_reader_test.go [new file with mode: 0644]
src/pkg/io/multi_writer.go [new file with mode: 0644]
src/pkg/io/multi_writer_test.go [new file with mode: 0644]

index 8c27ce551db32eb9d35af660fa301fd145d72845..ad2e4cec44849425e6e3fedb2c770eae7a2ab72d 100644 (file)
@@ -7,6 +7,8 @@ include ../../Make.$(GOARCH)
 TARG=io
 GOFILES=\
        io.go\
+       multi_reader.go\
+       multi_writer.go\
        pipe.go\
 
 include ../../Make.pkg
diff --git a/src/pkg/io/multi_reader.go b/src/pkg/io/multi_reader.go
new file mode 100644 (file)
index 0000000..b6fa5dd
--- /dev/null
@@ -0,0 +1,36 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package io
+
+import "os"
+
+type multiReader struct {
+       readers []Reader
+}
+
+func (mr *multiReader) Read(p []byte) (n int, err os.Error) {
+       for len(mr.readers) > 0 {
+               n, err = mr.readers[0].Read(p)
+               if n > 0 || err != os.EOF {
+                       if err == os.EOF {
+                               // This shouldn't happen.
+                               // Well-behaved Readers should never
+                               // return non-zero bytes read with an
+                               // EOF.  But if so, we clean it.
+                               err = nil
+                       }
+                       return
+               }
+               mr.readers = mr.readers[1:]
+       }
+       return 0, os.EOF
+}
+
+// MultiReader returns a Reader that's the logical concatenation of
+// the provided input readers.  They're read sequentially.  Once all
+// inputs are drained, Read will return os.EOF.
+func MultiReader(readers ...Reader) Reader {
+       return &multiReader{readers}
+}
diff --git a/src/pkg/io/multi_reader_test.go b/src/pkg/io/multi_reader_test.go
new file mode 100644 (file)
index 0000000..7d46396
--- /dev/null
@@ -0,0 +1,58 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package io_test
+
+import (
+       . "io"
+       "os"
+       "strings"
+       "testing"
+)
+
+func TestMultiReader(t *testing.T) {
+       var mr Reader
+       var buf []byte
+       nread := 0
+       withFooBar := func(tests func()) {
+               r1 := strings.NewReader("foo ")
+               r2 := strings.NewReader("bar")
+               mr = MultiReader(r1, r2)
+               buf = make([]byte, 20)
+               tests()
+       }
+       expectRead := func(size int, expected string, eerr os.Error) {
+               nread++
+               n, gerr := mr.Read(buf[0:size])
+               if n != len(expected) {
+                       t.Errorf("#%d, expected %d bytes; got %d",
+                               nread, len(expected), n)
+               }
+               got := string(buf[0:n])
+               if got != expected {
+                       t.Errorf("#%d, expected %q; got %q",
+                               nread, expected, got)
+               }
+               if gerr != eerr {
+                       t.Errorf("#%d, expected error %v; got %v",
+                               nread, eerr, gerr)
+               }
+               buf = buf[n:]
+       }
+       withFooBar(func() {
+               expectRead(2, "fo", nil)
+               expectRead(5, "o ", nil)
+               expectRead(5, "bar", nil)
+               expectRead(5, "", os.EOF)
+       })
+       withFooBar(func() {
+               expectRead(4, "foo ", nil)
+               expectRead(1, "b", nil)
+               expectRead(3, "ar", nil)
+               expectRead(1, "", os.EOF)
+       })
+       withFooBar(func() {
+               expectRead(5, "foo ", nil)
+       })
+}
diff --git a/src/pkg/io/multi_writer.go b/src/pkg/io/multi_writer.go
new file mode 100644 (file)
index 0000000..5825288
--- /dev/null
@@ -0,0 +1,31 @@
+// Copyright 2010 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package io
+
+import "os"
+
+type multiWriter struct {
+       writers []Writer
+}
+
+func (t *multiWriter) Write(p []byte) (n int, err os.Error) {
+       for _, w := range t.writers {
+               n, err = w.Write(p)
+               if err != nil {
+                       return
+               }
+               if n != len(p) {
+                       err = ErrShortWrite
+                       return
+               }
+       }
+       return len(p), nil
+}
+
+// MultiWriter creates a writer that duplicates its writes to all the
+// provided writers, similar to the Unix tee(1) command.
+func MultiWriter(writers ...Writer) Writer {
+       return &multiWriter{writers}
+}
diff --git a/src/pkg/io/multi_writer_test.go b/src/pkg/io/multi_writer_test.go
new file mode 100644 (file)
index 0000000..251a477
--- /dev/null
@@ -0,0 +1,41 @@
+// Copyright 2010 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package io_test
+
+import (
+       . "io"
+       "bytes"
+       "crypto/sha1"
+       "fmt"
+       "strings"
+       "testing"
+)
+
+func TestMultiWriter(t *testing.T) {
+       sha1 := sha1.New()
+       sink := new(bytes.Buffer)
+       mw := MultiWriter(sha1, sink)
+
+       sourceString := "My input text."
+       source := strings.NewReader(sourceString)
+       written, err := Copy(mw, source)
+
+       if written != int64(len(sourceString)) {
+               t.Errorf("short write of %d, not %d", written, len(sourceString))
+       }
+
+       if err != nil {
+               t.Errorf("unexpected error: %v", err)
+       }
+
+       sha1hex := fmt.Sprintf("%x", sha1.Sum())
+       if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
+               t.Error("incorrect sha1 value")
+       }
+
+       if sink.String() != sourceString {
+               t.Error("expected %q; got %q", sourceString, sink.String())
+       }
+}