]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: don't cache value addressability when building first encoder
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 24 Sep 2013 02:57:19 +0000 (19:57 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 24 Sep 2013 02:57:19 +0000 (19:57 -0700)
newTypeEncoder (called once per type and then cached) was
looking at the first value seen of that type's addressability
and caching the encoder decision.  If the first value seen was
addressable and a future one wasn't, it would panic.

Instead, introduce a new wrapper encoder type that checks
CanAddr at runtime.

Fixes #6458

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/13839045

src/pkg/encoding/json/encode.go
src/pkg/encoding/json/encode_test.go

index f951250e980b7eb19c86cb69624027bc9e1ce346..7d6c71d7a9016a276c3199b4054111005e1d20f4 100644 (file)
@@ -305,7 +305,7 @@ func (e *encodeState) reflectValue(v reflect.Value) {
        valueEncoder(v)(e, v, false)
 }
 
-type encoderFunc func(e *encodeState, v reflect.Value, _ bool)
+type encoderFunc func(e *encodeState, v reflect.Value, quoted bool)
 
 var encoderCache struct {
        sync.RWMutex
@@ -316,11 +316,10 @@ func valueEncoder(v reflect.Value) encoderFunc {
        if !v.IsValid() {
                return invalidValueEncoder
        }
-       t := v.Type()
-       return typeEncoder(t, v)
+       return typeEncoder(v.Type())
 }
 
-func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+func typeEncoder(t reflect.Type) encoderFunc {
        encoderCache.RLock()
        f := encoderCache.m[t]
        encoderCache.RUnlock()
@@ -346,7 +345,7 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
 
        // Compute fields without lock.
        // Might duplicate effort but won't hold other computations back.
-       f = newTypeEncoder(t, vx)
+       f = newTypeEncoder(t, true)
        wg.Done()
        encoderCache.Lock()
        encoderCache.m[t] = f
@@ -354,38 +353,33 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
        return f
 }
 
-// newTypeEncoder constructs an encoderFunc for a type.
-// The provided vx is an example value of type t. It's the first seen
-// value of that type and should not be used to encode. It may be
-// zero.
-func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
-       if !vx.IsValid() {
-               vx = reflect.New(t).Elem()
-       }
+var (
+       marshalerType     = reflect.TypeOf(new(Marshaler)).Elem()
+       textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
+)
 
-       _, ok := vx.Interface().(Marshaler)
-       if ok {
+// newTypeEncoder constructs an encoderFunc for a type.
+// The returned encoder only checks CanAddr when allowAddr is true.
+func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
+       if t.Implements(marshalerType) {
                return marshalerEncoder
        }
-       if vx.Kind() != reflect.Ptr && vx.CanAddr() {
-               _, ok = vx.Addr().Interface().(Marshaler)
-               if ok {
-                       return addrMarshalerEncoder
+       if t.Kind() != reflect.Ptr && allowAddr {
+               if reflect.PtrTo(t).Implements(marshalerType) {
+                       return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
                }
        }
 
-       _, ok = vx.Interface().(encoding.TextMarshaler)
-       if ok {
+       if t.Implements(textMarshalerType) {
                return textMarshalerEncoder
        }
-       if vx.Kind() != reflect.Ptr && vx.CanAddr() {
-               _, ok = vx.Addr().Interface().(encoding.TextMarshaler)
-               if ok {
-                       return addrTextMarshalerEncoder
+       if t.Kind() != reflect.Ptr && allowAddr {
+               if reflect.PtrTo(t).Implements(textMarshalerType) {
+                       return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
                }
        }
 
-       switch vx.Kind() {
+       switch t.Kind() {
        case reflect.Bool:
                return boolEncoder
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -401,15 +395,15 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
        case reflect.Interface:
                return interfaceEncoder
        case reflect.Struct:
-               return newStructEncoder(t, vx)
+               return newStructEncoder(t)
        case reflect.Map:
-               return newMapEncoder(t, vx)
+               return newMapEncoder(t)
        case reflect.Slice:
-               return newSliceEncoder(t, vx)
+               return newSliceEncoder(t)
        case reflect.Array:
-               return newArrayEncoder(t, vx)
+               return newArrayEncoder(t)
        case reflect.Ptr:
-               return newPtrEncoder(t, vx)
+               return newPtrEncoder(t)
        default:
                return unsupportedTypeEncoder
        }
@@ -593,27 +587,19 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
                }
                e.string(f.name)
                e.WriteByte(':')
-               if tenc := se.fieldEncs[i]; tenc != nil {
-                       tenc(e, fv, f.quoted)
-               } else {
-                       // Slower path.
-                       e.reflectValue(fv)
-               }
+               se.fieldEncs[i](e, fv, f.quoted)
        }
        e.WriteByte('}')
 }
 
-func newStructEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+func newStructEncoder(t reflect.Type) encoderFunc {
        fields := cachedTypeFields(t)
        se := &structEncoder{
                fields:    fields,
                fieldEncs: make([]encoderFunc, len(fields)),
        }
        for i, f := range fields {
-               vxf := fieldByIndex(vx, f.index)
-               if vxf.IsValid() {
-                       se.fieldEncs[i] = typeEncoder(vxf.Type(), vxf)
-               }
+               se.fieldEncs[i] = typeEncoder(typeByIndex(t, f.index))
        }
        return se.encode
 }
@@ -641,11 +627,11 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
        e.WriteByte('}')
 }
 
-func newMapEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+func newMapEncoder(t reflect.Type) encoderFunc {
        if t.Key().Kind() != reflect.String {
                return unsupportedTypeEncoder
        }
-       me := &mapEncoder{typeEncoder(vx.Type().Elem(), reflect.Value{})}
+       me := &mapEncoder{typeEncoder(t.Elem())}
        return me.encode
 }
 
@@ -684,12 +670,12 @@ func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
        se.arrayEnc(e, v, false)
 }
 
-func newSliceEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+func newSliceEncoder(t reflect.Type) encoderFunc {
        // Byte slices get special treatment; arrays don't.
-       if vx.Type().Elem().Kind() == reflect.Uint8 {
+       if t.Elem().Kind() == reflect.Uint8 {
                return encodeByteSlice
        }
-       enc := &sliceEncoder{newArrayEncoder(t, vx)}
+       enc := &sliceEncoder{newArrayEncoder(t)}
        return enc.encode
 }
 
@@ -709,8 +695,8 @@ func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
        e.WriteByte(']')
 }
 
-func newArrayEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
-       enc := &arrayEncoder{typeEncoder(t.Elem(), reflect.Value{})}
+func newArrayEncoder(t reflect.Type) encoderFunc {
+       enc := &arrayEncoder{typeEncoder(t.Elem())}
        return enc.encode
 }
 
@@ -726,8 +712,27 @@ func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
        pe.elemEnc(e, v.Elem(), false)
 }
 
-func newPtrEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
-       enc := &ptrEncoder{typeEncoder(t.Elem(), reflect.Value{})}
+func newPtrEncoder(t reflect.Type) encoderFunc {
+       enc := &ptrEncoder{typeEncoder(t.Elem())}
+       return enc.encode
+}
+
+type condAddrEncoder struct {
+       canAddrEnc, elseEnc encoderFunc
+}
+
+func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+       if v.CanAddr() {
+               ce.canAddrEnc(e, v, quoted)
+       } else {
+               ce.elseEnc(e, v, quoted)
+       }
+}
+
+// newCondAddrEncoder returns an encoder that checks whether its value
+// CanAddr and delegates to canAddrEnc if so, else to elseEnc.
+func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc {
+       enc := &condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
        return enc.encode
 }
 
@@ -763,6 +768,16 @@ func fieldByIndex(v reflect.Value, index []int) reflect.Value {
        return v
 }
 
+func typeByIndex(t reflect.Type, index []int) reflect.Type {
+       for _, i := range index {
+               if t.Kind() == reflect.Ptr {
+                       t = t.Elem()
+               }
+               t = t.Field(i).Type
+       }
+       return t
+}
+
 // stringValues is a slice of reflect.Value holding *reflect.StringValue.
 // It implements the methods to sort by string.
 type stringValues []reflect.Value
index 7052e1db7c93228e53d479056dadd13f5f6980ba..9395db7cb6fd543bae8e340ed59c4f77e726bebb 100644 (file)
@@ -401,3 +401,27 @@ func TestStringBytes(t *testing.T) {
                t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
        }
 }
+
+func TestIssue6458(t *testing.T) {
+       type Foo struct {
+               M RawMessage
+       }
+       x := Foo{RawMessage(`"foo"`)}
+
+       b, err := Marshal(&x)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if want := `{"M":"foo"}`; string(b) != want {
+               t.Errorf("Marshal(&x) = %#q; want %#q", b, want)
+       }
+
+       b, err = Marshal(x)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if want := `{"M":"ImZvbyI="}`; string(b) != want {
+               t.Errorf("Marshal(x) = %#q; want %#q", b, want)
+       }
+}