]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: support encoding.TextMarshaler, encoding.TextUnmarshaler
authorRuss Cox <rsc@golang.org>
Wed, 14 Aug 2013 18:56:07 +0000 (14:56 -0400)
committerRuss Cox <rsc@golang.org>
Wed, 14 Aug 2013 18:56:07 +0000 (14:56 -0400)
R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12703043

src/pkg/encoding/json/decode.go
src/pkg/encoding/json/decode_test.go
src/pkg/encoding/json/encode.go
src/pkg/encoding/json/encode_test.go

index 62ac294b89f3bd113c6b63f6974a1c4f7d147dcb..b6c23cc77a1d0548790291b10c244868d6c7d155 100644 (file)
@@ -8,6 +8,7 @@
 package json
 
 import (
+       "encoding"
        "encoding/base64"
        "errors"
        "fmt"
@@ -293,7 +294,7 @@ func (d *decodeState) value(v reflect.Value) {
 // until it gets to a non-pointer.
 // if it encounters an Unmarshaler, indirect stops and returns that.
 // if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
-func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) {
+func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
        // If v is a named type and is addressable,
        // start with its address, so that if the type has pointer methods,
        // we find them.
@@ -322,28 +323,38 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler,
                        v.Set(reflect.New(v.Type().Elem()))
                }
                if v.Type().NumMethod() > 0 {
-                       if unmarshaler, ok := v.Interface().(Unmarshaler); ok {
-                               return unmarshaler, reflect.Value{}
+                       if u, ok := v.Interface().(Unmarshaler); ok {
+                               return u, nil, reflect.Value{}
+                       }
+                       if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
+                               return nil, u, reflect.Value{}
                        }
                }
                v = v.Elem()
        }
