]> Cypherpunks repositories - gostls13.git/commitdiff
bufio: fix potential endless loop in ReadByte
authorRobert Griesemer <gri@golang.org>
Fri, 11 Apr 2014 04:46:00 +0000 (21:46 -0700)
committerRobert Griesemer <gri@golang.org>
Fri, 11 Apr 2014 04:46:00 +0000 (21:46 -0700)
Also: Simplify ReadSlice implementation and
ensure that it doesn't call fill() with a full
buffer (this caused a failure in net/textproto
TestLargeReadMIMEHeader because fill() wasn't able
to read more data).

Fixes #7745.

LGTM=bradfitz
R=r, bradfitz
CC=golang-codereviews
https://golang.org/cl/86590043

src/pkg/bufio/bufio.go
src/pkg/bufio/bufio_test.go

index 1e0cdae38e8ebdb862d9a46c6805a9684fc2e356..ecd2708f7845f9dba7517e7425eded2441f69229 100644 (file)
@@ -88,15 +88,26 @@ func (b *Reader) fill() {
                b.r = 0
        }
 
-       // Read new data.
-       n, err := b.rd.Read(b.buf[b.w:])
-       if n < 0 {
-               panic(errNegativeRead)
+       if b.w >= len(b.buf) {
+               panic("bufio: tried to fill full buffer")
        }
-       b.w += n
-       if err != nil {
-               b.err = err
+
+       // Read new data: try a limited number of times.
+       for i := maxConsecutiveEmptyReads; i > 0; i-- {
+               n, err := b.rd.Read(b.buf[b.w:])
+               if n < 0 {
+                       panic(errNegativeRead)
+               }
+               b.w += n
+               if err != nil {
+                       b.err = err
+                       return
+               }
+               if n > 0 {
+                       return
+               }
        }
+       b.err = io.ErrNoProgress
 }
 
 func (b *Reader) readErr() error {
@@ -116,8 +127,9 @@ func (b *Reader) Peek(n int) ([]byte, error) {
        if n > len(b.buf) {
                return nil, ErrBufferFull
        }
+       // 0 <= n <= len(b.buf)
        for b.w-b.r < n && b.err == nil {
-               b.fill()
+               b.fill() // b.w-b.r < len(b.buf) => buffer is not full
        }
        m := b.w - b.r
        if m > n {
@@ -143,7 +155,7 @@ func (b *Reader) Read(p []byte) (n int, err error) {
        if n == 0 {
                return 0, b.readErr()
        }
-       if b.w == b.r {
+       if b.r == b.w {
                if b.err != nil {
                        return 0, b.readErr()
                }
@@ -151,13 +163,16 @@ func (b *Reader) Read(p []byte) (n int, err error) {
                        // Large read, empty buffer.
                        // Read directly into p to avoid copy.
                        n, b.err = b.rd.Read(p)
+                       if n < 0 {
+                               panic(errNegativeRead)
+                       }
                        if n > 0 {
                                b.lastByte = int(p[n-1])
                                b.lastRuneSize = -1
                        }
                        return n, b.readErr()
                }
-               b.fill()
+               b.fill() // buffer is empty
                if b.w == b.r {
                        return 0, b.readErr()
                }
@@ -181,7 +196,7 @@ func (b *Reader) ReadByte() (c byte, err error) {
                if b.err != nil {
                        return 0, b.readErr()
                }
-               b.fill()
+               b.fill() // buffer is empty
        }
        c = b.buf[b.r]
        b.r++
@@ -211,8 +226,8 @@ func (b *Reader) UnreadByte() error {
 // rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
 // and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
 func (b *Reader) ReadRune() (r rune, size int, err error) {
-       for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil {
-               b.fill()
+       for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil && b.w-b.r < len(b.buf) {
+               b.fill() // b.w-b.r < len(buf) => buffer is not full
        }
        b.lastRuneSize = -1
        if b.r == b.w {
@@ -256,36 +271,28 @@ func (b *Reader) Buffered() int { return b.w - b.r }
 // ReadBytes or ReadString instead.
 // ReadSlice returns err != nil if and only if line does not end in delim.
 func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
-       // Look in buffer.
-       if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 {
-               line1 := b.buf[b.r : b.r+i+1]
-               b.r += i + 1
-               return line1, nil
-       }
-
-       // Read more into buffer, until buffer fills or we find delim.
        for {
+               // Search buffer.
+               if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 {
+                       line := b.buf[b.r : b.r+i+1]
+                       b.r += i + 1
+                       return line, nil
+               }
+
+               // Pending error?
                if b.err != nil {
                        line := b.buf[b.r:b.w]
                        b.r = b.w
                        return line, b.readErr()
                }
 
-               n := b.Buffered()
-               b.fill()
-
-               // Search new part of buffer
-               if i := bytes.IndexByte(b.buf[n:b.w], delim); i >= 0 {
-                       line := b.buf[0 : n+i+1]
-                       b.r = n + i + 1
-                       return line, nil
-               }
-
-               // Buffer is full?
-               if b.Buffered() >= len(b.buf) {
+               // Buffer full?
+               if n := b.Buffered(); n >= len(b.buf) {
                        b.r = b.w
                        return b.buf, ErrBufferFull
                }
+
+               b.fill() // buffer is not full
        }
 }
 
@@ -417,12 +424,18 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
                return n, err
        }
 
-       for b.fill(); b.r < b.w; b.fill() {
+       if b.w-b.r < len(b.buf) {
+               b.fill() // buffer not full
+       }
+
+       for b.r < b.w {
+               // b.r < b.w => buffer is not empty
                m, err := b.writeBuf(w)
                n += m
                if err != nil {
                        return n, err
                }
+               b.fill() // buffer is empty
        }
 
        if b.err == io.EOF {
@@ -435,6 +448,9 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
 // writeBuf writes the Reader's buffer to the writer.
 func (b *Reader) writeBuf(w io.Writer) (int64, error) {
        n, err := w.Write(b.buf[b.r:b.w])
+       if n < b.r-b.w {
+               panic(errors.New("bufio: writer did not write all data"))
+       }
        b.r += n
        return int64(n), err
 }
index 32ca86161fc0f087f16edd9c0a8d47296c5cba1a..406eb153ba11ad08f3ddac3b86fed151e47da644 100644 (file)
@@ -14,6 +14,7 @@ import (
        "strings"
        "testing"
        "testing/iotest"
+       "time"
        "unicode/utf8"
 )
 
@@ -174,6 +175,34 @@ func TestReader(t *testing.T) {
        }
 }
 
+type zeroReader struct{}
+
+func (zeroReader) Read(p []byte) (int, error) {
+       return 0, nil
+}
+
+func TestZeroReader(t *testing.T) {
+       var z zeroReader
+       r := NewReader(z)
+
+       c := make(chan error)
+       go func() {
+               _, err := r.ReadByte()
+               c <- err
+       }()
+
+       select {
+       case err := <-c:
+               if err == nil {
+                       t.Error("error expected")
+               } else if err != io.ErrNoProgress {
+                       t.Error("unexpected error:", err)
+               }
+       case <-time.After(time.Second):
+               t.Error("test timed out (endless loop in ReadByte?)")
+       }
+}
+
 // A StringReader delivers its data one string segment at a time via Read.
 type StringReader struct {
        data []string