]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: allow non-string type keys for (un-)marshal
authorAugusto Roman <aroman@gmail.com>
Tue, 8 Mar 2016 20:41:35 +0000 (12:41 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 5 Apr 2016 15:08:04 +0000 (15:08 +0000)
This CL allows JSON-encoding & -decoding maps whose keys are types that
implement encoding.TextMarshaler / TextUnmarshaler.

During encode, the map keys are marshaled upfront so that they can be
sorted.

Fixes #12146

Change-Id: I43809750a7ad82a3603662f095c7baf75fd172da
Reviewed-on: https://go-review.googlesource.com/20356
Run-TryBot: Caleb Spare <cespare@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/encoding/json/decode.go
src/encoding/json/decode_test.go
src/encoding/json/encode.go
src/encoding/json/encode_test.go

index 3e4b16e4107d9c97e35134eae822a6c575d600a0..a7ff8cf3dc377efcb531eafd7fd136e0c95a013f 100644 (file)
@@ -61,10 +61,11 @@ import (
 // If the JSON array is smaller than the Go array,
 // the additional Go array elements are set to zero values.
 //
-// To unmarshal a JSON object into a string-keyed map, Unmarshal first
-// establishes a map to use, If the map is nil, Unmarshal allocates a new map.
-// Otherwise Unmarshal reuses the existing map, keeping existing entries.
-// Unmarshal then stores key-value pairs from the JSON object into the map.
+// To unmarshal a JSON object into a map, Unmarshal first establishes a map to
+// use, If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal
+// reuses the existing map, keeping existing entries. Unmarshal then stores key-
+// value pairs from the JSON object into the map.  The map's key type must
+// either be a string or implement encoding.TextUnmarshaler.
 //
 // If a JSON value is not appropriate for a given target type,
 // or if a JSON number overflows the target type, Unmarshal
@@ -549,6 +550,7 @@ func (d *decodeState) array(v reflect.Value) {
 }
 
 var nullLiteral = []byte("null")
+var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
 
 // object consumes an object from d.data[d.off-1:], decoding into the value v.
 // the first byte ('{') of the object has been read already.
@@ -577,12 +579,15 @@ func (d *decodeState) object(v reflect.Value) {
                return
        }
 
-       // Check type of target: struct or map[string]T
+       // Check type of target:
+       //   struct or
+       //   map[string]T or map[encoding.TextUnmarshaler]T
        switch v.Kind() {
        case reflect.Map:
-               // map must have string kind
+               // Map key must either have string kind or be an encoding.TextUnmarshaler.
                t := v.Type()
-               if t.Key().Kind() != reflect.String {
+               if t.Key().Kind() != reflect.String &&
+                       !reflect.PtrTo(t.Key()).Implements(textUnmarshalerType) {
                        d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)})
                        d.off--
                        d.next() // skip over { } in input
@@ -687,7 +692,18 @@ func (d *decodeState) object(v reflect.Value) {
                // Write value back to map;
                // if using struct, subv points into struct already.
                if v.Kind() == reflect.Map {
-                       kv := reflect.ValueOf(key).Convert(v.Type().Key())
+                       kt := v.Type().Key()
+                       var kv reflect.Value
+                       switch {
+                       case kt.Kind() == reflect.String:
+                               kv = reflect.ValueOf(key).Convert(v.Type().Key())
+                       case reflect.PtrTo(kt).Implements(textUnmarshalerType):
+                               kv = reflect.New(v.Type().Key())
+                               d.literalStore(item, kv, true)
+                               kv = kv.Elem()
+                       default:
+                               panic("json: Unexpected key type") // should never occur
+                       }
                        v.SetMapIndex(kv, subv)
                }
 
index 98291f85e96243b4041034ba3d2e6df61fec291a..30e46ca44f07bdd5d41fa8afda11758e5807024f 100644 (file)
@@ -7,6 +7,7 @@ package json
 import (
        "bytes"
        "encoding"
+       "errors"
        "fmt"
        "image"
        "net"
@@ -68,16 +69,20 @@ type ustruct struct {
 }
 
 type unmarshalerText struct {
-       T bool
+       A, B string
 }
 
 // needed for re-marshaling tests
-func (u *unmarshalerText) MarshalText() ([]byte, error) {
-       return []byte(""), nil
+func (u unmarshalerText) MarshalText() ([]byte, error) {
+       return []byte(u.A + ":" + u.B), nil
 }
 
 func (u *unmarshalerText) UnmarshalText(b []byte) error {
-       *u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
+       pos := bytes.Index(b, []byte(":"))
+       if pos == -1 {
+               return errors.New("missing separator")
+       }
+       u.A, u.B = string(b[:pos]), string(b[pos+1:])
        return nil
 }
 
@@ -95,12 +100,16 @@ var (
        umslicep = new([]unmarshaler)
        umstruct = ustruct{unmarshaler{true}}
 
-       um0T, um1T unmarshalerText // target2 of unmarshaling
-       umpT       = &um1T
-       umtrueT    = unmarshalerText{true}
-       umsliceT   = []unmarshalerText{{true}}
-       umslicepT  = new([]unmarshalerText)
-       umstructT  = ustructText{unmarshalerText{true}}
+       um0T, um1T   unmarshalerText // target2 of unmarshaling
+       umpType      = &um1T
+       umtrueXY     = unmarshalerText{"x", "y"}
+       umsliceXY    = []unmarshalerText{{"x", "y"}}
+       umslicepType = new([]unmarshalerText)
+       umstructType = new(ustructText)
+       umstructXY   = ustructText{unmarshalerText{"x", "y"}}
+
+       ummapType = map[unmarshalerText]bool{}
+       ummapXY   = map[unmarshalerText]bool{unmarshalerText{"x", "y"}: true}
 )
 
 // Test data structures for anonymous fields.
@@ -302,14 +311,19 @@ var unmarshalTests = []unmarshalTest{
        {in: `{"T":false}`, ptr: &ump, out: &umtrue},
        {in: `[{"T":false}]`, ptr: &umslice, out: umslice},
        {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
-       {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
+       {in: `{"M":{"T":"x:y"}}`, ptr: &umstruct, out: umstruct},
 
        // UnmarshalText interface test
-       {in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
-       {in: `"X"`, ptr: &umpT, out: &umtrueT},
-       {in: `["X"]`, ptr: &umsliceT, out: umsliceT},
-       {in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
-       {in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
+       {in: `"x:y"`, ptr: &um0T, out: umtrueXY},
+       {in: `"x:y"`, ptr: &umpType, out: &umtrueXY},
+       {in: `["x:y"]`, ptr: &umsliceXY, out: umsliceXY},
+       {in: `["x:y"]`, ptr: &umslicepType, out: &umsliceXY},
+       {in: `{"M":"x:y"}`, ptr: umstructType, out: umstructXY},
+
+       // Map keys can be encoding.TextUnmarshalers
+       {in: `{"x:y":true}`, ptr: &ummapType, out: ummapXY},
+       // If multiple values for the same key exists, only the most recent value is used.
+       {in: `{"x:y":false,"x:y":true}`, ptr: &ummapType, out: ummapXY},
 
        // Overwriting of data.
        // This is different from package xml, but it's what we've always done.
@@ -426,11 +440,23 @@ var unmarshalTests = []unmarshalTest{
                out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld",
        },
 
-       // issue 8305
+       // Used to be issue 8305, but time.Time implements encoding.TextUnmarshaler so this works now.
        {
                in:  `{"2009-11-10T23:00:00Z": "hello world"}`,
                ptr: &map[time.Time]string{},
-               err: &UnmarshalTypeError{"object", reflect.TypeOf(map[time.Time]string{}), 1},
+               out: map[time.Time]string{time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC): "hello world"},
+       },
+
+       // issue 8305
+       {
+               in:  `{"2009-11-10T23:00:00Z": "hello world"}`,
+               ptr: &map[Point]string{},
+               err: &UnmarshalTypeError{"object", reflect.TypeOf(map[Point]string{}), 1},
+       },
+       {
+               in:  `{"asdf": "hello world"}`,
+               ptr: &map[unmarshaler]string{},
+               err: &UnmarshalTypeError{"object", reflect.TypeOf(map[unmarshaler]string{}), 1},
        },
 }
 
index 982561d6ecd1eafd4c59bf243c53b09f39508e26..bcae6838cc0574daf6ca0b2e88598004eaf68b4d 100644 (file)
@@ -116,8 +116,8 @@ import (
 // an anonymous struct field in both current and earlier versions, give the field
 // a JSON tag of "-".
 //
-// Map values encode as JSON objects.
-// The map's key type must be string; the map keys are used as JSON object
+// Map values encode as JSON objects. The map's key type must either be a string
+// or implement encoding.TextMarshaler.  The map keys are used as JSON object
 // keys, subject to the UTF-8 coercion described for string values above.
 //
 // Pointer values encode as the value pointed to.
@@ -611,21 +611,31 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
                return
        }
        e.WriteByte('{')
-       var sv stringValues = v.MapKeys()
-       sort.Sort(sv)
-       for i, k := range sv {
+
+       // Extract and sort the keys.
+       keys := v.MapKeys()
+       sv := make([]reflectWithString, len(keys))
+       for i, v := range keys {
+               sv[i].v = v
+               if err := sv[i].resolve(); err != nil {
+                       e.error(&MarshalerError{v.Type(), err})
+               }
+       }
+       sort.Sort(byString(sv))
+
+       for i, kv := range sv {
                if i > 0 {
                        e.WriteByte(',')
                }
-               e.string(k.String())
+               e.string(kv.s)
                e.WriteByte(':')
-               me.elemEnc(e, v.MapIndex(k), false)
+               me.elemEnc(e, v.MapIndex(kv.v), false)
        }
        e.WriteByte('}')
 }
 
 func newMapEncoder(t reflect.Type) encoderFunc {
-       if t.Key().Kind() != reflect.String {
+       if t.Key().Kind() != reflect.String && !t.Key().Implements(textMarshalerType) {
                return unsupportedTypeEncoder
        }
        me := &mapEncoder{typeEncoder(t.Elem())}
@@ -775,14 +785,29 @@ func typeByIndex(t reflect.Type, index []int) reflect.Type {
        return t
 }
 
-// stringValues is a slice of reflect.Value holding *reflect.StringValue.
+type reflectWithString struct {
+       v reflect.Value
+       s string
+}
+
+func (w *reflectWithString) resolve() error {
+       if w.v.Kind() == reflect.String {
+               w.s = w.v.String()
+               return nil
+       }
+       buf, err := w.v.Interface().(encoding.TextMarshaler).MarshalText()
+       w.s = string(buf)
+       return err
+}
+
+// byString is a slice of reflectWithString where the reflect.Value is either
+// a string or an encoding.TextMarshaler.
 // It implements the methods to sort by string.
-type stringValues []reflect.Value
+type byString []reflectWithString
 
-func (sv stringValues) Len() int           { return len(sv) }
-func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
-func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
-func (sv stringValues) get(i int) string   { return sv[i].String() }
+func (sv byString) Len() int           { return len(sv) }
+func (sv byString) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
+func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s }
 
 // NOTE: keep in sync with stringBytes below.
 func (e *encodeState) string(s string) int {
index c00491e00c03163f1c1512819308fea548aef0de..eed40a42723ca994acbb167904f9d80656f1aa76 100644 (file)
@@ -536,3 +536,19 @@ func TestEncodeString(t *testing.T) {
                }
        }
 }
+
+func TestTextMarshalerMapKeysAreSorted(t *testing.T) {
+       b, err := Marshal(map[unmarshalerText]int{
+               {"x", "y"}: 1,
+               {"y", "x"}: 2,
+               {"a", "z"}: 3,
+               {"z", "a"}: 4,
+       })
+       if err != nil {
+               t.Fatalf("Failed to Marshal text.Marshaler: %v", err)
+       }
+       const want = `{"a:z":3,"x:y":1,"y:x":2,"z:a":4}`
+       if string(b) != want {
+               t.Errorf("Marshal map with text.Marshaler keys: got %#q, want %#q", b, want)
+       }
+}