From afde71cfbdb1032cce8fefc6bc665b816b3276e6 Mon Sep 17 00:00:00 2001 From: Olivier Saingre Date: Wed, 20 Feb 2013 14:41:23 -0800 Subject: [PATCH] encoding/xml: make sure Encoder.Encode reports Write errors. Fixes #4112. R=remyoudompheng, daniel.morsing, dave, rsc CC=golang-dev https://golang.org/cl/7085053 --- src/pkg/encoding/xml/marshal.go | 27 ++++++++++++++++++--------- src/pkg/encoding/xml/marshal_test.go | 10 ++++++++++ src/pkg/encoding/xml/xml.go | 24 +++++++++++++++++++----- src/pkg/encoding/xml/xml_test.go | 14 ++++++++++++++ 4 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/pkg/encoding/xml/marshal.go b/src/pkg/encoding/xml/marshal.go index ea891bfb3e..ea58ce2542 100644 --- a/src/pkg/encoding/xml/marshal.go +++ b/src/pkg/encoding/xml/marshal.go @@ -193,7 +193,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { if xmlns != "" { p.WriteString(` xmlns="`) // TODO: EscapeString, to avoid the allocation. - Escape(p, []byte(xmlns)) + if err := EscapeText(p, []byte(xmlns)); err != nil { + return err + } p.WriteByte('"') } @@ -252,19 +254,22 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits())) case reflect.String: // TODO: Add EscapeString. - Escape(p, []byte(val.String())) + EscapeText(p, []byte(val.String())) case reflect.Bool: p.WriteString(strconv.FormatBool(val.Bool())) case reflect.Array: // will be [...]byte - bytes := make([]byte, val.Len()) - for i := range bytes { - bytes[i] = val.Index(i).Interface().(byte) + var bytes []byte + if val.CanAddr() { + bytes = val.Slice(0, val.Len()).Bytes() + } else { + bytes = make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(bytes), val) } - Escape(p, bytes) + EscapeText(p, bytes) case reflect.Slice: // will be []byte - Escape(p, val.Bytes()) + EscapeText(p, val.Bytes()) default: return &UnsupportedTypeError{typ} } @@ -298,10 +303,14 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { case reflect.Bool: Escape(p, strconv.AppendBool(scratch[:0], vf.Bool())) case reflect.String: - Escape(p, []byte(vf.String())) + if err := EscapeText(p, []byte(vf.String())); err != nil { + return err + } case reflect.Slice: if elem, ok := vf.Interface().([]byte); ok { - Escape(p, elem) + if err := EscapeText(p, elem); err != nil { + return err + } } case reflect.Struct: if vf.Type() == timeType { diff --git a/src/pkg/encoding/xml/marshal_test.go b/src/pkg/encoding/xml/marshal_test.go index ed856813a7..3a190def6c 100644 --- a/src/pkg/encoding/xml/marshal_test.go +++ b/src/pkg/encoding/xml/marshal_test.go @@ -965,6 +965,16 @@ func TestMarshalWriteErrors(t *testing.T) { } } +func TestMarshalWriteIOErrors(t *testing.T) { + enc := NewEncoder(errWriter{}) + + expectErr := "unwritable" + err := enc.Encode(&Passenger{}) + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} + func BenchmarkMarshal(b *testing.B) { for i := 0; i < b.N; i++ { Marshal(atomValue) diff --git a/src/pkg/encoding/xml/xml.go b/src/pkg/encoding/xml/xml.go index 3e004306a1..143fec554c 100644 --- a/src/pkg/encoding/xml/xml.go +++ b/src/pkg/encoding/xml/xml.go @@ -1720,9 +1720,9 @@ var ( esc_cr = []byte(" ") ) -// Escape writes to w the properly escaped XML equivalent +// EscapeText writes to w the properly escaped XML equivalent // of the plain text data s. -func Escape(w io.Writer, s []byte) { +func EscapeText(w io.Writer, s []byte) error { var esc []byte last := 0 for i, c := range s { @@ -1746,11 +1746,25 @@ func Escape(w io.Writer, s []byte) { default: continue } - w.Write(s[last:i]) - w.Write(esc) + if _, err := w.Write(s[last:i]); err != nil { + return err + } + if _, err := w.Write(esc); err != nil { + return err + } last = i + 1 } - w.Write(s[last:]) + if _, err := w.Write(s[last:]); err != nil { + return err + } + return nil +} + +// Escape is like EscapeText but omits the error return value. +// It is provided for backwards compatibility with Go 1.0. +// Code targeting Go 1.1 or later should use EscapeText. +func Escape(w io.Writer, s []byte) { + EscapeText(w, s) } // procInstEncoding parses the `encoding="..."` or `encoding='...'` diff --git a/src/pkg/encoding/xml/xml_test.go b/src/pkg/encoding/xml/xml_test.go index 981d352031..54dab5484a 100644 --- a/src/pkg/encoding/xml/xml_test.go +++ b/src/pkg/encoding/xml/xml_test.go @@ -689,3 +689,17 @@ func TestDirectivesWithComments(t *testing.T) { } } } + +// Writer whose Write method always returns an error. +type errWriter struct{} + +func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") } + +func TestEscapeTextIOErrors(t *testing.T) { + expectErr := "unwritable" + err := EscapeText(errWriter{}, []byte{'A'}) + + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} -- 2.48.1