]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: support generic encoding interfaces
authorRuss Cox <rsc@golang.org>
Wed, 14 Aug 2013 22:52:09 +0000 (18:52 -0400)
committerRuss Cox <rsc@golang.org>
Wed, 14 Aug 2013 22:52:09 +0000 (18:52 -0400)
Remove custom support for time.Time.
No new tests: the tests for the time.Time special case
now test the general case.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12751045

src/pkg/encoding/xml/marshal.go
src/pkg/encoding/xml/read.go
src/pkg/encoding/xml/xml.go
src/pkg/go/build/deps_test.go

index 68efbcabb90068266c8b6af69386399bdb5b0ed2..a6ee5d51285b729ec307f8fc7e927e5f666b3175 100644 (file)
@@ -7,12 +7,12 @@ package xml
 import (
        "bufio"
        "bytes"
+       "encoding"
        "fmt"
        "io"
        "reflect"
        "strconv"
        "strings"
-       "time"
 )
 
 const (
@@ -319,6 +319,7 @@ func (p *printer) popPrefix() {
 var (
        marshalerType     = reflect.TypeOf((*Marshaler)(nil)).Elem()
        marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem()
+       textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
 )
 
 // marshalValue writes one or more XML elements representing val.
@@ -348,14 +349,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
        }
 
        // Check for marshaler.
-       if typ.Name() != "" && val.CanAddr() {
+       if val.CanInterface() && typ.Implements(marshalerType) {
+               return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
+       }
+       if val.CanAddr() {
                pv := val.Addr()
                if pv.CanInterface() && pv.Type().Implements(marshalerType) {
-                       return p.marshalInterface(pv.Interface().(Marshaler), pv.Type(), finfo, startTemplate)
+                       return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
                }
        }
-       if val.CanInterface() && typ.Implements(marshalerType) {
-               return p.marshalInterface(val.Interface().(Marshaler), typ, finfo, startTemplate)
+
+       // Check for text marshaler.
+       if val.CanInterface() && typ.Implements(textMarshalerType) {
+               return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
+       }
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
+                       return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
+               }
        }
 
        // Slices and arrays iterate over the elements. They do not have an enclosing tag.
@@ -416,6 +428,21 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                        continue
                }
 
+               if fv.Kind() == reflect.Interface && fv.IsNil() {
+                       continue
+               }
+
+               if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
+                       attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
+                       if err != nil {
+                               return err
+                       }
+                       if attr.Name.Local != "" {
+                               start.Attr = append(start.Attr, attr)
+                       }
+                       continue
+               }
+
                if fv.CanAddr() {
                        pv := fv.Addr()
                        if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
@@ -430,20 +457,27 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                        }
                }
 
-               if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
-                       if fv.Kind() == reflect.Interface && fv.IsNil() {
-                               continue
-                       }
-                       attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
+               if fv.CanInterface() && fv.Type().Implements(textMarshalerType) {
+                       text, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
                        if err != nil {
                                return err
                        }
-                       if attr.Name.Local != "" {
-                               start.Attr = append(start.Attr, attr)
-                       }
+                       start.Attr = append(start.Attr, Attr{name, string(text)})
                        continue
                }
 
+               if fv.CanAddr() {
+                       pv := fv.Addr()
+                       if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
+                               text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
+                               if err != nil {
+                                       return err
+                               }
+                               start.Attr = append(start.Attr, Attr{name, string(text)})
+                               continue
+                       }
+               }
+
                // Dereference or skip nil pointer, interface values.
                switch fv.Kind() {
                case reflect.Ptr, reflect.Interface:
@@ -490,10 +524,10 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
        return p.cachedWriteError()
 }
 
