]> Cypherpunks repositories - gostls13.git/commitdiff
io: prevent ReadAtLeast spinloop if min > len(buf)
authorAndrew Gerrand <adg@golang.org>
Mon, 23 Aug 2010 02:04:15 +0000 (12:04 +1000)
committerAndrew Gerrand <adg@golang.org>
Mon, 23 Aug 2010 02:04:15 +0000 (12:04 +1000)
R=r, heresy.mc
CC=golang-dev
https://golang.org/cl/2017042

src/pkg/io/io.go
src/pkg/io/io_test.go

index a41a674cea90ac6d919815390605e3950e147ff3..2b2f4d567144ae73cddad52bfd2a5b374bdf3e78 100644 (file)
@@ -19,6 +19,9 @@ type Error struct {
 // but failed to return an explicit error.
 var ErrShortWrite os.Error = &Error{"short write"}
 
+// ErrShortBuffer means that a read required a longer buffer than was provided.
+var ErrShortBuffer os.Error = &Error{"short buffer"}
+
 // ErrUnexpectedEOF means that os.EOF was encountered in the
 // middle of reading a fixed-size block or data structure.
 var ErrUnexpectedEOF os.Error = &Error{"unexpected EOF"}
@@ -165,8 +168,11 @@ func WriteString(w Writer, s string) (n int, err os.Error) {
 // The error is os.EOF only if no bytes were read.
 // If an EOF happens after reading fewer than min bytes,
 // ReadAtLeast returns ErrUnexpectedEOF.
+// If min is greater than the length of buf, ReadAtLeast returns ErrShortBuffer.
 func ReadAtLeast(r Reader, buf []byte, min int) (n int, err os.Error) {
-       n = 0
+       if len(buf) < min {
+               return 0, ErrShortBuffer
+       }
        for n < min {
                nn, e := r.Read(buf[n:])
                if nn > 0 {
@@ -179,7 +185,7 @@ func ReadAtLeast(r Reader, buf []byte, min int) (n int, err os.Error) {
                        return n, e
                }
        }
-       return n, nil
+       return
 }
 
 // ReadFull reads exactly len(buf) bytes from r into buf.
index 4ad1e595107ad29c769eddb16a72131bda927c2d..20f240a51a5efbc582dace07c3e9e41957078763 100644 (file)
@@ -7,6 +7,7 @@ package io_test
 import (
        "bytes"
        . "io"
+       "os"
        "testing"
 )
 
@@ -78,3 +79,42 @@ func TestCopynWriteTo(t *testing.T) {
                t.Errorf("Copyn did not work properly")
        }
 }
+
+func TestReadAtLeast(t *testing.T) {
+       var rb bytes.Buffer
+       rb.Write([]byte("0123"))
+       buf := make([]byte, 2)
+       n, err := ReadAtLeast(&rb, buf, 2)
+       if err != nil {
+               t.Error(err)
+       }
+       n, err = ReadAtLeast(&rb, buf, 4)
+       if err != ErrShortBuffer {
+               t.Errorf("expected ErrShortBuffer got %v", err)
+       }
+       if n != 0 {
+               t.Errorf("expected to have read 0 bytes, got %v", n)
+       }
+       n, err = ReadAtLeast(&rb, buf, 1)
+       if err != nil {
+               t.Error(err)
+       }
+       if n != 2 {
+               t.Errorf("expected to have read 2 bytes, got %v", n)
+       }
+       n, err = ReadAtLeast(&rb, buf, 2)
+       if err != os.EOF {
+               t.Errorf("expected EOF, got %v", err)
+       }
+       if n != 0 {
+               t.Errorf("expected to have read 0 bytes, got %v", n)
+       }
+       rb.Write([]byte("4"))
+       n, err = ReadAtLeast(&rb, buf, 2)
+       if err != ErrUnexpectedEOF {
+               t.Errorf("expected ErrUnexpectedEOF, got %v", err)
+       }
+       if n != 1 {
+               t.Errorf("expected to have read 1 bytes, got %v", n)
+       }
+}