]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: add Encoder.DisableHTMLEscaping
authorCaleb Spare <cespare@gmail.com>
Sun, 10 Apr 2016 04:18:22 +0000 (21:18 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 22 Apr 2016 21:35:56 +0000 (21:35 +0000)
This provides a way to disable the escaping of <, >, and & in JSON
strings.

Fixes #14749.

Change-Id: I1afeb0244455fc8b06c6cce920444532f229555b
Reviewed-on: https://go-review.googlesource.com/21796
Run-TryBot: Caleb Spare <cespare@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/encoding/json/encode.go
src/encoding/json/encode_test.go
src/encoding/json/stream.go
src/encoding/json/stream_test.go

index 0088f25ab82af3a489f99d28ef0cb7526513ad25..d8c779869b33f89be3d6e5964ce06763a7aba930 100644 (file)
@@ -49,6 +49,7 @@ import (
 // The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e"
 // to keep some browsers from misinterpreting JSON output as HTML.
 // Ampersand "&" is also escaped to "\u0026" for the same reason.
+// This escaping can be disabled using an Encoder with DisableHTMLEscaping.
 //
 // Array and slice values encode as JSON arrays, except that
 // []byte encodes as a base64-encoded string, and a nil slice
@@ -136,7 +137,7 @@ import (
 //
 func Marshal(v interface{}) ([]byte, error) {
        e := &encodeState{}
-       err := e.marshal(v)
+       err := e.marshal(v, encOpts{escapeHTML: true})
        if err != nil {
                return nil, err
        }
@@ -259,7 +260,7 @@ func newEncodeState() *encodeState {
        return new(encodeState)
 }
 
-func (e *encodeState) marshal(v interface{}) (err error) {
+func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
        defer func() {
                if r := recover(); r != nil {
                        if _, ok := r.(runtime.Error); ok {
@@ -271,7 +272,7 @@ func (e *encodeState) marshal(v interface{}) (err error) {
                        err = r.(error)
                }
        }()
-       e.reflectValue(reflect.ValueOf(v))
+       e.reflectValue(reflect.ValueOf(v), opts)
        return nil
 }
 
@@ -297,11 +298,18 @@ func isEmptyValue(v reflect.Value) bool {
        return false
 }
 
-func (e *encodeState) reflectValue(v reflect.Value) {
-       valueEncoder(v)(e, v, false)
+func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
+       valueEncoder(v)(e, v, opts)
 }
 
-type encoderFunc func(e *encodeState, v reflect.Value, quoted bool)
+type encOpts struct {
+       // quoted causes primitive fields to be encoded inside JSON strings.
+       quoted bool
+       // escapeHTML causes '<', '>', and '&' to be escaped in JSON strings.
+       escapeHTML bool
+}
+
+type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
 
 var encoderCache struct {
        sync.RWMutex
@@ -333,9 +341,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
        }
        var wg sync.WaitGroup
        wg.Add(1)
-       encoderCache.m[t] = func(e *encodeState, v reflect.Value, quoted bool) {
+       encoderCache.m[t] = func(e *encodeState, v reflect.Value, opts encOpts) {
                wg.Wait()
-               f(e, v, quoted)
+               f(e, v, opts)
        }
        encoderCache.Unlock()
 
@@ -405,11 +413,11 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
        }
 }
 
-func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func invalidValueEncoder(e *encodeState, v reflect.Value, _ encOpts) {
        e.WriteString("null")
 }
 
-func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        if v.Kind() == reflect.Ptr && v.IsNil() {
                e.WriteString("null")
                return
@@ -418,14 +426,14 @@ func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        b, err := m.MarshalJSON()
        if err == nil {
                // copy JSON into buffer, checking validity.
-               err = compact(&e.Buffer, b, true)
+               err = compact(&e.Buffer, b, opts.escapeHTML)
        }
        if err != nil {
                e.error(&MarshalerError{v.Type(), err})
        }
 }
 
-func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func addrMarshalerEncoder(e *encodeState, v reflect.Value, _ encOpts) {
        va := v.Addr()
        if va.IsNil() {
                e.WriteString("null")
@@ -442,7 +450,7 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        }
 }
 
-func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        if v.Kind() == reflect.Ptr && v.IsNil() {
                e.WriteString("null")
                return
@@ -452,10 +460,10 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        if err != nil {
                e.error(&MarshalerError{v.Type(), err})
        }
-       e.stringBytes(b)
+       e.stringBytes(b, opts.escapeHTML)
 }
 
-func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        va := v.Addr()
        if va.IsNil() {
                e.WriteString("null")
@@ -466,11 +474,11 @@ func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        if err != nil {
                e.error(&MarshalerError{v.Type(), err})
        }
-       e.stringBytes(b)
+       e.stringBytes(b, opts.escapeHTML)
 }
 
-func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
-       if quoted {
+func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
+       if opts.quoted {
                e.WriteByte('"')
        }
        if v.Bool() {
@@ -478,46 +486,46 @@ func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
        } else {
                e.WriteString("false")
        }
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
 }
 
-func intEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
        e.Write(b)
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
 }
 
-func uintEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
        e.Write(b)
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
 }
 
 type floatEncoder int // number of bits
 
-func (bits floatEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        f := v.Float()
        if math.IsInf(f, 0) || math.IsNaN(f) {
                e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))})
        }
        b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, int(bits))
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
        e.Write(b)