-// marshalInterface marshals a Marshaler interface value.
-func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) error {
+// defaultStart returns the default start element to use,
+// given the reflect type, field info, and start template.
+func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement {
        var start StartElement
-
        // Precedence for the XML element name is as above,
        // except that we do not look inside structs for the first field.
        if startTemplate != nil {
@@ -509,7 +543,11 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field
                // since it has the Marshaler methods.
                start.Name.Local = typ.Elem().Name()
        }
+       return start
+}
 
+// marshalInterface marshals a Marshaler interface value.
+func (p *printer) marshalInterface(val Marshaler, start StartElement) error {
        // Push a marker onto the tag stack so that MarshalXML
        // cannot close the XML tags that it did not open.
        p.tags = append(p.tags, Name{})
@@ -528,6 +566,19 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field
        return nil
 }
 
+// marshalTextInterface marshals a TextMarshaler interface value.
+func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error {
+       if err := p.writeStart(&start); err != nil {
+               return err
+       }
+       text, err := val.MarshalText()
+       if err != nil {
+               return err
+       }
+       EscapeText(p, text)
+       return p.writeEnd(start.Name)
+}
+
 // writeStart writes the given start element.
 func (p *printer) writeStart(start *StartElement) error {
        if start.Name.Local == "" {
@@ -591,13 +642,7 @@ func (p *printer) writeEnd(name Name) error {
        return nil
 }
 
-var timeType = reflect.TypeOf(time.Time{})
-
 func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) {
-       // Normally we don't see structs, but this can happen for an attribute.
-       if val.Type() == timeType {
-               return val.Interface().(time.Time).Format(time.RFC3339Nano), nil, nil
-       }
        switch val.Kind() {
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                return strconv.FormatInt(val.Int(), 10), nil, nil
@@ -629,10 +674,6 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []
 var ddBytes = []byte("--")
 
 func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
-       if val.Type() == timeType {
-               _, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
-               return err
-       }
        s := parentStack{p: p}
        for i := range tinfo.fields {
                finfo := &tinfo.fields[i]
@@ -651,6 +692,25 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
 
                switch finfo.flags & fMode {
                case fCharData:
+                       if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
+                               data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
+                               if err != nil {
+                                       return err
+                               }
+                               Escape(p, data)
+                               continue
+                       }
+                       if vf.CanAddr() {
+                               pv := vf.Addr()
+                               if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
+                                       data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
+                                       if err != nil {
+                                               return err
+                                       }
+                                       Escape(p, data)
+                                       continue
+                               }
+                       }
                        var scratch [64]byte
                        switch vf.Kind() {
                        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -671,10 +731,6 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                                                return err
                                        }
                                }
-                       case reflect.Struct:
-                               if vf.Type() == timeType {
-                                       Escape(p, []byte(vf.Interface().(time.Time).Format(time.RFC3339Nano)))
-                               }
                        }
                        continue
 
index 698bf1a22efba53a91af4565ccc6ea15c1937962..da7ad3baedc28ee0a755c26a61280aed9309f08c 100644 (file)
@@ -6,12 +6,12 @@ package xml
 
 import (
        "bytes"
+       "encoding"
        "errors"
        "fmt"
        "reflect"
        "strconv"
        "strings"
-       "time"
 )
 
 // BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
@@ -178,8 +178,7 @@ func receiverType(val interface{}) string {
        return "(" + t.String() + ")"
 }
 
-// unmarshalInterface unmarshals a single XML element into val,
-// which is known to implement Unmarshaler.
+// unmarshalInterface unmarshals a single XML element into val.
 // start is the opening tag of the element.
 func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
        // Record that decoder must stop at end tag corresponding to start.
@@ -200,6 +199,31 @@ func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error
        return nil
 }
 
+// unmarshalTextInterface unmarshals a single XML element into val.
+// The chardata contained in the element (but not its children)
+// is passed to the text unmarshaler.
+func (p *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler, start *StartElement) error {
+       var buf []byte
+       depth := 1
+       for depth > 0 {
+               t, err := p.Token()
+               if err != nil {
+                       return err
+               }
+               switch t := t.(type) {
+               case CharData:
+                       if depth == 1 {
+                               buf = append(buf, t...)
+                       }
+               case StartElement:
+                       depth++
+               case EndElement:
+                       depth--
+               }
+       }
+       return val.UnmarshalText(buf)
+}
+
 // unmarshalAttr unmarshals a single XML attribute into val.
 func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
        if val.Kind() == reflect.Ptr {
@@ -221,7 +245,18 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
                }
        }
 
