From 32a0cbb88187deecfcbcf62f11056ca09fe4a4a0 Mon Sep 17 00:00:00 2001 From: Jan Ziak <0xe2.0x9a.0x9b@gmail.com> Date: Mon, 25 Jun 2012 16:00:35 -0400 Subject: [PATCH] encoding/csv, encoding/xml: report write errors Fixes #3773. R=bradfitz, rsc CC=golang-dev https://golang.org/cl/6327053 --- src/pkg/encoding/csv/writer.go | 5 ++- src/pkg/encoding/xml/marshal.go | 26 ++++++++------ src/pkg/encoding/xml/marshal_test.go | 52 ++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 13 deletions(-) diff --git a/src/pkg/encoding/csv/writer.go b/src/pkg/encoding/csv/writer.go index c4dcba5668..324944cc82 100644 --- a/src/pkg/encoding/csv/writer.go +++ b/src/pkg/encoding/csv/writer.go @@ -101,11 +101,10 @@ func (w *Writer) WriteAll(records [][]string) (err error) { for _, record := range records { err = w.Write(record) if err != nil { - break + return err } } - w.Flush() - return nil + return w.w.Flush() } // fieldNeedsQuotes returns true if our field must be enclosed in quotes. diff --git a/src/pkg/encoding/xml/marshal.go b/src/pkg/encoding/xml/marshal.go index 51e1dc8f96..8592a0c15c 100644 --- a/src/pkg/encoding/xml/marshal.go +++ b/src/pkg/encoding/xml/marshal.go @@ -83,9 +83,7 @@ func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { enc := NewEncoder(&b) enc.prefix = prefix enc.indent = indent - err := enc.marshalValue(reflect.ValueOf(v), nil) - enc.Flush() - if err != nil { + if err := enc.Encode(v); err != nil { return nil, err } return b.Bytes(), nil @@ -107,8 +105,10 @@ func NewEncoder(w io.Writer) *Encoder { // of Go values to XML. func (enc *Encoder) Encode(v interface{}) error { err := enc.marshalValue(reflect.ValueOf(v), nil) - enc.Flush() - return err + if err != nil { + return err + } + return enc.Flush() } type printer struct { @@ -224,7 +224,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { p.WriteString(name) p.WriteByte('>') - return nil + return p.cachedWriteError() } var timeType = reflect.TypeOf(time.Time{}) @@ -260,15 +260,15 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { default: return &UnsupportedTypeError{typ} } - return nil + return p.cachedWriteError() } var ddBytes = []byte("--") func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { if val.Type() == timeType { - p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano)) - return nil + _, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano)) + return err } s := parentStack{printer: p} for i := range tinfo.fields { @@ -353,7 +353,13 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { } } s.trim(nil) - return nil + return p.cachedWriteError() +} + +// return the bufio Writer's cached write error +func (p *printer) cachedWriteError() error { + _, err := p.Write(nil) + return err } func (p *printer) writeIndent(depthDelta int) { diff --git a/src/pkg/encoding/xml/marshal_test.go b/src/pkg/encoding/xml/marshal_test.go index 90b4925e7f..e729a247af 100644 --- a/src/pkg/encoding/xml/marshal_test.go +++ b/src/pkg/encoding/xml/marshal_test.go @@ -5,6 +5,9 @@ package xml import ( + "bytes" + "errors" + "io" "reflect" "strconv" "strings" @@ -779,6 +782,55 @@ func TestUnmarshal(t *testing.T) { } } +type limitedBytesWriter struct { + w io.Writer + remain int // until writes fail +} + +func (lw *limitedBytesWriter) Write(p []byte) (n int, err error) { + if lw.remain <= 0 { + println("error") + return 0, errors.New("write limit hit") + } + if len(p) > lw.remain { + p = p[:lw.remain] + n, _ = lw.w.Write(p) + lw.remain = 0 + return n, errors.New("write limit hit") + } + n, err = lw.w.Write(p) + lw.remain -= n + return n, err +} + +func TestMarshalWriteErrors(t *testing.T) { + var buf bytes.Buffer + const writeCap = 1024 + w := &limitedBytesWriter{&buf, writeCap} + enc := NewEncoder(w) + var err error + var i int + const n = 4000 + for i = 1; i <= n; i++ { + err = enc.Encode(&Passenger{ + Name: []string{"Alice", "Bob"}, + Weight: 5, + }) + if err != nil { + break + } + } + if err == nil { + t.Error("expected an error") + } + if i == n { + t.Errorf("expected to fail before the end") + } + if buf.Len() != writeCap { + t.Errorf("buf.Len() = %d; want %d", buf.Len(), writeCap) + } +} + func BenchmarkMarshal(b *testing.B) { for i := 0; i < b.N; i++ { Marshal(atomValue) -- 2.48.1