From: Russ Cox Date: Wed, 14 Aug 2013 22:52:09 +0000 (-0400) Subject: encoding/xml: support generic encoding interfaces X-Git-Tag: go1.2rc2~564 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=904e11361581a5e0f3ffa3576489fe994bfeff6a;p=gostls13.git encoding/xml: support generic encoding interfaces Remove custom support for time.Time. No new tests: the tests for the time.Time special case now test the general case. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/12751045 --- diff --git a/src/pkg/encoding/xml/marshal.go b/src/pkg/encoding/xml/marshal.go index 68efbcabb9..a6ee5d5128 100644 --- a/src/pkg/encoding/xml/marshal.go +++ b/src/pkg/encoding/xml/marshal.go @@ -7,12 +7,12 @@ package xml import ( "bufio" "bytes" + "encoding" "fmt" "io" "reflect" "strconv" "strings" - "time" ) const ( @@ -319,6 +319,7 @@ func (p *printer) popPrefix() { var ( marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) // marshalValue writes one or more XML elements representing val. @@ -348,14 +349,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat } // Check for marshaler. - if typ.Name() != "" && val.CanAddr() { + if val.CanInterface() && typ.Implements(marshalerType) { + return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate)) + } + if val.CanAddr() { pv := val.Addr() if pv.CanInterface() && pv.Type().Implements(marshalerType) { - return p.marshalInterface(pv.Interface().(Marshaler), pv.Type(), finfo, startTemplate) + return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate)) } } - if val.CanInterface() && typ.Implements(marshalerType) { - return p.marshalInterface(val.Interface().(Marshaler), typ, finfo, startTemplate) + + // Check for text marshaler. + if val.CanInterface() && typ.Implements(textMarshalerType) { + return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate)) + } } // Slices and arrays iterate over the elements. They do not have an enclosing tag. @@ -416,6 +428,21 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat continue } + if fv.Kind() == reflect.Interface && fv.IsNil() { + continue + } + + if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { + attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + if err != nil { + return err + } + if attr.Name.Local != "" { + start.Attr = append(start.Attr, attr) + } + continue + } + if fv.CanAddr() { pv := fv.Addr() if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { @@ -430,20 +457,27 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat } } - if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { - if fv.Kind() == reflect.Interface && fv.IsNil() { - continue - } - attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + if fv.CanInterface() && fv.Type().Implements(textMarshalerType) { + text, err := fv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return err } - if attr.Name.Local != "" { - start.Attr = append(start.Attr, attr) - } + start.Attr = append(start.Attr, Attr{name, string(text)}) continue } + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + start.Attr = append(start.Attr, Attr{name, string(text)}) + continue + } + } + // Dereference or skip nil pointer, interface values. switch fv.Kind() { case reflect.Ptr, reflect.Interface: @@ -490,10 +524,10 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat return p.cachedWriteError() } -// marshalInterface marshals a Marshaler interface value. -func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) error { +// defaultStart returns the default start element to use, +// given the reflect type, field info, and start template. +func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement { var start StartElement - // Precedence for the XML element name is as above, // except that we do not look inside structs for the first field. if startTemplate != nil { @@ -509,7 +543,11 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field // since it has the Marshaler methods. start.Name.Local = typ.Elem().Name() } + return start +} +// marshalInterface marshals a Marshaler interface value. +func (p *printer) marshalInterface(val Marshaler, start StartElement) error { // Push a marker onto the tag stack so that MarshalXML // cannot close the XML tags that it did not open. p.tags = append(p.tags, Name{}) @@ -528,6 +566,19 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field return nil } +// marshalTextInterface marshals a TextMarshaler interface value. +func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error { + if err := p.writeStart(&start); err != nil { + return err + } + text, err := val.MarshalText() + if err != nil { + return err + } + EscapeText(p, text) + return p.writeEnd(start.Name) +} + // writeStart writes the given start element. func (p *printer) writeStart(start *StartElement) error { if start.Name.Local == "" { @@ -591,13 +642,7 @@ func (p *printer) writeEnd(name Name) error { return nil } -var timeType = reflect.TypeOf(time.Time{}) - func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) { - // Normally we don't see structs, but this can happen for an attribute. - if val.Type() == timeType { - return val.Interface().(time.Time).Format(time.RFC3339Nano), nil, nil - } switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return strconv.FormatInt(val.Int(), 10), nil, nil @@ -629,10 +674,6 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, [] var ddBytes = []byte("--") func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { - if val.Type() == timeType { - _, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano)) - return err - } s := parentStack{p: p} for i := range tinfo.fields { finfo := &tinfo.fields[i] @@ -651,6 +692,25 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { switch finfo.flags & fMode { case fCharData: + if vf.CanInterface() && vf.Type().Implements(textMarshalerType) { + data, err := vf.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + Escape(p, data) + continue + } + if vf.CanAddr() { + pv := vf.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + data, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + Escape(p, data) + continue + } + } var scratch [64]byte switch vf.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -671,10 +731,6 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { return err } } - case reflect.Struct: - if vf.Type() == timeType { - Escape(p, []byte(vf.Interface().(time.Time).Format(time.RFC3339Nano))) - } } continue diff --git a/src/pkg/encoding/xml/read.go b/src/pkg/encoding/xml/read.go index 698bf1a22e..da7ad3baed 100644 --- a/src/pkg/encoding/xml/read.go +++ b/src/pkg/encoding/xml/read.go @@ -6,12 +6,12 @@ package xml import ( "bytes" + "encoding" "errors" "fmt" "reflect" "strconv" "strings" - "time" ) // BUG(rsc): Mapping between XML elements and data structures is inherently flawed: @@ -178,8 +178,7 @@ func receiverType(val interface{}) string { return "(" + t.String() + ")" } -// unmarshalInterface unmarshals a single XML element into val, -// which is known to implement Unmarshaler. +// unmarshalInterface unmarshals a single XML element into val. // start is the opening tag of the element. func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error { // Record that decoder must stop at end tag corresponding to start. @@ -200,6 +199,31 @@ func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error return nil } +// unmarshalTextInterface unmarshals a single XML element into val. +// The chardata contained in the element (but not its children) +// is passed to the text unmarshaler. +func (p *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler, start *StartElement) error { + var buf []byte + depth := 1 + for depth > 0 { + t, err := p.Token() + if err != nil { + return err + } + switch t := t.(type) { + case CharData: + if depth == 1 { + buf = append(buf, t...) + } + case StartElement: + depth++ + case EndElement: + depth-- + } + } + return val.UnmarshalText(buf) +} + // unmarshalAttr unmarshals a single XML attribute into val. func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { if val.Kind() == reflect.Ptr { @@ -221,7 +245,18 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { } } - // TODO: Check for and use encoding.TextUnmarshaler. + // Not an UnmarshalerAttr; try encoding.TextUnmarshaler. + if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + } + } copyValue(val, []byte(attr.Value)) return nil @@ -230,6 +265,7 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { var ( unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() ) // Unmarshal a single XML element into val. @@ -268,7 +304,16 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { } } - // TODO: Check for and use encoding.TextUnmarshaler. + if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + return p.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler), start) + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + return p.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler), start) + } + } var ( data []byte @@ -332,10 +377,6 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { v.Set(reflect.ValueOf(start.Name)) break } - if typ == timeType { - saveData = v - break - } sv = v tinfo, err = getTypeInfo(typ) @@ -464,6 +505,23 @@ Loop: } } + if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) { + if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} + } + + if saveData.IsValid() && saveData.CanAddr() { + pv := saveData.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} + } + } + if err := copyValue(saveData, data); err != nil { return err } @@ -486,6 +544,8 @@ Loop: } func copyValue(dst reflect.Value, src []byte) (err error) { + dst0 := dst + if dst.Kind() == reflect.Ptr { if dst.IsNil() { dst.Set(reflect.New(dst.Type().Elem())) @@ -496,9 +556,9 @@ func copyValue(dst reflect.Value, src []byte) (err error) { // Save accumulated data. switch dst.Kind() { case reflect.Invalid: - // Probably a commendst. + // Probably a comment. default: - return errors.New("cannot happen: unknown type " + dst.Type().String()) + return errors.New("cannot unmarshal into " + dst0.Type().String()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits()) if err != nil { @@ -531,14 +591,6 @@ func copyValue(dst reflect.Value, src []byte) (err error) { src = []byte{} } dst.SetBytes(src) - case reflect.Struct: - if dst.Type() == timeType { - tv, err := time.Parse(time.RFC3339, string(src)) - if err != nil { - return err - } - dst.Set(reflect.ValueOf(tv)) - } } return nil } diff --git a/src/pkg/encoding/xml/xml.go b/src/pkg/encoding/xml/xml.go index da8eb2e5f9..467c2ae14f 100644 --- a/src/pkg/encoding/xml/xml.go +++ b/src/pkg/encoding/xml/xml.go @@ -67,6 +67,11 @@ func (e StartElement) Copy() StartElement { return e } +// End returns the corresponding XML end element. +func (e StartElement) End() EndElement { + return EndElement{e.Name} +} + // An EndElement represents an XML end element. type EndElement struct { Name Name diff --git a/src/pkg/go/build/deps_test.go b/src/pkg/go/build/deps_test.go index 5e5982422b..1a8564136f 100644 --- a/src/pkg/go/build/deps_test.go +++ b/src/pkg/go/build/deps_test.go @@ -200,7 +200,7 @@ var pkgDeps = map[string][]string{ "encoding/hex": {"L4"}, "encoding/json": {"L4", "encoding"}, "encoding/pem": {"L4"}, - "encoding/xml": {"L4"}, + "encoding/xml": {"L4", "encoding"}, "flag": {"L4", "OS"}, "go/build": {"L4", "OS", "GOPARSER"}, "html": {"L4"},