-       return nil, v
+       return nil, nil, v
 }
 
 // array consumes an array from d.data[d.off-1:], decoding into the value v.
 // the first byte of the array ('[') has been read already.
 func (d *decodeState) array(v reflect.Value) {
        // Check for unmarshaler.
-       unmarshaler, pv := d.indirect(v, false)
-       if unmarshaler != nil {
+       u, ut, pv := d.indirect(v, false)
+       if u != nil {
                d.off--
-               err := unmarshaler.UnmarshalJSON(d.next())
+               err := u.UnmarshalJSON(d.next())
                if err != nil {
                        d.error(err)
                }
                return
        }
+       if ut != nil {
+               d.saveError(&UnmarshalTypeError{"array", v.Type()})
+               d.off--
+               d.next()
+               return
+       }
+
        v = pv
 
        // Check type of target.
@@ -434,15 +445,21 @@ func (d *decodeState) array(v reflect.Value) {
 // the first byte of the object ('{') has been read already.
 func (d *decodeState) object(v reflect.Value) {
        // Check for unmarshaler.
-       unmarshaler, pv := d.indirect(v, false)
-       if unmarshaler != nil {
+       u, ut, pv := d.indirect(v, false)
+       if u != nil {
                d.off--
-               err := unmarshaler.UnmarshalJSON(d.next())
+               err := u.UnmarshalJSON(d.next())
                if err != nil {
                        d.error(err)
                }
                return
        }
+       if ut != nil {
+               d.saveError(&UnmarshalTypeError{"object", v.Type()})
+               d.off--
+               d.next() // skip over { } in input
+               return
+       }
        v = pv
 
        // Decoding into nil interface?  Switch to non-reflect code.
@@ -611,14 +628,37 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
                return
        }
        wantptr := item[0] == 'n' // null
-       unmarshaler, pv := d.indirect(v, wantptr)
-       if unmarshaler != nil {
-               err := unmarshaler.UnmarshalJSON(item)
+       u, ut, pv := d.indirect(v, wantptr)
+       if u != nil {
+               err := u.UnmarshalJSON(item)
+               if err != nil {
+                       d.error(err)
+               }
+               return
+       }
+       if ut != nil {
+               if item[0] != '"' {
+                       if fromQuoted {
+                               d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+                       } else {
+                               d.saveError(&UnmarshalTypeError{"string", v.Type()})
+                       }
+               }
+               s, ok := unquoteBytes(item)
+               if !ok {
+                       if fromQuoted {
+                               d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+                       } else {
+                               d.error(errPhase)
+                       }
+               }
+               err := ut.UnmarshalText(s)
                if err != nil {
                        d.error(err)
                }
                return
        }
+
        v = pv
 
        switch c := item[0]; c {
index 3fa366500ff08b3f4ab23aba2fa47672d25f12a5..6635ba6ec681597a55120ddc10b34f9298f85a9e 100644 (file)
@@ -6,6 +6,7 @@ package json
 
 import (
        "bytes"
+       "encoding"
        "fmt"
        "image"
        "reflect"
@@ -57,7 +58,7 @@ type unmarshaler struct {
 }
 
 func (u *unmarshaler) UnmarshalJSON(b []byte) error {
-       *u = unmarshaler{true} // All we need to see that UnmarshalJson is called.
+       *u = unmarshaler{true} // All we need to see that UnmarshalJSON is called.
        return nil
 }
 
@@ -65,6 +66,26 @@ type ustruct struct {
        M unmarshaler
 }
 
+type unmarshalerText struct {
+       T bool
+}
+
+// needed for re-marshaling tests
+func (u *unmarshalerText) MarshalText() ([]byte, error) {
+       return []byte(""), nil
+}
+
+func (u *unmarshalerText) UnmarshalText(b []byte) error {
+       *u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
+       return nil
+}
+
+var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil)
+
+type ustructText struct {
+       M unmarshalerText
+}
+
 var (
        um0, um1 unmarshaler // target2 of unmarshaling
        ump      = &um1
@@ -72,6 +93,13 @@ var (
        umslice  = []unmarshaler{{true}}
        umslicep = new([]unmarshaler)
        umstruct = ustruct{unmarshaler{true}}
+
+       um0T, um1T unmarshalerText // target2 of unmarshaling
+       umpT       = &um1T
+       umtrueT    = unmarshalerText{true}
+       umsliceT   = []unmarshalerText{{true}}
+       umslicepT  = new([]unmarshalerText)
+       umstructT  = ustructText{unmarshalerText{true}}
 )
 
 // Test data structures for anonymous fields.
@@ -261,6 +289,13 @@ var unmarshalTests = []unmarshalTest{
        {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
        {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
 
+       // UnmarshalText interface test
+       {in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
+       {in: `"X"`, ptr: &umpT, out: &umtrueT},
+       {in: `["X"]`, ptr: &umsliceT, out: umsliceT},
+       {in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
+       {in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
+
        {
                in: `{
                        "Level0": 1,
@@ -505,7 +540,7 @@ func TestUnmarshal(t *testing.T) {
                                dec.UseNumber()
                        }
                        if err := dec.Decode(vv.Interface()); err != nil {
-                               t.Errorf("#%d: error re-unmarshaling: %v", i, err)
+                               t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err)
                                continue
                        }
                        if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) {
@@ -979,15 +1014,20 @@ func TestRefUnmarshal(t *testing.T) {
                // Ref is defined in encode_test.go.
                R0 Ref
                R1 *Ref
+               R2 RefText
+               R3 *RefText
        }
        want := S{
                R0: 12,
                R1: new(Ref),
+               R2: 13,
+               R3: new(RefText),
        }
        *want.R1 = 12
+       *want.R3 = 13
 
        var got S
-       if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil {
+       if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil {
                t.Fatalf("Unmarshal: %v", err)
        }
        if !reflect.DeepEqual(got, want) {
index a11270726947b120d5af873ac87c787d7efecde3..f951250e980b7eb19c86cb69624027bc9e1ce346 100644 (file)
@@ -12,6 +12,7 @@ package json
 
 import (
        "bytes"
+       "encoding"
        "encoding/base64"
        "math"
        "reflect"
@@ -361,17 +362,29 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
        if !vx.IsValid() {
                vx = reflect.New(t).Elem()
        }
+
        _, ok := vx.Interface().(Marshaler)
        if ok {
-               return valueIsMarshallerEncoder
+               return marshalerEncoder
        }
-       // T doesn't match the interface. Check against *T too.
        if vx.Kind() != reflect.Ptr && vx.CanAddr() {
                _, ok = vx.Addr().Interface().(Marshaler)
                if ok {
-                       return valueAddrIsMarshallerEncoder
+                       return addrMarshalerEncoder
                }
        }
+
+       _, ok = vx.Interface().(encoding.TextMarshaler)
+       if ok {
+               return textMarshalerEncoder
+       }
+       if vx.Kind() != reflect.Ptr && vx.CanAddr() {
+               _, ok = vx.Addr().Interface().(encoding.TextMarshaler)
+               if ok {
+                       return addrTextMarshalerEncoder
+               }
+       }
+
        switch vx.Kind() {
        case reflect.Bool:
                return boolEncoder
@@ -406,7 +419,7 @@ func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
        e.WriteString("null")
 }
 
-func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        if v.Kind() == reflect.Ptr && v.IsNil() {
                e.WriteString("null")
                return
@@ -422,9 +435,9 @@ func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        }
 }
 
-func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
        va := v.Addr()
-       if va.Kind() == reflect.Ptr && va.IsNil() {
+       if va.IsNil() {
                e.WriteString("null")
                return
        }
@@ -439,6 +452,37 @@ func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool)
        }
 }
 
+func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       if v.Kind() == reflect.Ptr && v.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       m := v.Interface().(encoding.TextMarshaler)
+       b, err := m.MarshalText()
+       if err == nil {
+               _, err = e.stringBytes(b)
+       }
+       if err != nil {
+               e.error(&MarshalerError{v.Type(), err})
+       }
+}
+
+func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       va := v.Addr()
+       if va.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       m := va.Interface().(encoding.TextMarshaler)
+       b, err := m.MarshalText()
+       if err == nil {
+               _, err = e.stringBytes(b)
+       }
+       if err != nil {
+               e.error(&MarshalerError{v.Type(), err})
+       }
+}
+
 func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
        if quoted {
                e.WriteByte('"')
@@ -728,6 +772,7 @@ func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
 func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
 func (sv stringValues) get(i int) string   { return sv[i].String() }
 
+// NOTE: keep in sync with stringBytes below.
 func (e *encodeState) string(s string) (int, error) {
        len0 := e.Len()
        e.WriteByte('"')
@@ -800,6 +845,79 @@ func (e *encodeState) string(s string) (int, error) {
        return e.Len() - len0, nil
 }
 
+// NOTE: keep in sync with string above.
+func (e *encodeState) stringBytes(s []byte) (int, error) {
+       len0 := e.Len()
+       e.WriteByte('"')
+       start := 0
+       for i := 0; i < len(s); {
+               if b := s[i]; b < utf8.RuneSelf {
+                       if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
+                               i++
+                               continue
+                       }
+                       if start < i {
+                               e.Write(s[start:i])
+                       }
+                       switch b {
+                       case '\\', '"':
+                               e.WriteByte('\\')
+                               e.WriteByte(b)
+                       case '\n':
+                               e.WriteByte('\\')
+                               e.WriteByte('n')
+                       case '\r':
+                               e.WriteByte('\\')
+                               e.WriteByte('r')
+                       default:
+                               // This encodes bytes < 0x20 except for \n and \r,
+                               // as well as < and >. The latter are escaped because they
+                               // can lead to security holes when user-controlled strings
+                               // are rendered into JSON and served to some browsers.
+                               e.WriteString(`\u00`)
+                               e.WriteByte(hex[b>>4])
+                               e.WriteByte(hex[b&0xF])
+                       }
+                       i++
+                       start = i
+                       continue
+               }
+               c, size := utf8.DecodeRune(s[i:])
+               if c == utf8.RuneError && size == 1 {
+                       if start < i {
+                               e.Write(s[start:i])
+                       }
+                       e.WriteString(`\ufffd`)
+                       i += size
+                       start = i
+                       continue
+               }
+               // U+2028 is LINE SEPARATOR.
+               // U+2029 is PARAGRAPH SEPARATOR.
+               // They are both technically valid characters in JSON strings,
+               // but don't work in JSONP, which has to be evaluated as JavaScript,
+               // and can lead to security holes there. It is valid JSON to
+               // escape them, so we do so unconditionally.
+               // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
+               if c == '\u2028' || c == '\u2029' {
+                       if start < i {
+                               e.Write(s[start:i])
+                       }
+                       e.WriteString(`\u202`)
+                       e.WriteByte(hex[c&0xF])
+                       i += size
+                       start = i
+                       continue
+               }
+               i += size
+       }
+       if start < len(s) {
+               e.Write(s[start:])
+       }
+       e.WriteByte('"')
+       return e.Len() - len0, nil
+}
+
 // A field represents a single field found in a struct.
 type field struct {
        name      string
index 5be0a992e1cab8522d6353ea6ed343ff07d3dfed..7052e1db7c93228e53d479056dadd13f5f6980ba 100644 (file)
@@ -9,6 +9,7 @@ import (
        "math"
        "reflect"
        "testing"
+       "unicode"
 )
 
 type Optionals struct {
@@ -146,19 +147,46 @@ func (Val) MarshalJSON() ([]byte, error) {
        return []byte(`"val"`), nil
 }
 
+// RefText has Marshaler and Unmarshaler methods with pointer receiver.
+type RefText int
+
+func (*RefText) MarshalText() ([]byte, error) {
+       return []byte(`"ref"`), nil
+}
+
+func (r *RefText) UnmarshalText([]byte) error {
+       *r = 13
+       return nil
+}
+
+// ValText has Marshaler methods with value receiver.
+type ValText int
+
+func (ValText) MarshalText() ([]byte, error) {
+       return []byte(`"val"`), nil
+}
+
 func TestRefValMarshal(t *testing.T) {
        var s = struct {
                R0 Ref
                R1 *Ref
+               R2 RefText
+               R3 *RefText
                V0 Val
                V1 *Val
+               V2 ValText
+               V3 *ValText
        }{
                R0: 12,
                R1: new(Ref),
+               R2: 14,
+               R3: new(RefText),
                V0: 13,
                V1: new(Val),
+               V2: 15,
+               V3: new(ValText),
        }
-       const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}`
+       const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}`
        b, err := Marshal(&s)
        if err != nil {
                t.Fatalf("Marshal: %v", err)
@@ -175,15 +203,32 @@ func (C) MarshalJSON() ([]byte, error) {
        return []byte(`"<&>"`), nil
 }
 
+// CText implements Marshaler and returns unescaped text.
+type CText int
+
+func (CText) MarshalText() ([]byte, error) {
+       return []byte(`"<&>"`), nil
+}
+
 func TestMarshalerEscaping(t *testing.T) {
        var c C
-       const want = `"\u003c\u0026\u003e"`
+       want := `"\u003c\u0026\u003e"`
        b, err := Marshal(c)
        if err != nil {
-               t.Fatalf("Marshal: %v", err)
+               t.Fatalf("Marshal(c): %v", err)
        }
        if got := string(b); got != want {
-               t.Errorf("got %q, want %q", got, want)
+               t.Errorf("Marshal(c) = %#q, want %#q", got, want)
+       }
+
+       var ct CText
+       want = `"\"\u003c\u0026\u003e\""`
+       b, err = Marshal(ct)
+       if err != nil {
+               t.Fatalf("Marshal(ct): %v", err)
+       }
+       if got := string(b); got != want {
+               t.Errorf("Marshal(ct) = %#q, want %#q", got, want)
        }
 }
 
@@ -310,3 +355,49 @@ func TestDuplicatedFieldDisappears(t *testing.T) {
                t.Fatalf("Marshal: got %s want %s", got, want)
        }
 }
+
+func TestStringBytes(t *testing.T) {
+       // Test that encodeState.stringBytes and encodeState.string use the same encoding.
+       es := &encodeState{}
+       var r []rune
+       for i := '\u0000'; i <= unicode.MaxRune; i++ {
+               r = append(r, i)
+       }
+       s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
+       _, err := es.string(s)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       esBytes := &encodeState{}
+       _, err = esBytes.stringBytes([]byte(s))
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       enc := es.Buffer.String()
+       encBytes := esBytes.Buffer.String()
+       if enc != encBytes {
+               i := 0
+               for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
+                       i++
+               }
+               enc = enc[i:]
+               encBytes = encBytes[i:]
+               i = 0
+               for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
+                       i++
+               }
+               enc = enc[:len(enc)-i]
+               encBytes = encBytes[:len(encBytes)-i]
+
+               if len(enc) > 20 {
+                       enc = enc[:20] + "..."
+               }
+               if len(encBytes) > 20 {
+                       encBytes = encBytes[:20] + "..."
+               }
+
+               t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
+       }
+}