]> Cypherpunks repositories - gostls13.git/commitdiff
bytes, strings: fix Reader WriteTo return value on 0 bytes copied
authorBrad Fitzpatrick <bradfitz@golang.org>
Sun, 25 Nov 2012 17:04:13 +0000 (09:04 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sun, 25 Nov 2012 17:04:13 +0000 (09:04 -0800)
Fixes #4421

R=golang-dev, dave, minux.ma, mchaten, rsc
CC=golang-dev
https://golang.org/cl/6855083

src/pkg/bytes/reader.go
src/pkg/bytes/reader_test.go
src/pkg/strings/reader.go
src/pkg/strings/reader_test.go

index b34dfc11bffde916d5b963e4cff3626a2e9ddf32..77511b94555634d1ab934e3bd692802351eebce1 100644 (file)
@@ -125,7 +125,7 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) {
 func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
        r.prevRune = -1
        if r.i >= len(r.s) {
-               return 0, io.EOF
+               return 0, nil
        }
        b := r.s[r.i:]
        m, err := w.Write(b)
index 666881886760ea8ca6a6f78d2c9032800a14c2a4..f0a3e26c4a773380824d357d4028ce018b9724e9 100644 (file)
@@ -8,6 +8,7 @@ import (
        . "bytes"
        "fmt"
        "io"
+       "io/ioutil"
        "os"
        "testing"
 )
@@ -88,16 +89,20 @@ func TestReaderAt(t *testing.T) {
 }
 
 func TestReaderWriteTo(t *testing.T) {
-       for i := 3; i < 30; i += 3 {
-               s := data[:len(data)/i]
-               r := NewReader(testBytes[:len(testBytes)/i])
+       for i := 0; i < 30; i += 3 {
+               var l int
+               if i > 0 {
+                       l = len(data) / i
+               }
+               s := data[:l]
+               r := NewReader(testBytes[:l])
                var b Buffer
                n, err := r.WriteTo(&b)
                if expect := int64(len(s)); n != expect {
                        t.Errorf("got %v; want %v", n, expect)
                }
                if err != nil {
-                       t.Errorf("got error = %v; want nil", err)
+                       t.Errorf("for length %d: got error = %v; want nil", l, err)
                }
                if b.String() != s {
                        t.Errorf("got string %q; want %q", b.String(), s)
@@ -107,3 +112,26 @@ func TestReaderWriteTo(t *testing.T) {
                }
        }
 }
+
+// verify that copying from an empty reader always has the same results,
+// regardless of the presence of a WriteTo method.
+func TestReaderCopyNothing(t *testing.T) {
+       type nErr struct {
+               n   int64
+               err error
+       }
+       type justReader struct {
+               io.Reader
+       }
+       type justWriter struct {
+               io.Writer
+       }
+       discard := justWriter{ioutil.Discard} // hide ReadFrom
+
+       var with, withOut nErr
+       with.n, with.err = io.Copy(discard, NewReader(nil))
+       withOut.n, withOut.err = io.Copy(discard, justReader{NewReader(nil)})
+       if with != withOut {
+               t.Errorf("behavior differs: with = %#v; without: %#v", with, withOut)
+       }
+}
index 98325ce75bf659f69e14db9228e03d91dad38914..11240efc0780f7223d1e88a8b9161fb481b9414e 100644 (file)
@@ -124,7 +124,7 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) {
 func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
        r.prevRune = -1
        if r.i >= len(r.s) {
-               return 0, io.EOF
+               return 0, nil
        }
        s := r.s[r.i:]
        m, err := io.WriteString(w, s)
index bab91fc71979767871946702c2329de43eff5912..4fdddcdb58e47719271fc1e23ec4c4468d1eb8d9 100644 (file)
@@ -90,7 +90,7 @@ func TestReaderAt(t *testing.T) {
 
 func TestWriteTo(t *testing.T) {
        const str = "0123456789"
-       for i := 0; i < len(str); i++ {
+       for i := 0; i <= len(str); i++ {
                s := str[i:]
                r := strings.NewReader(s)
                var b bytes.Buffer
@@ -99,7 +99,7 @@ func TestWriteTo(t *testing.T) {
                        t.Errorf("got %v; want %v", n, expect)
                }
                if err != nil {
-                       t.Errorf("got error = %v; want nil", err)
+                       t.Errorf("for length %d: got error = %v; want nil", len(s), err)
                }
                if b.String() != s {
                        t.Errorf("got string %q; want %q", b.String(), s)