// If no byte is available, returns an error.
func (b *Reader) ReadByte() (c byte, err error) {
b.lastRuneSize = -1
- for b.w == b.r {
+ for b.r == b.w {
if b.err != nil {
return 0, b.readErr()
}
// UnreadByte unreads the last byte. Only the most recently read byte can be unread.
func (b *Reader) UnreadByte() error {
- b.lastRuneSize = -1
- if b.r == b.w && b.lastByte >= 0 {
- b.w = 1
- b.r = 0
- b.buf[0] = byte(b.lastByte)
- b.lastByte = -1
- return nil
- }
- if b.r <= 0 {
+ if b.lastByte < 0 || b.r == 0 && b.w > 0 {
return ErrInvalidUnreadByte
}
- b.r--
+ // b.r > 0 || b.w == 0
+ if b.r > 0 {
+ b.r--
+ } else {
+ // b.r == 0 && b.w == 0
+ b.w = 1
+ }
+ b.buf[b.r] = byte(b.lastByte)
b.lastByte = -1
+ b.lastRuneSize = -1
return nil
}
// regard it is stricter than UnreadByte, which will unread the last byte
// from any read operation.)
func (b *Reader) UnreadRune() error {
- if b.lastRuneSize < 0 || b.r == 0 {
+ if b.lastRuneSize < 0 || b.r < b.lastRuneSize {
return ErrInvalidUnreadRune
}
b.r -= b.lastRuneSize
}
func TestUnreadRune(t *testing.T) {
- got := ""
segments := []string{"Hello, world:", "日本語"}
- data := strings.Join(segments, "")
r := NewReader(&StringReader{data: segments})
+ got := ""
+ want := strings.Join(segments, "")
// Normal execution.
for {
r1, _, err := r.ReadRune()
if err != nil {
if err != io.EOF {
- t.Error("unexpected EOF")
+ t.Error("unexpected error on ReadRune:", err)
}
break
}
got += string(r1)
- // Put it back and read it again
+ // Put it back and read it again.
if err = r.UnreadRune(); err != nil {
- t.Error("unexpected error on UnreadRune:", err)
+ t.Fatal("unexpected error on UnreadRune:", err)
}
r2, _, err := r.ReadRune()
if err != nil {
- t.Error("unexpected error reading after unreading:", err)
+ t.Fatal("unexpected error reading after unreading:", err)
}
if r1 != r2 {
- t.Errorf("incorrect rune after unread: got %c wanted %c", r1, r2)
+ t.Fatalf("incorrect rune after unread: got %c, want %c", r1, r2)
}
}
- if got != data {
- t.Errorf("want=%q got=%q", data, got)
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
}
}
func TestUnreadByte(t *testing.T) {
- want := "Hello, world"
- got := ""
segments := []string{"Hello, ", "world"}
r := NewReader(&StringReader{data: segments})
+ got := ""
+ want := strings.Join(segments, "")
// Normal execution.
for {
b1, err := r.ReadByte()
if err != nil {
if err != io.EOF {
- t.Fatal("unexpected EOF")
+ t.Error("unexpected error on ReadByte:", err)
}
break
}
got += string(b1)
- // Put it back and read it again
+ // Put it back and read it again.
if err = r.UnreadByte(); err != nil {
- t.Fatalf("unexpected error on UnreadByte: %v", err)
+ t.Fatal("unexpected error on UnreadByte:", err)
}
b2, err := r.ReadByte()
if err != nil {
- t.Fatalf("unexpected error reading after unreading: %v", err)
+ t.Fatal("unexpected error reading after unreading:", err)
}
if b1 != b2 {
- t.Fatalf("incorrect byte after unread: got %c wanted %c", b1, b2)
+ t.Fatalf("incorrect byte after unread: got %q, want %q", b1, b2)
}
}
if got != want {
- t.Errorf("got=%q want=%q", got, want)
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
+
+func TestUnreadByteMultiple(t *testing.T) {
+ segments := []string{"Hello, ", "world"}
+ data := strings.Join(segments, "")
+ for n := 0; n <= len(data); n++ {
+ r := NewReader(&StringReader{data: segments})
+ // Read n bytes.
+ for i := 0; i < n; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ t.Fatalf("n = %d: unexpected error on ReadByte: %v", n, err)
+ }
+ if b != data[i] {
+ t.Fatalf("n = %d: incorrect byte returned from ReadByte: got %q, want %q", n, b, data[i])
+ }
+ }
+ // Unread one byte if there is one.
+ if n > 0 {
+ if err := r.UnreadByte(); err != nil {
+ t.Errorf("n = %d: unexpected error on UnreadByte: %v", n, err)
+ }
+ }
+ // Test that we cannot unread any further.
+ if err := r.UnreadByte(); err == nil {
+ t.Errorf("n = %d: expected error on UnreadByte", n)
+ }
}
}