]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: handle anonymous fields
authorRuss Cox <rsc@golang.org>
Tue, 11 Sep 2012 03:31:40 +0000 (23:31 -0400)
committerRuss Cox <rsc@golang.org>
Tue, 11 Sep 2012 03:31:40 +0000 (23:31 -0400)
Fixes #3069.

R=golang-dev, n13m3y3r
CC=golang-dev
https://golang.org/cl/6460044

src/pkg/encoding/json/decode.go
src/pkg/encoding/json/decode_test.go
src/pkg/encoding/json/encode.go

index bce868bb8fd1c60a82322427afa2ced17b0d3355..47e3d89aa358ca0a8cb529aa779b45ffd0730fae 100644 (file)
@@ -493,56 +493,39 @@ func (d *decodeState) object(v reflect.Value) {
                        }
                        subv = mapElem
                } else {
-                       var f reflect.StructField
-                       var ok bool
-                       st := sv.Type()
-                       for i := 0; i < sv.NumField(); i++ {
-                               sf := st.Field(i)
-                               tag := sf.Tag.Get("json")
-                               if tag == "-" {
-                                       // Pretend this field doesn't exist.
-                                       continue
+                       var f *field
+                       fields := cachedTypeFields(sv.Type())
+                       for i := range fields {
+                               ff := &fields[i]
+                               if ff.name == key {
+                                       f = ff
+                                       break
                                }
-                               if sf.Anonymous {
-                                       // Pretend this field doesn't exist,
-                                       // so that we can do a good job with
-                                       // these in a later version.
-                                       continue
+                               if f == nil && strings.EqualFold(ff.name, key) {
+                                       f = ff
                                }
-                               // First, tag match
-                               tagName, _ := parseTag(tag)
-                               if tagName != "" {
-                                       if tagName == key {
-                                               f = sf
-                                               ok = true
-                                               break // no better match possible
+                       }
+                       if f != nil {
+                               subv = sv
+                               destring = f.quoted
+                               for _, i := range f.index {
+                                       if subv.Kind() == reflect.Ptr {
+                                               if subv.IsNil() {
+                                                       subv.Set(reflect.New(subv.Type().Elem()))
+                                               }
+                                               subv = subv.Elem()
                                        }
-                                       // There was a tag, but it didn't match.
-                                       // Ignore field names.
-                                       continue
-                               }
-                               // Second, exact field name match
-                               if sf.Name == key {
-                                       f = sf
-                                       ok = true
-                               }
-                               // Third, case-insensitive field name match,
-                               // but only if a better match hasn't already been seen
-                               if !ok && strings.EqualFold(sf.Name, key) {
-                                       f = sf
-                                       ok = true
+                                       subv = subv.Field(i)
                                }
-                       }
-
-                       // Extract value; name must be exported.
-                       if ok {
-                               if f.PkgPath != "" {
-                                       d.saveError(&UnmarshalFieldError{key, st, f})
-                               } else {
-                                       subv = sv.FieldByIndex(f.Index)
+                       } else {
+                               // To give a good error, a quick scan for unexported fields in top level.
+                               st := sv.Type()
+                               for i := 0; i < st.NumField(); i++ {
+                                       f := st.Field(i)
+                                       if f.PkgPath != "" && strings.EqualFold(f.Name, key) {
+                                               d.saveError(&UnmarshalFieldError{key, st, f})
+                                       }
                                }
-                               _, opts := parseTag(f.Tag.Get("json"))
-                               destring = opts.Contains("string")
                        }
                }
 
@@ -1005,11 +988,3 @@ func unquoteBytes(s []byte) (t []byte, ok bool) {
        }
        return b[0:w], true
 }
-
-// The following is issue 3069.
-
-// BUG(rsc): This package ignores anonymous (embedded) struct fields
-// during encoding and decoding.  A future version may assign meaning
-// to them.  To force an anonymous field to be ignored in all future
-// versions of this package, use an explicit `json:"-"` tag in the struct
-// definition.
index e588b2853360b2f57523e0cfaee7fdb4260da09e..f2da141b8f999aa7ba05c2e920415f377b9fd3b2 100644 (file)
@@ -7,6 +7,7 @@ package json
 import (
        "bytes"
        "fmt"
+       "image"
        "reflect"
        "strings"
        "testing"
@@ -74,6 +75,100 @@ var (
        umstruct = ustruct{unmarshaler{true}}
 )
 
+// Test data structures for anonymous fields.
+
+type Point struct {
+       Z int
+}
+
+type Top struct {
+       Level0 int
+       Embed0
+       *Embed0a
+       *Embed0b `json:"e,omitempty"` // treated as named
+       Embed0c  `json:"-"`           // ignored
+       Loop
+       Embed0p // has Point with X, Y, used
+       Embed0q // has Point with Z, used
+}
+
+type Embed0 struct {
+       Level1a int // overridden by Embed0a's Level1a with json tag
+       Level1b int // used because Embed0a's Level1b is renamed
+       Level1c int // used because Embed0a's Level1c is ignored
+       Level1d int // annihilated by Embed0a's Level1d
+       Level1e int `json:"x"` // annihilated by Embed0a.Level1e
+}
+
+type Embed0a struct {
+       Level1a int `json:"Level1a,omitempty"`
+       Level1b int `json:"LEVEL1B,omitempty"`
+       Level1c int `json:"-"`
+       Level1d int // annihilated by Embed0's Level1d
+       Level1f int `json:"x"` // annihilated by Embed0's Level1e
+}
+
+type Embed0b Embed0
+
+type Embed0c Embed0
+
+type Embed0p struct {
+       image.Point
+}
+
+type Embed0q struct {
+       Point
+}
+
+type Loop struct {
+       Loop1 int `json:",omitempty"`
+       Loop2 int `json:",omitempty"`
+       *Loop
+}
+
+// From reflect test:
+// The X in S6 and S7 annihilate, but they also block the X in S8.S9.
+type S5 struct {
+       S6
+       S7
+       S8
+}
+
+type S6 struct {
+       X int
+}
+
+type S7 S6
+
+type S8 struct {
+       S9
+}
+
+type S9 struct {
+       X int
+       Y int
+}
+
+// From reflect test:
+// The X in S11.S6 and S12.S6 annihilate, but they also block the X in S13.S8.S9.
+type S10 struct {
+       S11
+       S12
+       S13
+}
+
+type S11 struct {
+       S6
+}
+
+type S12 struct {
+       S6
+}
+
+type S13 struct {
+       S8
+}
+
 type unmarshalTest struct {
        in        string
        ptr       interface{}
@@ -82,6 +177,12 @@ type unmarshalTest struct {
        useNumber bool
 }
 
+type Ambig struct {
+       // Given "hello", the first match should win.
+       First  int `json:"HELLO"`
+       Second int `json:"Hello"`
+}
+
 var unmarshalTests = []unmarshalTest{
        // basic types
        {in: `true`, ptr: new(bool), out: true},
@@ -137,6 +238,74 @@ var unmarshalTests = []unmarshalTest{
        {in: `[{"T":false}]`, ptr: &umslice, out: umslice},
        {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
        {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
+
+       {
+               in: `{
+                       "Level0": 1,
+                       "Level1b": 2,
+                       "Level1c": 3,
+                       "x": 4,
+                       "Level1a": 5,
+                       "LEVEL1B": 6,
+                       "e": {
+                               "Level1a": 8,
+                               "Level1b": 9,
+                               "Level1c": 10,
+                               "Level1d": 11,
+                               "x": 12
+                       },
+                       "Loop1": 13,
+                       "Loop2": 14,
+                       "X": 15,
+                       "Y": 16,
+                       "Z": 17
+               }`,
+               ptr: new(Top),
+               out: Top{
+                       Level0: 1,
+                       Embed0: Embed0{
+                               Level1b: 2,
+                               Level1c: 3,
+                       },
+                       Embed0a: &Embed0a{
+                               Level1a: 5,
+                               Level1b: 6,
+                       },
+                       Embed0b: &Embed0b{
+                               Level1a: 8,
+                               Level1b: 9,
+                               Level1c: 10,
+                               Level1d: 11,
+                               Level1e: 12,
+                       },
+                       Loop: Loop{
+                               Loop1: 13,
+                               Loop2: 14,
+                       },
+                       Embed0p: Embed0p{
+                               Point: image.Point{X: 15, Y: 16},
+                       },
+                       Embed0q: Embed0q{
+                               Point: Point{Z: 17},
+                       },
+               },
+       },
+       {
+               in:  `{"hello": 1}`,
+               ptr: new(Ambig),
+               out: Ambig{First: 1},
+       },
+
+       {
+               in:  `{"X": 1,"Y":2}`,
+               ptr: new(S5),
+               out: S5{S8: S8{S9: S9{Y: 2}}},
+       },
+       {
+               in:  `{"X": 1,"Y":2}`,
+               ptr: new(S10),
+               out: S10{S13: S13{S8: S8{S9: S9{Y: 2}}}},
+       },
 }
 
 func TestMarshal(t *testing.T) {
@@ -720,35 +889,6 @@ func TestRefUnmarshal(t *testing.T) {
        }
 }
 
-// Test that anonymous fields are ignored.
-// We may assign meaning to them later.
-func TestAnonymous(t *testing.T) {
-       type S struct {
-               T
-               N int
-       }
-
-       data, err := Marshal(new(S))
-       if err != nil {
-               t.Fatalf("Marshal: %v", err)
-       }
-       want := `{"N":0}`
-       if string(data) != want {
-               t.Fatalf("Marshal = %#q, want %#q", string(data), want)
-       }
-
-       var s S
-       if err := Unmarshal([]byte(`{"T": 1, "T": {"Y": 1}, "N": 2}`), &s); err != nil {
-               t.Fatalf("Unmarshal: %v", err)
-       }
-       if s.N != 2 {
-               t.Fatal("Unmarshal: did not set N")
-       }
-       if s.T.Y != 0 {
-               t.Fatal("Unmarshal: did set T.Y")
-       }
-}
-
 // Test that the empty string doesn't panic decoding when ,string is specified
 // Issue 3450
 func TestEmptyString(t *testing.T) {
index 49ab13c79f2b40d248beb463b9803e3eb32bbf9c..c8535ef79d6f5e832cc77fa8bf47d5ed6aca9689 100644 (file)
@@ -84,6 +84,16 @@ import (
 // only Unicode letters, digits, dollar signs, percent signs, hyphens,
 // underscores and slashes.
 //
+// Anonymous struct fields are usually marshaled as if their inner exported fields
+// were fields in the outer struct, subject to the usual Go visibility rules.
+// An anonymous struct field with a name given in its JSON tag is treated as 
+// having that name instead of as anonymous.
+//
+// Handling of anonymous struct fields is new in Go 1.1.
+// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of
+// an anonymous struct field in both current and earlier versions, give the field
+// a JSON tag of "-".
+//
 // Map values encode as JSON objects.
 // The map's key type must be string; the object keys are used directly
 // as map keys.
@@ -333,9 +343,9 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
        case reflect.Struct:
                e.WriteByte('{')
                first := true
-               for _, ef := range encodeFields(v.Type()) {
-                       fieldValue := v.Field(ef.i)
-                       if ef.omitEmpty && isEmptyValue(fieldValue) {
+               for _, f := range cachedTypeFields(v.Type()) {
+                       fv := fieldByIndex(v, f.index)
+                       if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) {
                                continue
                        }
                        if first {
@@ -343,9 +353,9 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
                        } else {
                                e.WriteByte(',')
                        }
-                       e.string(ef.tag)
+                       e.string(f.name)
                        e.WriteByte(':')
-                       e.reflectValueQuoted(fieldValue, ef.quoted)
+                       e.reflectValueQuoted(fv, f.quoted)
                }
                e.WriteByte('}')
 
@@ -440,6 +450,19 @@ func isValidTag(s string) bool {
        return true
 }
 
+func fieldByIndex(v reflect.Value, index []int) reflect.Value {
+       for _, i := range index {
+               if v.Kind() == reflect.Ptr {
+                       if v.IsNil() {
+                               return reflect.Value{}
+                       }
+                       v = v.Elem()
+               }
+               v = v.Field(i)
+       }
+       return v
+}
+
 // stringValues is a slice of reflect.Value holding *reflect.StringValue.
 // It implements the methods to sort by string.
 type stringValues []reflect.Value
@@ -498,67 +521,183 @@ func (e *encodeState) string(s string) (int, error) {
        return e.Len() - len0, nil
 }
 
-// encodeField contains information about how to encode a field of a
-// struct.
-type encodeField struct {
-       i         int // field index in struct
-       tag       string
-       quoted    bool
+// A field represents a single field found in a struct.
+type field struct {
+       name      string
+       tag       bool
+       index     []int
+       typ       reflect.Type
        omitEmpty bool
+       quoted    bool
 }
 
-var (
-       typeCacheLock     sync.RWMutex
-       encodeFieldsCache = make(map[reflect.Type][]encodeField)
-)
+// byName sorts field by name, breaking ties with depth,
+// then breaking ties with "name came from json tag", then
+// breaking ties with index sequence.
+type byName []field
 
-// encodeFields returns a slice of encodeField for a given
-// struct type.
-func encodeFields(t reflect.Type) []encodeField {
-       typeCacheLock.RLock()
-       fs, ok := encodeFieldsCache[t]
-       typeCacheLock.RUnlock()
-       if ok {
-               return fs
-       }
+func (x byName) Len() int { return len(x) }
 
-       typeCacheLock.Lock()
-       defer typeCacheLock.Unlock()
-       fs, ok = encodeFieldsCache[t]
-       if ok {
-               return fs
+func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+
+func (x byName) Less(i, j int) bool {
+       if x[i].name != x[j].name {
+               return x[i].name < x[j].name
+       }
+       if len(x[i].index) != len(x[j].index) {
+               return len(x[i].index) < len(x[j].index)
+       }
+       if x[i].tag != x[j].tag {
+               return x[i].tag
        }
+       return byIndex(x).Less(i, j)
+}
 
-       v := reflect.Zero(t)
-       n := v.NumField()
-       for i := 0; i < n; i++ {
-               f := t.Field(i)
-               if f.PkgPath != "" {
-                       continue
+// byIndex sorts field by index sequence.
+type byIndex []field
+
+func (x byIndex) Len() int { return len(x) }
+
+func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+
+func (x byIndex) Less(i, j int) bool {
+       for k, xik := range x[i].index {
+               if k >= len(x[j].index) {
+                       return false
                }
-               if f.Anonymous {
-                       // We want to do a better job with these later,
-                       // so for now pretend they don't exist.
-                       continue
+               if xik != x[j].index[k] {
+                       return xik < x[j].index[k]
                }
-               var ef encodeField
-               ef.i = i
-               ef.tag = f.Name
+       }
+       return len(x[i].index) < len(x[j].index)
+}
 
-               tv := f.Tag.Get("json")
-               if tv != "" {
-                       if tv == "-" {
+// typeFields returns a list of fields that JSON should recognize for the given type.
+// The algorithm is breadth-first search over the set of structs to include - the top struct
+// and then any reachable anonymous structs.
+func typeFields(t reflect.Type) []field {
+       // Anonymous fields to explore at the current level and the next.
+       current := []field{}
+       next := []field{{typ: t}}
+
+       // Count of queued names for current level and the next.
+       count := map[reflect.Type]int{}
+       nextCount := map[reflect.Type]int{}
+
+       // Types already visited at an earlier level.
+       visited := map[reflect.Type]bool{}
+
+       // Fields found.
+       var fields []field
+
+       for len(next) > 0 {
+               current, next = next, current[:0]
+               count, nextCount = nextCount, map[reflect.Type]int{}
+
+               for _, f := range current {
+                       if visited[f.typ] {
                                continue
                        }
-                       name, opts := parseTag(tv)
-                       if isValidTag(name) {
-                               ef.tag = name
+                       visited[f.typ] = true
+
+                       // Scan f.typ for fields to include.
+                       for i := 0; i < f.typ.NumField(); i++ {
+                               sf := f.typ.Field(i)
+                               if sf.PkgPath != "" { // unexported
+                                       continue
+                               }
+                               tag := sf.Tag.Get("json")
+                               if tag == "-" {
+                                       continue
+                               }
+                               name, opts := parseTag(tag)
+                               if !isValidTag(name) {
+                                       name = ""
+                               }
+                               index := make([]int, len(f.index)+1)
+                               copy(index, f.index)
+                               index[len(f.index)] = i
+                               // Record found field and index sequence.
+                               if name != "" || !sf.Anonymous {
+                                       tagged := name != ""
+                                       if name == "" {
+                                               name = sf.Name
+                                       }
+                                       fields = append(fields, field{name, tagged, index, sf.Type,
+                                               opts.Contains("omitempty"), opts.Contains("string")})
+                                       if count[f.typ] > 1 {
+                                               // If there were multiple instances, add a second,
+                                               // so that the annihilation code will see a duplicate.
+                                               // It only cares about the distinction between 1 or 2,
+                                               // so don't bother generating any more copies.
+                                               fields = append(fields, fields[len(fields)-1])
+                                       }
+                                       continue
+                               }
+
+                               // Record new anonymous struct to explore in next round.
+                               ft := sf.Type
+                               if ft.Name() == "" {
+                                       // Must be pointer.
+                                       ft = ft.Elem()
+                               }
+                               nextCount[ft]++
+                               if nextCount[ft] == 1 {
+                                       next = append(next, field{name: ft.Name(), index: index, typ: ft})
+                               }
                        }
-                       ef.omitEmpty = opts.Contains("omitempty")
-                       ef.quoted = opts.Contains("string")
                }
-               fs = append(fs, ef)
        }
-       encodeFieldsCache[t] = fs
-       return fs
+
+       sort.Sort(byName(fields))
+
+       // Remove fields with annihilating name collisions
+       // and also fields shadowed by fields with explicit JSON tags.
+       name := ""
+       out := fields[:0]
+       for _, f := range fields {
+               if f.name != name {
+                       name = f.name
+                       out = append(out, f)
+                       continue
+               }
+               if n := len(out); n > 0 && out[n-1].name == name && (!out[n-1].tag || f.tag) {
+                       out = out[:n-1]
+               }
+       }
+       fields = out
+
+       sort.Sort(byIndex(fields))
+
+       return fields
+}
+
+var fieldCache struct {
+       sync.RWMutex
+       m map[reflect.Type][]field
+}
+
+// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
+func cachedTypeFields(t reflect.Type) []field {
+       fieldCache.RLock()
+       f := fieldCache.m[t]
+       fieldCache.RUnlock()
+       if f != nil {
+               return f
+       }
+
+       // Compute fields without lock.
+       // Might duplicate effort but won't hold other computations back.
+       f = typeFields(t)
+       if f == nil {
+               f = []field{}
+       }
+
+       fieldCache.Lock()
+       if fieldCache.m == nil {
+               fieldCache.m = map[reflect.Type][]field{}
+       }
+       fieldCache.m[t] = f
+       fieldCache.Unlock()
+       return f
 }