]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: make sure Encoder.Encode reports Write errors.
authorOlivier Saingre <osaingre@gmail.com>
Wed, 20 Feb 2013 22:41:23 +0000 (14:41 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 20 Feb 2013 22:41:23 +0000 (14:41 -0800)
Fixes #4112.

R=remyoudompheng, daniel.morsing, dave, rsc
CC=golang-dev
https://golang.org/cl/7085053

src/pkg/encoding/xml/marshal.go
src/pkg/encoding/xml/marshal_test.go
src/pkg/encoding/xml/xml.go
src/pkg/encoding/xml/xml_test.go

index ea891bfb3ee5ba13cbb6d3a81d0d265e47eb345e..ea58ce25422d2da9436e3465c06ded6cc8a1a595 100644 (file)
@@ -193,7 +193,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
        if xmlns != "" {
                p.WriteString(` xmlns="`)
                // TODO: EscapeString, to avoid the allocation.
-               Escape(p, []byte(xmlns))
+               if err := EscapeText(p, []byte(xmlns)); err != nil {
+                       return err
+               }
                p.WriteByte('"')
        }
 
@@ -252,19 +254,22 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
                p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()))
        case reflect.String:
                // TODO: Add EscapeString.
-               Escape(p, []byte(val.String()))
+               EscapeText(p, []byte(val.String()))
        case reflect.Bool:
                p.WriteString(strconv.FormatBool(val.Bool()))
        case reflect.Array:
                // will be [...]byte
-               bytes := make([]byte, val.Len())
-               for i := range bytes {
-                       bytes[i] = val.Index(i).Interface().(byte)
+               var bytes []byte
+               if val.CanAddr() {
+                       bytes = val.Slice(0, val.Len()).Bytes()
+               } else {
+                       bytes = make([]byte, val.Len())
+                       reflect.Copy(reflect.ValueOf(bytes), val)
                }
-               Escape(p, bytes)
+               EscapeText(p, bytes)
        case reflect.Slice:
                // will be []byte
-               Escape(p, val.Bytes())
+               EscapeText(p, val.Bytes())
        default:
                return &UnsupportedTypeError{typ}
        }
@@ -298,10 +303,14 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                        case reflect.Bool:
                                Escape(p, strconv.AppendBool(scratch[:0], vf.Bool()))
                        case reflect.String:
-                               Escape(p, []byte(vf.String()))
+                               if err := EscapeText(p, []byte(vf.String())); err != nil {
+                                       return err
+                               }
                        case reflect.Slice:
                                if elem, ok := vf.Interface().([]byte); ok {
-                                       Escape(p, elem)
+                                       if err := EscapeText(p, elem); err != nil {
+                                               return err
+                                       }
                                }
                        case reflect.Struct:
                                if vf.Type() == timeType {
index ed856813a7eb92dd161a60f3142036a69ff29f48..3a190def6c1c21e38a4121f9a5bb6449bf887200 100644 (file)
@@ -965,6 +965,16 @@ func TestMarshalWriteErrors(t *testing.T) {
        }
 }
 
+func TestMarshalWriteIOErrors(t *testing.T) {
+       enc := NewEncoder(errWriter{})
+
+       expectErr := "unwritable"
+       err := enc.Encode(&Passenger{})
+       if err == nil || err.Error() != expectErr {
+               t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr)
+       }
+}
+
 func BenchmarkMarshal(b *testing.B) {
        for i := 0; i < b.N; i++ {
                Marshal(atomValue)
index 3e004306a1f3499046bf67fec138a73366f4d30d..143fec554cffbba72a3e7711662a6928374a477e 100644 (file)
@@ -1720,9 +1720,9 @@ var (
        esc_cr   = []byte("&#xD;")
 )
 
-// Escape writes to w the properly escaped XML equivalent
+// EscapeText writes to w the properly escaped XML equivalent
 // of the plain text data s.
-func Escape(w io.Writer, s []byte) {
+func EscapeText(w io.Writer, s []byte) error {
        var esc []byte
        last := 0
        for i, c := range s {
@@ -1746,11 +1746,25 @@ func Escape(w io.Writer, s []byte) {
                default:
                        continue
                }
-               w.Write(s[last:i])
-               w.Write(esc)
+               if _, err := w.Write(s[last:i]); err != nil {
+                       return err
+               }
+               if _, err := w.Write(esc); err != nil {
+                       return err
+               }
                last = i + 1
        }
-       w.Write(s[last:])
+       if _, err := w.Write(s[last:]); err != nil {
+               return err
+       }
+       return nil
+}
+
+// Escape is like EscapeText but omits the error return value.
+// It is provided for backwards compatibility with Go 1.0.
+// Code targeting Go 1.1 or later should use EscapeText.
+func Escape(w io.Writer, s []byte) {
+       EscapeText(w, s)
 }
 
 // procInstEncoding parses the `encoding="..."` or `encoding='...'`
index 981d3520313d18b01b0e11476748630cdca66215..54dab5484a6bdbdcc0c3b73c5217059c5f929f07 100644 (file)
@@ -689,3 +689,17 @@ func TestDirectivesWithComments(t *testing.T) {
                }
        }
 }
+
+// Writer whose Write method always returns an error.
+type errWriter struct{}
+
+func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") }
+
+func TestEscapeTextIOErrors(t *testing.T) {
+       expectErr := "unwritable"
+       err := EscapeText(errWriter{}, []byte{'A'})
+
+       if err == nil || err.Error() != expectErr {
+               t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr)
+       }
+}