]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: add, support Unmarshaler interface
authorRuss Cox <rsc@golang.org>
Wed, 14 Aug 2013 18:57:45 +0000 (14:57 -0400)
committerRuss Cox <rsc@golang.org>
Wed, 14 Aug 2013 18:57:45 +0000 (14:57 -0400)
See golang.org/s/go12xml for design.

R=golang-dev, dominik.honnef, dan.kortschak
CC=golang-dev
https://golang.org/cl/12556043

src/pkg/encoding/xml/read.go
src/pkg/encoding/xml/read_test.go
src/pkg/encoding/xml/xml.go

index f960f5649caa7acef81c568b37a4b21450e99bd8..698bf1a22efba53a91af4565ccc6ea15c1937962 100644 (file)
@@ -7,6 +7,7 @@ package xml
 import (
        "bytes"
        "errors"
+       "fmt"
        "reflect"
        "strconv"
        "strings"
@@ -137,6 +138,100 @@ type UnmarshalError string
 
 func (e UnmarshalError) Error() string { return string(e) }
 
+// Unmarshaler is the interface implemented by objects that can unmarshal
+// an XML element description of themselves.
+//
+// UnmarshalXML decodes a single XML element
+// beginning with the given start element.
+// If it returns an error, the outer call to Unmarshal stops and
+// returns that error.
+// UnmarshalXML must consume exactly one XML element.
+// One common implementation strategy is to unmarshal into
+// a separate value with a layout matching the expected XML
+// using d.DecodeElement,  and then to copy the data from
+// that value into the receiver.
+// Another common strategy is to use d.Token to process the
+// XML object one token at a time.
+// UnmarshalXML may not use d.RawToken.
+type Unmarshaler interface {
+       UnmarshalXML(d *Decoder, start StartElement) error
+}
+
+// UnmarshalerAttr is the interface implemented by objects that can unmarshal
+// an XML attribute description of themselves.
+//
+// UnmarshalXMLAttr decodes a single XML attribute.
+// If it returns an error, the outer call to Unmarshal stops and
+// returns that error.
+// UnmarshalXMLAttr is used only for struct fields with the
+// "attr" option in the field tag.
+type UnmarshalerAttr interface {
+       UnmarshalXMLAttr(attr Attr) error
+}
+
+// receiverType returns the receiver type to use in an expression like "%s.MethodName".
+func receiverType(val interface{}) string {
+       t := reflect.TypeOf(val)
+       if t.Name() != "" {
+               return t.String()
+       }
+       return "(" + t.String() + ")"
+}
+
+// unmarshalInterface unmarshals a single XML element into val,
+// which is known to implement Unmarshaler.
+// start is the opening tag of the element.
+func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
+       // Record that decoder must stop at end tag corresponding to start.
+       p.pushEOF()
+
+       p.unmarshalDepth++
+       err := val.UnmarshalXML(p, *start)
+       p.unmarshalDepth--
+       if err != nil {
+               p.popEOF()
+               return err
+       }
+
+       if !p.popEOF() {
+               return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local)
+       }
+
+       return nil
+}
+
+// unmarshalAttr unmarshals a single XML attribute into val.
+func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
+       if val.Kind() == reflect.Ptr {
+               if val.IsNil() {
+                       val.Set(reflect.New(val.Type().Elem()))
+               }
+               val = val.Elem()
+       }
+
+       if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
+               // This is an unmarshaler with a non-pointer receiver,
+               // so it's likely to be incorrect, but we do what we're told.
+               return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
+       }
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
+                       return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
+               }
+       }
+
+       // TODO: Check for and use encoding.TextUnmarshaler.
+
+       copyValue(val, []byte(attr.Value))
+       return nil
+}
+
+var (
+       unmarshalerType     = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
+       unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem()
+)
+
 // Unmarshal a single XML element into val.
 func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
        // Find start element if we need it.
@@ -153,13 +248,28 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                }
        }
 
-       if pv := val; pv.Kind() == reflect.Ptr {
-               if pv.IsNil() {
-                       pv.Set(reflect.New(pv.Type().Elem()))
+       if val.Kind() == reflect.Ptr {
+               if val.IsNil() {
+                       val.Set(reflect.New(val.Type().Elem()))
                }
-               val = pv.Elem()
+               val = val.Elem()
+       }
+
+       if val.CanInterface() && val.Type().Implements(unmarshalerType) {
+               // This is an unmarshaler with a non-pointer receiver,
+               // so it's likely to be incorrect, but we do what we're told.
+               return p.unmarshalInterface(val.Interface().(Unmarshaler), start)
        }
 
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
+                       return p.unmarshalInterface(pv.Interface().(Unmarshaler), start)
+               }
+       }
+
+       // TODO: Check for and use encoding.TextUnmarshaler.
+
        var (
                data         []byte
                saveData     reflect.Value
@@ -264,7 +374,9 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                                // Look for attribute.
                                for _, a := range start.Attr {
                                        if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) {
-                                               copyValue(strv, []byte(a.Value))
+                                               if err := p.unmarshalAttr(strv, a); err != nil {
+                                                       return err
+                                               }
                                                break
                                        }
                                }
