]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: reuse values when decoding map elements
authorDaniel Martí <mvdan@mvdan.cc>
Wed, 29 May 2019 10:09:50 +0000 (11:09 +0100)
committerDaniel Martí <mvdan@mvdan.cc>
Fri, 8 May 2020 21:19:55 +0000 (21:19 +0000)
When we decode into a struct, each input key-value may be decoded into
one of the struct's fields. Particularly, existing data isn't dropped,
so that some sub-fields can be decoded into without zeroing all other
data.

However, decoding into a map behaved in the opposite way. Whenever a
key-value was decoded, it completely replaced the previous map element.
If the map contained any non-zero data in that key, it's dropped.

Instead, try to reuse the existing element value if possible. If the map
element type is a pointer, and the value is non-nil, we can decode
directly into it. If it's not a pointer, make a copy and decode into
that copy, as map element values aren't addressable.

This means we have to parse and convert the map element key before the
value, to be able to obtain the existing element value. This is fine,
though. Moreover, reporting errors on the key before the value follows
the input order more closely.

Finally, add a test to explore the four combinations, involving pointer
and non-pointer, and non-zero and zero values. A table-driven test
wasn't used, as each case required different checks, such as checking
that the non-nil pointer case doesn't end up with a different pointer.

Fixes #31924.

Change-Id: I5ca40c9963a98aaf92f26f0b35843c021028dfca
Reviewed-on: https://go-review.googlesource.com/c/go/+/179337
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/encoding/json/decode.go
src/encoding/json/decode_test.go

index 5f34af44eacf0c57f59c14896c62d760414bb671..5acc6d8b26ee09412fbb4748908f5ed00175fa39 100644 (file)
@@ -677,7 +677,6 @@ func (d *decodeState) object(v reflect.Value) error {
                return nil
        }
 
