From ffbd31e9f79ad8b6aaeceac1397678e237581064 Mon Sep 17 00:00:00 2001 From: Augusto Roman Date: Tue, 8 Mar 2016 12:41:35 -0800 Subject: [PATCH] encoding/json: allow non-string type keys for (un-)marshal This CL allows JSON-encoding & -decoding maps whose keys are types that implement encoding.TextMarshaler / TextUnmarshaler. During encode, the map keys are marshaled upfront so that they can be sorted. Fixes #12146 Change-Id: I43809750a7ad82a3603662f095c7baf75fd172da Reviewed-on: https://go-review.googlesource.com/20356 Run-TryBot: Caleb Spare TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/encoding/json/decode.go | 32 ++++++++++++----- src/encoding/json/decode_test.go | 62 ++++++++++++++++++++++---------- src/encoding/json/encode.go | 53 +++++++++++++++++++-------- src/encoding/json/encode_test.go | 16 +++++++++ 4 files changed, 123 insertions(+), 40 deletions(-) diff --git a/src/encoding/json/decode.go b/src/encoding/json/decode.go index 3e4b16e410..a7ff8cf3dc 100644 --- a/src/encoding/json/decode.go +++ b/src/encoding/json/decode.go @@ -61,10 +61,11 @@ import ( // If the JSON array is smaller than the Go array, // the additional Go array elements are set to zero values. // -// To unmarshal a JSON object into a string-keyed map, Unmarshal first -// establishes a map to use, If the map is nil, Unmarshal allocates a new map. -// Otherwise Unmarshal reuses the existing map, keeping existing entries. -// Unmarshal then stores key-value pairs from the JSON object into the map. +// To unmarshal a JSON object into a map, Unmarshal first establishes a map to +// use, If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores key- +// value pairs from the JSON object into the map. The map's key type must +// either be a string or implement encoding.TextUnmarshaler. // // If a JSON value is not appropriate for a given target type, // or if a JSON number overflows the target type, Unmarshal @@ -549,6 +550,7 @@ func (d *decodeState) array(v reflect.Value) { } var nullLiteral = []byte("null") +var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() // object consumes an object from d.data[d.off-1:], decoding into the value v. // the first byte ('{') of the object has been read already. @@ -577,12 +579,15 @@ func (d *decodeState) object(v reflect.Value) { return } - // Check type of target: struct or map[string]T + // Check type of target: + // struct or + // map[string]T or map[encoding.TextUnmarshaler]T switch v.Kind() { case reflect.Map: - // map must have string kind + // Map key must either have string kind or be an encoding.TextUnmarshaler. t := v.Type() - if t.Key().Kind() != reflect.String { + if t.Key().Kind() != reflect.String && + !reflect.PtrTo(t.Key()).Implements(textUnmarshalerType) { d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) d.off-- d.next() // skip over { } in input @@ -687,7 +692,18 @@ func (d *decodeState) object(v reflect.Value) { // Write value back to map; // if using struct, subv points into struct already. if v.Kind() == reflect.Map { - kv := reflect.ValueOf(key).Convert(v.Type().Key()) + kt := v.Type().Key() + var kv reflect.Value + switch { + case kt.Kind() == reflect.String: + kv = reflect.ValueOf(key).Convert(v.Type().Key()) + case reflect.PtrTo(kt).Implements(textUnmarshalerType): + kv = reflect.New(v.Type().Key()) + d.literalStore(item, kv, true) + kv = kv.Elem() + default: + panic("json: Unexpected key type") // should never occur + } v.SetMapIndex(kv, subv) } diff --git a/src/encoding/json/decode_test.go b/src/encoding/json/decode_test.go index 98291f85e9..30e46ca44f 100644 --- a/src/encoding/json/decode_test.go +++ b/src/encoding/json/decode_test.go @@ -7,6 +7,7 @@ package json import ( "bytes" "encoding" + "errors" "fmt" "image" "net" @@ -68,16 +69,20 @@ type ustruct struct { } type unmarshalerText struct { - T bool + A, B string } // needed for re-marshaling tests -func (u *unmarshalerText) MarshalText() ([]byte, error) { - return []byte(""), nil +func (u unmarshalerText) MarshalText() ([]byte, error) { + return []byte(u.A + ":" + u.B), nil } func (u *unmarshalerText) UnmarshalText(b []byte) error { - *u = unmarshalerText{true} // All we need to see that UnmarshalText is called. + pos := bytes.Index(b, []byte(":")) + if pos == -1 { + return errors.New("missing separator") + } + u.A, u.B = string(b[:pos]), string(b[pos+1:]) return nil } @@ -95,12 +100,16 @@ var ( umslicep = new([]unmarshaler) umstruct = ustruct{unmarshaler{true}} - um0T, um1T unmarshalerText // target2 of unmarshaling - umpT = &um1T - umtrueT = unmarshalerText{true} - umsliceT = []unmarshalerText{{true}} - umslicepT = new([]unmarshalerText) - umstructT = ustructText{unmarshalerText{true}} + um0T, um1T unmarshalerText // target2 of unmarshaling + umpType = &um1T + umtrueXY = unmarshalerText{"x", "y"} + umsliceXY = []unmarshalerText{{"x", "y"}} + umslicepType = new([]unmarshalerText) + umstructType = new(ustructText) + umstructXY = ustructText{unmarshalerText{"x", "y"}} + + ummapType = map[unmarshalerText]bool{} + ummapXY = map[unmarshalerText]bool{unmarshalerText{"x", "y"}: true} ) // Test data structures for anonymous fields. @@ -302,14 +311,19 @@ var unmarshalTests = []unmarshalTest{ {in: `{"T":false}`, ptr: &ump, out: &umtrue}, {in: `[{"T":false}]`, ptr: &umslice, out: umslice}, {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice}, - {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct}, + {in: `{"M":{"T":"x:y"}}`, ptr: &umstruct, out: umstruct}, // UnmarshalText interface test - {in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called - {in: `"X"`, ptr: &umpT, out: &umtrueT}, - {in: `["X"]`, ptr: &umsliceT, out: umsliceT}, - {in: `["X"]`, ptr: &umslicepT, out: &umsliceT}, - {in: `{"M":"X"}`, ptr: &umstructT, out: umstructT}, + {in: `"x:y"`, ptr: &um0T, out: umtrueXY}, + {in: `"x:y"`, ptr: &umpType, out: &umtrueXY}, + {in: `["x:y"]`, ptr: &umsliceXY, out: umsliceXY}, + {in: `["x:y"]`, ptr: &umslicepType, out: &umsliceXY}, + {in: `{"M":"x:y"}`, ptr: umstructType, out: umstructXY}, + + // Map keys can be encoding.TextUnmarshalers + {in: `{"x:y":true}`, ptr: &ummapType, out: ummapXY}, + // If multiple values for the same key exists, only the most recent value is used. + {in: `{"x:y":false,"x:y":true}`, ptr: &ummapType, out: ummapXY}, // Overwriting of data. // This is different from package xml, but it's what we've always done. @@ -426,11 +440,23 @@ var unmarshalTests = []unmarshalTest{ out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld", }, - // issue 8305 + // Used to be issue 8305, but time.Time implements encoding.TextUnmarshaler so this works now. { in: `{"2009-11-10T23:00:00Z": "hello world"}`, ptr: &map[time.Time]string{}, - err: &UnmarshalTypeError{"object", reflect.TypeOf(map[time.Time]string{}), 1}, + out: map[time.Time]string{time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC): "hello world"}, + }, + + // issue 8305 + { + in: `{"2009-11-10T23:00:00Z": "hello world"}`, + ptr: &map[Point]string{}, + err: &UnmarshalTypeError{"object", reflect.TypeOf(map[Point]string{}), 1}, + }, + { + in: `{"asdf": "hello world"}`, + ptr: &map[unmarshaler]string{}, + err: &UnmarshalTypeError{"object", reflect.TypeOf(map[unmarshaler]string{}), 1}, }, } diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go index 982561d6ec..bcae6838cc 100644 --- a/src/encoding/json/encode.go +++ b/src/encoding/json/encode.go @@ -116,8 +116,8 @@ import ( // an anonymous struct field in both current and earlier versions, give the field // a JSON tag of "-". // -// Map values encode as JSON objects. -// The map's key type must be string; the map keys are used as JSON object +// Map values encode as JSON objects. The map's key type must either be a string +// or implement encoding.TextMarshaler. The map keys are used as JSON object // keys, subject to the UTF-8 coercion described for string values above. // // Pointer values encode as the value pointed to. @@ -611,21 +611,31 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { return } e.WriteByte('{') - var sv stringValues = v.MapKeys() - sort.Sort(sv) - for i, k := range sv { + + // Extract and sort the keys. + keys := v.MapKeys() + sv := make([]reflectWithString, len(keys)) + for i, v := range keys { + sv[i].v = v + if err := sv[i].resolve(); err != nil { + e.error(&MarshalerError{v.Type(), err}) + } + } + sort.Sort(byString(sv)) + + for i, kv := range sv { if i > 0 { e.WriteByte(',') } - e.string(k.String()) + e.string(kv.s) e.WriteByte(':') - me.elemEnc(e, v.MapIndex(k), false) + me.elemEnc(e, v.MapIndex(kv.v), false) } e.WriteByte('}') } func newMapEncoder(t reflect.Type) encoderFunc { - if t.Key().Kind() != reflect.String { + if t.Key().Kind() != reflect.String && !t.Key().Implements(textMarshalerType) { return unsupportedTypeEncoder } me := &mapEncoder{typeEncoder(t.Elem())} @@ -775,14 +785,29 @@ func typeByIndex(t reflect.Type, index []int) reflect.Type { return t } -// stringValues is a slice of reflect.Value holding *reflect.StringValue. +type reflectWithString struct { + v reflect.Value + s string +} + +func (w *reflectWithString) resolve() error { + if w.v.Kind() == reflect.String { + w.s = w.v.String() + return nil + } + buf, err := w.v.Interface().(encoding.TextMarshaler).MarshalText() + w.s = string(buf) + return err +} + +// byString is a slice of reflectWithString where the reflect.Value is either +// a string or an encoding.TextMarshaler. // It implements the methods to sort by string. -type stringValues []reflect.Value +type byString []reflectWithString -func (sv stringValues) Len() int { return len(sv) } -func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } -func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } -func (sv stringValues) get(i int) string { return sv[i].String() } +func (sv byString) Len() int { return len(sv) } +func (sv byString) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } +func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s } // NOTE: keep in sync with stringBytes below. func (e *encodeState) string(s string) int { diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go index c00491e00c..eed40a4272 100644 --- a/src/encoding/json/encode_test.go +++ b/src/encoding/json/encode_test.go @@ -536,3 +536,19 @@ func TestEncodeString(t *testing.T) { } } } + +func TestTextMarshalerMapKeysAreSorted(t *testing.T) { + b, err := Marshal(map[unmarshalerText]int{ + {"x", "y"}: 1, + {"y", "x"}: 2, + {"a", "z"}: 3, + {"z", "a"}: 4, + }) + if err != nil { + t.Fatalf("Failed to Marshal text.Marshaler: %v", err) + } + const want = `{"a:z":3,"x:y":1,"y:x":2,"z:a":4}` + if string(b) != want { + t.Errorf("Marshal map with text.Marshaler keys: got %#q, want %#q", b, want) + } +} -- 2.48.1