]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/csv, encoding/xml: report write errors
authorJan Ziak <0xe2.0x9a.0x9b@gmail.com>
Mon, 25 Jun 2012 20:00:35 +0000 (16:00 -0400)
committerRuss Cox <rsc@golang.org>
Mon, 25 Jun 2012 20:00:35 +0000 (16:00 -0400)
Fixes #3773.

R=bradfitz, rsc
CC=golang-dev
https://golang.org/cl/6327053

src/pkg/encoding/csv/writer.go
src/pkg/encoding/xml/marshal.go
src/pkg/encoding/xml/marshal_test.go

index c4dcba5668af61672db93d4abb0a7a87aafda8ec..324944cc8299ff6ab3eac21629ac8b1b8404136f 100644 (file)
@@ -101,11 +101,10 @@ func (w *Writer) WriteAll(records [][]string) (err error) {
        for _, record := range records {
                err = w.Write(record)
                if err != nil {
-                       break
+                       return err
                }
        }
-       w.Flush()
-       return nil
+       return w.w.Flush()
 }
 
 // fieldNeedsQuotes returns true if our field must be enclosed in quotes.
index 51e1dc8f96461653115f1f2fea45fb7e722268fd..8592a0c15cb963872ce3b5fabf896111929f2473 100644 (file)
@@ -83,9 +83,7 @@ func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) {
        enc := NewEncoder(&b)
        enc.prefix = prefix
        enc.indent = indent
-       err := enc.marshalValue(reflect.ValueOf(v), nil)
-       enc.Flush()
-       if err != nil {
+       if err := enc.Encode(v); err != nil {
                return nil, err
        }
        return b.Bytes(), nil
@@ -107,8 +105,10 @@ func NewEncoder(w io.Writer) *Encoder {
 // of Go values to XML.
 func (enc *Encoder) Encode(v interface{}) error {
        err := enc.marshalValue(reflect.ValueOf(v), nil)
-       enc.Flush()
-       return err
+       if err != nil {
+               return err
+       }
+       return enc.Flush()
 }
 
 type printer struct {
@@ -224,7 +224,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
        p.WriteString(name)
        p.WriteByte('>')
 
-       return nil
+       return p.cachedWriteError()
 }
 
 var timeType = reflect.TypeOf(time.Time{})
@@ -260,15 +260,15 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
        default:
                return &UnsupportedTypeError{typ}
        }
-       return nil
+       return p.cachedWriteError()
 }
 
 var ddBytes = []byte("--")
 
 func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
        if val.Type() == timeType {
-               p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
-               return nil
+               _, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
+               return err
        }
        s := parentStack{printer: p}
        for i := range tinfo.fields {
@@ -353,7 +353,13 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                }
        }
        s.trim(nil)
-       return nil
+       return p.cachedWriteError()
+}
+
+// return the bufio Writer's cached write error
+func (p *printer) cachedWriteError() error {
+       _, err := p.Write(nil)
+       return err
 }
 
 func (p *printer) writeIndent(depthDelta int) {
index 90b4925e7f9cb42a3a95c7b9727ed7fdc7dc7170..e729a247af40d5110068f519c10f69c328e1b371 100644 (file)
@@ -5,6 +5,9 @@
 package xml
 
 import (
+       "bytes"
+       "errors"
+       "io"
        "reflect"
        "strconv"
        "strings"
@@ -779,6 +782,55 @@ func TestUnmarshal(t *testing.T) {
        }
 }
 
+type limitedBytesWriter struct {
+       w      io.Writer
+       remain int // until writes fail
+}
+
+func (lw *limitedBytesWriter) Write(p []byte) (n int, err error) {
+       if lw.remain <= 0 {
+               println("error")
+               return 0, errors.New("write limit hit")
+       }
+       if len(p) > lw.remain {
+               p = p[:lw.remain]
+               n, _ = lw.w.Write(p)
+               lw.remain = 0
+               return n, errors.New("write limit hit")
+       }
+       n, err = lw.w.Write(p)
+       lw.remain -= n
+       return n, err
+}
+
+func TestMarshalWriteErrors(t *testing.T) {
+       var buf bytes.Buffer
+       const writeCap = 1024
+       w := &limitedBytesWriter{&buf, writeCap}
+       enc := NewEncoder(w)
+       var err error
+       var i int
+       const n = 4000
+       for i = 1; i <= n; i++ {
+               err = enc.Encode(&Passenger{
+                       Name:   []string{"Alice", "Bob"},
+                       Weight: 5,
+               })
+               if err != nil {
+                       break
+               }
+       }
+       if err == nil {
+               t.Error("expected an error")
+       }
+       if i == n {
+               t.Errorf("expected to fail before the end")
+       }
+       if buf.Len() != writeCap {
+               t.Errorf("buf.Len() = %d; want %d", buf.Len(), writeCap)
+       }
+}
+
 func BenchmarkMarshal(b *testing.B) {
        for i := 0; i < b.N; i++ {
                Marshal(atomValue)