]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: fix xmlns= behavior
authorRoger Peppe <rogpeppe@gmail.com>
Mon, 29 Jun 2015 11:36:48 +0000 (12:36 +0100)
committerroger peppe <rogpeppe@gmail.com>
Tue, 30 Jun 2015 07:42:37 +0000 (07:42 +0000)
When an xmlns="..." attribute was explicitly generated,
it was being ignored because the name space on the
attribute was assumed to have been explicitly set (to the empty
name space) and it's not possible to have an element in the
empty name space when there is a non-empty name space set.

We fix this by recording when a default name space has been
explicitly set and setting the name space of the element to that
so printer.defineNS can do its work correctly.

We do not attempt to add our own xmlns="..." attribute
when one is explicitly set.

We also add tests for EncodeElement, as that's the only way
to attain coverage of some of the changed behaviour.
Some other test coverage is also increased, although
more work remains to be done in this area.

This change was jointly developed with Martin Hilton (mhilton on github).

Fixes #11431.

Change-Id: I7b85e06eea5b18b2c15ec16dcbd92a8e1d6a9a4e
Reviewed-on: https://go-review.googlesource.com/11635
Reviewed-by: Russ Cox <rsc@golang.org>
src/encoding/xml/marshal.go
src/encoding/xml/marshal_test.go
src/encoding/xml/xml.go

index 63f8e2aa8705c9ec8f4f7b589bbb7cda24540991..100e41df2436faa3d5822fdf4f03c56056d86e99 100644 (file)
@@ -578,12 +578,14 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
        // 3. type name
        var start StartElement
 
-       // Historic behaviour: elements use the default name space
-       // they are contained in by default.
-       start.Name.Space = p.defaultNS
+       // explicitNS records whether the element's name
+       // space has been explicitly set (for example an
+       // and XMLName field).
+       explicitNS := false
 
        if startTemplate != nil {
                start.Name = startTemplate.Name
+               explicitNS = true
                start.Attr = append(start.Attr, startTemplate.Attr...)
        } else if tinfo.xmlname != nil {
                xmlname := tinfo.xmlname
@@ -592,11 +594,13 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" {
                        start.Name = v
                }
+               explicitNS = true
        }
        if start.Name.Local == "" && finfo != nil {
                start.Name.Local = finfo.name
                if finfo.xmlns != "" {
                        start.Name.Space = finfo.xmlns
+                       explicitNS = true
                }
        }
        if start.Name.Local == "" {
@@ -606,9 +610,12 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                }
                start.Name.Local = name
        }
-       // Historic behaviour: an element that's in a namespace sets
-       // the default namespace for all elements contained within it.
-       start.setDefaultNamespace()
+
+       // defaultNS records the default name space as set by a xmlns="..."
+       // attribute. We don't set p.defaultNS because we want to let
+       // the attribute writing code (in p.defineNS) be solely responsible
+       // for maintaining that.
+       defaultNS := p.defaultNS
 
        // Attributes
        for i := range tinfo.fields {
@@ -616,81 +623,26 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                if finfo.flags&fAttr == 0 {
                        continue
                }
-               fv := finfo.value(val)
-               name := Name{Space: finfo.xmlns, Local: finfo.name}
-
-               if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) {
-                       continue
-               }
-
-               if fv.Kind() == reflect.Interface && fv.IsNil() {
-                       continue
-               }
-
-               if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
-                       attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
-                       if err != nil {
-                               return err
-                       }
-                       if attr.Name.Local != "" {
-                               start.Attr = append(start.Attr, attr)
-                       }
-                       continue
-               }
-
-               if fv.CanAddr() {
-                       pv := fv.Addr()
-                       if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
-                               attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
-                               if err != nil {
-                                       return err
-                               }
-                               if attr.Name.Local != "" {
-                                       start.Attr = append(start.Attr, attr)
-                               }
-                               continue
-                       }
-               }
-
-               if fv.CanInterface() && fv.Type().Implements(textMarshalerType) {
-                       text, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
-                       if err != nil {
-                               return err
-                       }
-                       start.Attr = append(start.Attr, Attr{name, string(text)})
-                       continue
-               }
-
-               if fv.CanAddr() {
-                       pv := fv.Addr()
-                       if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
-                               text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
-                               if err != nil {
-                                       return err
-                               }
-                               start.Attr = append(start.Attr, Attr{name, string(text)})
-                               continue
-                       }
-               }
-
-               // Dereference or skip nil pointer, interface values.
-               switch fv.Kind() {
-               case reflect.Ptr, reflect.Interface:
-                       if fv.IsNil() {
-                               continue
-                       }
-                       fv = fv.Elem()
-               }
-
-               s, b, err := p.marshalSimple(fv.Type(), fv)
+               attr, add, err := p.fieldAttr(finfo, val)
                if err != nil {
                        return err
                }