-       if quoted {
+       if opts.quoted {
                e.WriteByte('"')
        }
 }
@@ -527,7 +535,7 @@ var (
        float64Encoder = (floatEncoder(64)).encode
 )
 
-func stringEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        if v.Type() == numberType {
                numStr := v.String()
                // In Go1.5 the empty string encodes to "0", while this is not a valid number literal
@@ -541,26 +549,26 @@ func stringEncoder(e *encodeState, v reflect.Value, quoted bool) {
                e.WriteString(numStr)
                return
        }
-       if quoted {
+       if opts.quoted {
                sb, err := Marshal(v.String())
                if err != nil {
                        e.error(err)
                }
-               e.string(string(sb))
+               e.string(string(sb), opts.escapeHTML)
        } else {
-               e.string(v.String())
+               e.string(v.String(), opts.escapeHTML)
        }
 }
 
-func interfaceEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) {
        if v.IsNil() {
                e.WriteString("null")
                return
        }
-       e.reflectValue(v.Elem())
+       e.reflectValue(v.Elem(), opts)
 }
 
-func unsupportedTypeEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func unsupportedTypeEncoder(e *encodeState, v reflect.Value, _ encOpts) {
        e.error(&UnsupportedTypeError{v.Type()})
 }
 
@@ -569,7 +577,7 @@ type structEncoder struct {
        fieldEncs []encoderFunc
 }
 
-func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        e.WriteByte('{')
        first := true
        for i, f := range se.fields {
@@ -582,9 +590,10 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
                } else {
                        e.WriteByte(',')
                }
-               e.string(f.name)
+               e.string(f.name, opts.escapeHTML)
                e.WriteByte(':')
-               se.fieldEncs[i](e, fv, f.quoted)
+               opts.quoted = f.quoted
+               se.fieldEncs[i](e, fv, opts)
        }
        e.WriteByte('}')
 }
@@ -605,7 +614,7 @@ type mapEncoder struct {
        elemEnc encoderFunc
 }
 
-func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+func (me *mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        if v.IsNil() {
                e.WriteString("null")
                return
@@ -627,9 +636,9 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
                if i > 0 {
                        e.WriteByte(',')
                }
-               e.string(kv.s)
+               e.string(kv.s, opts.escapeHTML)
                e.WriteByte(':')
-               me.elemEnc(e, v.MapIndex(kv.v), false)
+               me.elemEnc(e, v.MapIndex(kv.v), opts)
        }
        e.WriteByte('}')
 }
@@ -642,7 +651,7 @@ func newMapEncoder(t reflect.Type) encoderFunc {
        return me.encode
 }
 
-func encodeByteSlice(e *encodeState, v reflect.Value, _ bool) {
+func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) {
        if v.IsNil() {
                e.WriteString("null")
                return
@@ -669,12 +678,12 @@ type sliceEncoder struct {
        arrayEnc encoderFunc
 }
 
-func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        if v.IsNil() {
                e.WriteString("null")
                return
        }
-       se.arrayEnc(e, v, false)
+       se.arrayEnc(e, v, opts)
 }
 
 func newSliceEncoder(t reflect.Type) encoderFunc {
@@ -692,14 +701,14 @@ type arrayEncoder struct {
        elemEnc encoderFunc
 }
 
-func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        e.WriteByte('[')
        n := v.Len()
        for i := 0; i < n; i++ {
                if i > 0 {
                        e.WriteByte(',')
                }
-               ae.elemEnc(e, v.Index(i), false)
+               ae.elemEnc(e, v.Index(i), opts)
        }
        e.WriteByte(']')
 }
@@ -713,12 +722,12 @@ type ptrEncoder struct {
        elemEnc encoderFunc
 }
 
-func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        if v.IsNil() {
                e.WriteString("null")
                return
        }
-       pe.elemEnc(e, v.Elem(), quoted)
+       pe.elemEnc(e, v.Elem(), opts)
 }
 
 func newPtrEncoder(t reflect.Type) encoderFunc {
@@ -730,11 +739,11 @@ type condAddrEncoder struct {
        canAddrEnc, elseEnc encoderFunc
 }
 
-func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
        if v.CanAddr() {
-               ce.canAddrEnc(e, v, quoted)
+               ce.canAddrEnc(e, v, opts)
        } else {
-               ce.elseEnc(e, v, quoted)
+               ce.elseEnc(e, v, opts)
        }
 }
 