-       var mapElem reflect.Value
        origErrorContext := d.errorContext
 
        for {
@@ -701,17 +700,66 @@ func (d *decodeState) object(v reflect.Value) error {
                }
 
                // Figure out field corresponding to key.
-               var subv reflect.Value
+               var kv, subv reflect.Value
                destring := false // whether the value is wrapped in a string to be decoded first
 
                if v.Kind() == reflect.Map {
-                       elemType := t.Elem()
-                       if !mapElem.IsValid() {
-                               mapElem = reflect.New(elemType).Elem()
-                       } else {
-                               mapElem.Set(reflect.Zero(elemType))
+                       // First, figure out the key value from the input.
+                       kt := t.Key()
+                       switch {
+                       case reflect.PtrTo(kt).Implements(textUnmarshalerType):
+                               kv = reflect.New(kt)
+                               if err := d.literalStore(item, kv, true); err != nil {
+                                       return err
+                               }
+                               kv = kv.Elem()
+                       case kt.Kind() == reflect.String:
+                               kv = reflect.ValueOf(key).Convert(kt)
+                       default:
+                               switch kt.Kind() {
+                               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                                       s := string(key)
+                                       n, err := strconv.ParseInt(s, 10, 64)
+                                       if err != nil || reflect.Zero(kt).OverflowInt(n) {
+                                               d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
+                                               break
+                                       }
+                                       kv = reflect.ValueOf(n).Convert(kt)
+                               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+                                       s := string(key)
+                                       n, err := strconv.ParseUint(s, 10, 64)
+                                       if err != nil || reflect.Zero(kt).OverflowUint(n) {
+                                               d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
+                                               break
+                                       }
+                                       kv = reflect.ValueOf(n).Convert(kt)
+                               default:
+                                       panic("json: Unexpected key type") // should never occur
+                               }
+                       }
+
+                       // Then, decide what element value we'll decode into.
+                       et := t.Elem()
+                       if kv.IsValid() {
+                               if existing := v.MapIndex(kv); !existing.IsValid() {
+                                       // Nothing to reuse.
+                               } else if et.Kind() == reflect.Ptr {
+                                       // Pointer; decode directly into it if non-nil.
+                                       if !existing.IsNil() {
+                                               subv = existing
+                                       }
+                               } else {
+                                       // Non-pointer. Make a copy and decode into the
+                                       // addressable copy. Don't just use a new/zero
+                                       // value, as that would lose existing data.
+                                       subv = reflect.New(et).Elem()
+                                       subv.Set(existing)
+                               }
+                       }
+                       if !subv.IsValid() {
+                               // We couldn't reuse an existing value.
+                               subv = reflect.New(et).Elem()
                        }
-                       subv = mapElem
                } else {
                        var f *field
                        if i, ok := fields.nameIndex[string(key)]; ok {
@@ -790,43 +838,8 @@ func (d *decodeState) object(v reflect.Value) error {
 
                // Write value back to map;
                // if using struct, subv points into struct already.
-               if v.Kind() == reflect.Map {
-                       kt := t.Key()
-                       var kv reflect.Value
-                       switch {
-                       case reflect.PtrTo(kt).Implements(textUnmarshalerType):
-                               kv = reflect.New(kt)
-                               if err := d.literalStore(item, kv, true); err != nil {
-                                       return err
-                               }
-                               kv = kv.Elem()
-                       case kt.Kind() == reflect.String:
-                               kv = reflect.ValueOf(key).Convert(kt)
-                       default:
-                               switch kt.Kind() {
-                               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-                                       s := string(key)
-                                       n, err := strconv.ParseInt(s, 10, 64)
-                                       if err != nil || reflect.Zero(kt).OverflowInt(n) {
-                                               d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
-                                               break
-                                       }
-                                       kv = reflect.ValueOf(n).Convert(kt)
-                               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-                                       s := string(key)
-                                       n, err := strconv.ParseUint(s, 10, 64)
-                                       if err != nil || reflect.Zero(kt).OverflowUint(n) {
-                                               d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
-                                               break
-                                       }
-                                       kv = reflect.ValueOf(n).Convert(kt)
-                               default:
-                                       panic("json: Unexpected key type") // should never occur
-                               }
-                       }
-                       if kv.IsValid() {
-                               v.SetMapIndex(kv, subv)
-                       }
+               if v.Kind() == reflect.Map && kv.IsValid() {
+                       v.SetMapIndex(kv, subv)
                }
 
                // Next token must be , or }.
index 5ac1022207c9e8c38504434dab9b46ee6cd36fd2..a62488d447af66db934db45975c702289bb1689d 100644 (file)
@@ -2569,3 +2569,57 @@ func TestUnmarshalMaxDepth(t *testing.T) {
                }
        }
 }
+
+func TestUnmarshalMapPointerElem(t *testing.T) {
+       type S struct{ Unchanged, Changed int }
+       input := []byte(`{"S":{"Changed":5}}`)
+       want := S{1, 5}
+
+       // First, a map with struct pointer elements. The key-value pair exists,
+       // so reuse the existing value.
+       s := &S{1, 2}
+       ptrMap := map[string]*S{"S": s}
+       if err := Unmarshal(input, &ptrMap); err != nil {
+               t.Fatal(err)
+       }
+       if s != ptrMap["S"] {
+               t.Fatal("struct pointer element in map was completely replaced")
+       }
+       if got := *s; got != want {
+               t.Fatalf("want %#v, got %#v", want, got)
+       }
+
+       // Second, a map with struct elements. The key-value pair exists, but
+       // the value isn't addresable, so make a copy and use that.
+       s = &S{1, 2}
+       strMap := map[string]S{"S": *s}
+       if err := Unmarshal(input, &strMap); err != nil {
+               t.Fatal(err)
+       }
+       if *s == strMap["S"] {
+               t.Fatal("struct element in map wasn't copied")
+       }
+       if got := strMap["S"]; got != want {
+               t.Fatalf("want %#v, got %#v", want, got)
+       }
+
+       // Finally, check the cases where the key-value pair exists, but the
+       // value is zero.
+       want = S{0, 5}
+
+       ptrMap = map[string]*S{"S": nil}
+       if err := Unmarshal(input, &ptrMap); err != nil {
+               t.Fatal(err)
+       }
+       if got := *ptrMap["S"]; got != want {
+               t.Fatalf("want %#v, got %#v", want, got)
+       }
+
+       strMap = map[string]S{"S": {}}
+       if err := Unmarshal(input, &strMap); err != nil {
+               t.Fatal(err)
+       }
+       if got := strMap["S"]; got != want {
+               t.Fatalf("want %#v, got %#v", want, got)
+       }
+}