]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: add "overflow" struct tag option
authorAndrew Gerrand <adg@golang.org>
Thu, 29 Aug 2013 04:39:55 +0000 (14:39 +1000)
committerAndrew Gerrand <adg@golang.org>
Thu, 29 Aug 2013 04:39:55 +0000 (14:39 +1000)
Fixes #6213.

R=golang-dev, dsymonds, bradfitz
CC=golang-dev
https://golang.org/cl/13180043

src/pkg/encoding/json/decode.go
src/pkg/encoding/json/decode_test.go
src/pkg/encoding/json/encode.go
src/pkg/encoding/json/encode_test.go
src/pkg/encoding/json/example_test.go

index b6c23cc77a1d0548790291b10c244868d6c7d155..c3167674428fa40a7d4622e91aa3693b7adbae59 100644 (file)
@@ -61,6 +61,10 @@ import (
 // Instead, they are replaced by the Unicode replacement
 // character U+FFFD.
 //
+// When unmarshalling into struct values, a struct field of map type with the
+// "overflow" option will store values whose keys do not match the other struct
+// fields.
+//
 func Unmarshal(data []byte, v interface{}) error {
        // Check for well-formedness.
        // Avoids filling out half a data structure
@@ -489,8 +493,6 @@ func (d *decodeState) object(v reflect.Value) {
                return
        }
 
-       var mapElem reflect.Value
-
        for {
                // Read opening " of string key or closing }.
                op := d.scanWhile(scanSkipSpace)
@@ -512,44 +514,7 @@ func (d *decodeState) object(v reflect.Value) {
                }
 
                // Figure out field corresponding to key.
-               var subv reflect.Value
-               destring := false // whether the value is wrapped in a string to be decoded first
-
-               if v.Kind() == reflect.Map {
-                       elemType := v.Type().Elem()
-                       if !mapElem.IsValid() {
-                               mapElem = reflect.New(elemType).Elem()
-                       } else {
-                               mapElem.Set(reflect.Zero(elemType))
-                       }
-                       subv = mapElem
-               } else {
-                       var f *field
-                       fields := cachedTypeFields(v.Type())
-                       for i := range fields {
-                               ff := &fields[i]
-                               if ff.name == key {
-                                       f = ff
-                                       break
-                               }
-                               if f == nil && strings.EqualFold(ff.name, key) {
-                                       f = ff
-                               }
-                       }
-                       if f != nil {
-                               subv = v
-                               destring = f.quoted
-                               for _, i := range f.index {
-                                       if subv.Kind() == reflect.Ptr {
-                                               if subv.IsNil() {
-                                                       subv.Set(reflect.New(subv.Type().Elem()))
-                                               }
-                                               subv = subv.Elem()
-                                       }
-                                       subv = subv.Field(i)
-                               }
-                       }
-               }
+               subv, mapv, destring := subValue(v, key)
 
                // Read : before value.
                if op == scanSkipSpace {
@@ -569,9 +534,9 @@ func (d *decodeState) object(v reflect.Value) {
 
                // Write value back to map;
                // if using struct, subv points into struct already.
-               if v.Kind() == reflect.Map {
-                       kv := reflect.ValueOf(key).Convert(v.Type().Key())
-                       v.SetMapIndex(kv, subv)
+               if mapv.IsValid() {
+                       kv := reflect.ValueOf(key).Convert(mapv.Type().Key())
+                       mapv.SetMapIndex(kv, subv)
                }
 
                // Next token must be , or }.
@@ -585,6 +550,57 @@ func (d *decodeState) object(v reflect.Value) {
        }
 }
 
+// subValue returns (and allocates, if necessary) the field in the struct or
+// map v whose name matches key.
+func subValue(v reflect.Value, key string) (subv, mapv reflect.Value, destring bool) {
+       // Create new map element.
+       if v.Kind() == reflect.Map {
+               subv = reflect.New(v.Type().Elem()).Elem()
+               mapv = v
+               return
+       }
+
+       // Get struct field.
+       var f *field
+       fields := cachedTypeFields(v.Type())
+       for i := range fields {
+               ff := &fields[i]
+               if ff.name == key {
+                       f = ff
+                       break
+               }
+               if f == nil && strings.EqualFold(ff.name, key) {
+                       f = ff
+               }
+       }
+       if f != nil {
+               subv = fieldByIndex(v, f.index, true)
+               destring = f.quoted
+               return
+       }
+
+       // Decode into overflow field if present.
+       for _, f := range fields {
+               if f.overflow {
+                       // Find overflow field.
+                       mapv = fieldByIndex(v, f.index, true)
+                       if k := mapv.Kind(); k != reflect.Map {
+                               panic("unsupported overflow field kind: " + k.String())
+                       }
+                       // Make map if necessary.
+                       if mapv.IsNil() {
+                               mapv.Set(reflect.MakeMap(mapv.Type()))
+                       }
+                       // Create new map element.
+                       subv = reflect.New(mapv.Type().Elem()).Elem()
+                       return
+               }
+       }
+
+       // Not found.
+       return
+}
+
 // literal consumes a literal from d.data[d.off-1:], decoding into the value v.
 // The first byte of the literal has been read already
 // (that's how the caller knows it's a literal).
index 6635ba6ec681597a55120ddc10b34f9298f85a9e..4531e99656a61c5b9988252d98a2e97c59722709 100644 (file)
@@ -1275,3 +1275,23 @@ func TestSkipArrayObjects(t *testing.T) {
                t.Errorf("got error %q, want nil", err)
        }
 }
+
+func TestDecodeOverflow(t *testing.T) {
+       json := `{"A":1,"B":2,"C":3}`
+       type S struct {
+               A int
+               E map[string]interface{} `json:",overflow"`
+               C int
+       }
+       var (
+               want = S{1, map[string]interface{}{"B": float64(2)}, 3}
+               dest S
+       )
+       err := Unmarshal([]byte(json), &dest)
+       if err != nil {
+               t.Errorf("got error %q, want nil", err)
+       }
+       if !reflect.DeepEqual(dest, want) {
+               t.Errorf("Got %+v; want %+v", dest, want)
+       }
+}
index f951250e980b7eb19c86cb69624027bc9e1ce346..590010f3b24ffc076e01d3d6c3979da9e25208f5 100644 (file)
@@ -109,6 +109,10 @@ import (
 // an anonymous struct field in both current and earlier versions, give the field
 // a JSON tag of "-".
 //
+// The "overflow" option may be used with a struct field of a map type to
+// indicate that the map contents should be marshalled as if the keys are part
+// of the struct object itself.
+//
 // Map values encode as JSON objects.
 // The map's key type must be string; the object keys are used directly
 // as map keys.
@@ -239,6 +243,32 @@ var hex = "0123456789abcdef"
 type encodeState struct {
        bytes.Buffer // accumulated output
        scratch      [64]byte
+       overflow     int
+}
+
+func (e *encodeState) startOverflow() {
+       e.overflow = e.Len()
+}
+
+func (e *encodeState) endOverflow() {
+       if e.overflow == 0 {
+               panic("endOverflow called before startOverflow")
+       }
+       start, end := e.overflow, e.Len()
+       b := e.Bytes()
+       if b[start] == '{' && b[end-1] == '}' {
+               // Remove surrounding { and }.
+               copy(b[start:], b[start+1:])
+               e.Truncate(end - 2)
+       } else if bytes.Equal(b[start:end], []byte("null")) {
+               // Drop "null".
+               e.Truncate(start)
+       }
+       // Remove trailing comma if overflow value was null or {}.
+       if start > 0 && e.Len() == start && b[start-1] == ',' {
+               e.Truncate(start - 1)
+       }
+       e.overflow = 0
 }
 
 // TODO(bradfitz): use a sync.Cache here
@@ -582,7 +612,7 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
        e.WriteByte('{')
        first := true
        for i, f := range se.fields {
-               fv := fieldByIndex(v, f.index)
+               fv := fieldByIndex(v, f.index, false)
                if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) {
                        continue
                }
@@ -591,6 +621,16 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
                } else {
                        e.WriteByte(',')
                }
+               if f.overflow {
+                       if tenc := se.fieldEncs[i]; tenc != nil {
+                               e.startOverflow()
+                               tenc(e, fv, f.quoted)
+                               e.endOverflow()
+                       } else {
+                               panic("no encoder for " + fv.String())
+                       }
+                       continue
+               }
                e.string(f.name)
                e.WriteByte(':')
                if tenc := se.fieldEncs[i]; tenc != nil {
@@ -610,7 +650,7 @@ func newStructEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
                fieldEncs: make([]encoderFunc, len(fields)),
        }
        for i, f := range fields {
-               vxf := fieldByIndex(vx, f.index)
+               vxf := fieldByIndex(vx, f.index, false)
                if vxf.IsValid() {
                        se.fieldEncs[i] = typeEncoder(vxf.Type(), vxf)
                }
@@ -750,11 +790,16 @@ func isValidTag(s string) bool {
        return true
 }
 
-func fieldByIndex(v reflect.Value, index []int) reflect.Value {
+// fieldByIndex fetches (and allocates, if create is true) the field in v
+// indentified by index.
+func fieldByIndex(v reflect.Value, index []int, create bool) reflect.Value {
        for _, i := range index {
                if v.Kind() == reflect.Ptr {
                        if v.IsNil() {
-                               return reflect.Value{}
+                               if !create {
+                                       return reflect.Value{}
+                               }
+                               v.Set(reflect.New(v.Type().Elem()))
                        }
                        v = v.Elem()
                }
@@ -926,6 +971,7 @@ type field struct {
        typ       reflect.Type
        omitEmpty bool
        quoted    bool
+       overflow  bool
 }
 
 // byName sorts field by name, breaking ties with depth,
@@ -1027,8 +1073,15 @@ func typeFields(t reflect.Type) []field {
                                        if name == "" {
                                                name = sf.Name
                                        }
-                                       fields = append(fields, field{name, tagged, index, ft,
-                                               opts.Contains("omitempty"), opts.Contains("string")})
+                                       fields = append(fields, field{
+                                               name:      name,
+                                               tag:       tagged,
+                                               index:     index,
+                                               typ:       ft,
+                                               omitEmpty: opts.Contains("omitempty"),
+                                               quoted:    opts.Contains("string"),
+                                               overflow:  opts.Contains("overflow")},
+                                       )
                                        if count[f.typ] > 1 {
                                                // If there were multiple instances, add a second,
                                                // so that the annihilation code will see a duplicate.
index 7052e1db7c93228e53d479056dadd13f5f6980ba..f1bb144a0e7d4d4ed8d195d8798f2d04e7ef9f8d 100644 (file)
@@ -401,3 +401,59 @@ func TestStringBytes(t *testing.T) {
                t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
        }
 }
+
+func TestEncodeOverflow(t *testing.T) {
+       for _, c := range []struct {
+               in   interface{}
+               want string
+       }{
+               {
+                       struct {
+                               A int
+                               E map[string]interface{} `json:",overflow"`
+                               C int
+                       }{
+                               A: 12,
+                               E: map[string]interface{}{"B": 42},
+                               C: 64,
+                       },
+                       `{"A":12,"B":42,"C":64}`,
+               },
+               {
+                       struct {
+                               E map[string]interface{} `json:",overflow"`
+                       }{
+                               E: map[string]interface{}{"B": 42},
+                       },
+                       `{"B":42}`,
+               },
+               {
+                       struct {
+                               A int
+                               E map[string]interface{} `json:",overflow"`
+                       }{
+                               A: 12,
+                       },
+                       `{"A":12}`,
+               },
+               {
+                       struct {
+                               A int
+                               E map[string]interface{} `json:",overflow"`
+                       }{
+                               A: 12,
+                               E: map[string]interface{}{},
+                       },
+                       `{"A":12}`,
+               },
+       } {
+               b, err := Marshal(c.in)
+               if err != nil {
+                       t.Error(err)
+                       continue
+               }
+               if got := string(b); got != c.want {
+                       t.Errorf("Marshal(%q) = %s, want %s", c.in, got, c.want)
+               }
+       }
+}
index b8d150eda504c2f986da19a6f94dd2f025bdc449..1f84e151e557983ccced8cdfd2b234f2ee010e62 100644 (file)
@@ -33,6 +33,29 @@ func ExampleMarshal() {
        // {"ID":1,"Name":"Reds","Colors":["Crimson","Red","Ruby","Maroon"]}
 }
 
+func ExampleMarshal_overflow() {
+       type Record struct {
+               ID   string
+               Seq  int
+               Meta map[string]string `json:",overflow"`
+       }
+       r := Record{
+               ID:  "CheeseWhiz",
+               Seq: 42,
+               Meta: map[string]string{
+                       "Created":   "1980-06-20",
+                       "Destroyed": "1998-02-06",
+               },
+       }
+       b, err := json.Marshal(r)
+       if err != nil {
+               fmt.Println("error:", err)
+       }
+       os.Stdout.Write(b)
+       // Output:
+       // {"ID":"CheeseWhiz","Seq":42,"Created":"1980-06-20","Destroyed":"1998-02-06"}
+}
+
 func ExampleUnmarshal() {
        var jsonBlob = []byte(`[
                {"Name": "Platypus", "Order": "Monotremata"},
@@ -52,6 +75,29 @@ func ExampleUnmarshal() {
        // [{Name:Platypus Order:Monotremata} {Name:Quoll Order:Dasyuromorphia}]
 }
 
+func ExampleUnmarshal_overflow() {
+       var jsonBlob = []byte(`
+               {
+                       "Token": "Kaip4uM1ieng6Eiw",
+                       "User":  "bimmler",
+                       "Animal": "rabbit"
+               }
+       `)
+       type Auth struct {
+               Token string
+               User  string
+               Extra map[string]string `json:",overflow"`
+       }
+       var auth Auth
+       err := json.Unmarshal(jsonBlob, &auth)
+       if err != nil {
+               fmt.Println("error:", err)
+       }
+       fmt.Printf("%+v", auth)
+       // Output:
+       // {Token:Kaip4uM1ieng6Eiw User:bimmler Extra:map[Animal:rabbit]}
+}
+
 // This example uses a Decoder to decode a stream of distinct JSON values.
 func ExampleDecoder() {
        const jsonStream = `