]> Cypherpunks repositories - gostls13.git/commitdiff
bufio: fix UnreadByte
authorRobert Griesemer <gri@golang.org>
Wed, 9 Apr 2014 21:19:13 +0000 (14:19 -0700)
committerRobert Griesemer <gri@golang.org>
Wed, 9 Apr 2014 21:19:13 +0000 (14:19 -0700)
Also:
- fix error messages in tests
- make tests more symmetric

Fixes #7607.

LGTM=r
R=r
CC=golang-codereviews
https://golang.org/cl/86180043

src/pkg/bufio/bufio.go
src/pkg/bufio/bufio_test.go

index de81b4ddfdbadc863ff600e8657fc91ea16cd619..1e0cdae38e8ebdb862d9a46c6805a9684fc2e356 100644 (file)
@@ -177,7 +177,7 @@ func (b *Reader) Read(p []byte) (n int, err error) {
 // If no byte is available, returns an error.
 func (b *Reader) ReadByte() (c byte, err error) {
        b.lastRuneSize = -1
-       for b.w == b.r {
+       for b.r == b.w {
                if b.err != nil {
                        return 0, b.readErr()
                }
@@ -191,19 +191,19 @@ func (b *Reader) ReadByte() (c byte, err error) {
 
 // UnreadByte unreads the last byte.  Only the most recently read byte can be unread.
 func (b *Reader) UnreadByte() error {
-       b.lastRuneSize = -1
-       if b.r == b.w && b.lastByte >= 0 {
-               b.w = 1
-               b.r = 0
-               b.buf[0] = byte(b.lastByte)
-               b.lastByte = -1
-               return nil
-       }
-       if b.r <= 0 {
+       if b.lastByte < 0 || b.r == 0 && b.w > 0 {
                return ErrInvalidUnreadByte
        }
-       b.r--
+       // b.r > 0 || b.w == 0
+       if b.r > 0 {
+               b.r--
+       } else {
+               // b.r == 0 && b.w == 0
+               b.w = 1
+       }
+       b.buf[b.r] = byte(b.lastByte)
        b.lastByte = -1
+       b.lastRuneSize = -1
        return nil
 }
 
@@ -233,7 +233,7 @@ func (b *Reader) ReadRune() (r rune, size int, err error) {
 // regard it is stricter than UnreadByte, which will unread the last byte
 // from any read operation.)
 func (b *Reader) UnreadRune() error {
-       if b.lastRuneSize < 0 || b.r == 0 {
+       if b.lastRuneSize < 0 || b.r < b.lastRuneSize {
                return ErrInvalidUnreadRune
        }
        b.r -= b.lastRuneSize
index 800c6d2717f2893b88c0f9f90df5df20521b111a..32ca86161fc0f087f16edd9c0a8d47296c5cba1a 100644 (file)
@@ -228,66 +228,94 @@ func TestReadRune(t *testing.T) {
 }
 
 func TestUnreadRune(t *testing.T) {
-       got := ""
        segments := []string{"Hello, world:", "日本語"}
-       data := strings.Join(segments, "")
        r := NewReader(&StringReader{data: segments})
+       got := ""
+       want := strings.Join(segments, "")
        // Normal execution.
        for {
                r1, _, err := r.ReadRune()
                if err != nil {
                        if err != io.EOF {
-                               t.Error("unexpected EOF")
+                               t.Error("unexpected error on ReadRune:", err)
                        }
                        break
                }
                got += string(r1)
-               // Put it back and read it again
+               // Put it back and read it again.
                if err = r.UnreadRune(); err != nil {
-                       t.Error("unexpected error on UnreadRune:", err)
+                       t.Fatal("unexpected error on UnreadRune:", err)
                }
                r2, _, err := r.ReadRune()
                if err != nil {
-                       t.Error("unexpected error reading after unreading:", err)
+                       t.Fatal("unexpected error reading after unreading:", err)
                }
                if r1 != r2 {
-                       t.Errorf("incorrect rune after unread: got %c wanted %c", r1, r2)
+                       t.Fatalf("incorrect rune after unread: got %c, want %c", r1, r2)
                }
        }
-       if got != data {
-               t.Errorf("want=%q got=%q", data, got)
+       if got != want {
+               t.Errorf("got %q, want %q", got, want)
        }
 }
 
 func TestUnreadByte(t *testing.T) {
-       want := "Hello, world"
-       got := ""
        segments := []string{"Hello, ", "world"}
        r := NewReader(&StringReader{data: segments})
+       got := ""
+       want := strings.Join(segments, "")
        // Normal execution.
        for {
                b1, err := r.ReadByte()
                if err != nil {
                        if err != io.EOF {
-                               t.Fatal("unexpected EOF")
+                               t.Error("unexpected error on ReadByte:", err)
                        }
                        break
                }
                got += string(b1)
-               // Put it back and read it again
+               // Put it back and read it again.
                if err = r.UnreadByte(); err != nil {
-                       t.Fatalf("unexpected error on UnreadByte: %v", err)
+                       t.Fatal("unexpected error on UnreadByte:", err)
                }
                b2, err := r.ReadByte()
                if err != nil {
-                       t.Fatalf("unexpected error reading after unreading: %v", err)
+                       t.Fatal("unexpected error reading after unreading:", err)
                }
                if b1 != b2 {
-                       t.Fatalf("incorrect byte after unread: got %c wanted %c", b1, b2)
+                       t.Fatalf("incorrect byte after unread: got %q, want %q", b1, b2)
                }
        }
        if got != want {
-               t.Errorf("got=%q want=%q", got, want)
+               t.Errorf("got %q, want %q", got, want)
+       }
+}
+
+func TestUnreadByteMultiple(t *testing.T) {
+       segments := []string{"Hello, ", "world"}
+       data := strings.Join(segments, "")
+       for n := 0; n <= len(data); n++ {
+               r := NewReader(&StringReader{data: segments})
+               // Read n bytes.
+               for i := 0; i < n; i++ {
+                       b, err := r.ReadByte()
+                       if err != nil {
+                               t.Fatalf("n = %d: unexpected error on ReadByte: %v", n, err)
+                       }
+                       if b != data[i] {
+                               t.Fatalf("n = %d: incorrect byte returned from ReadByte: got %q, want %q", n, b, data[i])
+                       }
+               }
+               // Unread one byte if there is one.
+               if n > 0 {
+                       if err := r.UnreadByte(); err != nil {
+                               t.Errorf("n = %d: unexpected error on UnreadByte: %v", n, err)
+                       }
+               }
+               // Test that we cannot unread any further.
+               if err := r.UnreadByte(); err == nil {
+                       t.Errorf("n = %d: expected error on UnreadByte", n)
+               }
        }
 }