]> Cypherpunks repositories - gostls13.git/commitdiff
bufio: add UnreadRune.
authorRob Pike <r@golang.org>
Sun, 12 Sep 2010 07:40:27 +0000 (17:40 +1000)
committerRob Pike <r@golang.org>
Sun, 12 Sep 2010 07:40:27 +0000 (17:40 +1000)
R=rsc
CC=golang-dev
https://golang.org/cl/2103046

src/pkg/bufio/bufio.go
src/pkg/bufio/bufio_test.go

index 37bdea274a6395e274f3f0e804a871f626121d8a..b85a0793ccf8bb4189c372d3945a992f09344c8c 100644 (file)
@@ -27,6 +27,7 @@ type Error struct {
 
 var (
        ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"}
+       ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"}
        ErrBufferFull        os.Error = &Error{"bufio: buffer full"}
        ErrNegativeCount     os.Error = &Error{"bufio: negative count"}
        errInternal          os.Error = &Error{"bufio: internal error"}
@@ -44,11 +45,12 @@ func (b BufSizeError) String() string {
 
 // Reader implements buffering for an io.Reader object.
 type Reader struct {
-       buf      []byte
-       rd       io.Reader
-       r, w     int
-       err      os.Error
-       lastbyte int
+       buf          []byte
+       rd           io.Reader
+       r, w         int
+       err          os.Error
+       lastByte     int
+       lastRuneSize int
 }
 
 // NewReaderSize creates a new Reader whose buffer has the specified size,
@@ -67,7 +69,8 @@ func NewReaderSize(rd io.Reader, size int) (*Reader, os.Error) {
        b = new(Reader)
        b.buf = make([]byte, size)
        b.rd = rd
-       b.lastbyte = -1
+       b.lastByte = -1
+       b.lastRuneSize = -1
        return b, nil
 }
 
@@ -141,7 +144,8 @@ func (b *Reader) Read(p []byte) (nn int, err os.Error) {
                                // Read directly into p to avoid copy.
                                n, b.err = b.rd.Read(p)
                                if n > 0 {
-                                       b.lastbyte = int(p[n-1])
+                                       b.lastByte = int(p[n-1])
+                                       b.lastRuneSize = -1
                                }
                                p = p[n:]
                                nn += n
@@ -156,7 +160,8 @@ func (b *Reader) Read(p []byte) (nn int, err os.Error) {
                copy(p[0:n], b.buf[b.r:])
                p = p[n:]
                b.r += n
-               b.lastbyte = int(b.buf[b.r-1])
+               b.lastByte = int(b.buf[b.r-1])
+               b.lastRuneSize = -1
                nn += n
        }
        return nn, nil
@@ -165,6 +170,7 @@ func (b *Reader) Read(p []byte) (nn int, err os.Error) {
 // ReadByte reads and returns a single byte.
 // If no byte is available, returns an error.
 func (b *Reader) ReadByte() (c byte, err os.Error) {
+       b.lastRuneSize = -1
        for b.w == b.r {
                if b.err != nil {
                        return 0, b.err
@@ -173,24 +179,25 @@ func (b *Reader) ReadByte() (c byte, err os.Error) {
        }
        c = b.buf[b.r]
        b.r++
-       b.lastbyte = int(c)
+       b.lastByte = int(c)
        return c, nil
 }
 
 // UnreadByte unreads the last byte.  Only the most recently read byte can be unread.
 func (b *Reader) UnreadByte() os.Error {
-       if b.r == b.w && b.lastbyte >= 0 {
+       b.lastRuneSize = -1
+       if b.r == b.w && b.lastByte >= 0 {
                b.w = 1
                b.r = 0
-               b.buf[0] = byte(b.lastbyte)
-               b.lastbyte = -1
+               b.buf[0] = byte(b.lastByte)
+               b.lastByte = -1
                return nil
        }
        if b.r <= 0 {
                return ErrInvalidUnreadByte
        }
        b.r--
-       b.lastbyte = -1
+       b.lastByte = -1
        return nil
 }
 
@@ -208,10 +215,25 @@ func (b *Reader) ReadRune() (rune int, size int, err os.Error) {
                rune, size = utf8.DecodeRune(b.buf[b.r:b.w])
        }
        b.r += size
-       b.lastbyte = int(b.buf[b.r-1])
+       b.lastByte = int(b.buf[b.r-1])
+       b.lastRuneSize = size
        return rune, size, nil
 }
 
+// UnreadRune unreads the last rune.  If the most recent read operation on
+// the buffer was not a ReadRune, UnreadRune returns an error.  (In this
+// regard it is stricter than UnreadByte, which will unread the last byte
+// from any read operation.)
+func (b *Reader) UnreadRune() os.Error {
+       if b.lastRuneSize < 0 {
+               return ErrInvalidUnreadRune
+       }
+       b.r -= b.lastRuneSize
+       b.lastByte = -1
+       b.lastRuneSize = -1
+       return nil
+}
+
 // Buffered returns the number of bytes that can be read from the current buffer.
 func (b *Reader) Buffered() int { return b.w - b.r }
 
index 876270fcaa32780acd030859af1879003105af77..10c14ecd0fd8225b88fb250c71bbeb71815230a1 100644 (file)
@@ -226,6 +226,99 @@ func TestReadRune(t *testing.T) {
        }
 }
 
+func TestUnreadRune(t *testing.T) {
+       got := ""
+       segments := []string{"Hello, world:", "日本語"}
+       data := strings.Join(segments, "")
+       r := NewReader(&StringReader{data: segments})
+       // Normal execution.
+       for {
+               rune, _, err := r.ReadRune()
+               if err != nil {
+                       if err != os.EOF {
+                               t.Error("unexpected EOF")
+                       }
+                       break
+               }
+               got += string(rune)
+               // Put it back and read it again
+               if err = r.UnreadRune(); err != nil {
+                       t.Error("unexpected error on UnreadRune:", err)
+               }
+               rune1, _, err := r.ReadRune()
+               if err != nil {
+                       t.Error("unexpected error reading after unreading:", err)
+               }
+               if rune != rune1 {
+                       t.Error("incorrect rune after unread: got %c wanted %c", rune1, rune)
+               }
+       }
+       if got != data {
+               t.Errorf("want=%q got=%q", data, got)
+       }
+}
+
+// Test that UnreadRune fails if the preceding operation was not a ReadRune.
+func TestUnreadRuneError(t *testing.T) {
+       buf := make([]byte, 3) // All runes in this test are 3 bytes long
+       r := NewReader(&StringReader{data: []string{"日本語日本語日本語"}})
+       if r.UnreadRune() == nil {
+               t.Error("expected error on UnreadRune from fresh buffer")
+       }
+       _, _, err := r.ReadRune()
+       if err != nil {
+               t.Error("unexpected error on ReadRune (1):", err)
+       }
+       if err = r.UnreadRune(); err != nil {
+               t.Error("unexpected error on UnreadRune (1):", err)
+       }
+       if r.UnreadRune() == nil {
+               t.Error("expected error after UnreadRune (1)")
+       }
+       // Test error after Read.
+       _, _, err = r.ReadRune() // reset state
+       if err != nil {
+               t.Error("unexpected error on ReadRune (2):", err)
+       }
+       _, err = r.Read(buf)
+       if err != nil {
+               t.Error("unexpected error on Read (2):", err)
+       }
+       if r.UnreadRune() == nil {
+               t.Error("expected error after Read (2)")
+       }
+       // Test error after ReadByte.
+       _, _, err = r.ReadRune() // reset state
+       if err != nil {
+               t.Error("unexpected error on ReadRune (2):", err)
+       }
+       for _ = range buf {
+               _, err = r.ReadByte()
+               if err != nil {
+                       t.Error("unexpected error on ReadByte (2):", err)
+               }
+       }
+       if r.UnreadRune() == nil {
+               t.Error("expected error after ReadByte")
+       }
+       // Test error after UnreadByte.
+       _, _, err = r.ReadRune() // reset state
+       if err != nil {
+               t.Error("unexpected error on ReadRune (3):", err)
+       }
+       _, err = r.ReadByte()
+       if err != nil {
+               t.Error("unexpected error on ReadByte (3):", err)
+       }
+       err = r.UnreadByte()
+       if err != nil {
+               t.Error("unexpected error on UnreadByte (3):", err)
+       }
+       if r.UnreadRune() == nil {
+               t.Error("expected error after UnreadByte (3)")
+       }
+}
+
 func TestReadWriteRune(t *testing.T) {
        const NRune = 1000
        byteBuf := new(bytes.Buffer)