]> Cypherpunks repositories - gostls13.git/commitdiff
bytes.Buffer: turn buffer size overflows into errors
authorRob Pike <r@golang.org>
Fri, 20 Jan 2012 21:51:49 +0000 (13:51 -0800)
committerRob Pike <r@golang.org>
Fri, 20 Jan 2012 21:51:49 +0000 (13:51 -0800)
Fixes #2743.

R=golang-dev, gri, r
CC=golang-dev
https://golang.org/cl/5556072

src/pkg/bytes/buffer.go
src/pkg/bytes/buffer_test.go

index 77757af1d804fc501693ff416ac92c65ce425284..9d58326a4fb8014756e0873e422986001a2cad48 100644 (file)
@@ -33,6 +33,9 @@ const (
        opRead                   // Any other read operation.
 )
 
+// ErrTooLarge is returned if there is too much data to fit in a buffer.
+var ErrTooLarge = errors.New("bytes.Buffer: too large")
+
 // Bytes returns a slice of the contents of the unread portion of the buffer;
 // len(b.Bytes()) == b.Len().  If the caller changes the contents of the
 // returned slice, the contents of the buffer will change provided there
@@ -68,8 +71,10 @@ func (b *Buffer) Truncate(n int) {
 // b.Reset() is the same as b.Truncate(0).
 func (b *Buffer) Reset() { b.Truncate(0) }
 
-// Grow buffer to guarantee space for n more bytes.
-// Return index where bytes should be written.
+// grow grows the buffer to guarantee space for n more bytes.
+// It returns the index where bytes should be written.
+// If the buffer can't grow, it returns -1, which will
+// become ErrTooLarge in the caller.
 func (b *Buffer) grow(n int) int {
        m := b.Len()
        // If buffer is empty, reset to recover space.
@@ -82,7 +87,10 @@ func (b *Buffer) grow(n int) int {
                        buf = b.bootstrap[0:]
                } else {
                        // not enough space anywhere
-                       buf = make([]byte, 2*cap(b.buf)+n)
+                       buf = makeSlice(2*cap(b.buf) + n)
+                       if buf == nil {
+                               return -1
+                       }
                        copy(buf, b.buf[b.off:])
                }
                b.buf = buf
@@ -97,6 +105,9 @@ func (b *Buffer) grow(n int) int {
 func (b *Buffer) Write(p []byte) (n int, err error) {
        b.lastRead = opInvalid
        m := b.grow(len(p))
+       if m < 0 {
+               return 0, ErrTooLarge
+       }
        return copy(b.buf[m:], p), nil
 }
 
@@ -105,6 +116,9 @@ func (b *Buffer) Write(p []byte) (n int, err error) {
 func (b *Buffer) WriteString(s string) (n int, err error) {
        b.lastRead = opInvalid
        m := b.grow(len(s))
+       if m < 0 {
+               return 0, ErrTooLarge
+       }
        return copy(b.buf[m:], s), nil
 }
 
@@ -133,7 +147,10 @@ func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) {
                                newBuf = b.buf[0 : len(b.buf)-b.off]
                        } else {
                                // not enough space at end; put space on end
-                               newBuf = make([]byte, len(b.buf)-b.off, 2*(cap(b.buf)-b.off)+MinRead)
+                               newBuf = makeSlice(2*(cap(b.buf)-b.off) + MinRead)[:len(b.buf)-b.off]
+                               if newBuf == nil {
+                                       return n, ErrTooLarge
+                               }
                        }
                        copy(newBuf, b.buf[b.off:])
                        b.buf = newBuf
@@ -152,6 +169,18 @@ func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) {
        return n, nil // err is EOF, so return nil explicitly
 }
 
+// makeSlice allocates a slice of size n, returning nil if the slice cannot be allocated.
+func makeSlice(n int) []byte {
+       if n < 0 {
+               return nil
+       }
+       // Catch out of memory panics.
+       defer func() {
+               recover()
+       }()
+       return make([]byte, n)
+}
+
 // WriteTo writes data to w until the buffer is drained or an error
 // occurs. The return value n is the number of bytes written; it always
 // fits into an int, but it is int64 to match the io.WriterTo interface.
@@ -179,6 +208,9 @@ func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) {
 func (b *Buffer) WriteByte(c byte) error {
        b.lastRead = opInvalid
        m := b.grow(1)
+       if m < 0 {
+               return ErrTooLarge
+       }
        b.buf[m] = c
        return nil
 }
index d0af11f104b9d1706a70121c2523bf946cf5ef22..a36d1d010a72a43c16afe07e48bd1b6ad6aeaac9 100644 (file)
@@ -386,3 +386,19 @@ func TestReadEmptyAtEOF(t *testing.T) {
                t.Errorf("wrong count; got %d want 0", n)
        }
 }
+
+func TestHuge(t *testing.T) {
+       // About to use tons of memory, so avoid for simple installation testing.
+       if testing.Short() {
+               return
+       }
+       b := new(Buffer)
+       big := make([]byte, 500e6)
+       for i := 0; i < 1000; i++ {
+               if _, err := b.Write(big); err != nil {
+                       // Got error as expected. Stop
+                       return
+               }
+       }
+       t.Error("error expected")
+}