package json
import (
+ "encoding"
"encoding/base64"
"errors"
"fmt"
// until it gets to a non-pointer.
// if it encounters an Unmarshaler, indirect stops and returns that.
// if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
-func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) {
+func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
- if unmarshaler, ok := v.Interface().(Unmarshaler); ok {
- return unmarshaler, reflect.Value{}
+ if u, ok := v.Interface().(Unmarshaler); ok {
+ return u, nil, reflect.Value{}
+ }
+ if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
+ return nil, u, reflect.Value{}
}
}
v = v.Elem()
}
- return nil, v
+ return nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into the value v.
// the first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) {
// Check for unmarshaler.
- unmarshaler, pv := d.indirect(v, false)
- if unmarshaler != nil {
+ u, ut, pv := d.indirect(v, false)
+ if u != nil {
d.off--
- err := unmarshaler.UnmarshalJSON(d.next())
+ err := u.UnmarshalJSON(d.next())
if err != nil {
d.error(err)
}
return
}
+ if ut != nil {
+ d.saveError(&UnmarshalTypeError{"array", v.Type()})
+ d.off--
+ d.next()
+ return
+ }
+
v = pv
// Check type of target.
// the first byte of the object ('{') has been read already.
func (d *decodeState) object(v reflect.Value) {
// Check for unmarshaler.
- unmarshaler, pv := d.indirect(v, false)
- if unmarshaler != nil {
+ u, ut, pv := d.indirect(v, false)
+ if u != nil {
d.off--
- err := unmarshaler.UnmarshalJSON(d.next())
+ err := u.UnmarshalJSON(d.next())
if err != nil {
d.error(err)
}
return
}
+ if ut != nil {
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ d.off--
+ d.next() // skip over { } in input
+ return
+ }
v = pv
// Decoding into nil interface? Switch to non-reflect code.
return
}
wantptr := item[0] == 'n' // null
- unmarshaler, pv := d.indirect(v, wantptr)
- if unmarshaler != nil {
- err := unmarshaler.UnmarshalJSON(item)
+ u, ut, pv := d.indirect(v, wantptr)
+ if u != nil {
+ err := u.UnmarshalJSON(item)
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ if ut != nil {
+ if item[0] != '"' {
+ if fromQuoted {
+ d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+ } else {
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ }
+ }
+ s, ok := unquoteBytes(item)
+ if !ok {
+ if fromQuoted {
+ d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+ } else {
+ d.error(errPhase)
+ }
+ }
+ err := ut.UnmarshalText(s)
if err != nil {
d.error(err)
}
return
}
+
v = pv
switch c := item[0]; c {
import (
"bytes"
+ "encoding"
"fmt"
"image"
"reflect"
}
func (u *unmarshaler) UnmarshalJSON(b []byte) error {
- *u = unmarshaler{true} // All we need to see that UnmarshalJson is called.
+ *u = unmarshaler{true} // All we need to see that UnmarshalJSON is called.
return nil
}
M unmarshaler
}
+type unmarshalerText struct {
+ T bool
+}
+
+// needed for re-marshaling tests
+func (u *unmarshalerText) MarshalText() ([]byte, error) {
+ return []byte(""), nil
+}
+
+func (u *unmarshalerText) UnmarshalText(b []byte) error {
+ *u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
+ return nil
+}
+
+var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil)
+
+type ustructText struct {
+ M unmarshalerText
+}
+
var (
um0, um1 unmarshaler // target2 of unmarshaling
ump = &um1
umslice = []unmarshaler{{true}}
umslicep = new([]unmarshaler)
umstruct = ustruct{unmarshaler{true}}
+
+ um0T, um1T unmarshalerText // target2 of unmarshaling
+ umpT = &um1T
+ umtrueT = unmarshalerText{true}
+ umsliceT = []unmarshalerText{{true}}
+ umslicepT = new([]unmarshalerText)
+ umstructT = ustructText{unmarshalerText{true}}
)
// Test data structures for anonymous fields.
{in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
{in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
+ // UnmarshalText interface test
+ {in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
+ {in: `"X"`, ptr: &umpT, out: &umtrueT},
+ {in: `["X"]`, ptr: &umsliceT, out: umsliceT},
+ {in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
+ {in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
+
{
in: `{
"Level0": 1,
dec.UseNumber()
}
if err := dec.Decode(vv.Interface()); err != nil {
- t.Errorf("#%d: error re-unmarshaling: %v", i, err)
+ t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err)
continue
}
if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) {
// Ref is defined in encode_test.go.
R0 Ref
R1 *Ref
+ R2 RefText
+ R3 *RefText
}
want := S{
R0: 12,
R1: new(Ref),
+ R2: 13,
+ R3: new(RefText),
}
*want.R1 = 12
+ *want.R3 = 13
var got S
- if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil {
+ if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if !reflect.DeepEqual(got, want) {
import (
"bytes"
+ "encoding"
"encoding/base64"
"math"
"reflect"
if !vx.IsValid() {
vx = reflect.New(t).Elem()
}
+
_, ok := vx.Interface().(Marshaler)
if ok {
- return valueIsMarshallerEncoder
+ return marshalerEncoder
}
- // T doesn't match the interface. Check against *T too.
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(Marshaler)
if ok {
- return valueAddrIsMarshallerEncoder
+ return addrMarshalerEncoder
}
}
+
+ _, ok = vx.Interface().(encoding.TextMarshaler)
+ if ok {
+ return textMarshalerEncoder
+ }
+ if vx.Kind() != reflect.Ptr && vx.CanAddr() {
+ _, ok = vx.Addr().Interface().(encoding.TextMarshaler)
+ if ok {
+ return addrTextMarshalerEncoder
+ }
+ }
+
switch vx.Kind() {
case reflect.Bool:
return boolEncoder
e.WriteString("null")
}
-func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
if v.Kind() == reflect.Ptr && v.IsNil() {
e.WriteString("null")
return
}
}
-func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
va := v.Addr()
- if va.Kind() == reflect.Ptr && va.IsNil() {
+ if va.IsNil() {
e.WriteString("null")
return
}
}
}
+func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ e.WriteString("null")
+ return
+ }
+ m := v.Interface().(encoding.TextMarshaler)
+ b, err := m.MarshalText()
+ if err == nil {
+ _, err = e.stringBytes(b)
+ }
+ if err != nil {
+ e.error(&MarshalerError{v.Type(), err})
+ }
+}
+
+func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+ va := v.Addr()
+ if va.IsNil() {
+ e.WriteString("null")
+ return
+ }
+ m := va.Interface().(encoding.TextMarshaler)
+ b, err := m.MarshalText()
+ if err == nil {
+ _, err = e.stringBytes(b)
+ }
+ if err != nil {
+ e.error(&MarshalerError{v.Type(), err})
+ }
+}
+
func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
if quoted {
e.WriteByte('"')
func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
func (sv stringValues) get(i int) string { return sv[i].String() }
+// NOTE: keep in sync with stringBytes below.
func (e *encodeState) string(s string) (int, error) {
len0 := e.Len()
e.WriteByte('"')
return e.Len() - len0, nil
}
+// NOTE: keep in sync with string above.
+func (e *encodeState) stringBytes(s []byte) (int, error) {
+ len0 := e.Len()
+ e.WriteByte('"')
+ start := 0
+ for i := 0; i < len(s); {
+ if b := s[i]; b < utf8.RuneSelf {
+ if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
+ i++
+ continue
+ }
+ if start < i {
+ e.Write(s[start:i])
+ }
+ switch b {
+ case '\\', '"':
+ e.WriteByte('\\')
+ e.WriteByte(b)
+ case '\n':
+ e.WriteByte('\\')
+ e.WriteByte('n')
+ case '\r':
+ e.WriteByte('\\')
+ e.WriteByte('r')
+ default:
+ // This encodes bytes < 0x20 except for \n and \r,
+ // as well as < and >. The latter are escaped because they
+ // can lead to security holes when user-controlled strings
+ // are rendered into JSON and served to some browsers.
+ e.WriteString(`\u00`)
+ e.WriteByte(hex[b>>4])
+ e.WriteByte(hex[b&0xF])
+ }
+ i++
+ start = i
+ continue
+ }
+ c, size := utf8.DecodeRune(s[i:])
+ if c == utf8.RuneError && size == 1 {
+ if start < i {
+ e.Write(s[start:i])
+ }
+ e.WriteString(`\ufffd`)
+ i += size
+ start = i
+ continue
+ }
+ // U+2028 is LINE SEPARATOR.
+ // U+2029 is PARAGRAPH SEPARATOR.
+ // They are both technically valid characters in JSON strings,
+ // but don't work in JSONP, which has to be evaluated as JavaScript,
+ // and can lead to security holes there. It is valid JSON to
+ // escape them, so we do so unconditionally.
+ // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
+ if c == '\u2028' || c == '\u2029' {
+ if start < i {
+ e.Write(s[start:i])
+ }
+ e.WriteString(`\u202`)
+ e.WriteByte(hex[c&0xF])
+ i += size
+ start = i
+ continue
+ }
+ i += size
+ }
+ if start < len(s) {
+ e.Write(s[start:])
+ }
+ e.WriteByte('"')
+ return e.Len() - len0, nil
+}
+
// A field represents a single field found in a struct.
type field struct {
name string
"math"
"reflect"
"testing"
+ "unicode"
)
type Optionals struct {
return []byte(`"val"`), nil
}
+// RefText has Marshaler and Unmarshaler methods with pointer receiver.
+type RefText int
+
+func (*RefText) MarshalText() ([]byte, error) {
+ return []byte(`"ref"`), nil
+}
+
+func (r *RefText) UnmarshalText([]byte) error {
+ *r = 13
+ return nil
+}
+
+// ValText has Marshaler methods with value receiver.
+type ValText int
+
+func (ValText) MarshalText() ([]byte, error) {
+ return []byte(`"val"`), nil
+}
+
func TestRefValMarshal(t *testing.T) {
var s = struct {
R0 Ref
R1 *Ref
+ R2 RefText
+ R3 *RefText
V0 Val
V1 *Val
+ V2 ValText
+ V3 *ValText
}{
R0: 12,
R1: new(Ref),
+ R2: 14,
+ R3: new(RefText),
V0: 13,
V1: new(Val),
+ V2: 15,
+ V3: new(ValText),
}
- const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}`
+ const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}`
b, err := Marshal(&s)
if err != nil {
t.Fatalf("Marshal: %v", err)
return []byte(`"<&>"`), nil
}
+// CText implements Marshaler and returns unescaped text.
+type CText int
+
+func (CText) MarshalText() ([]byte, error) {
+ return []byte(`"<&>"`), nil
+}
+
func TestMarshalerEscaping(t *testing.T) {
var c C
- const want = `"\u003c\u0026\u003e"`
+ want := `"\u003c\u0026\u003e"`
b, err := Marshal(c)
if err != nil {
- t.Fatalf("Marshal: %v", err)
+ t.Fatalf("Marshal(c): %v", err)
}
if got := string(b); got != want {
- t.Errorf("got %q, want %q", got, want)
+ t.Errorf("Marshal(c) = %#q, want %#q", got, want)
+ }
+
+ var ct CText
+ want = `"\"\u003c\u0026\u003e\""`
+ b, err = Marshal(ct)
+ if err != nil {
+ t.Fatalf("Marshal(ct): %v", err)
+ }
+ if got := string(b); got != want {
+ t.Errorf("Marshal(ct) = %#q, want %#q", got, want)
}
}
t.Fatalf("Marshal: got %s want %s", got, want)
}
}
+
+func TestStringBytes(t *testing.T) {
+ // Test that encodeState.stringBytes and encodeState.string use the same encoding.
+ es := &encodeState{}
+ var r []rune
+ for i := '\u0000'; i <= unicode.MaxRune; i++ {
+ r = append(r, i)
+ }
+ s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
+ _, err := es.string(s)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ esBytes := &encodeState{}
+ _, err = esBytes.stringBytes([]byte(s))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ enc := es.Buffer.String()
+ encBytes := esBytes.Buffer.String()
+ if enc != encBytes {
+ i := 0
+ for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
+ i++
+ }
+ enc = enc[i:]
+ encBytes = encBytes[i:]
+ i = 0
+ for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
+ i++
+ }
+ enc = enc[:len(enc)-i]
+ encBytes = encBytes[:len(encBytes)-i]
+
+ if len(enc) > 20 {
+ enc = enc[:20] + "..."
+ }
+ if len(encBytes) > 20 {
+ encBytes = encBytes[:20] + "..."
+ }
+
+ t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
+ }
+}