]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: faster encoding
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 9 Aug 2013 16:46:47 +0000 (09:46 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 9 Aug 2013 16:46:47 +0000 (09:46 -0700)
The old code was caching per-type struct field info. Instead,
cache type-specific encoding funcs, tailored for that
particular type to avoid unnecessary reflection at runtime.
Once the machine is built once, future encodings of that type
just run the func.

benchmark               old ns/op    new ns/op    delta
BenchmarkCodeEncoder     48424939     36975320  -23.64%

benchmark                old MB/s     new MB/s  speedup
BenchmarkCodeEncoder        40.07        52.48    1.31x

Additionally, the numbers seem stable now at ~52 MB/s, whereas
the numbers for the old code were all over the place: 11 MB/s,
40 MB/s, 13 MB/s, 39 MB/s, etc.  In the benchmark above I compared
against the best I saw the old code do.

R=rsc, adg
CC=gobot, golang-dev, r
https://golang.org/cl/9129044

src/pkg/encoding/json/decode_test.go
src/pkg/encoding/json/encode.go

index 3e16c4aec0dee86e6eaef20a3a2f1631b514247a..3fa366500ff08b3f4ab23aba2fa47672d25f12a5 100644 (file)
@@ -421,6 +421,45 @@ func TestMarshalNumberZeroVal(t *testing.T) {
        }
 }
 
+func TestMarshalEmbeds(t *testing.T) {
+       top := &Top{
+               Level0: 1,
+               Embed0: Embed0{
+                       Level1b: 2,
+                       Level1c: 3,
+               },
+               Embed0a: &Embed0a{
+                       Level1a: 5,
+                       Level1b: 6,
+               },
+               Embed0b: &Embed0b{
+                       Level1a: 8,
+                       Level1b: 9,
+                       Level1c: 10,
+                       Level1d: 11,
+                       Level1e: 12,
+               },
+               Loop: Loop{
+                       Loop1: 13,
+                       Loop2: 14,
+               },
+               Embed0p: Embed0p{
+                       Point: image.Point{X: 15, Y: 16},
+               },
+               Embed0q: Embed0q{
+                       Point: Point{Z: 17},
+               },
+       }
+       b, err := Marshal(top)
+       if err != nil {
+               t.Fatal(err)
+       }
+       want := "{\"Level0\":1,\"Level1b\":2,\"Level1c\":3,\"Level1a\":5,\"LEVEL1B\":6,\"e\":{\"Level1a\":8,\"Level1b\":9,\"Level1c\":10,\"Level1d\":11,\"x\":12},\"Loop1\":13,\"Loop2\":14,\"X\":15,\"Y\":16,\"Z\":17}"
+       if string(b) != want {
+               t.Errorf("Wrong marshal result.\n got: %q\nwant: %q", b, want)
+       }
+}
+
 func TestUnmarshal(t *testing.T) {
        for i, tt := range unmarshalTests {
                var scan scanner
@@ -436,7 +475,7 @@ func TestUnmarshal(t *testing.T) {
                }
                // v = new(right-type)
                v := reflect.New(reflect.TypeOf(tt.ptr).Elem())
-               dec := NewDecoder(bytes.NewBuffer(in))
+               dec := NewDecoder(bytes.NewReader(in))
                if tt.useNumber {
                        dec.UseNumber()
                }
@@ -461,7 +500,7 @@ func TestUnmarshal(t *testing.T) {
                                continue
                        }
                        vv := reflect.New(reflect.TypeOf(tt.ptr).Elem())
-                       dec = NewDecoder(bytes.NewBuffer(enc))
+                       dec = NewDecoder(bytes.NewReader(enc))
                        if tt.useNumber {
                                dec.UseNumber()
                        }
@@ -471,6 +510,8 @@ func TestUnmarshal(t *testing.T) {
                        }
                        if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) {
                                t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), vv.Elem().Interface())
+                               t.Errorf("     In: %q", strings.Map(noSpace, string(in)))
+                               t.Errorf("Marshal: %q", strings.Map(noSpace, string(enc)))
                                continue
                        }
                }
index be4a7b1fa66c82b7650863fbe498628f7acf6427..5e8020502f7e7de66455b57e86ea2b15b3af3fa4 100644 (file)
@@ -266,6 +266,9 @@ func (e *encodeState) marshal(v interface{}) (err error) {
                        if _, ok := r.(runtime.Error); ok {
                                panic(r)
                        }
+                       if s, ok := r.(string); ok {
+                               panic(s)
+                       }
                        err = r.(error)
                }
        }()
@@ -298,186 +301,390 @@ func isEmptyValue(v reflect.Value) bool {
 }
 
 func (e *encodeState) reflectValue(v reflect.Value) {
-       e.reflectValueQuoted(v, false)
+       valueEncoder(v)(e, v, false)
+}
+
+type encoderFunc func(e *encodeState, v reflect.Value, _ bool)
+
+var encoderCache struct {
+       sync.RWMutex
+       m map[reflect.Type]encoderFunc
 }
 