-               if b != nil {
-                       s = string(b)
+               if !add {
+                       continue
+               }
+               start.Attr = append(start.Attr, attr)
+               if attr.Name.Space == "" && attr.Name.Local == "xmlns" {
+                       defaultNS = attr.Value
                }
-               start.Attr = append(start.Attr, Attr{name, s})
        }
+       if !explicitNS {
+               // Historic behavior: elements use the default name space
+               // they are contained in by default.
+               start.Name.Space = defaultNS
+       }
+       // Historic behaviour: an element that's in a namespace sets
+       // the default namespace for all elements contained within it.
+       start.setDefaultNamespace()
 
        if err := p.writeStart(&start); err != nil {
                return err
@@ -719,6 +671,64 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
        return p.cachedWriteError()
 }
 
+// fieldAttr returns the attribute of the given field and
+// whether it should actually be added as an attribute;
+// val holds the value containing the field.
+func (p *printer) fieldAttr(finfo *fieldInfo, val reflect.Value) (Attr, bool, error) {
+       fv := finfo.value(val)
+       name := Name{Space: finfo.xmlns, Local: finfo.name}
+       if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) {
+               return Attr{}, false, nil
+       }
+       if fv.Kind() == reflect.Interface && fv.IsNil() {
+               return Attr{}, false, nil
+       }
+       if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
+               attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
+               return attr, attr.Name.Local != "", err
+       }
+       if fv.CanAddr() {
+               pv := fv.Addr()
+               if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
+                       attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
+                       return attr, attr.Name.Local != "", err
+               }
+       }
+       if fv.CanInterface() && fv.Type().Implements(textMarshalerType) {
+               text, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
+               if err != nil {
+                       return Attr{}, false, err
+               }
+               return Attr{name, string(text)}, true, nil
+       }
+       if fv.CanAddr() {
+               pv := fv.Addr()
+               if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
+                       text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
+                       if err != nil {
+                               return Attr{}, false, err
+                       }
+                       return Attr{name, string(text)}, true, nil
+               }
+       }
+       // Dereference or skip nil pointer, interface values.
+       switch fv.Kind() {
+       case reflect.Ptr, reflect.Interface:
+               if fv.IsNil() {
+                       return Attr{}, false, nil
+               }
+               fv = fv.Elem()
+       }
+       s, b, err := p.marshalSimple(fv.Type(), fv)
+       if err != nil {
+               return Attr{}, false, err
+       }
+       if b != nil {
+               s = string(b)
+       }
+       return Attr{name, s}, true, nil
+}
+
 // defaultStart returns the default start element to use,
 // given the reflect type, field info, and start template.
 func (p *printer) defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement {
index 394855782ec13b069006b2e4801c074b2e5a2109..4c478ddded94efb9975139e5676afa2a9cf60790 100644 (file)
@@ -174,6 +174,11 @@ type XMLNameWithTag struct {
        Value   string `xml:",chardata"`
 }
 
+type XMLNameWithNSTag struct {
+       XMLName Name   `xml:"ns InXMLNameWithNSTag"`
+       Value   string `xml:",chardata"`
+}
+
 type XMLNameWithoutTag struct {
        XMLName Name
        Value   string `xml:",chardata"`
@@ -302,8 +307,7 @@ func (m *MyMarshalerTest) MarshalXML(e *Encoder, start StartElement) error {
        return nil
 }
 
-type MyMarshalerAttrTest struct {
-}
+type MyMarshalerAttrTest struct{}
 
 var _ MarshalerAttr = (*MyMarshalerAttrTest)(nil)
 
@@ -311,10 +315,22 @@ func (m *MyMarshalerAttrTest) MarshalXMLAttr(name Name) (Attr, error) {
        return Attr{name, "hello world"}, nil
 }
 
+type MyMarshalerValueAttrTest struct{}
+
+var _ MarshalerAttr = MyMarshalerValueAttrTest{}
+
+func (m MyMarshalerValueAttrTest) MarshalXMLAttr(name Name) (Attr, error) {
+       return Attr{name, "hello world"}, nil
+}
+
 type MarshalerStruct struct {
        Foo MyMarshalerAttrTest `xml:",attr"`
 }
 
+type MarshalerValueStruct struct {
+       Foo MyMarshalerValueAttrTest `xml:",attr"`
+}
+
 type InnerStruct struct {
        XMLName Name `xml:"testns outer"`
 }
@@ -350,6 +366,34 @@ type NestedAndComment struct {
        Comment string   `xml:",comment"`
 }
 
+type XMLNSFieldStruct struct {
+       Ns   string `xml:"xmlns,attr"`
+       Body string
+}
+
+type NamedXMLNSFieldStruct struct {
+       XMLName struct{} `xml:"testns test"`
+       Ns      string   `xml:"xmlns,attr"`
+       Body    string
+}
+
+type XMLNSFieldStructWithOmitEmpty struct {
+       Ns   string `xml:"xmlns,attr,omitempty"`
+       Body string
+}
+
+type NamedXMLNSFieldStructWithEmptyNamespace struct {
+       XMLName struct{} `xml:"test"`
+       Ns      string   `xml:"xmlns,attr"`
+       Body    string
+}
+
+type RecursiveXMLNSFieldStruct struct {
+       Ns   string                     `xml:"xmlns,attr"`
+       Body *RecursiveXMLNSFieldStruct `xml:",omitempty"`
+       Text string                     `xml:",omitempty"`
+}
+
 func ifaceptr(x interface{}) interface{} {
        return &x
 }
@@ -989,6 +1033,10 @@ var marshalTests = []struct {
                ExpectXML: `<MarshalerStruct Foo="hello world"></MarshalerStruct>`,
                Value:     &MarshalerStruct{},
        },
+       {
+               ExpectXML: `<MarshalerValueStruct Foo="hello world"></MarshalerValueStruct>`,
+               Value:     &MarshalerValueStruct{},
+       },
        {
                ExpectXML: `<outer xmlns="testns" int="10"></outer>`,
                Value:     &OuterStruct{IntAttr: 10},
@@ -1013,6 +1061,39 @@ var marshalTests = []struct {
                ExpectXML: `<NestedAndComment><A><B></B><B></B></A><!--test--></NestedAndComment>`,
                Value:     &NestedAndComment{AB: make([]string, 2), Comment: "test"},
        },
+       {
+               ExpectXML: `<XMLNSFieldStruct xmlns="http://example.com/ns"><Body>hello world</Body></XMLNSFieldStruct>`,
+               Value:     &XMLNSFieldStruct{Ns: "http://example.com/ns", Body: "hello world"},
+       },
+       {
+               ExpectXML: `<testns:test xmlns:testns="testns" xmlns="http://example.com/ns"><Body>hello world</Body></testns:test>`,
+               Value:     &NamedXMLNSFieldStruct{Ns: "http://example.com/ns", Body: "hello world"},
+       },
+       {
+               ExpectXML: `<testns:test xmlns:testns="testns"><Body>hello world</Body></testns:test>`,
+               Value:     &NamedXMLNSFieldStruct{Ns: "", Body: "hello world"},
+       },
+       {
+               ExpectXML: `<XMLNSFieldStructWithOmitEmpty><Body>hello world</Body></XMLNSFieldStructWithOmitEmpty>`,
+               Value:     &XMLNSFieldStructWithOmitEmpty{Body: "hello world"},
+       },
+       {
+               // The xmlns attribute must be ignored because the <test>
+               // element is in the empty namespace, so it's not possible
+               // to set the default namespace to something non-empty.
+               ExpectXML:   `<test><Body>hello world</Body></test>`,
+               Value:       &NamedXMLNSFieldStructWithEmptyNamespace{Ns: "foo", Body: "hello world"},
+               MarshalOnly: true,
+       },
+       {
+               ExpectXML: `<RecursiveXMLNSFieldStruct xmlns="foo"><Body xmlns=""><Text>hello world</Text></Body></RecursiveXMLNSFieldStruct>`,
+               Value: &RecursiveXMLNSFieldStruct{
+                       Ns: "foo",
+                       Body: &RecursiveXMLNSFieldStruct{
+                               Text: "hello world",
+                       },
+               },
+       },
 }
 
 func TestMarshal(t *testing.T) {
@@ -1235,6 +1316,100 @@ func TestMarshalFlush(t *testing.T) {
        }
 }
 
+var encodeElementTests = []struct {
+       desc      string
+       value     interface{}
+       start     StartElement
+       expectXML string
+}{{
+       desc:  "simple string",
+       value: "hello",
+       start: StartElement{
+               Name: Name{Local: "a"},
+       },
+       expectXML: `<a>hello</a>`,
+}, {
+       desc:  "string with added attributes",
+       value: "hello",
+       start: StartElement{
+               Name: Name{Local: "a"},
+               Attr: []Attr{{
+                       Name:  Name{Local: "x"},
+                       Value: "y",
+               }, {
+                       Name:  Name{Local: "foo"},
+                       Value: "bar",
+               }},
+       },
+       expectXML: `<a x="y" foo="bar">hello</a>`,
+}, {
+       desc: "start element with default name space",
+       value: struct {
+               Foo XMLNameWithNSTag
+       }{
+               Foo: XMLNameWithNSTag{
+                       Value: "hello",
+               },
+       },
+       start: StartElement{
+               Name: Name{Space: "ns", Local: "a"},
+               Attr: []Attr{{
+                       Name: Name{Local: "xmlns"},
+                       // "ns" is the name space defined in XMLNameWithNSTag
+                       Value: "ns",
+               }},
+       },
+       expectXML: `<a xmlns="ns"><InXMLNameWithNSTag>hello</InXMLNameWithNSTag></a>`,
+}, {
+       desc: "start element in name space with different default name space",
+       value: struct {
+               Foo XMLNameWithNSTag
+       }{
+               Foo: XMLNameWithNSTag{
+                       Value: "hello",
+               },
+       },
+       start: StartElement{
+               Name: Name{Space: "ns2", Local: "a"},
+               Attr: []Attr{{
+                       Name: Name{Local: "xmlns"},
+                       // "ns" is the name space defined in XMLNameWithNSTag
+                       Value: "ns",
+               }},
+       },
+       expectXML: `<ns2:a xmlns:ns2="ns2" xmlns="ns"><InXMLNameWithNSTag>hello</InXMLNameWithNSTag></ns2:a>`,
+}, {
+       desc:  "XMLMarshaler with start element with default name space",
+       value: &MyMarshalerTest{},
+       start: StartElement{
+               Name: Name{Space: "ns2", Local: "a"},
+               Attr: []Attr{{
+                       Name: Name{Local: "xmlns"},
+                       // "ns" is the name space defined in XMLNameWithNSTag
+                       Value: "ns",
+               }},
+       },
+       expectXML: `<ns2:a xmlns:ns2="ns2" xmlns="ns">hello world</ns2:a>`,
+}}
+
+func TestEncodeElement(t *testing.T) {
+       for idx, test := range encodeElementTests {
+               var buf bytes.Buffer
+               enc := NewEncoder(&buf)
+               err := enc.EncodeElement(test.value, test.start)
+               if err != nil {
+                       t.Fatalf("enc.EncodeElement: %v", err)
+               }
+               err = enc.Flush()
+               if err != nil {
+                       t.Fatalf("enc.Flush: %v", err)
+               }
+               if got, want := buf.String(), test.expectXML; got != want {
+                       t.Errorf("#%d(%s): EncodeElement(%#v, %#v):\nhave %#q\nwant %#q", idx, test.desc, test.value, test.start, got, want)
+               }
+       }
+}
+
 func BenchmarkMarshal(b *testing.B) {
        b.ReportAllocs()
        for i := 0; i < b.N; i++ {
index 3090750c483336006325e55567b2c44b19802f21..ffab4a70c9dc069bea742e978ead4ab742f15752 100644 (file)
@@ -91,6 +91,12 @@ func (e *StartElement) setDefaultNamespace() {
                // or was just using the default namespace.
                return
        }
+       // Don't add a default name space if there's already one set.
+       for _, attr := range e.Attr {
+               if attr.Name.Space == "" && attr.Name.Local == "xmlns" {
+                       return
+               }
+       }
        e.Attr = append(e.Attr, Attr{
                Name: Name{
                        Local: "xmlns",