--- /dev/null
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest_test
+
+import (
+ "errors"
+ "fmt"
+ "testing/iotest"
+)
+
+func ExampleErrReader() {
+ // A reader that always returns a custom error.
+ r := iotest.ErrReader(errors.New("custom error"))
+ n, err := r.Read(nil)
+ fmt.Printf("n: %d\nerr: %q\n", n, err)
+
+ // Output:
+ // n: 0
+ // err: "custom error"
+}
data := []byte("Hello, World!")
p := make([]byte, len(data))
- lr := ErrReader()
+ lr := ErrReader(errors.New("io failure"))
rl := NewReadLogger("read", lr)
n, err := rl.Read(p)
if err == nil {
t.Fatalf("Unexpectedly succeeded to read: %v", err)
}
- wantLogWithHex := fmt.Sprintf("lr: read %x: %v\n", p[:n], "io")
+ wantLogWithHex := fmt.Sprintf("lr: read %x: io failure\n", p[:n])
if g, w := lOut.String(), wantLogWithHex; g != w {
t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
}
return r.r.Read(p)
}
-// ErrIO is a fake IO error.
-var ErrIO = errors.New("io")
-
-// ErrReader returns a fake error every time it is read from.
-func ErrReader() io.Reader {
- return errReader(0)
+// ErrReader returns an io.Reader that returns 0, err from all Read calls.
+func ErrReader(err error) io.Reader {
+ return &alwaysErrReader{err: err}
}
-type errReader int
+type alwaysErrReader struct {
+ err error
+}
-func (r errReader) Read(p []byte) (int, error) {
- return 0, ErrIO
+func (aer *alwaysErrReader) Read(p []byte) (int, error) {
+ return 0, aer.err
}
import (
"bytes"
+ "errors"
"io"
"testing"
)
}
func TestErrReader(t *testing.T) {
- n, err := ErrReader().Read([]byte{})
- if err != ErrIO {
- t.Errorf("ErrReader.Read(any) should have returned ErrIO, returned %v", err)
- }
- if n != 0 {
- t.Errorf("ErrReader.Read(any) should have read 0 bytes, read %v", n)
+ cases := []struct {
+ name string
+ err error
+ }{
+ {"nil error", nil},
+ {"non-nil error", errors.New("io failure")},
+ {"io.EOF", io.EOF},
+ }
+
+ for _, tt := range cases {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ n, err := ErrReader(tt.err).Read(nil)
+ if err != tt.err {
+ t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, tt.err)
+ }
+ if n != 0 {
+ t.Fatalf("Byte count mismatch: got %d want 0", n)
+ }
+ })
}
}