]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: handle anonymous pointer fields
authorGustavo Niemeyer <gustavo@niemeyer.net>
Thu, 17 May 2012 02:21:31 +0000 (23:21 -0300)
committerGustavo Niemeyer <gustavo@niemeyer.net>
Thu, 17 May 2012 02:21:31 +0000 (23:21 -0300)
This CL makes

    type T struct { *U }

behave in a similar way to:

    type T struct { U }

Fixes #3108.

R=golang-dev, rsc, gustavo
CC=golang-dev
https://golang.org/cl/5694044

src/pkg/encoding/xml/marshal.go
src/pkg/encoding/xml/marshal_test.go
src/pkg/encoding/xml/read.go
src/pkg/encoding/xml/typeinfo.go

index 6c3170bdda3d42074b9721b9ac21ebc401e7f82d..51e1dc8f96461653115f1f2fea45fb7e722268fd 100644 (file)
@@ -57,8 +57,8 @@ const (
 //       if the field value is empty. The empty values are false, 0, any
 //       nil pointer or interface value, and any array, slice, map, or
 //       string of length zero.
-//     - a non-pointer anonymous struct field is handled as if the
-//       fields of its value were part of the outer struct.
+//     - an anonymous struct field is handled as if the fields of its
+//       value were part of the outer struct.
 //
 // If a field uses a tag "a>b>c", then the element c will be nested inside
 // parent elements a and b.  Fields that appear next to each other that name
@@ -164,7 +164,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
                xmlname := tinfo.xmlname
                if xmlname.name != "" {
                        xmlns, name = xmlname.xmlns, xmlname.name
-               } else if v, ok := val.FieldByIndex(xmlname.idx).Interface().(Name); ok && v.Local != "" {
+               } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" {
                        xmlns, name = v.Space, v.Local
                }
        }
@@ -195,7 +195,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
                if finfo.flags&fAttr == 0 {
                        continue
                }
-               fv := val.FieldByIndex(finfo.idx)
+               fv := finfo.value(val)
                if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) {
                        continue
                }
@@ -276,7 +276,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                if finfo.flags&(fAttr|fAny) != 0 {
                        continue
                }
-               vf := val.FieldByIndex(finfo.idx)
+               vf := finfo.value(val)
                switch finfo.flags & fMode {
                case fCharData:
                        switch vf.Kind() {
index b6978a1e65b5a590c6c722153af5d675224386fe..90b4925e7f9cb42a3a95c7b9727ed7fdc7dc7170 100644 (file)
@@ -108,7 +108,7 @@ type EmbedA struct {
 
 type EmbedB struct {
        FieldB string
-       EmbedC
+       *EmbedC
 }
 
 type EmbedC struct {
@@ -493,7 +493,7 @@ var marshalTests = []struct {
                        },
                        EmbedB: EmbedB{
                                FieldB: "A.B.B",
-                               EmbedC: EmbedC{
+                               EmbedC: &EmbedC{
                                        FieldA1: "A.B.C.A1",
                                        FieldA2: "A.B.C.A2",
                                        FieldB:  "", // Shadowed by A.B.B
index c2168242091a493f3c7af28e182c6da2aa32048c..0e6761d66adac7916ea3b7be8fff331142c1c0f6 100644 (file)
@@ -81,8 +81,8 @@ import (
 //      of the above rules and the struct has a field with tag ",any",
 //      unmarshal maps the sub-element to that struct field.
 //
-//   * A non-pointer anonymous struct field is handled as if the
-//      fields of its value were part of the outer struct.
+//   * An anonymous struct field is handled as if the fields of its
+//      value were part of the outer struct.
 //
 //   * A struct field with tag "-" is never unmarshalled into.
 //
@@ -248,7 +248,7 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                                }
                                return UnmarshalError(e)
                        }
-                       fv := sv.FieldByIndex(finfo.idx)
+                       fv := finfo.value(sv)
                        if _, ok := fv.Interface().(Name); ok {
                                fv.Set(reflect.ValueOf(start.Name))
                        }
@@ -260,7 +260,7 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                        finfo := &tinfo.fields[i]
                        switch finfo.flags & fMode {
                        case fAttr:
-                               strv := sv.FieldByIndex(finfo.idx)
+                               strv := finfo.value(sv)
                                // Look for attribute.
                                for _, a := range start.Attr {
                                        if a.Name.Local == finfo.name {
@@ -271,22 +271,22 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
 
                        case fCharData:
                                if !saveData.IsValid() {
-                                       saveData = sv.FieldByIndex(finfo.idx)
+                                       saveData = finfo.value(sv)
                                }
 
                        case fComment:
                                if !saveComment.IsValid() {
-                                       saveComment = sv.FieldByIndex(finfo.idx)
+                                       saveComment = finfo.value(sv)
                                }
 
                        case fAny:
                                if !saveAny.IsValid() {
-                                       saveAny = sv.FieldByIndex(finfo.idx)
+                                       saveAny = finfo.value(sv)
                                }
 
                        case fInnerXml:
                                if !saveXML.IsValid() {
-                                       saveXML = sv.FieldByIndex(finfo.idx)
+                                       saveXML = finfo.value(sv)
                                        if p.saved == nil {
                                                saveXMLIndex = 0
                                                p.saved = new(bytes.Buffer)
@@ -461,7 +461,7 @@ Loop:
                }
                if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
                        // It's a perfect match, unmarshal the field.
-                       return true, p.unmarshal(sv.FieldByIndex(finfo.idx), start)
+                       return true, p.unmarshal(finfo.value(sv), start)
                }
                if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
                        // It's a prefix for the field. Break and recurse
index 8e2e4508b104880af8cf0ab1215d9947eda8b136..970d1701932d4a1eea02cf35ddefb914eb413a7c 100644 (file)
@@ -66,10 +66,14 @@ func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
 
                        // For embedded structs, embed its fields.
                        if f.Anonymous {
-                               if f.Type.Kind() != reflect.Struct {
+                               t := f.Type
+                               if t.Kind() == reflect.Ptr {
+                                       t = t.Elem()
+                               }
+                               if t.Kind() != reflect.Struct {
                                        continue
                                }
-                               inner, err := getTypeInfo(f.Type)
+                               inner, err := getTypeInfo(t)
                                if err != nil {
                                        return nil, err
                                }
@@ -327,3 +331,22 @@ type TagPathError struct {
 func (e *TagPathError) Error() string {
        return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
 }
+
+// value returns v's field value corresponding to finfo.
+// It's equivalent to v.FieldByIndex(finfo.idx), but initializes
+// and dereferences pointers as necessary.
+func (finfo *fieldInfo) value(v reflect.Value) reflect.Value {
+       for i, x := range finfo.idx {
+               if i > 0 {
+                       t := v.Type()
+                       if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
+                               if v.IsNil() {
+                                       v.Set(reflect.New(v.Type().Elem()))
+                               }
+                               v = v.Elem()
+                       }
+               }
+               v = v.Field(x)
+       }
+       return v
+}