]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/gob: custom array/slice decoders
authorRob Pike <r@golang.org>
Fri, 17 Oct 2014 19:37:41 +0000 (12:37 -0700)
committerRob Pike <r@golang.org>
Fri, 17 Oct 2014 19:37:41 +0000 (12:37 -0700)
Use go generate to write better loops for decoding arrays,
just as we did for encoding. It doesn't help as much,
relatively speaking, but it's still noticeable.

benchmark                          old ns/op     new ns/op     delta
BenchmarkDecodeComplex128Slice     202348        184529        -8.81%
BenchmarkDecodeFloat64Slice        135800        120979        -10.91%
BenchmarkDecodeInt32Slice          121200        105149        -13.24%
BenchmarkDecodeStringSlice         288129        278214        -3.44%

LGTM=rsc
R=rsc
CC=golang-codereviews
https://golang.org/cl/154420044

src/encoding/gob/dec_helpers.go [new file with mode: 0644]
src/encoding/gob/decgen.go [new file with mode: 0644]
src/encoding/gob/decode.go
src/encoding/gob/enc_helpers.go
src/encoding/gob/encgen.go
src/encoding/gob/encode.go
src/encoding/gob/timing_test.go

diff --git a/src/encoding/gob/dec_helpers.go b/src/encoding/gob/dec_helpers.go
new file mode 100644 (file)
index 0000000..ae59ef0
--- /dev/null
@@ -0,0 +1,468 @@
+// Created by decgen --output dec_helpers.go; DO NOT EDIT
+
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+       "math"
+       "reflect"
+)
+
+var decArrayHelper = map[reflect.Kind]decHelper{
+       reflect.Bool:       decBoolArray,
+       reflect.Complex64:  decComplex64Array,
+       reflect.Complex128: decComplex128Array,
+       reflect.Float32:    decFloat32Array,
+       reflect.Float64:    decFloat64Array,
+       reflect.Int:        decIntArray,
+       reflect.Int16:      decInt16Array,
+       reflect.Int32:      decInt32Array,
+       reflect.Int64:      decInt64Array,
+       reflect.Int8:       decInt8Array,
+       reflect.String:     decStringArray,
+       reflect.Uint:       decUintArray,
+       reflect.Uint16:     decUint16Array,
+       reflect.Uint32:     decUint32Array,
+       reflect.Uint64:     decUint64Array,
+       reflect.Uintptr:    decUintptrArray,
+}
+
+var decSliceHelper = map[reflect.Kind]decHelper{
+       reflect.Bool:       decBoolSlice,
+       reflect.Complex64:  decComplex64Slice,
+       reflect.Complex128: decComplex128Slice,
+       reflect.Float32:    decFloat32Slice,
+       reflect.Float64:    decFloat64Slice,
+       reflect.Int:        decIntSlice,
+       reflect.Int16:      decInt16Slice,
+       reflect.Int32:      decInt32Slice,
+       reflect.Int64:      decInt64Slice,
+       reflect.Int8:       decInt8Slice,
+       reflect.String:     decStringSlice,
+       reflect.Uint:       decUintSlice,
+       reflect.Uint16:     decUint16Slice,
+       reflect.Uint32:     decUint32Slice,
+       reflect.Uint64:     decUint64Slice,
+       reflect.Uintptr:    decUintptrSlice,
+}
+
+func decBoolArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decBoolSlice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decBoolSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]bool)
+       if !ok {
+               // It is kind bool but not type bool. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding bool array or slice: length exceeds input size (%!d(string=Bool) elements)", length)
+               }
+               slice[i] = state.decodeUint() != 0
+       }
+       return true
+}
+
+func decComplex64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decComplex64Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decComplex64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]complex64)
+       if !ok {
+               // It is kind complex64 but not type complex64. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding complex64 array or slice: length exceeds input size (%!d(string=Complex64) elements)", length)
+               }
+               real := float32FromBits(state.decodeUint(), ovfl)
+               imag := float32FromBits(state.decodeUint(), ovfl)
+               slice[i] = complex(float32(real), float32(imag))
+       }
+       return true
+}
+
+func decComplex128Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decComplex128Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decComplex128Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]complex128)
+       if !ok {
+               // It is kind complex128 but not type complex128. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding complex128 array or slice: length exceeds input size (%!d(string=Complex128) elements)", length)
+               }
+               real := float64FromBits(state.decodeUint())
+               imag := float64FromBits(state.decodeUint())
+               slice[i] = complex(real, imag)
+       }
+       return true
+}
+
+func decFloat32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decFloat32Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decFloat32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]float32)
+       if !ok {
+               // It is kind float32 but not type float32. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding float32 array or slice: length exceeds input size (%!d(string=Float32) elements)", length)
+               }
+               slice[i] = float32(float32FromBits(state.decodeUint(), ovfl))
+       }
+       return true
+}
+
+func decFloat64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decFloat64Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decFloat64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]float64)
+       if !ok {
+               // It is kind float64 but not type float64. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding float64 array or slice: length exceeds input size (%!d(string=Float64) elements)", length)
+               }
+               slice[i] = float64FromBits(state.decodeUint())
+       }
+       return true
+}
+
+func decIntArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decIntSlice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decIntSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]int)
+       if !ok {
+               // It is kind int but not type int. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding int array or slice: length exceeds input size (%!d(string=Int) elements)", length)
+               }
+               x := state.decodeInt()
+               // MinInt and MaxInt
+               if x < ^int64(^uint(0)>>1) || int64(^uint(0)>>1) < x {
+                       error_(ovfl)
+               }
+               slice[i] = int(x)
+       }
+       return true
+}
+
+func decInt16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decInt16Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decInt16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]int16)
+       if !ok {
+               // It is kind int16 but not type int16. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding int16 array or slice: length exceeds input size (%!d(string=Int16) elements)", length)
+               }
+               x := state.decodeInt()
+               if x < math.MinInt16 || math.MaxInt16 < x {
+                       error_(ovfl)
+               }
+               slice[i] = int16(x)
+       }
+       return true
+}
+
+func decInt32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decInt32Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decInt32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]int32)
+       if !ok {
+               // It is kind int32 but not type int32. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding int32 array or slice: length exceeds input size (%!d(string=Int32) elements)", length)
+               }
+               x := state.decodeInt()
+               if x < math.MinInt32 || math.MaxInt32 < x {
+                       error_(ovfl)
+               }
+               slice[i] = int32(x)
+       }
+       return true
+}
+
+func decInt64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decInt64Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decInt64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]int64)
+       if !ok {
+               // It is kind int64 but not type int64. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding int64 array or slice: length exceeds input size (%!d(string=Int64) elements)", length)
+               }
+               slice[i] = state.decodeInt()
+       }
+       return true
+}
+
+func decInt8Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decInt8Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decInt8Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]int8)
+       if !ok {
+               // It is kind int8 but not type int8. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding int8 array or slice: length exceeds input size (%!d(string=Int8) elements)", length)
+               }
+               x := state.decodeInt()
+               if x < math.MinInt8 || math.MaxInt8 < x {
+                       error_(ovfl)
+               }
+               slice[i] = int8(x)
+       }
+       return true
+}
+
+func decStringArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decStringSlice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decStringSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]string)
+       if !ok {
+               // It is kind string but not type string. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding string array or slice: length exceeds input size (%!d(string=String) elements)", length)
+               }
+               u := state.decodeUint()
+               n := int(u)
+               if n < 0 || uint64(n) != u || n > state.b.Len() {
+                       errorf("length of string exceeds input size (%d bytes)", u)
+               }
+               if n > state.b.Len() {
+                       errorf("string data too long for buffer: %d", n)
+               }
+               // Read the data.
+               data := make([]byte, n)
+               if _, err := state.b.Read(data); err != nil {
+                       errorf("error decoding string: %s", err)
+               }
+               slice[i] = string(data)
+       }
+       return true
+}
+
+func decUintArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decUintSlice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decUintSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]uint)
+       if !ok {
+               // It is kind uint but not type uint. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding uint array or slice: length exceeds input size (%!d(string=Uint) elements)", length)
+               }
+               x := state.decodeUint()
+               /*TODO if math.MaxUint32 < x {
+                       error_(ovfl)
+               }*/
+               slice[i] = uint(x)
+       }
+       return true
+}
+
+func decUint16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decUint16Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decUint16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]uint16)
+       if !ok {
+               // It is kind uint16 but not type uint16. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding uint16 array or slice: length exceeds input size (%!d(string=Uint16) elements)", length)
+               }
+               x := state.decodeUint()
+               if math.MaxUint16 < x {
+                       error_(ovfl)
+               }
+               slice[i] = uint16(x)
+       }
+       return true
+}
+
+func decUint32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decUint32Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decUint32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]uint32)
+       if !ok {
+               // It is kind uint32 but not type uint32. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding uint32 array or slice: length exceeds input size (%!d(string=Uint32) elements)", length)
+               }
+               x := state.decodeUint()
+               if math.MaxUint32 < x {
+                       error_(ovfl)
+               }
+               slice[i] = uint32(x)
+       }
+       return true
+}
+
+func decUint64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decUint64Slice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decUint64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]uint64)
+       if !ok {
+               // It is kind uint64 but not type uint64. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding uint64 array or slice: length exceeds input size (%!d(string=Uint64) elements)", length)
+               }
+               slice[i] = state.decodeUint()
+       }
+       return true
+}
+
+func decUintptrArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return decUintptrSlice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+
+func decUintptrSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]uintptr)
+       if !ok {
+               // It is kind uintptr but not type uintptr. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding uintptr array or slice: length exceeds input size (%!d(string=Uintptr) elements)", length)
+               }
+               x := state.decodeUint()
+               if uint64(^uintptr(0)) < x {
+                       error_(ovfl)
+               }
+               slice[i] = uintptr(x)
+       }
+       return true
+}
diff --git a/src/encoding/gob/decgen.go b/src/encoding/gob/decgen.go
new file mode 100644 (file)
index 0000000..1cd1fb0
--- /dev/null
@@ -0,0 +1,240 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+// encgen writes the helper functions for encoding. Intended to be
+// used with go generate; see the invocation in encode.go.
+
+// TODO: We could do more by being unsafe. Add a -unsafe flag?
+
+package main
+
+import (
+       "bytes"
+       "flag"
+       "fmt"
+       "go/format"
+       "log"
+       "os"
+)
+
+var output = flag.String("output", "dec_helpers.go", "file name to write")
+
+type Type struct {
+       lower   string
+       upper   string
+       decoder string
+}
+
+var types = []Type{
+       {
+               "bool",
+               "Bool",
+               `slice[i] = state.decodeUint() != 0`,
+       },
+       {
+               "complex64",
+               "Complex64",
+               `real := float32FromBits(state.decodeUint(), ovfl)
+               imag := float32FromBits(state.decodeUint(), ovfl)
+               slice[i] = complex(float32(real), float32(imag))`,
+       },
+       {
+               "complex128",
+               "Complex128",
+               `real := float64FromBits(state.decodeUint())
+               imag := float64FromBits(state.decodeUint())
+               slice[i] = complex(real, imag)`,
+       },
+       {
+               "float32",
+               "Float32",
+               `slice[i] = float32(float32FromBits(state.decodeUint(), ovfl))`,
+       },
+       {
+               "float64",
+               "Float64",
+               `slice[i] = float64FromBits(state.decodeUint())`,
+       },
+       {
+               "int",
+               "Int",
+               `x := state.decodeInt()
+               // MinInt and MaxInt
+               if x < ^int64(^uint(0)>>1) || int64(^uint(0)>>1) < x {
+                       error_(ovfl)
+               }
+               slice[i] = int(x)`,
+       },
+       {
+               "int16",
+               "Int16",
+               `x := state.decodeInt()
+               if x < math.MinInt16 || math.MaxInt16 < x {
+                       error_(ovfl)
+               }
+               slice[i] = int16(x)`,
+       },
+       {
+               "int32",
+               "Int32",
+               `x := state.decodeInt()
+               if x < math.MinInt32 || math.MaxInt32 < x {
+                       error_(ovfl)
+               }
+               slice[i] = int32(x)`,
+       },
+       {
+               "int64",
+               "Int64",
+               `slice[i] = state.decodeInt()`,
+       },
+       {
+               "int8",
+               "Int8",
+               `x := state.decodeInt()
+               if x < math.MinInt8 || math.MaxInt8 < x {
+                       error_(ovfl)
+               }
+               slice[i] = int8(x)`,
+       },
+       {
+               "string",
+               "String",
+               `u := state.decodeUint()
+               n := int(u)
+               if n < 0 || uint64(n) != u || n > state.b.Len() {
+                       errorf("length of string exceeds input size (%d bytes)", u)
+               }
+               if n > state.b.Len() {
+                       errorf("string data too long for buffer: %d", n)
+               }
+               // Read the data.
+               data := make([]byte, n)
+               if _, err := state.b.Read(data); err != nil {
+                       errorf("error decoding string: %s", err)
+               }
+               slice[i] = string(data)`,
+       },
+       {
+               "uint",
+               "Uint",
+               `x := state.decodeUint()
+               /*TODO if math.MaxUint32 < x {
+                       error_(ovfl)
+               }*/
+               slice[i] = uint(x)`,
+       },
+       {
+               "uint16",
+               "Uint16",
+               `x := state.decodeUint()
+               if math.MaxUint16 < x {
+                       error_(ovfl)
+               }
+               slice[i] = uint16(x)`,
+       },
+       {
+               "uint32",
+               "Uint32",
+               `x := state.decodeUint()
+               if math.MaxUint32 < x {
+                       error_(ovfl)
+               }
+               slice[i] = uint32(x)`,
+       },
+       {
+               "uint64",
+               "Uint64",
+               `slice[i] = state.decodeUint()`,
+       },
+       {
+               "uintptr",
+               "Uintptr",
+               `x := state.decodeUint()
+               if uint64(^uintptr(0)) < x {
+                       error_(ovfl)
+               }
+               slice[i] = uintptr(x)`,
+       },
+       // uint8 Handled separately.
+}
+
+func main() {
+       log.SetFlags(0)
+       log.SetPrefix("decgen: ")
+       flag.Parse()
+       if flag.NArg() != 0 {
+               log.Fatal("usage: decgen [--output filename]")
+       }
+       var b bytes.Buffer
+       fmt.Fprintf(&b, "// Created by decgen --output %s; DO NOT EDIT\n", *output)
+       fmt.Fprint(&b, header)
+       printMaps(&b, "Array")
+       fmt.Fprint(&b, "\n")
+       printMaps(&b, "Slice")
+       for _, t := range types {
+               fmt.Fprintf(&b, arrayHelper, t.lower, t.upper)
+               fmt.Fprintf(&b, sliceHelper, t.lower, t.upper, t.decoder)
+       }
+       source, err := format.Source(b.Bytes())
+       if err != nil {
+               log.Fatal("source format error:", err)
+       }
+       fd, err := os.Create(*output)
+       _, err = fd.Write(source)
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func printMaps(b *bytes.Buffer, upperClass string) {
+       fmt.Fprintf(b, "var dec%sHelper = map[reflect.Kind]decHelper{\n", upperClass)
+       for _, t := range types {
+               fmt.Fprintf(b, "reflect.%s: dec%s%s,\n", t.upper, t.upper, upperClass)
+       }
+       fmt.Fprintf(b, "}\n")
+}
+
+const header = `
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+       "math"
+       "reflect"
+)
+
+`
+
+const arrayHelper = `
+func dec%[2]sArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       // Can only slice if it is addressable.
+       if !v.CanAddr() {
+               return false
+       }
+       return dec%[2]sSlice(state, v.Slice(0, v.Len()), length, ovfl)
+}
+`
+
+const sliceHelper = `
+func dec%[2]sSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
+       slice, ok := v.Interface().([]%[1]s)
+       if !ok {
+               // It is kind %[1]s but not type %[1]s. TODO: We can handle this unsafely.
+               return false
+       }
+       for i := 0; i < length; i++ {
+               if state.b.Len() == 0 {
+                       errorf("decoding %[1]s array or slice: length exceeds input size (%d elements)", length)
+               }
+               %[3]s
+       }
+       return true
+}
+`
index 6a9213fb3cb6c7f8dcb21efafaee1de4006ca1d3..f44838e4cf44014ca759d87efee45522eb40c090 100644 (file)
@@ -2,6 +2,8 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
+//go:generate go run decgen.go -output dec_helpers.go
+
 package gob
 
 import (
@@ -19,6 +21,8 @@ var (
        errRange   = errors.New("gob: bad data: field numbers out of bounds")
 )
 
+type decHelper func(state *decoderState, v reflect.Value, length int, ovfl error) bool
+
 // decoderState is the execution state of an instance of the decoder. A new state
 // is created for nested objects.
 type decoderState struct {
@@ -257,7 +261,7 @@ func float64FromBits(u uint64) float64 {
 // number, and returns it. It's a helper function for float32 and complex64.
 // It returns a float64 because that's what reflection needs, but its return
 // value is known to be accurately representable in a float32.
-func float32FromBits(i *decInstr, u uint64) float64 {
+func float32FromBits(u uint64, ovfl error) float64 {
        v := float64FromBits(u)
        av := v
        if av < 0 {
@@ -265,7 +269,7 @@ func float32FromBits(i *decInstr, u uint64) float64 {
        }
        // +Inf is OK in both 32- and 64-bit floats.  Underflow is always OK.
        if math.MaxFloat32 < av && av <= math.MaxFloat64 {
-               error_(i.ovfl)
+               error_(ovfl)
        }
        return v
 }
@@ -273,7 +277,7 @@ func float32FromBits(i *decInstr, u uint64) float64 {
 // decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
 // number, and stores it in value.
 func decFloat32(i *decInstr, state *decoderState, value reflect.Value) {
-       value.SetFloat(float32FromBits(i, state.decodeUint()))
+       value.SetFloat(float32FromBits(state.decodeUint(), i.ovfl))
 }
 
 // decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point
@@ -286,8 +290,8 @@ func decFloat64(i *decInstr, state *decoderState, value reflect.Value) {
 // pair of floating point numbers, and stores them as a complex64 in value.
 // The real part comes first.
 func decComplex64(i *decInstr, state *decoderState, value reflect.Value) {
-       real := float32FromBits(i, state.decodeUint())
-       imag := float32FromBits(i, state.decodeUint())
+       real := float32FromBits(state.decodeUint(), i.ovfl)
+       imag := float32FromBits(state.decodeUint(), i.ovfl)
        value.SetComplex(complex(real, imag))
 }
 
@@ -450,7 +454,10 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) {
 }
 
 // decodeArrayHelper does the work for decoding arrays and slices.
-func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error) {
+func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) {
+       if helper != nil && helper(state, value, length, ovfl) {
+               return
+       }
        instr := &decInstr{elemOp, 0, nil, ovfl}
        isPtr := value.Type().Elem().Kind() == reflect.Ptr
        for i := 0; i < length; i++ {
@@ -468,11 +475,11 @@ func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value,
 // decodeArray decodes an array and stores it in value.
 // The length is an unsigned integer preceding the elements.  Even though the length is redundant
 // (it's part of the type), it's a useful check and is included in the encoding.
-func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error) {
+func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) {
        if n := state.decodeUint(); n != uint64(length) {
                errorf("length mismatch in decodeArray")
        }
-       dec.decodeArrayHelper(state, value, elemOp, length, ovfl)
+       dec.decodeArrayHelper(state, value, elemOp, length, ovfl, helper)
 }
 
 // decodeIntoValue is a helper for map decoding.
@@ -534,7 +541,7 @@ func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
 
 // decodeSlice decodes a slice and stores it in value.
 // Slices are encoded as an unsigned length followed by the elements.
-func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp decOp, ovfl error) {
+func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp decOp, ovfl error, helper decHelper) {
        u := state.decodeUint()
        typ := value.Type()
        size := uint64(typ.Elem().Size())
@@ -551,7 +558,7 @@ func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp
        } else {
                value.Set(value.Slice(0, n))
        }
-       dec.decodeArrayHelper(state, value, elemOp, n, ovfl)
+       dec.decodeArrayHelper(state, value, elemOp, n, ovfl, helper)
 }
 
 // ignoreSlice skips over the data for a slice value with no destination.
@@ -720,8 +727,9 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
                        elemId := dec.wireType[wireId].ArrayT.Elem
                        elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress)
                        ovfl := overflow(name)
+                       helper := decArrayHelper[t.Elem().Kind()]
                        op = func(i *decInstr, state *decoderState, value reflect.Value) {
-                               state.dec.decodeArray(t, state, value, *elemOp, t.Len(), ovfl)
+                               state.dec.decodeArray(t, state, value, *elemOp, t.Len(), ovfl, helper)
                        }
 
                case reflect.Map:
@@ -748,8 +756,9 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
                        }
                        elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress)
                        ovfl := overflow(name)
+                       helper := decSliceHelper[t.Elem().Kind()]
                        op = func(i *decInstr, state *decoderState, value reflect.Value) {
-                               state.dec.decodeSlice(state, value, *elemOp, ovfl)
+                               state.dec.decodeSlice(state, value, *elemOp, ovfl, helper)
                        }
 
                case reflect.Struct:
index 1e6f30718472d57a33206901815b4c66f4a860c8..804e539d84d386f5b7d3854c8c7eee2954977950 100644 (file)
@@ -10,7 +10,7 @@ import (
        "reflect"
 )
 
-var arrayHelper = map[reflect.Kind]encHelper{
+var encArrayHelper = map[reflect.Kind]encHelper{
        reflect.Bool:       encBoolArray,
        reflect.Complex64:  encComplex64Array,
        reflect.Complex128: encComplex128Array,
@@ -29,7 +29,7 @@ var arrayHelper = map[reflect.Kind]encHelper{
        reflect.Uintptr:    encUintptrArray,
 }
 
-var sliceHelper = map[reflect.Kind]encHelper{
+var encSliceHelper = map[reflect.Kind]encHelper{
        reflect.Bool:       encBoolSlice,
        reflect.Complex64:  encComplex64Slice,
        reflect.Complex128: encComplex128Slice,
index fa500e3dab3ee02089663be3bf14636cd7fc0100..efdd9282921da612c2d43afc45bb79f0742fcc42 100644 (file)
@@ -144,7 +144,7 @@ var types = []Type{
 
 func main() {
        log.SetFlags(0)
-       log.SetPrefix("helpergen: ")
+       log.SetPrefix("encgen: ")
        flag.Parse()
        if flag.NArg() != 0 {
                log.Fatal("usage: encgen [--output filename]")
@@ -152,9 +152,9 @@ func main() {
        var b bytes.Buffer
        fmt.Fprintf(&b, "// Created by encgen --output %s; DO NOT EDIT\n", *output)
        fmt.Fprint(&b, header)
-       printMaps(&b, "array", "Array")
+       printMaps(&b, "Array")
        fmt.Fprint(&b, "\n")
-       printMaps(&b, "slice", "Slice")
+       printMaps(&b, "Slice")
        for _, t := range types {
                fmt.Fprintf(&b, arrayHelper, t.lower, t.upper)
                fmt.Fprintf(&b, sliceHelper, t.lower, t.upper, t.zero, t.encoder)
@@ -170,8 +170,8 @@ func main() {
        }
 }
 
-func printMaps(b *bytes.Buffer, lowerClass, upperClass string) {
-       fmt.Fprintf(b, "var %sHelper = map[reflect.Kind]encHelper{\n", lowerClass)
+func printMaps(b *bytes.Buffer, upperClass string) {
+       fmt.Fprintf(b, "var enc%sHelper = map[reflect.Kind]encHelper{\n", upperClass)
        for _, t := range types {
                fmt.Fprintf(b, "reflect.%s: enc%s%s,\n", t.upper, t.upper, upperClass)
        }
index 3b8d0b42712359bb3a1d1c48284e92afd1253793..3da848c851583012fb955f4c464da6e3500ef4ac 100644 (file)
@@ -508,7 +508,7 @@ func encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp, building map[
                        }
                        // Slices have a header; we decode it to find the underlying array.
                        elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
-                       helper := sliceHelper[t.Elem().Kind()]
+                       helper := encSliceHelper[t.Elem().Kind()]
                        op = func(i *encInstr, state *encoderState, slice reflect.Value) {
                                if !state.sendZero && slice.Len() == 0 {
                                        return
@@ -519,7 +519,7 @@ func encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp, building map[
                case reflect.Array:
                        // True arrays have size in the type.
                        elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
-                       helper := arrayHelper[t.Elem().Kind()]
+                       helper := encArrayHelper[t.Elem().Kind()]
                        op = func(i *encInstr, state *encoderState, array reflect.Value) {
                                state.update(i)
                                state.enc.encodeArray(state.b, array, *elemOp, elemIndir, array.Len(), helper)
index abfe936e835fb25821f110c6b8cbf4b305c77c32..940e5ad4126a7f025de32624d62eede74759548b 100644 (file)
@@ -132,13 +132,14 @@ func TestCountDecodeMallocs(t *testing.T) {
        }
 }
 
-func BenchmarkComplex128Slice(b *testing.B) {
+func BenchmarkEncodeComplex128Slice(b *testing.B) {
        var buf bytes.Buffer
        enc := NewEncoder(&buf)
        a := make([]complex128, 1000)
        for i := range a {
                a[i] = 1.2 + 3.4i
        }
+       b.ResetTimer()
        for i := 0; i < b.N; i++ {
                buf.Reset()
                err := enc.Encode(a)
@@ -148,13 +149,14 @@ func BenchmarkComplex128Slice(b *testing.B) {
        }
 }
 
-func BenchmarkInt32Slice(b *testing.B) {
+func BenchmarkEncodeFloat64Slice(b *testing.B) {
        var buf bytes.Buffer
        enc := NewEncoder(&buf)
-       a := make([]int32, 1000)
+       a := make([]float64, 1000)
        for i := range a {
-               a[i] = 1234
+               a[i] = 1.23e4
        }
+       b.ResetTimer()
        for i := 0; i < b.N; i++ {
                buf.Reset()
                err := enc.Encode(a)
@@ -164,13 +166,14 @@ func BenchmarkInt32Slice(b *testing.B) {
        }
 }
 
-func BenchmarkFloat64Slice(b *testing.B) {
+func BenchmarkEncodeInt32Slice(b *testing.B) {
        var buf bytes.Buffer
        enc := NewEncoder(&buf)
-       a := make([]float64, 1000)
+       a := make([]int32, 1000)
        for i := range a {
-               a[i] = 1.23e4
+               a[i] = 1234
        }
+       b.ResetTimer()
        for i := 0; i < b.N; i++ {
                buf.Reset()
                err := enc.Encode(a)
@@ -180,13 +183,14 @@ func BenchmarkFloat64Slice(b *testing.B) {
        }
 }
 
-func BenchmarkStringSlice(b *testing.B) {
+func BenchmarkEncodeStringSlice(b *testing.B) {
        var buf bytes.Buffer
        enc := NewEncoder(&buf)
        a := make([]string, 1000)
        for i := range a {
                a[i] = "now is the time"
        }
+       b.ResetTimer()
        for i := 0; i < b.N; i++ {
                buf.Reset()
                err := enc.Encode(a)
@@ -195,3 +199,127 @@ func BenchmarkStringSlice(b *testing.B) {
                }
        }
 }
+
+// benchmarkBuf is a read buffer we can reset
+type benchmarkBuf struct {
+       offset int
+       data   []byte
+}
+
+func (b *benchmarkBuf) Read(p []byte) (n int, err error) {
+       n = copy(p, b.data[b.offset:])
+       if n == 0 {
+               return 0, io.EOF
+       }
+       b.offset += n
+       return
+}
+
+func (b *benchmarkBuf) ReadByte() (c byte, err error) {
+       if b.offset >= len(b.data) {
+               return 0, io.EOF
+       }
+       c = b.data[b.offset]
+       b.offset++
+       return
+}
+
+func (b *benchmarkBuf) reset() {
+       b.offset = 0
+}
+
+func BenchmarkDecodeComplex128Slice(b *testing.B) {
+       var buf bytes.Buffer
+       enc := NewEncoder(&buf)
+       a := make([]complex128, 1000)
+       for i := range a {
+               a[i] = 1.2 + 3.4i
+       }
+       err := enc.Encode(a)
+       if err != nil {
+               b.Fatal(err)
+       }
+       x := make([]complex128, 1000)
+       bbuf := benchmarkBuf{data: buf.Bytes()}
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               bbuf.reset()
+               dec := NewDecoder(&bbuf)
+               err := dec.Decode(&x)
+               if err != nil {
+                       b.Fatal(i, err)
+               }
+       }
+}
+
+func BenchmarkDecodeFloat64Slice(b *testing.B) {
+       var buf bytes.Buffer
+       enc := NewEncoder(&buf)
+       a := make([]float64, 1000)
+       for i := range a {
+               a[i] = 1.23e4
+       }
+       err := enc.Encode(a)
+       if err != nil {
+               b.Fatal(err)
+       }
+       x := make([]float64, 1000)
+       bbuf := benchmarkBuf{data: buf.Bytes()}
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               bbuf.reset()
+               dec := NewDecoder(&bbuf)
+               err := dec.Decode(&x)
+               if err != nil {
+                       b.Fatal(i, err)
+               }
+       }
+}
+
+func BenchmarkDecodeInt32Slice(b *testing.B) {
+       var buf bytes.Buffer
+       enc := NewEncoder(&buf)
+       a := make([]int32, 1000)
+       for i := range a {
+               a[i] = 1234
+       }
+       err := enc.Encode(a)
+       if err != nil {
+               b.Fatal(err)
+       }
+       x := make([]int32, 1000)
+       bbuf := benchmarkBuf{data: buf.Bytes()}
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               bbuf.reset()
+               dec := NewDecoder(&bbuf)
+               err := dec.Decode(&x)
+               if err != nil {
+                       b.Fatal(i, err)
+               }
+       }
+}
+
+func BenchmarkDecodeStringSlice(b *testing.B) {
+       var buf bytes.Buffer
+       enc := NewEncoder(&buf)
+       a := make([]string, 1000)
+       for i := range a {
+               a[i] = "now is the time"
+       }
+       err := enc.Encode(a)
+       if err != nil {
+               b.Fatal(err)
+       }
+       x := make([]string, 1000)
+       bbuf := benchmarkBuf{data: buf.Bytes()}
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               bbuf.reset()
+               dec := NewDecoder(&bbuf)
+               err := dec.Decode(&x)
+               if err != nil {
+                       b.Fatal(i, err)
+               }
+       }
+}