index 7d28c5d7d6ca41cbe360e113772150e6462c7528..1404c900f50506977c98c3f6d80f02cd0c841fa3 100644 (file)
@@ -5,6 +5,7 @@
 package xml
 
 import (
+       "io"
        "reflect"
        "strings"
        "testing"
@@ -621,3 +622,66 @@ func TestMarshalNSAttr(t *testing.T) {
                t.Errorf("Unmarshal = %q, want %q", dst, src)
        }
 }
+
+type MyCharData struct {
+       body string
+}
+
+func (m *MyCharData) UnmarshalXML(d *Decoder, start StartElement) error {
+       for {
+               t, err := d.Token()
+               if err == io.EOF { // found end of element
+                       break
+               }
+               if err != nil {
+                       return err
+               }
+               if char, ok := t.(CharData); ok {
+                       m.body += string(char)
+               }
+       }
+       return nil
+}
+
+var _ Unmarshaler = (*MyCharData)(nil)
+
+func (m *MyCharData) UnmarshalXMLAttr(attr Attr) error {
+       panic("must not call")
+}
+
+type MyAttr struct {
+       attr string
+}
+
+func (m *MyAttr) UnmarshalXMLAttr(attr Attr) error {
+       m.attr = attr.Value
+       return nil
+}
+
+var _ UnmarshalerAttr = (*MyAttr)(nil)
+
+type MyStruct struct {
+       Data *MyCharData
+       Attr *MyAttr `xml:",attr"`
+
+       Data2 MyCharData
+       Attr2 MyAttr `xml:",attr"`
+}
+
+func TestUnmarshaler(t *testing.T) {
+       xml := `<?xml version="1.0" encoding="utf-8"?>
+               <MyStruct Attr="attr1" Attr2="attr2">
+               <Data>hello <!-- comment -->world</Data>
+               <Data2>howdy <!-- comment -->world</Data2>
+               </MyStruct>
+       `
+
+       var m MyStruct
+       if err := Unmarshal([]byte(xml), &m); err != nil {
+               t.Fatal(err)
+       }
+
+       if m.Data == nil || m.Attr == nil || m.Data.body != "hello world" || m.Attr.attr != "attr1" || m.Data2.body != "howdy world" || m.Attr2.attr != "attr2" {
+               t.Errorf("m=%#+v\n", m)
+       }
+}
index 2f36604797df01a7257bbfc6e41b1d592f3ea8f4..da8eb2e5f9d7d60582d3dfa9ea9a6de92080a32e 100644 (file)
@@ -16,6 +16,7 @@ package xml
 import (
        "bufio"
        "bytes"
+       "errors"
        "fmt"
        "io"
        "strconv"
@@ -174,18 +175,19 @@ type Decoder struct {
        // the attribute xmlns="DefaultSpace".
        DefaultSpace string
 
-       r         io.ByteReader
-       buf       bytes.Buffer
-       saved     *bytes.Buffer
-       stk       *stack
-       free      *stack
-       needClose bool
-       toClose   Name
-       nextToken Token
-       nextByte  int
-       ns        map[string]string
-       err       error
-       line      int
+       r              io.ByteReader
+       buf            bytes.Buffer
+       saved          *bytes.Buffer
+       stk            *stack
+       free           *stack
+       needClose      bool
+       toClose        Name
+       nextToken      Token
+       nextByte       int
+       ns             map[string]string
+       err            error
+       line           int
+       unmarshalDepth int
 }
 
 // NewDecoder creates a new XML parser reading from r.
@@ -223,10 +225,14 @@ func NewDecoder(r io.Reader) *Decoder {
 // If Token encounters an unrecognized name space prefix,
 // it uses the prefix as the Space rather than report an error.
 func (d *Decoder) Token() (t Token, err error) {
+       if d.stk != nil && d.stk.kind == stkEOF {
+               err = io.EOF
+               return
+       }
        if d.nextToken != nil {
                t = d.nextToken
                d.nextToken = nil
-       } else if t, err = d.RawToken(); err != nil {
+       } else if t, err = d.rawToken(); err != nil {
                return
        }
 
@@ -322,6 +328,7 @@ type stack struct {
 const (
        stkStart = iota
        stkNs
+       stkEOF
 )
 
 func (d *Decoder) push(kind int) *stack {
@@ -347,6 +354,43 @@ func (d *Decoder) pop() *stack {
        return s
 }
 
+// Record that after the current element is finished
+// (that element is already pushed on the stack)
+// Token should return EOF until popEOF is called.
+func (d *Decoder) pushEOF() {
+       // Walk down stack to find Start.
+       // It might not be the top, because there might be stkNs
+       // entries above it.
+       start := d.stk
+       for start.kind != stkStart {
+               start = start.next
+       }
+       // The stkNs entries below a start are associated with that
+       // element too; skip over them.
+       for start.next != nil && start.next.kind == stkNs {
+               start = start.next
+       }
+       s := d.free
+       if s != nil {
+               d.free = s.next
+       } else {
+               s = new(stack)
+       }
+       s.kind = stkEOF
+       s.next = start.next
+       start.next = s
+}
+
+// Undo a pushEOF.
+// The element must have been finished, so the EOF should be at the top of the stack.
+func (d *Decoder) popEOF() bool {
+       if d.stk == nil || d.stk.kind != stkEOF {
+               return false
+       }
+       d.pop()
+       return true
+}
+
 // Record that we are starting an element with the given name.
 func (d *Decoder) pushElement(name Name) {
        s := d.push(stkStart)
@@ -395,9 +439,9 @@ func (d *Decoder) popElement(t *EndElement) bool {
                return false
        }
 
-       // Pop stack until a Start is on the top, undoing the
+       // Pop stack until a Start or EOF is on the top, undoing the
        // translations that were associated with the element we just closed.
-       for d.stk != nil && d.stk.kind != stkStart {
+       for d.stk != nil && d.stk.kind != stkStart && d.stk.kind != stkEOF {
                s := d.pop()
                if s.ok {
                        d.ns[s.name.Local] = s.name.Space
@@ -429,10 +473,19 @@ func (d *Decoder) autoClose(t Token) (Token, bool) {
        return nil, false
 }
 
+var errRawToken = errors.New("xml: cannot use RawToken from UnmarshalXML method")
+
 // RawToken is like Token but does not verify that
 // start and end elements match and does not translate
 // name space prefixes to their corresponding URLs.
 func (d *Decoder) RawToken() (Token, error) {
+       if d.unmarshalDepth > 0 {
+               return nil, errRawToken
+       }
+       return d.rawToken()
+}
+
+func (d *Decoder) rawToken() (Token, error) {
        if d.err != nil {
                return nil, d.err
        }
@@ -484,8 +537,7 @@ func (d *Decoder) RawToken() (Token, error) {
 
        case '?':
                // <?: Processing instruction.
-               // TODO(rsc): Should parse the <?xml declaration to make sure
-               // the version is 1.0 and the encoding is UTF-8.
+               // TODO(rsc): Should parse the <?xml declaration to make sure the version is 1.0.
                var target string
                if target, ok = d.name(); !ok {
                        if d.err == nil {
@@ -1112,6 +1164,30 @@ func isName(s []byte) bool {
        return true
 }
 
+func isNameString(s string) bool {
+       if len(s) == 0 {
+               return false
+       }
+       c, n := utf8.DecodeRuneInString(s)
+       if c == utf8.RuneError && n == 1 {
+               return false
+       }
+       if !unicode.Is(first, c) {
+               return false
+       }
+       for n < len(s) {
+               s = s[n:]
+               c, n = utf8.DecodeRuneInString(s)
+               if c == utf8.RuneError && n == 1 {
+                       return false
+               }
+               if !unicode.Is(first, c) && !unicode.Is(second, c) {
+                       return false
+               }
+       }
+       return true
+}
+
 // These tables were generated by cut and paste from Appendix B of
 // the XML spec at http://www.xml.com/axml/testaxml.htm
 // and then reformatting.  First corresponds to (Letter | '_' | ':')
@@ -1778,6 +1854,45 @@ func EscapeText(w io.Writer, s []byte) error {
        return nil
 }
 
+// EscapeString writes to p the properly escaped XML equivalent
+// of the plain text data s.
+func (p *printer) EscapeString(s string) {
+       var esc []byte
+       last := 0
+       for i := 0; i < len(s); {
+               r, width := utf8.DecodeRuneInString(s[i:])
+               i += width
+               switch r {
+               case '"':
+                       esc = esc_quot
+               case '\'':
+                       esc = esc_apos
+               case '&':
+                       esc = esc_amp
+               case '<':
+                       esc = esc_lt
+               case '>':
+                       esc = esc_gt
+               case '\t':
+                       esc = esc_tab
+               case '\n':
+                       esc = esc_nl
+               case '\r':
+                       esc = esc_cr
+               default:
+                       if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
+                               esc = esc_fffd
+                               break
+                       }
+                       continue
+               }
+               p.WriteString(s[last : i-width])
+               p.Write(esc)
+               last = i
+       }
+       p.WriteString(s[last:])
+}
+
 // Escape is like EscapeText but omits the error return value.
 // It is provided for backwards compatibility with Go 1.0.
 // Code targeting Go 1.1 or later should use EscapeText.