-// reflectValueQuoted writes the value in v to the output.
-// If quoted is true, the serialization is wrapped in a JSON string.
-func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
+func valueEncoder(v reflect.Value) encoderFunc {
        if !v.IsValid() {
-               e.WriteString("null")
-               return
+               return invalidValueEncoder
        }
+       t := v.Type()
+       return typeEncoder(t, v)
+}
 
-       m, ok := v.Interface().(Marshaler)
-       if !ok {
-               // T doesn't match the interface. Check against *T too.
-               if v.Kind() != reflect.Ptr && v.CanAddr() {
-                       m, ok = v.Addr().Interface().(Marshaler)
-                       if ok {
-                               v = v.Addr()
-                       }
-               }
+func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       encoderCache.RLock()
+       f := encoderCache.m[t]
+       encoderCache.RUnlock()
+       if f != nil {
+               return f
        }
-       if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
-               b, err := m.MarshalJSON()
-               if err == nil {
-                       // copy JSON into buffer, checking validity.
-                       err = compact(&e.Buffer, b, true)
-               }
-               if err != nil {
-                       e.error(&MarshalerError{v.Type(), err})
+
+       // To deal with recursive types, populate the map with an
+       // indirect func before we build it. This type waits on the
+       // real func (f) to be ready and then calls it.  This indirect
+       // func is only used for recursive types.
+       encoderCache.Lock()
+       if encoderCache.m == nil {
+               encoderCache.m = make(map[reflect.Type]encoderFunc)
+       }
+       var wg sync.WaitGroup
+       wg.Add(1)
+       encoderCache.m[t] = func(e *encodeState, v reflect.Value, quoted bool) {
+               wg.Wait()
+               f(e, v, quoted)
+       }
+       encoderCache.Unlock()
+
+       // Compute fields without lock.
+       // Might duplicate effort but won't hold other computations back.
+       f = newTypeEncoder(t, vx)
+       wg.Done()
+       encoderCache.Lock()
+       encoderCache.m[t] = f
+       encoderCache.Unlock()
+       return f
+}
+
+// newTypeEncoder constructs an encoderFunc for a type.
+// The provided vx is an example value of type t. It's the first seen
+// value of that type and should not be used to encode. It may be
+// zero.
+func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       if !vx.IsValid() {
+               vx = reflect.New(t).Elem()
+       }
+       _, ok := vx.Interface().(Marshaler)
+       if ok {
+               return valueIsMarshallerEncoder
+       }
+       // T doesn't match the interface. Check against *T too.
+       if vx.Kind() != reflect.Ptr && vx.CanAddr() {
+               _, ok = vx.Addr().Interface().(Marshaler)
+               if ok {
+                       return valueAddrIsMarshallerEncoder
                }
+       }
+       switch vx.Kind() {
+       case reflect.Bool:
+               return boolEncoder
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               return intEncoder
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+               return uintEncoder
+       case reflect.Float32:
+               return float32Encoder
+       case reflect.Float64:
+               return float64Encoder
+       case reflect.String:
+               return stringEncoder
+       case reflect.Interface:
+               return interfaceEncoder
+       case reflect.Struct:
+               return newStructEncoder(t, vx)
+       case reflect.Map:
+               return newMapEncoder(t, vx)
+       case reflect.Slice:
+               return newSliceEncoder(t, vx)
+       case reflect.Array:
+               return newArrayEncoder(t, vx)
+       case reflect.Ptr:
+               return newPtrEncoder(t, vx)
+       default:
+               return unsupportedTypeEncoder
+       }
+}
+
+func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       e.WriteString("null")
+}
+
+func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       if v.Kind() == reflect.Ptr && v.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       m := v.Interface().(Marshaler)
+       b, err := m.MarshalJSON()
+       if err == nil {
+               // copy JSON into buffer, checking validity.
+               err = compact(&e.Buffer, b, true)
+       }
+       if err != nil {
+               e.error(&MarshalerError{v.Type(), err})
+       }
+}
+
+func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       va := v.Addr()
+       if va.Kind() == reflect.Ptr && va.IsNil() {
+               e.WriteString("null")
                return
        }
+       m := va.Interface().(Marshaler)
+       b, err := m.MarshalJSON()
+       if err == nil {
+               // copy JSON into buffer, checking validity.
+               err = compact(&e.Buffer, b, true)
+       }
+       if err != nil {
+               e.error(&MarshalerError{v.Type(), err})
+       }
+}
 
