From: Brad Fitzpatrick Date: Tue, 24 Sep 2013 02:57:19 +0000 (-0700) Subject: encoding/json: don't cache value addressability when building first encoder X-Git-Tag: go1.2rc2~127 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=0f3ea75020cf7dda64805fe9aeef26be60cf16cd;p=gostls13.git encoding/json: don't cache value addressability when building first encoder newTypeEncoder (called once per type and then cached) was looking at the first value seen of that type's addressability and caching the encoder decision. If the first value seen was addressable and a future one wasn't, it would panic. Instead, introduce a new wrapper encoder type that checks CanAddr at runtime. Fixes #6458 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/13839045 --- diff --git a/src/pkg/encoding/json/encode.go b/src/pkg/encoding/json/encode.go index f951250e98..7d6c71d7a9 100644 --- a/src/pkg/encoding/json/encode.go +++ b/src/pkg/encoding/json/encode.go @@ -305,7 +305,7 @@ func (e *encodeState) reflectValue(v reflect.Value) { valueEncoder(v)(e, v, false) } -type encoderFunc func(e *encodeState, v reflect.Value, _ bool) +type encoderFunc func(e *encodeState, v reflect.Value, quoted bool) var encoderCache struct { sync.RWMutex @@ -316,11 +316,10 @@ func valueEncoder(v reflect.Value) encoderFunc { if !v.IsValid() { return invalidValueEncoder } - t := v.Type() - return typeEncoder(t, v) + return typeEncoder(v.Type()) } -func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func typeEncoder(t reflect.Type) encoderFunc { encoderCache.RLock() f := encoderCache.m[t] encoderCache.RUnlock() @@ -346,7 +345,7 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { // Compute fields without lock. // Might duplicate effort but won't hold other computations back. - f = newTypeEncoder(t, vx) + f = newTypeEncoder(t, true) wg.Done() encoderCache.Lock() encoderCache.m[t] = f @@ -354,38 +353,33 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { return f } -// newTypeEncoder constructs an encoderFunc for a type. -// The provided vx is an example value of type t. It's the first seen -// value of that type and should not be used to encode. It may be -// zero. -func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { - if !vx.IsValid() { - vx = reflect.New(t).Elem() - } +var ( + marshalerType = reflect.TypeOf(new(Marshaler)).Elem() + textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +) - _, ok := vx.Interface().(Marshaler) - if ok { +// newTypeEncoder constructs an encoderFunc for a type. +// The returned encoder only checks CanAddr when allowAddr is true. +func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { + if t.Implements(marshalerType) { return marshalerEncoder } - if vx.Kind() != reflect.Ptr && vx.CanAddr() { - _, ok = vx.Addr().Interface().(Marshaler) - if ok { - return addrMarshalerEncoder + if t.Kind() != reflect.Ptr && allowAddr { + if reflect.PtrTo(t).Implements(marshalerType) { + return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } } - _, ok = vx.Interface().(encoding.TextMarshaler) - if ok { + if t.Implements(textMarshalerType) { return textMarshalerEncoder } - if vx.Kind() != reflect.Ptr && vx.CanAddr() { - _, ok = vx.Addr().Interface().(encoding.TextMarshaler) - if ok { - return addrTextMarshalerEncoder + if t.Kind() != reflect.Ptr && allowAddr { + if reflect.PtrTo(t).Implements(textMarshalerType) { + return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) } } - switch vx.Kind() { + switch t.Kind() { case reflect.Bool: return boolEncoder case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -401,15 +395,15 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { case reflect.Interface: return interfaceEncoder case reflect.Struct: - return newStructEncoder(t, vx) + return newStructEncoder(t) case reflect.Map: - return newMapEncoder(t, vx) + return newMapEncoder(t) case reflect.Slice: - return newSliceEncoder(t, vx) + return newSliceEncoder(t) case reflect.Array: - return newArrayEncoder(t, vx) + return newArrayEncoder(t) case reflect.Ptr: - return newPtrEncoder(t, vx) + return newPtrEncoder(t) default: return unsupportedTypeEncoder } @@ -593,27 +587,19 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { } e.string(f.name) e.WriteByte(':') - if tenc := se.fieldEncs[i]; tenc != nil { - tenc(e, fv, f.quoted) - } else { - // Slower path. - e.reflectValue(fv) - } + se.fieldEncs[i](e, fv, f.quoted) } e.WriteByte('}') } -func newStructEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func newStructEncoder(t reflect.Type) encoderFunc { fields := cachedTypeFields(t) se := &structEncoder{ fields: fields, fieldEncs: make([]encoderFunc, len(fields)), } for i, f := range fields { - vxf := fieldByIndex(vx, f.index) - if vxf.IsValid() { - se.fieldEncs[i] = typeEncoder(vxf.Type(), vxf) - } + se.fieldEncs[i] = typeEncoder(typeByIndex(t, f.index)) } return se.encode } @@ -641,11 +627,11 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { e.WriteByte('}') } -func newMapEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func newMapEncoder(t reflect.Type) encoderFunc { if t.Key().Kind() != reflect.String { return unsupportedTypeEncoder } - me := &mapEncoder{typeEncoder(vx.Type().Elem(), reflect.Value{})} + me := &mapEncoder{typeEncoder(t.Elem())} return me.encode } @@ -684,12 +670,12 @@ func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) { se.arrayEnc(e, v, false) } -func newSliceEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. - if vx.Type().Elem().Kind() == reflect.Uint8 { + if t.Elem().Kind() == reflect.Uint8 { return encodeByteSlice } - enc := &sliceEncoder{newArrayEncoder(t, vx)} + enc := &sliceEncoder{newArrayEncoder(t)} return enc.encode } @@ -709,8 +695,8 @@ func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) { e.WriteByte(']') } -func newArrayEncoder(t reflect.Type, vx reflect.Value) encoderFunc { - enc := &arrayEncoder{typeEncoder(t.Elem(), reflect.Value{})} +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := &arrayEncoder{typeEncoder(t.Elem())} return enc.encode } @@ -726,8 +712,27 @@ func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, _ bool) { pe.elemEnc(e, v.Elem(), false) } -func newPtrEncoder(t reflect.Type, vx reflect.Value) encoderFunc { - enc := &ptrEncoder{typeEncoder(t.Elem(), reflect.Value{})} +func newPtrEncoder(t reflect.Type) encoderFunc { + enc := &ptrEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type condAddrEncoder struct { + canAddrEnc, elseEnc encoderFunc +} + +func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { + if v.CanAddr() { + ce.canAddrEnc(e, v, quoted) + } else { + ce.elseEnc(e, v, quoted) + } +} + +// newCondAddrEncoder returns an encoder that checks whether its value +// CanAddr and delegates to canAddrEnc if so, else to elseEnc. +func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { + enc := &condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} return enc.encode } @@ -763,6 +768,16 @@ func fieldByIndex(v reflect.Value, index []int) reflect.Value { return v } +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, i := range index { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + t = t.Field(i).Type + } + return t +} + // stringValues is a slice of reflect.Value holding *reflect.StringValue. // It implements the methods to sort by string. type stringValues []reflect.Value diff --git a/src/pkg/encoding/json/encode_test.go b/src/pkg/encoding/json/encode_test.go index 7052e1db7c..9395db7cb6 100644 --- a/src/pkg/encoding/json/encode_test.go +++ b/src/pkg/encoding/json/encode_test.go @@ -401,3 +401,27 @@ func TestStringBytes(t *testing.T) { t.Errorf("encodings differ at %#q vs %#q", enc, encBytes) } } + +func TestIssue6458(t *testing.T) { + type Foo struct { + M RawMessage + } + x := Foo{RawMessage(`"foo"`)} + + b, err := Marshal(&x) + if err != nil { + t.Fatal(err) + } + if want := `{"M":"foo"}`; string(b) != want { + t.Errorf("Marshal(&x) = %#q; want %#q", b, want) + } + + b, err = Marshal(x) + if err != nil { + t.Fatal(err) + } + + if want := `{"M":"ImZvbyI="}`; string(b) != want { + t.Errorf("Marshal(x) = %#q; want %#q", b, want) + } +}