--- /dev/null
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "log"
+ "testing"
+)
+
+type errWriter struct {
+ err error
+}
+
+func (w errWriter) Write([]byte) (int, error) {
+ return 0, w.err
+}
+
+func TestWriteLogger(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(bytes.Buffer)
+ log.SetPrefix("lw: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ lw := new(bytes.Buffer)
+ wl := NewWriteLogger("write:", lw)
+ if _, err := wl.Write([]byte("Hello, World!")); err != nil {
+ t.Fatalf("Unexpectedly failed to write: %v", err)
+ }
+
+ if g, w := lw.String(), "Hello, World!"; g != w {
+ t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+ wantLogWithHex := fmt.Sprintf("lw: write: %x\n", "Hello, World!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
+
+func TestWriteLogger_errorOnWrite(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(bytes.Buffer)
+ log.SetPrefix("lw: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ lw := errWriter{err: errors.New("Write Error!")}
+ wl := NewWriteLogger("write:", lw)
+ if _, err := wl.Write([]byte("Hello, World!")); err == nil {
+ t.Fatalf("Unexpectedly succeeded to write: %v", err)
+ }
+
+ wantLogWithHex := fmt.Sprintf("lw: write: %x: %v\n", "", "Write Error!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
+
+type errReader struct {
+ err error
+}
+
+func (r errReader) Read([]byte) (int, error) {
+ return 0, r.err
+}
+
+func TestReadLogger(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(bytes.Buffer)
+ log.SetPrefix("lr: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ data := []byte("Hello, World!")
+ p := make([]byte, len(data))
+ lr := bytes.NewReader(data)
+ rl := NewReadLogger("read:", lr)
+
+ n, err := rl.Read(p)
+ if err != nil {
+ t.Fatalf("Unexpectedly failed to read: %v", err)
+ }
+
+ if g, w := p[:n], data; !bytes.Equal(g, w) {
+ t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ wantLogWithHex := fmt.Sprintf("lr: read: %x\n", "Hello, World!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
+
+func TestReadLogger_errorOnRead(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(bytes.Buffer)
+ log.SetPrefix("lr: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ data := []byte("Hello, World!")
+ p := make([]byte, len(data))
+
+ lr := errReader{err: errors.New("Read Error!")}
+ rl := NewReadLogger("read", lr)
+ n, err := rl.Read(p)
+ if err == nil {
+ t.Fatalf("Unexpectedly succeeded to read: %v", err)
+ }
+
+ wantLogWithHex := fmt.Sprintf("lr: read %x: %v\n", p[:n], "Read Error!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
--- /dev/null
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+func TestOneByteReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+
+ obr := OneByteReader(buf)
+ var b []byte
+ n, err := obr.Read(b)
+ if err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 3)
+ // Read from obr until EOF.
+ got := new(bytes.Buffer)
+ for i := 0; ; i++ {
+ n, err = obr.Read(b)
+ if err != nil {
+ break
+ }
+ if g, w := n, 1; g != w {
+ t.Errorf("Iteration #%d read %d bytes, want %d", i, g, w)
+ }
+ got.Write(b[:n])
+ }
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Unexpected error after reading all bytes\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := got.String(), "Hello, World!"; g != w {
+ t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w)
+ }
+}
+
+func TestOneByteReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+
+ obr := OneByteReader(r)
+ var b []byte
+ if n, err := obr.Read(b); err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 5)
+ n, err := obr.Read(b)
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestHalfReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+ // empty read buffer
+ hr := HalfReader(buf)
+ var b []byte
+ n, err := hr.Read(b)
+ if err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // non empty read buffer
+ b = make([]byte, 2)
+ got := new(bytes.Buffer)
+ for i := 0; ; i++ {
+ n, err = hr.Read(b)
+ if err != nil {
+ break
+ }
+ if g, w := n, 1; g != w {
+ t.Errorf("Iteration #%d read %d bytes, want %d", i, g, w)
+ }
+ got.Write(b[:n])
+ }
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Unexpected error after reading all bytes\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := got.String(), "Hello, World!"; g != w {
+ t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w)
+ }
+}
+
+func TestHalfReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+
+ hr := HalfReader(r)
+ var b []byte
+ if n, err := hr.Read(b); err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 5)
+ n, err := hr.Read(b)
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestTimeOutReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+ // empty read buffer
+ tor := TimeoutReader(buf)
+ var b []byte
+ n, err := tor.Read(b)
+ if err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err = tor.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+ // non empty read buffer
+ tor2 := TimeoutReader(buf)
+ b = make([]byte, 3)
+ if n, err := tor2.Read(b); err != nil || n == 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err = tor2.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestTimeOutReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+ // empty read buffer
+ tor := TimeoutReader(r)
+ var b []byte
+ if n, err := tor.Read(b); err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err := tor.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+ // non empty read buffer
+ tor2 := TimeoutReader(r)
+ b = make([]byte, 5)
+ if n, err := tor2.Read(b); err != io.EOF || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err = tor2.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestDataErrReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+
+ der := DataErrReader(buf)
+
+ b := make([]byte, 3)
+ got := new(bytes.Buffer)
+ var n int
+ var err error
+ for {
+ n, err = der.Read(b)
+ got.Write(b[:n])
+ if err != nil {
+ break
+ }
+ }
+ if err != io.EOF || n == 0 {
+ t.Errorf("Last Read returned n=%d err=%v", n, err)
+ }
+ if g, w := got.String(), "Hello, World!"; g != w {
+ t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w)
+ }
+}
+
+func TestDataErrReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+
+ der := DataErrReader(r)
+ var b []byte
+ if n, err := der.Read(b); err != io.EOF || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 5)
+ n, err := der.Read(b)
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}