-       writeString := (*encodeState).WriteString
+func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       if quoted {
+               e.WriteByte('"')
+       }
+       if v.Bool() {
+               e.WriteString("true")
+       } else {
+               e.WriteString("false")
+       }
        if quoted {
-               writeString = (*encodeState).string
+               e.WriteByte('"')
        }
+}
 
-       switch v.Kind() {
-       case reflect.Bool:
-               x := v.Bool()
-               if x {
-                       writeString(e, "true")
-               } else {
-                       writeString(e, "false")
-               }
+func intEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
+       if quoted {
+               e.WriteByte('"')
+       }
+       e.Write(b)
+       if quoted {
+               e.WriteByte('"')
+       }
+}
 
-       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-               b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
-               if quoted {
-                       writeString(e, string(b))
-               } else {
-                       e.Write(b)
+func uintEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
+       if quoted {
+               e.WriteByte('"')
+       }
+       e.Write(b)
+       if quoted {
+               e.WriteByte('"')
+       }
+}
+
+type floatEncoder int // number of bits
+
+func (bits floatEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+       f := v.Float()
+       if math.IsInf(f, 0) || math.IsNaN(f) {
+               e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))})
+       }
+       b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, int(bits))
+       if quoted {
+               e.WriteByte('"')
+       }
+       e.Write(b)
+       if quoted {
+               e.WriteByte('"')
+       }
+}
+
+var (
+       float32Encoder = (floatEncoder(32)).encode
+       float64Encoder = (floatEncoder(64)).encode
+)
+
+func stringEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       if v.Type() == numberType {
+               numStr := v.String()
+               if numStr == "" {
+                       numStr = "0" // Number's zero-val
                }
-       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-               b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
-               if quoted {
-                       writeString(e, string(b))
-               } else {
-                       e.Write(b)
+               e.WriteString(numStr)
+               return
+       }
+       if quoted {
+               sb, err := Marshal(v.String())
+               if err != nil {
+                       e.error(err)
                }
-       case reflect.Float32, reflect.Float64:
-               f := v.Float()
-               if math.IsInf(f, 0) || math.IsNaN(f) {
-                       e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits())})
+               e.string(string(sb))
+       } else {
+               e.string(v.String())
+       }
+}
+
+func interfaceEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       if v.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       e.reflectValue(v.Elem())
+}
+
+func unsupportedTypeEncoder(e *encodeState, v reflect.Value, quoted bool) {
+       e.error(&UnsupportedTypeError{v.Type()})
+}
+
+type structEncoder struct {
+       fields    []field
+       fieldEncs []encoderFunc
+}
+
+func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
+       e.WriteByte('{')
+       first := true
+       for i, f := range se.fields {
+               fv := fieldByIndex(v, f.index)
+               if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) {
+                       continue
                }
-               b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, v.Type().Bits())
-               if quoted {
-                       writeString(e, string(b))
+               if first {
+                       first = false
                } else {
-                       e.Write(b)
+                       e.WriteByte(',')
                }
-       case reflect.String:
-               if v.Type() == numberType {
-                       numStr := v.String()
-                       if numStr == "" {
-                               numStr = "0" // Number's zero-val
-                       }
-                       e.WriteString(numStr)
-                       break
-               }
-               if quoted {
-                       sb, err := Marshal(v.String())
-                       if err != nil {
-                               e.error(err)
-                       }
-                       e.string(string(sb))
+               e.string(f.name)
+               e.WriteByte(':')
+               if tenc := se.fieldEncs[i]; tenc != nil {
+                       tenc(e, fv, f.quoted)
                } else {
-                       e.string(v.String())
+                       // Slower path.
+                       e.reflectValue(fv)
                }
+       }
+       e.WriteByte('}')
+}
 
-       case reflect.Struct:
-               e.WriteByte('{')
-               first := true
-               for _, f := range cachedTypeFields(v.Type()) {
-                       fv := fieldByIndex(v, f.index)
-                       if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) {
-                               continue
-                       }
-                       if first {
-                               first = false
-                       } else {
-                               e.WriteByte(',')
-                       }
-                       e.string(f.name)
-                       e.WriteByte(':')
-                       e.reflectValueQuoted(fv, f.quoted)
+func newStructEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       fields := cachedTypeFields(t)
+       se := &structEncoder{
+               fields:    fields,
+               fieldEncs: make([]encoderFunc, len(fields)),
+       }
+       for i, f := range fields {
+               vxf := fieldByIndex(vx, f.index)
+               if vxf.IsValid() {
+                       se.fieldEncs[i] = typeEncoder(vxf.Type(), vxf)
                }
-               e.WriteByte('}')
+       }
+       return se.encode
+}
 
