]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: add (*Encoder).Close
authorAxel Wagner <axel.wagner.hh@googlemail.com>
Thu, 18 Aug 2022 14:18:34 +0000 (16:18 +0200)
committerGopher Robot <gobot@golang.org>
Tue, 23 Aug 2022 18:24:30 +0000 (18:24 +0000)
Flush can not check for unclosed elements, as more data might be encoded
after Flush is called. Close implicitly calls Flush and also checks that
all opened elements are closed as well.

Fixes #53346

Change-Id: I889b9f5ae54e5dfabb9e6948d96c5ed7bc1110f9
Reviewed-on: https://go-review.googlesource.com/c/go/+/424777
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Reviewed-by: David Chase <drchase@google.com>
api/next/53346.txt [new file with mode: 0644]
src/encoding/xml/marshal.go
src/encoding/xml/marshal_test.go

diff --git a/api/next/53346.txt b/api/next/53346.txt
new file mode 100644 (file)
index 0000000..dd39f23
--- /dev/null
@@ -0,0 +1 @@
+pkg encoding/xml, method (*Encoder) Close() error #53346
index 01f673a85165a07ade0e22d75274490961cc77ad..07b6042da87bf929dd2c3256187fa704b05841e8 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bufio"
        "bytes"
        "encoding"
+       "errors"
        "fmt"
        "io"
        "reflect"
