]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: expand allowed entity names
authorPatrick Smith <pat42smith@gmail.com>
Mon, 22 Oct 2012 00:33:24 +0000 (20:33 -0400)
committerRuss Cox <rsc@golang.org>
Mon, 22 Oct 2012 00:33:24 +0000 (20:33 -0400)
Previously, multi-byte characters were not allowed. Also certain single-byte
characters, such as '-', were disallowed.
Fixes #3813.

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6641052

src/pkg/encoding/xml/xml.go
src/pkg/encoding/xml/xml_test.go

index ab853c61a430d89019726d70a979ebecb955ce9a..decb2bec65047091bcafe7e47002c540e914766f 100644 (file)
@@ -181,7 +181,6 @@ type Decoder struct {
        ns        map[string]string
        err       error
        line      int
-       tmp       [32]byte
 }
 
 // NewDecoder creates a new XML parser reading from r.
@@ -877,92 +876,92 @@ Input:
                        // XML in all its glory allows a document to define and use
                        // its own character names with <!ENTITY ...> directives.
                        // Parsers are required to recognize lt, gt, amp, apos, and quot
-                       // even if they have not been declared.  That's all we allow.
-                       var i int
-                       var semicolon bool
-                       var valid bool
-                       for i = 0; i < len(d.tmp); i++ {
-                               var ok bool
-                               d.tmp[i], ok = d.getc()
-                               if !ok {
-                                       if d.err == io.EOF {
-                                               d.err = d.syntaxError("unexpected EOF")
-                                       }
+                       // even if they have not been declared.
+                       before := d.buf.Len()
+                       d.buf.WriteByte('&')
+                       var ok bool
+                       var text string
+                       var haveText bool
+                       if b, ok = d.mustgetc(); !ok {
+                               return nil
+                       }
+                       if b == '#' {
+                               d.buf.WriteByte(b)
+                               if b, ok = d.mustgetc(); !ok {
                                        return nil
                                }
-                               c := d.tmp[i]
-                               if c == ';' {
-                                       semicolon = true
-                                       valid = i > 0
-                                       break
-                               }
-                               if 'a' <= c && c <= 'z' ||
-                                       'A' <= c && c <= 'Z' ||
-                                       '0' <= c && c <= '9' ||
-                                       c == '_' || c == '#' {
-                                       continue
-                               }
-                               d.ungetc(c)
-                               break
-                       }
-                       s := string(d.tmp[0:i])
-                       if !valid {
-                               if !d.Strict {
-                                       b0, b1 = 0, 0
-                                       d.buf.WriteByte('&')
-                                       d.buf.Write(d.tmp[0:i])
-                                       if semicolon {
-                                               d.buf.WriteByte(';')
+                               base := 10
+                               if b == 'x' {
+                                       base = 16
+                                       d.buf.WriteByte(b)
+                                       if b, ok = d.mustgetc(); !ok {
+                                               return nil
                                        }
-                                       continue Input
                                }
-                               semi := ";"
-                               if !semicolon {
-                                       semi = " (no semicolon)"
+                               start := d.buf.Len()
+                               for '0' <= b && b <= '9' ||
+                                       base == 16 && 'a' <= b && b <= 'f' ||
+                                       base == 16 && 'A' <= b && b <= 'F' {
+                                       d.buf.WriteByte(b)
+                                       if b, ok = d.mustgetc(); !ok {
+                                               return nil
+                                       }
                                }
-                               if i < len(d.tmp) {
-                                       d.err = d.syntaxError("invalid character entity &" + s + semi)
+                               if b != ';' {
+                                       d.ungetc(b)
                                } else {
-                                       d.err = d.syntaxError("invalid character entity &" + s + "... too long")
-                               }
-                               return nil
-                       }
-                       var haveText bool
-                       var text string
-                       if i >= 2 && s[0] == '#' {
-                               var n uint64
-                               var err error
-                               if i >= 3 && s[1] == 'x' {
-                                       n, err = strconv.ParseUint(s[2:], 16, 64)
-                               } else {
-                                       n, err = strconv.ParseUint(s[1:], 10, 64)
-                               }
-                               if err == nil && n <= unicode.MaxRune {
-                                       text = string(n)
-                                       haveText = true
+                                       s := string(d.buf.Bytes()[start:])
+                                       d.buf.WriteByte(';')
+                                       n, err := strconv.ParseUint(s, base, 64)
+                                       if err == nil && n <= unicode.MaxRune {
+                                               text = string(n)
+                                               haveText = true
+                                       }
                                }
                        } else {
-                               if r, ok := entity[s]; ok {
-                                       text = string(r)
-                                       haveText = true
-                               } else if d.Entity != nil {
-                                       text, haveText = d.Entity[s]
+                               d.ungetc(b)
+                               if !d.readName() {
+                                       if d.err != nil {
+                                               return nil
+                                       }
+                                       ok = false
                                }
-                       }
-                       if !haveText {
-                               if !d.Strict {
-                                       b0, b1 = 0, 0
-                                       d.buf.WriteByte('&')
-                                       d.buf.Write(d.tmp[0:i])
+                               if b, ok = d.mustgetc(); !ok {
+                                       return nil
+                               }
+                               if b != ';' {
+                                       d.ungetc(b)
+                               } else {
+                                       name := d.buf.Bytes()[before+1:]
                                        d.buf.WriteByte(';')
-                                       continue Input
+                                       if isName(name) {
+                                               s := string(name)
+                                               if r, ok := entity[s]; ok {
+                                                       text = string(r)
+                                                       haveText = true
+                                               } else if d.Entity != nil {
+                                                       text, haveText = d.Entity[s]
+                                               }
+                                       }
                                }
-                               d.err = d.syntaxError("invalid character entity &" + s + ";")
-                               return nil
                        }
-                       d.buf.Write([]byte(text))
-                       b0, b1 = 0, 0
-                       continue Input
+
+                       if haveText {
+                               d.buf.Truncate(before)
+                               d.buf.Write([]byte(text))
+                               b0, b1 = 0, 0
+                               continue Input
+                       }
+                       if !d.Strict {
+                               b0, b1 = 0, 0
+                               continue Input
+                       }
+                       ent := string(d.buf.Bytes()[before])
+                       if ent[len(ent)-1] != ';' {
+                               ent += " (no semicolon)"
+                       }
+                       d.err = d.syntaxError("invalid character entity " + ent)
+                       return nil
                }
 
                // We must rewrite unescaped \r and \r\n into \n.
@@ -1030,18 +1029,34 @@ func (d *Decoder) nsname() (name Name, ok bool) {
 // Do not set d.err if the name is missing (unless unexpected EOF is received):
 // let the caller provide better context.
 func (d *Decoder) name() (s string, ok bool) {
+       d.buf.Reset()
+       if !d.readName() {
+               return "", false
+       }
+
+       // Now we check the characters.
+       s = d.buf.String()
+       if !isName([]byte(s)) {
+               d.err = d.syntaxError("invalid XML name: " + s)
+               return "", false
+       }
+       return s, true
+}
+
+// Read a name and append its bytes to d.buf.
+// The name is delimited by any single-byte character not valid in names.
+// All multi-byte characters are accepted; the caller must check their validity.
+func (d *Decoder) readName() (ok bool) {
        var b byte
        if b, ok = d.mustgetc(); !ok {
                return
        }
-
-       // As a first approximation, we gather the bytes [A-Za-z_:.-\x80-\xFF]*
        if b < utf8.RuneSelf && !isNameByte(b) {
                d.ungetc(b)
-               return "", false
+               return false
        }
-       d.buf.Reset()
        d.buf.WriteByte(b)
+
        for {
                if b, ok = d.mustgetc(); !ok {
                        return
@@ -1052,16 +1067,7 @@ func (d *Decoder) name() (s string, ok bool) {
                }
                d.buf.WriteByte(b)
        }
-
-       // Then we check the characters.
-       s = d.buf.String()
-       for i, c := range s {
-               if !unicode.Is(first, c) && (i == 0 || !unicode.Is(second, c)) {
-                       d.err = d.syntaxError("invalid XML name: " + s)
-                       return "", false
-               }
-       }
-       return s, true
+       return true
 }
 
 func isNameByte(c byte) bool {
@@ -1071,6 +1077,30 @@ func isNameByte(c byte) bool {
                c == '_' || c == ':' || c == '.' || c == '-'
 }
 
+func isName(s []byte) bool {
+       if len(s) == 0 {
+               return false
+       }
+       c, n := utf8.DecodeRune(s)
+       if c == utf8.RuneError && n == 1 {
+               return false
+       }
+       if !unicode.Is(first, c) {
+               return false
+       }
+       for n < len(s) {
+               s = s[n:]
+               c, n = utf8.DecodeRune(s)
+               if c == utf8.RuneError && n == 1 {
+                       return false
+               }
+               if !unicode.Is(first, c) && !unicode.Is(second, c) {
+                       return false
+               }
+       }
+       return true
+}
+
 // These tables were generated by cut and paste from Appendix B of
 // the XML spec at http://www.xml.com/axml/testaxml.htm
 // and then reformatting.  First corresponds to (Letter | '_' | ':')
index 2ad4d4af5df595fee839ca85d0517422a72b3db4..981d3520313d18b01b0e11476748630cdca66215 100644 (file)
@@ -19,6 +19,7 @@ const testInput = `
 <body xmlns:foo="ns1" xmlns="ns2" xmlns:tag="ns3" ` +
        "\r\n\t" + `  >
   <hello lang="en">World &lt;&gt;&apos;&quot; &#x767d;&#40300;翔</hello>
+  <query>&何; &is-it;</query>
   <goodbye />
   <outer foo:attr="value" xmlns:tag="ns4">
     <inner/>
@@ -28,6 +29,8 @@ const testInput = `
   </tag:name>
 </body><!-- missing final newline -->`
 
+var testEntity = map[string]string{"何": "What", "is-it": "is it?"}
+
 var rawTokens = []Token{
        CharData("\n"),
        ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)},
@@ -41,6 +44,10 @@ var rawTokens = []Token{
        CharData("World <>'\" 白鵬翔"),
        EndElement{Name{"", "hello"}},
        CharData("\n  "),
+       StartElement{Name{"", "query"}, []Attr{}},
+       CharData("What is it?"),
+       EndElement{Name{"", "query"}},
+       CharData("\n  "),
        StartElement{Name{"", "goodbye"}, []Attr{}},
        EndElement{Name{"", "goodbye"}},
        CharData("\n  "),
@@ -74,6 +81,10 @@ var cookedTokens = []Token{
        CharData("World <>'\" 白鵬翔"),
        EndElement{Name{"ns2", "hello"}},
        CharData("\n  "),
+       StartElement{Name{"ns2", "query"}, []Attr{}},
+       CharData("What is it?"),
+       EndElement{Name{"ns2", "query"}},
+       CharData("\n  "),
        StartElement{Name{"ns2", "goodbye"}, []Attr{}},
        EndElement{Name{"ns2", "goodbye"}},
        CharData("\n  "),
@@ -156,6 +167,7 @@ var xmlInput = []string{
 
 func TestRawToken(t *testing.T) {
        d := NewDecoder(strings.NewReader(testInput))
+       d.Entity = testEntity
        testRawToken(t, d, rawTokens)
 }
 
@@ -164,8 +176,14 @@ const nonStrictInput = `
 <tag>&unknown;entity</tag>
 <tag>&#123</tag>
 <tag>&#zzz;</tag>
+<tag>&なまえ3;</tag>
+<tag>&lt-gt;</tag>
+<tag>&;</tag>
+<tag>&0a;</tag>
 `
 
+var nonStringEntity = map[string]string{"": "oops!", "0a": "oops!"}
+
 var nonStrictTokens = []Token{
        CharData("\n"),
        StartElement{Name{"", "tag"}, []Attr{}},
@@ -184,6 +202,22 @@ var nonStrictTokens = []Token{
        CharData("&#zzz;"),
        EndElement{Name{"", "tag"}},
        CharData("\n"),
+       StartElement{Name{"", "tag"}, []Attr{}},
+       CharData("&なまえ3;"),
+       EndElement{Name{"", "tag"}},
+       CharData("\n"),
+       StartElement{Name{"", "tag"}, []Attr{}},
+       CharData("&lt-gt;"),
+       EndElement{Name{"", "tag"}},
+       CharData("\n"),
+       StartElement{Name{"", "tag"}, []Attr{}},
+       CharData("&;"),
+       EndElement{Name{"", "tag"}},
+       CharData("\n"),
+       StartElement{Name{"", "tag"}, []Attr{}},
+       CharData("&0a;"),
+       EndElement{Name{"", "tag"}},
+       CharData("\n"),
 }
 
 func TestNonStrictRawToken(t *testing.T) {
@@ -317,6 +351,7 @@ func TestNestedDirectives(t *testing.T) {
 
 func TestToken(t *testing.T) {
        d := NewDecoder(strings.NewReader(testInput))
+       d.Entity = testEntity
 
        for i, want := range cookedTokens {
                have, err := d.Token()