"bufio"
"bytes"
"encoding"
+ "errors"
"fmt"
"io"
"reflect"
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(v any) ([]byte, error) {
var b bytes.Buffer
- if err := NewEncoder(&b).Encode(v); err != nil {
+ enc := NewEncoder(&b)
+ if err := enc.Encode(v); err != nil {
+ return nil, err
+ }
+ if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
if err := enc.Encode(v); err != nil {
return nil, err
}
+ if err := enc.Close(); err != nil {
+ return nil, err
+ }
return b.Bytes(), nil
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
- e := &Encoder{printer{Writer: bufio.NewWriter(w)}}
+ e := &Encoder{printer{w: bufio.NewWriter(w)}}
e.p.encoder = e
return e
}
if err != nil {
return err
}
- return enc.p.Flush()
+ return enc.p.w.Flush()
}
// EncodeElement writes the XML encoding of v to the stream,
if err != nil {
return err
}
- return enc.p.Flush()
+ return enc.p.w.Flush()
}
var (
case ProcInst:
// First token to be encoded which is also a ProcInst with target of xml
// is the xml declaration. The only ProcInst where target of xml is allowed.
- if t.Target == "xml" && p.Buffered() != 0 {
+ if t.Target == "xml" && p.w.Buffered() != 0 {
return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded")
}
if !isNameString(t.Target) {
// Flush flushes any buffered XML to the underlying writer.
// See the EncodeToken documentation for details about when it is necessary.
func (enc *Encoder) Flush() error {
- return enc.p.Flush()
+ return enc.p.w.Flush()
+}
+
+// Close the Encoder, indicating that no more data will be written. It flushes
+// any buffered XML to the underlying writer and returns an error if the
+// written XML is invalid (e.g. by containing unclosed elements).
+func (enc *Encoder) Close() error {
+ return enc.p.Close()
}
type printer struct {
- *bufio.Writer
+ w *bufio.Writer
encoder *Encoder
seq int
indent string
attrPrefix map[string]string // map name space -> prefix
prefixes []string
tags []Name
+ closed bool
+ err error
}
// createAttrPrefix finds the name space prefix attribute to use for the given name space,
return p.cachedWriteError()
}
+// Write implements io.Writer
+func (p *printer) Write(b []byte) (n int, err error) {
+ if p.closed && p.err == nil {
+ p.err = errors.New("use of closed Encoder")
+ }
+ if p.err == nil {
+ n, p.err = p.w.Write(b)
+ }
+ return n, p.err
+}
+
+// WriteString implements io.StringWriter
+func (p *printer) WriteString(s string) (n int, err error) {
+ if p.closed && p.err == nil {
+ p.err = errors.New("use of closed Encoder")
+ }
+ if p.err == nil {
+ n, p.err = p.w.WriteString(s)
+ }
+ return n, p.err
+}
+
+// WriteByte implements io.ByteWriter
+func (p *printer) WriteByte(c byte) error {
+ if p.closed && p.err == nil {
+ p.err = errors.New("use of closed Encoder")
+ }
+ if p.err == nil {
+ p.err = p.w.WriteByte(c)
+ }
+ return p.err
+}
+
+// Close the Encoder, indicating that no more data will be written. It flushes
+// any buffered XML to the underlying writer and returns an error if the
+// written XML is invalid (e.g. by containing unclosed elements).
+func (p *printer) Close() error {
+ if p.closed {
+ return nil
+ }
+ p.closed = true
+ if err := p.w.Flush(); err != nil {
+ return err
+ }
+ if len(p.tags) > 0 {
+ return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local)
+ }
+ return nil
+}
+
// return the bufio Writer's cached write error
func (p *printer) cachedWriteError() error {
_, err := p.Write(nil)
t.Fatalf("unexpected unmarshal result, want %q but got %q", proofXml, anotherXML)
}
}
+
+var closeTests = []struct {
+ desc string
+ toks []Token
+ want string
+ err string
+}{{
+ desc: "unclosed start element",
+ toks: []Token{
+ StartElement{Name{"", "foo"}, nil},
+ },
+ want: `<foo>`,
+ err: "unclosed tag <foo>",
+}, {
+ desc: "closed element",
+ toks: []Token{
+ StartElement{Name{"", "foo"}, nil},
+ EndElement{Name{"", "foo"}},
+ },
+ want: `<foo></foo>`,
+}, {
+ desc: "directive",
+ toks: []Token{
+ Directive("foo"),
+ },
+ want: `<!foo>`,
+}}
+
+func TestClose(t *testing.T) {
+ for _, tt := range closeTests {
+ tt := tt
+ t.Run(tt.desc, func(t *testing.T) {
+ var out strings.Builder
+ enc := NewEncoder(&out)
+ for j, tok := range tt.toks {
+ if err := enc.EncodeToken(tok); err != nil {
+ t.Fatalf("token #%d: %v", j, err)
+ }
+ }
+ err := enc.Close()
+ switch {
+ case tt.err != "" && err == nil:
+ t.Error(" expected error; got none")
+ case tt.err == "" && err != nil:
+ t.Errorf(" got error: %v", err)
+ case tt.err != "" && err != nil && tt.err != err.Error():
+ t.Errorf(" error mismatch; got %v, want %v", err, tt.err)
+ }
+ if got := out.String(); got != tt.want {
+ t.Errorf("\ngot %v\nwant %v", got, tt.want)
+ }
+ t.Log(enc.p.closed)
+ if err := enc.EncodeToken(Directive("foo")); err == nil {
+ t.Errorf("unexpected success when encoding after Close")
+ }
+ })
+ }
+}