]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: only initialize nil struct fields when decoding
authorDaniel Martí <mvdan@mvdan.cc>
Tue, 24 Sep 2019 17:14:10 +0000 (18:14 +0100)
committerDaniel Martí <mvdan@mvdan.cc>
Thu, 28 May 2020 22:48:53 +0000 (22:48 +0000)
fieldInfo.value used to initialize nil anonymous struct fields if they
were encountered. This behavior is wanted when decoding, but not when
encoding. When encoding, the value should never be modified, and these
nil fields should be skipped entirely.

To fix the bug, add a bool argument to the function which tells the
code whether we are encoding or decoding.

Finally, add a couple of tests to cover the edge cases pointed out in
the original issue.

Fixes #27240.

Change-Id: Ic97ae4bfe5f2062c8518e03d1dec07c3875e18f6
Reviewed-on: https://go-review.googlesource.com/c/go/+/196809
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
src/encoding/xml/marshal.go
src/encoding/xml/marshal_test.go
src/encoding/xml/read.go
src/encoding/xml/typeinfo.go

index 2440a51e205b70b0771366533e01619baae7f909..d8a04a95a2520fa39b8028dd97f43bef70804aba 100644 (file)
@@ -482,8 +482,11 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                xmlname := tinfo.xmlname
                if xmlname.name != "" {
                        start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
-               } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" {
-                       start.Name = v
+               } else {
+                       fv := xmlname.value(val, dontInitNilPointers)
+                       if v, ok := fv.Interface().(Name); ok && v.Local != "" {
+                               start.Name = v
+                       }
                }
        }
        if start.Name.Local == "" && finfo != nil {
@@ -503,7 +506,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                if finfo.flags&fAttr == 0 {
                        continue
                }
-               fv := finfo.value(val)
+               fv := finfo.value(val, dontInitNilPointers)
 
                if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) {
                        continue
@@ -806,7 +809,12 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                if finfo.flags&fAttr != 0 {
                        continue
                }
-               vf := finfo.value(val)
+               vf := finfo.value(val, dontInitNilPointers)
+               if !vf.IsValid() {
+                       // The field is behind an anonymous struct field that's
+                       // nil. Skip it.
+                       continue
+               }
 
                switch finfo.flags & fMode {
                case fCDATA, fCharData:
index 6085ddbba2300d22305686cf2960248d6a023e70..d2e5137afd7c05a7ad0ad1414e227f168ee4bf54 100644 (file)
@@ -309,6 +309,11 @@ type ChardataEmptyTest struct {
        Contents *string `xml:",chardata"`
 }
 
+type PointerAnonFields struct {
+       *MyInt
+       *NamedType
+}
+
 type MyMarshalerTest struct {
 }
 
@@ -889,6 +894,18 @@ var marshalTests = []struct {
                        `</EmbedA>`,
        },
 
+       // Anonymous struct pointer field which is nil
+       {
+               Value:     &EmbedB{},
+               ExpectXML: `<EmbedB><FieldB></FieldB></EmbedB>`,
+       },
+
+       // Other kinds of nil anonymous fields
+       {
+               Value:     &PointerAnonFields{},
+               ExpectXML: `<PointerAnonFields></PointerAnonFields>`,
+       },
+
        // Test that name casing matters
        {
                Value:     &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"},
index 10a60eed1a903a5e60fbf5151bdef021a66cd9cd..ef5df3f7f6aecca09302cb07050711d8b79b51da 100644 (file)
@@ -435,7 +435,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                                }
                                return UnmarshalError(e)
                        }
-                       fv := finfo.value(sv)
+                       fv := finfo.value(sv, initNilPointers)
                        if _, ok := fv.Interface().(Name); ok {
                                fv.Set(reflect.ValueOf(start.Name))
                        }
@@ -449,7 +449,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                                finfo := &tinfo.fields[i]
                                switch finfo.flags & fMode {
                                case fAttr:
-                                       strv := finfo.value(sv)
+                                       strv := finfo.value(sv, initNilPointers)
                                        if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) {
                                                if err := d.unmarshalAttr(strv, a); err != nil {
                                                        return err
@@ -465,7 +465,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                        }
                        if !handled && any >= 0 {
                                finfo := &tinfo.fields[any]
-                               strv := finfo.value(sv)
+                               strv := finfo.value(sv, initNilPointers)
                                if err := d.unmarshalAttr(strv, a); err != nil {
                                        return err
                                }
@@ -478,22 +478,22 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                        switch finfo.flags & fMode {
                        case fCDATA, fCharData:
                                if !saveData.IsValid() {
-                                       saveData = finfo.value(sv)
+                                       saveData = finfo.value(sv, initNilPointers)
                                }
 
                        case fComment:
                                if !saveComment.IsValid() {
-                                       saveComment = finfo.value(sv)
+                                       saveComment = finfo.value(sv, initNilPointers)
                                }
 
                        case fAny, fAny | fElement:
                                if !saveAny.IsValid() {
-                                       saveAny = finfo.value(sv)
+                                       saveAny = finfo.value(sv, initNilPointers)
                                }
 
                        case fInnerXML:
                                if !saveXML.IsValid() {
-                                       saveXML = finfo.value(sv)
+                                       saveXML = finfo.value(sv, initNilPointers)
                                        if d.saved == nil {
                                                saveXMLIndex = 0
                                                d.saved = new(bytes.Buffer)
@@ -687,7 +687,7 @@ Loop:
                }
                if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
                        // It's a perfect match, unmarshal the field.
-                       return true, d.unmarshal(finfo.value(sv), start)
+                       return true, d.unmarshal(finfo.value(sv, initNilPointers), start)
                }
                if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
                        // It's a prefix for the field. Break and recurse
index 639952c74adda7c13e5d3be3e59219eb79b8f9b0..f30fe58590bd774d6bf5c423506d0c107aae4819 100644 (file)
@@ -344,15 +344,25 @@ func (e *TagPathError) Error() string {
        return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
 }
 
+const (
+       initNilPointers     = true
+       dontInitNilPointers = false
+)
+
 // value returns v's field value corresponding to finfo.
-// It's equivalent to v.FieldByIndex(finfo.idx), but initializes
-// and dereferences pointers as necessary.
-func (finfo *fieldInfo) value(v reflect.Value) reflect.Value {
+// It's equivalent to v.FieldByIndex(finfo.idx), but when passed
+// initNilPointers, it initializes and dereferences pointers as necessary.
+// When passed dontInitNilPointers and a nil pointer is reached, the function
+// returns a zero reflect.Value.
+func (finfo *fieldInfo) value(v reflect.Value, shouldInitNilPointers bool) reflect.Value {
        for i, x := range finfo.idx {
                if i > 0 {
                        t := v.Type()
                        if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
                                if v.IsNil() {
+                                       if !shouldInitNilPointers {
+                                               return reflect.Value{}
+                                       }
                                        v.Set(reflect.New(v.Type().Elem()))
                                }
                                v = v.Elem()