-       case reflect.Map:
-               if v.Type().Key().Kind() != reflect.String {
-                       e.error(&UnsupportedTypeError{v.Type()})
-               }
-               if v.IsNil() {
-                       e.WriteString("null")
-                       break
-               }
-               e.WriteByte('{')
-               var sv stringValues = v.MapKeys()
-               sort.Sort(sv)
-               for i, k := range sv {
-                       if i > 0 {
-                               e.WriteByte(',')
-                       }
-                       e.string(k.String())
-                       e.WriteByte(':')
-                       e.reflectValue(v.MapIndex(k))
-               }
-               e.WriteByte('}')
+type mapEncoder struct {
+       elemEnc encoderFunc
+}
 
-       case reflect.Slice:
-               if v.IsNil() {
-                       e.WriteString("null")
-                       break
-               }
-               if v.Type().Elem().Kind() == reflect.Uint8 {
-                       // Byte slices get special treatment; arrays don't.
-                       s := v.Bytes()
-                       e.WriteByte('"')
-                       if len(s) < 1024 {
-                               // for small buffers, using Encode directly is much faster.
-                               dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
-                               base64.StdEncoding.Encode(dst, s)
-                               e.Write(dst)
-                       } else {
-                               // for large buffers, avoid unnecessary extra temporary
-                               // buffer space.
-                               enc := base64.NewEncoder(base64.StdEncoding, e)
-                               enc.Write(s)
-                               enc.Close()
-                       }
-                       e.WriteByte('"')
-                       break
-               }
-               // Slices can be marshalled as nil, but otherwise are handled
-               // as arrays.
-               fallthrough
-       case reflect.Array:
-               e.WriteByte('[')
-               n := v.Len()
-               for i := 0; i < n; i++ {
-                       if i > 0 {
-                               e.WriteByte(',')
-                       }
-                       e.reflectValue(v.Index(i))
+func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+       if v.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       e.WriteByte('{')
+       var sv stringValues = v.MapKeys()
+       sort.Sort(sv)
+       for i, k := range sv {
+               if i > 0 {
+                       e.WriteByte(',')
                }
-               e.WriteByte(']')
+               e.string(k.String())
+               e.WriteByte(':')
+               me.elemEnc(e, v.MapIndex(k), false)
+       }
+       e.WriteByte('}')
+}
 
-       case reflect.Interface, reflect.Ptr:
-               if v.IsNil() {
-                       e.WriteString("null")
-                       return
+func newMapEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       if t.Key().Kind() != reflect.String {
+               return unsupportedTypeEncoder
+       }
+       me := &mapEncoder{typeEncoder(vx.Type().Elem(), reflect.Value{})}
+       return me.encode
+}
+
+func encodeByteSlice(e *encodeState, v reflect.Value, _ bool) {
+       if v.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       s := v.Bytes()
+       e.WriteByte('"')
+       if len(s) < 1024 {
+               // for small buffers, using Encode directly is much faster.
+               dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
+               base64.StdEncoding.Encode(dst, s)
+               e.Write(dst)
+       } else {
+               // for large buffers, avoid unnecessary extra temporary
+               // buffer space.
+               enc := base64.NewEncoder(base64.StdEncoding, e)
+               enc.Write(s)
+               enc.Close()
+       }
+       e.WriteByte('"')
+}
+
+// sliceEncoder just wraps an arrayEncoder, checking to make sure the value isn't nil.
+type sliceEncoder struct {
+       arrayEnc encoderFunc
+}
+
+func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+       if v.IsNil() {
+               e.WriteString("null")
+               return
+       }
+       se.arrayEnc(e, v, false)
+}
+
+func newSliceEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       // Byte slices get special treatment; arrays don't.
+       if vx.Type().Elem().Kind() == reflect.Uint8 {
+               return encodeByteSlice
+       }
+       enc := &sliceEncoder{newArrayEncoder(t, vx)}
+       return enc.encode
+}
+
+type arrayEncoder struct {
+       elemEnc encoderFunc
+}
+
+func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+       e.WriteByte('[')
+       n := v.Len()
+       for i := 0; i < n; i++ {
+               if i > 0 {
+                       e.WriteByte(',')
                }
-               e.reflectValue(v.Elem())
+               ae.elemEnc(e, v.Index(i), false)
+       }
+       e.WriteByte(']')
+}
 
-       default:
-               e.error(&UnsupportedTypeError{v.Type()})
+func newArrayEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       enc := &arrayEncoder{typeEncoder(t.Elem(), reflect.Value{})}
+       return enc.encode
+}
+
+type ptrEncoder struct {
+       elemEnc encoderFunc
+}
+
+func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
+       if v.IsNil() {
+               e.WriteString("null")
+               return
        }
-       return
+       pe.elemEnc(e, v.Elem(), false)
+}
+
+func newPtrEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
+       enc := &ptrEncoder{typeEncoder(t.Elem(), reflect.Value{})}
+       return enc.encode
 }
 
 func isValidTag(s string) bool {