@@ -78,7 +79,11 @@ const (
 // Marshal will return an error if asked to marshal a channel, function, or map.
 func Marshal(v any) ([]byte, error) {
        var b bytes.Buffer
-       if err := NewEncoder(&b).Encode(v); err != nil {
+       enc := NewEncoder(&b)
+       if err := enc.Encode(v); err != nil {
+               return nil, err
+       }
+       if err := enc.Close(); err != nil {
                return nil, err
        }
        return b.Bytes(), nil
@@ -129,6 +134,9 @@ func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
        if err := enc.Encode(v); err != nil {
                return nil, err
        }
+       if err := enc.Close(); err != nil {
+               return nil, err
+       }
        return b.Bytes(), nil
 }
 
@@ -139,7 +147,7 @@ type Encoder struct {
 
 // NewEncoder returns a new encoder that writes to w.
 func NewEncoder(w io.Writer) *Encoder {
-       e := &Encoder{printer{Writer: bufio.NewWriter(w)}}
+       e := &Encoder{printer{w: bufio.NewWriter(w)}}
        e.p.encoder = e
        return e
 }
@@ -163,7 +171,7 @@ func (enc *Encoder) Encode(v any) error {
        if err != nil {
                return err
        }
-       return enc.p.Flush()
+       return enc.p.w.Flush()
 }
 
 // EncodeElement writes the XML encoding of v to the stream,
@@ -178,7 +186,7 @@ func (enc *Encoder) EncodeElement(v any, start StartElement) error {
        if err != nil {
                return err
        }
-       return enc.p.Flush()
+       return enc.p.w.Flush()
 }
 
 var (
@@ -224,7 +232,7 @@ func (enc *Encoder) EncodeToken(t Token) error {
        case ProcInst:
                // First token to be encoded which is also a ProcInst with target of xml
                // is the xml declaration. The only ProcInst where target of xml is allowed.
-               if t.Target == "xml" && p.Buffered() != 0 {
+               if t.Target == "xml" && p.w.Buffered() != 0 {
                        return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded")
                }
                if !isNameString(t.Target) {
@@ -297,11 +305,18 @@ func isValidDirective(dir Directive) bool {
 // Flush flushes any buffered XML to the underlying writer.
 // See the EncodeToken documentation for details about when it is necessary.
 func (enc *Encoder) Flush() error {
-       return enc.p.Flush()
+       return enc.p.w.Flush()
+}
+
+// Close the Encoder, indicating that no more data will be written. It flushes
+// any buffered XML to the underlying writer and returns an error if the
+// written XML is invalid (e.g. by containing unclosed elements).
+func (enc *Encoder) Close() error {
+       return enc.p.Close()
 }
 
 type printer struct {
-       *bufio.Writer
+       w          *bufio.Writer
        encoder    *Encoder
        seq        int
        indent     string
@@ -313,6 +328,8 @@ type printer struct {
        attrPrefix map[string]string // map name space -> prefix
        prefixes   []string
        tags       []Name
+       closed     bool
+       err        error
 }
 
 // createAttrPrefix finds the name space prefix attribute to use for the given name space,
@@ -961,6 +978,56 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
        return p.cachedWriteError()
 }
 
+// Write implements io.Writer
+func (p *printer) Write(b []byte) (n int, err error) {
+       if p.closed && p.err == nil {
+               p.err = errors.New("use of closed Encoder")
+       }
+       if p.err == nil {
+               n, p.err = p.w.Write(b)
+       }
+       return n, p.err
+}
+
+// WriteString implements io.StringWriter
+func (p *printer) WriteString(s string) (n int, err error) {
+       if p.closed && p.err == nil {
+               p.err = errors.New("use of closed Encoder")
+       }
+       if p.err == nil {
+               n, p.err = p.w.WriteString(s)
+       }
+       return n, p.err
+}
+
+// WriteByte implements io.ByteWriter
+func (p *printer) WriteByte(c byte) error {
+       if p.closed && p.err == nil {
+               p.err = errors.New("use of closed Encoder")
+       }
+       if p.err == nil {
+               p.err = p.w.WriteByte(c)
+       }
+       return p.err
+}
+
+// Close the Encoder, indicating that no more data will be written. It flushes
+// any buffered XML to the underlying writer and returns an error if the
+// written XML is invalid (e.g. by containing unclosed elements).
+func (p *printer) Close() error {
+       if p.closed {
+               return nil
+       }
+       p.closed = true
+       if err := p.w.Flush(); err != nil {
+               return err
+       }
+       if len(p.tags) > 0 {
+               return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local)
+       }
+       return nil
+}
+
 // return the bufio Writer's cached write error
 func (p *printer) cachedWriteError() error {
        _, err := p.Write(nil)
index 3fe7e2dc001c4ddb3fda2fb6ef1913d42d72b701..774793a6c54ff52d21e00e83afe87cb66db92d17 100644 (file)
@@ -2531,3 +2531,61 @@ func TestMarshalZeroValue(t *testing.T) {
                t.Fatalf("unexpected unmarshal result, want %q but got %q", proofXml, anotherXML)
        }
 }
+
+var closeTests = []struct {
+       desc string
+       toks []Token
+       want string
+       err  string
+}{{
+       desc: "unclosed start element",
+       toks: []Token{
+               StartElement{Name{"", "foo"}, nil},
+       },
+       want: `<foo>`,
+       err:  "unclosed tag <foo>",
+}, {
+       desc: "closed element",
+       toks: []Token{
+               StartElement{Name{"", "foo"}, nil},
+               EndElement{Name{"", "foo"}},
+       },
+       want: `<foo></foo>`,
+}, {
+       desc: "directive",
+       toks: []Token{
+               Directive("foo"),
+       },
+       want: `<!foo>`,
+}}
+
+func TestClose(t *testing.T) {
+       for _, tt := range closeTests {
+               tt := tt
+               t.Run(tt.desc, func(t *testing.T) {
+                       var out strings.Builder
+                       enc := NewEncoder(&out)
+                       for j, tok := range tt.toks {
+                               if err := enc.EncodeToken(tok); err != nil {
+                                       t.Fatalf("token #%d: %v", j, err)
+                               }
+                       }
+                       err := enc.Close()
+                       switch {
+                       case tt.err != "" && err == nil:
+                               t.Error(" expected error; got none")
+                       case tt.err == "" && err != nil:
+                               t.Errorf(" got error: %v", err)
+                       case tt.err != "" && err != nil && tt.err != err.Error():
+                               t.Errorf(" error mismatch; got %v, want %v", err, tt.err)
+                       }
+                       if got := out.String(); got != tt.want {
+                               t.Errorf("\ngot  %v\nwant %v", got, tt.want)
+                       }
+                       t.Log(enc.p.closed)
+                       if err := enc.EncodeToken(Directive("foo")); err == nil {
+                               t.Errorf("unexpected success when encoding after Close")
+                       }
+               })
+       }
+}