@@ -812,13 +821,14 @@ func (sv byString) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
 func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s }
 
 // NOTE: keep in sync with stringBytes below.
-func (e *encodeState) string(s string) int {
+func (e *encodeState) string(s string, escapeHTML bool) int {
        len0 := e.Len()
        e.WriteByte('"')
        start := 0
        for i := 0; i < len(s); {
                if b := s[i]; b < utf8.RuneSelf {
-                       if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
+                       if 0x20 <= b && b != '\\' && b != '"' &&
+                               (!escapeHTML || b != '<' && b != '>' && b != '&') {
                                i++
                                continue
                        }
@@ -839,10 +849,11 @@ func (e *encodeState) string(s string) int {
                                e.WriteByte('\\')
                                e.WriteByte('t')
                        default:
-                               // This encodes bytes < 0x20 except for \n and \r,
-                               // as well as <, > and &. The latter are escaped because they
-                               // can lead to security holes when user-controlled strings
-                               // are rendered into JSON and served to some browsers.
+                               // This encodes bytes < 0x20 except for \t, \n and \r.
+                               // If escapeHTML is set, it also escapes <, >, and &
+                               // because they can lead to security holes when
+                               // user-controlled strings are rendered into JSON
+                               // and served to some browsers.
                                e.WriteString(`\u00`)
                                e.WriteByte(hex[b>>4])
                                e.WriteByte(hex[b&0xF])
@@ -888,13 +899,14 @@ func (e *encodeState) string(s string) int {
 }
 
 // NOTE: keep in sync with string above.
-func (e *encodeState) stringBytes(s []byte) int {
+func (e *encodeState) stringBytes(s []byte, escapeHTML bool) int {
        len0 := e.Len()
        e.WriteByte('"')
        start := 0
        for i := 0; i < len(s); {
                if b := s[i]; b < utf8.RuneSelf {
-                       if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
+                       if 0x20 <= b && b != '\\' && b != '"' &&
+                               (!escapeHTML || b != '<' && b != '>' && b != '&') {
                                i++
                                continue
                        }
@@ -915,10 +927,11 @@ func (e *encodeState) stringBytes(s []byte) int {
                                e.WriteByte('\\')
                                e.WriteByte('t')
                        default:
-                               // This encodes bytes < 0x20 except for \n and \r,
-                               // as well as <, >, and &. The latter are escaped because they
-                               // can lead to security holes when user-controlled strings
-                               // are rendered into JSON and served to some browsers.
+                               // This encodes bytes < 0x20 except for \t, \n and \r.
+                               // If escapeHTML is set, it also escapes <, >, and &
+                               // because they can lead to security holes when
+                               // user-controlled strings are rendered into JSON
+                               // and served to some browsers.
                                e.WriteString(`\u00`)
                                e.WriteByte(hex[b>>4])
                                e.WriteByte(hex[b&0xF])
index eee59ccb4959074dddb0446999ecd95721905e68..b484022a70e1980956a0f3136722eeb3b1b2d898 100644 (file)
@@ -376,41 +376,45 @@ func TestDuplicatedFieldDisappears(t *testing.T) {
 
 func TestStringBytes(t *testing.T) {
        // Test that encodeState.stringBytes and encodeState.string use the same encoding.
-       es := &encodeState{}
        var r []rune
        for i := '\u0000'; i <= unicode.MaxRune; i++ {
                r = append(r, i)
        }
        s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
-       es.string(s)
 
-       esBytes := &encodeState{}
-       esBytes.stringBytes([]byte(s))
+       for _, escapeHTML := range []bool{true, false} {
+               es := &encodeState{}
+               es.string(s, escapeHTML)
 
-       enc := es.Buffer.String()
-       encBytes := esBytes.Buffer.String()
-       if enc != encBytes {
-               i := 0
-               for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
-                       i++
-               }
-               enc = enc[i:]
-               encBytes = encBytes[i:]
-               i = 0
-               for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
-                       i++
-               }
-               enc = enc[:len(enc)-i]
-               encBytes = encBytes[:len(encBytes)-i]
+               esBytes := &encodeState{}
+               esBytes.stringBytes([]byte(s), escapeHTML)
 
-               if len(enc) > 20 {
-                       enc = enc[:20] + "..."
-               }
-               if len(encBytes) > 20 {
-                       encBytes = encBytes[:20] + "..."
-               }
+               enc := es.Buffer.String()
+               encBytes := esBytes.Buffer.String()
+               if enc != encBytes {
+                       i := 0
+                       for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
+                               i++
+                       }
+                       enc = enc[i:]
+                       encBytes = encBytes[i:]
+                       i = 0
+                       for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
+                               i++
+                       }
+                       enc = enc[:len(enc)-i]
+                       encBytes = encBytes[:len(encBytes)-i]
 
-               t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
+                       if len(enc) > 20 {
+                               enc = enc[:20] + "..."
+                       }
+                       if len(encBytes) > 20 {
+                               encBytes = encBytes[:20] + "..."
+                       }
+
+                       t.Errorf("with escapeHTML=%t, encodings differ at %#q vs %#q",
+                               escapeHTML, enc, encBytes)
+               }
        }
 }
 
index 422837bb63964d984451291f638e38e8503cf8e7..d6b2992e9be1c3827027a6baa7172f9a10c86014 100644 (file)
@@ -166,8 +166,9 @@ func nonSpace(b []byte) bool {
 
 // An Encoder writes JSON values to an output stream.
 type Encoder struct {
-       w   io.Writer
-       err error
+       w          io.Writer
+       err        error
+       escapeHTML bool
 
        indentBuf    *bytes.Buffer
        indentPrefix string
@@ -176,7 +177,7 @@ type Encoder struct {
 
 // NewEncoder returns a new encoder that writes to w.
 func NewEncoder(w io.Writer) *Encoder {
-       return &Encoder{w: w}
+       return &Encoder{w: w, escapeHTML: true}
 }
 
 // Encode writes the JSON encoding of v to the stream,
@@ -189,7 +190,7 @@ func (enc *Encoder) Encode(v interface{}) error {
                return enc.err
        }
        e := newEncodeState()
-       err := e.marshal(v)
+       err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
        if err != nil {
                return err
        }
@@ -225,6 +226,12 @@ func (enc *Encoder) Indent(prefix, indent string) {
        enc.indentValue = indent
 }
 
+// DisableHTMLEscaping causes the encoder not to escape angle brackets
+// ("<" and ">") or ampersands ("&") in JSON strings.
+func (enc *Encoder) DisableHTMLEscaping() {
+       enc.escapeHTML = false
+}
+
 // RawMessage is a raw encoded JSON value.
 // It implements Marshaler and Unmarshaler and can
 // be used to delay JSON decoding or precompute a JSON encoding.
index db25708f4cd3833c82a7ca8229f7303ba5d34c3a..3516ac3b83d9712eac40def7c1a5e493454fa2d7 100644 (file)
@@ -87,6 +87,39 @@ func TestEncoderIndent(t *testing.T) {
        }
 }
 
+func TestEncoderDisableHTMLEscaping(t *testing.T) {
+       var c C
+       var ct CText
+       for _, tt := range []struct {
+               name       string
+               v          interface{}
+               wantEscape string
+               want       string
+       }{
+               {"c", c, `"\u003c\u0026\u003e"`, `"<&>"`},
+               {"ct", ct, `"\"\u003c\u0026\u003e\""`, `"\"<&>\""`},
+               {`"<&>"`, "<&>", `"\u003c\u0026\u003e"`, `"<&>"`},
+       } {
+               var buf bytes.Buffer
+               enc := NewEncoder(&buf)
+               if err := enc.Encode(tt.v); err != nil {
+                       t.Fatalf("Encode(%s): %s", tt.name, err)
+               }
+               if got := strings.TrimSpace(buf.String()); got != tt.wantEscape {
+                       t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.wantEscape)
+               }
+               buf.Reset()
+               enc.DisableHTMLEscaping()
+               if err := enc.Encode(tt.v); err != nil {
+                       t.Fatalf("DisableHTMLEscaping Encode(%s): %s", tt.name, err)
+               }
+               if got := strings.TrimSpace(buf.String()); got != tt.want {
+                       t.Errorf("DisableHTMLEscaping Encode(%s) = %#q, want %#q",
+                               tt.name, got, tt.want)
+               }
+       }
+}
+
 func TestDecoder(t *testing.T) {
        for i := 0; i <= len(streamTest); i++ {
                // Use stream without newlines as input,