-       // TODO: Check for and use encoding.TextUnmarshaler.
+       // Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
+       if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
+               // This is an unmarshaler with a non-pointer receiver,
+               // so it's likely to be incorrect, but we do what we're told.
+               return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
+       }
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
+                       return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
+               }
+       }
 
        copyValue(val, []byte(attr.Value))
        return nil
@@ -230,6 +265,7 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
 var (
        unmarshalerType     = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
        unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem()
+       textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
 )
 
 // Unmarshal a single XML element into val.
@@ -268,7 +304,16 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                }
        }
 
-       // TODO: Check for and use encoding.TextUnmarshaler.
+       if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
+               return p.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler), start)
+       }
+
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
+                       return p.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler), start)
+               }
+       }
 
        var (
                data         []byte
@@ -332,10 +377,6 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                        v.Set(reflect.ValueOf(start.Name))
                        break
                }
-               if typ == timeType {
-                       saveData = v
-                       break
-               }
 
                sv = v
                tinfo, err = getTypeInfo(typ)
@@ -464,6 +505,23 @@ Loop:
                }
        }
 
+       if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
+               if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
+                       return err
+               }
+               saveData = reflect.Value{}
+       }
+
+       if saveData.IsValid() && saveData.CanAddr() {
+               pv := saveData.Addr()
+               if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
+                       if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
+                               return err
+                       }
+                       saveData = reflect.Value{}
+               }
+       }
+
        if err := copyValue(saveData, data); err != nil {
                return err
        }
@@ -486,6 +544,8 @@ Loop:
 }
 
 func copyValue(dst reflect.Value, src []byte) (err error) {
+       dst0 := dst
+
        if dst.Kind() == reflect.Ptr {
                if dst.IsNil() {
                        dst.Set(reflect.New(dst.Type().Elem()))
@@ -496,9 +556,9 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
        // Save accumulated data.
        switch dst.Kind() {
        case reflect.Invalid:
-               // Probably a commendst.
+               // Probably a comment.
        default:
-               return errors.New("cannot happen: unknown type " + dst.Type().String())
+               return errors.New("cannot unmarshal into " + dst0.Type().String())
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits())
                if err != nil {
@@ -531,14 +591,6 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
                        src = []byte{}
                }
                dst.SetBytes(src)
-       case reflect.Struct:
-               if dst.Type() == timeType {
-                       tv, err := time.Parse(time.RFC3339, string(src))
-                       if err != nil {
-                               return err
-                       }
-                       dst.Set(reflect.ValueOf(tv))
-               }
        }
        return nil
 }
index da8eb2e5f9d7d60582d3dfa9ea9a6de92080a32e..467c2ae14f81b937b1178dffbf5ba39e82936d59 100644 (file)
@@ -67,6 +67,11 @@ func (e StartElement) Copy() StartElement {
        return e
 }
 
+// End returns the corresponding XML end element.
+func (e StartElement) End() EndElement {
+       return EndElement{e.Name}
+}
+
 // An EndElement represents an XML end element.
 type EndElement struct {
        Name Name
index 5e5982422bf2436e30f9c8d957813447d2cd9d16..1a8564136f8a15d6263260bddd8a3cde607044a6 100644 (file)
@@ -200,7 +200,7 @@ var pkgDeps = map[string][]string{
        "encoding/hex":        {"L4"},
        "encoding/json":       {"L4", "encoding"},
        "encoding/pem":        {"L4"},
-       "encoding/xml":        {"L4"},
+       "encoding/xml":        {"L4", "encoding"},
        "flag":                {"L4", "OS"},
        "go/build":            {"L4", "OS", "GOPARSER"},
        "html":                {"L4"},