]> Cypherpunks repositories - gostls13.git/commitdiff
gob: allow transmission of things other than structs at the top level.
authorRob Pike <r@golang.org>
Mon, 28 Jun 2010 21:09:47 +0000 (14:09 -0700)
committerRob Pike <r@golang.org>
Mon, 28 Jun 2010 21:09:47 +0000 (14:09 -0700)
also fix a bug handling nil maps: before, would needlessly send empty map

R=rsc
CC=golang-dev
https://golang.org/cl/1739043

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/decoder.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go

index d8bdf2d2f467e491fa7792ef7f559fdee509fcd7..49a13e84d71299d4032ac0b11bc5ae1260b20911 100644 (file)
@@ -1039,7 +1039,7 @@ func TestInvalidField(t *testing.T) {
 type Indirect struct {
        a ***[3]int
        s ***[]int
-       m ***map[string]int
+       m ****map[string]int
 }
 
 type Direct struct {
@@ -1059,10 +1059,11 @@ func TestIndirectSliceMapArray(t *testing.T) {
        *i.s = new(*[]int)
        **i.s = new([]int)
        ***i.s = []int{4, 5, 6}
-       i.m = new(**map[string]int)
-       *i.m = new(*map[string]int)
-       **i.m = new(map[string]int)
-       ***i.m = map[string]int{"one": 1, "two": 2, "three": 3}
+       i.m = new(***map[string]int)
+       *i.m = new(**map[string]int)
+       **i.m = new(*map[string]int)
+       ***i.m = new(map[string]int)
+       ****i.m = map[string]int{"one": 1, "two": 2, "three": 3}
        b := new(bytes.Buffer)
        NewEncoder(b).Encode(i)
        dec := NewDecoder(b)
@@ -1093,12 +1094,12 @@ func TestIndirectSliceMapArray(t *testing.T) {
                t.Error("error: ", err)
        }
        if len(***i.a) != 3 || (***i.a)[0] != 11 || (***i.a)[1] != 22 || (***i.a)[2] != 33 {
-               t.Errorf("indirect to direct: ***i.a is %v not %v", ***i.a, d.a)
+               t.Errorf("direct to indirect: ***i.a is %v not %v", ***i.a, d.a)
        }
        if len(***i.s) != 3 || (***i.s)[0] != 44 || (***i.s)[1] != 55 || (***i.s)[2] != 66 {
-               t.Errorf("indirect to direct: ***i.s is %v not %v", ***i.s, ***i.s)
+               t.Errorf("direct to indirect: ***i.s is %v not %v", ***i.s, ***i.s)
        }
-       if len(***i.m) != 3 || (***i.m)["four"] != 4 || (***i.m)["five"] != 5 || (***i.m)["six"] != 6 {
-               t.Errorf("indirect to direct: ***i.m is %v not %v", ***i.m, d.m)
+       if len(****i.m) != 3 || (****i.m)["four"] != 4 || (****i.m)["five"] != 5 || (****i.m)["six"] != 6 {
+               t.Errorf("direct to indirect: ****i.m is %v not %v", ****i.m, d.m)
        }
 }
index 0dbf81488770ff9a774d2304cde296e3f47167b0..51e439900035929d64c45b785490b53a87640a5b 100644 (file)
@@ -13,15 +13,13 @@ import (
        "math"
        "os"
        "reflect"
-       "runtime"
        "unsafe"
 )
 
 var (
-       errBadUint   = os.ErrorString("gob: encoded unsigned integer out of range")
-       errBadType   = os.ErrorString("gob: unknown type id or corrupted data")
-       errRange     = os.ErrorString("gob: internal error: field numbers out of bounds")
-       errNotStruct = os.ErrorString("gob: TODO: can only handle structs")
+       errBadUint = os.ErrorString("gob: encoded unsigned integer out of range")
+       errBadType = os.ErrorString("gob: unknown type id or corrupted data")
+       errRange   = os.ErrorString("gob: internal error: field numbers out of bounds")
 )
 
 // The global execution state of an instance of the decoder.
@@ -389,18 +387,44 @@ type decEngine struct {
        numInstr int // the number of active instructions
 }
 
-func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error {
-       if indir > 0 {
-               up := unsafe.Pointer(p)
-               if indir > 1 {
-                       up = decIndirect(up, indir)
-               }
-               if *(*unsafe.Pointer)(up) == nil {
-                       // Allocate object.
-                       *(*unsafe.Pointer)(up) = unsafe.New((*runtime.StructType)(unsafe.Pointer(rtyp)))
-               }
-               p = *(*uintptr)(up)
+// allocate makes sure storage is available for an object of underlying type rtyp
+// that is indir levels of indirection through p.
+func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
+       if indir == 0 {
+               return p
+       }
+       up := unsafe.Pointer(p)
+       if indir > 1 {
+               up = decIndirect(up, indir)
        }
+       if *(*unsafe.Pointer)(up) == nil {
+               // Allocate object.
+               *(*unsafe.Pointer)(up) = unsafe.New(rtyp)
+       }
+       return *(*uintptr)(up)
+}
+
+func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintptr, indir int) os.Error {
+       p = allocate(rtyp, p, indir)
+       state := newDecodeState(b)
+       state.fieldnum = singletonField
+       basep := p
+       delta := int(decodeUint(state))
+       if delta != 0 {
+               state.err = os.ErrorString("gob decode: corrupted data: non-zero delta for singleton")
+               return state.err
+       }
+       instr := &engine.instr[singletonField]
+       ptr := unsafe.Pointer(basep) // offset will be zero
+       if instr.indir > 1 {
+               ptr = decIndirect(ptr, instr.indir)
+       }
+       instr.op(instr, state, ptr)
+       return state.err
+}
+
+func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error {
+       p = allocate(rtyp, p, indir)
        state := newDecodeState(b)
        state.fieldnum = -1
        basep := p
@@ -468,12 +492,7 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint
 
 func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error {
        if indir > 0 {
-               up := unsafe.Pointer(p)
-               if *(*unsafe.Pointer)(up) == nil {
-                       // Allocate object.
-                       *(*unsafe.Pointer)(up) = unsafe.New(atyp)
-               }
-               p = *(*uintptr)(up)
+               p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect
        }
        if n := decodeUint(state); n != uint64(length) {
                return os.ErrorString("gob: length mismatch in decodeArray")
@@ -493,12 +512,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o
 
 func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error {
        if indir > 0 {
-               up := unsafe.Pointer(p)
-               if *(*unsafe.Pointer)(up) == nil {
-                       // Allocate object.
-                       *(*unsafe.Pointer)(up) = unsafe.New(mtyp)
-               }
-               p = *(*uintptr)(up)
+               p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect
        }
        up := unsafe.Pointer(p)
        if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime
@@ -806,18 +820,34 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
        return true
 }
 
+func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
+       engine = new(decEngine)
+       engine.instr = make([]decInstr, 1) // one item
+       name := rt.String()                // best we can do
+       if !dec.compatibleType(rt, remoteId) {
+               return nil, os.ErrorString("gob: wrong type received for local value " + name)
+       }
+       op, indir, err := dec.decOpFor(remoteId, rt, name)
+       if err != nil {
+               return nil, err
+       }
+       ovfl := os.ErrorString(`value for "` + name + `" out of range`)
+       engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
+       engine.numInstr = 1
+       return
+}
+
 func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
-       srt, ok1 := rt.(*reflect.StructType)
+       srt, ok := rt.(*reflect.StructType)
+       if !ok {
+               return dec.compileSingle(remoteId, rt)
+       }
        var wireStruct *structType
        // Builtin types can come from global pool; the rest must be defined by the decoder
        if t, ok := builtinIdToType[remoteId]; ok {
                wireStruct = t.(*structType)
        } else {
-               w, ok2 := dec.wireType[remoteId]
-               if !ok1 || !ok2 {
-                       return nil, errNotStruct
-               }
-               wireStruct = w.structT
+               wireStruct = dec.wireType[remoteId].structT
        }
        engine = new(decEngine)
        engine.instr = make([]decInstr, len(wireStruct.field))
@@ -891,20 +921,19 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
 func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
        // Dereference down to the underlying struct type.
        rt, indir := indirect(reflect.Typeof(e))
-       st, ok := rt.(*reflect.StructType)
-       if !ok {
-               return os.ErrorString("gob: decode can't handle " + rt.String())
-       }
        enginePtr, err := dec.getDecEnginePtr(wireId, rt)
        if err != nil {
                return err
        }
        engine := *enginePtr
-       if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 {
-               name := rt.Name()
-               return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
+       if st, ok := rt.(*reflect.StructType); ok {
+               if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 {
+                       name := rt.Name()
+                       return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
+               }
+               return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir)
        }
-       return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir)
+       return decodeSingle(engine, rt, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir)
 }
 
 func init() {
index 90dc2e34c864000ca2a41112fe698312fd873a59..caec517121c98d072164fb90df08cc816b228549 100644 (file)
@@ -108,8 +108,9 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
                }
 
                // No, it's a value.
-               // Make sure the type has been defined already.
-               if dec.wireType[id] == nil {
+               // Make sure the type has been defined already or is a builtin type (for
+               // top-level singleton values).
+               if dec.wireType[id] == nil && builtinIdToType[id] == nil {
                        dec.state.err = errBadType
                        break
                }
index b48c1f698a9a7cd172c54a0d50ec5296f26ae6b8..a7d44ecc2b0ca5549ed26471a6a57bf32595358a 100644 (file)
@@ -271,7 +271,7 @@ const uint64Size = unsafe.Sizeof(uint64(0))
 type encoderState struct {
        b        *bytes.Buffer
        err      os.Error             // error encountered during encoding.
-       inArray  bool                 // encoding an array element or map key/value pair
+       sendZero bool                 // encoding an array element or map key/value pair; send zero values
        fieldnum int                  // the last field number written.
        buf      [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
 }
@@ -352,7 +352,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
 
 func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
        b := *(*bool)(p)
-       if b || state.inArray {
+       if b || state.sendZero {
                state.update(i)
                if b {
                        encodeUint(state, 1)
@@ -364,7 +364,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeInt(state, v)
        }
@@ -372,7 +372,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeUint(state, v)
        }
@@ -380,7 +380,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int8)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeInt(state, v)
        }
@@ -388,7 +388,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint8)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeUint(state, v)
        }
@@ -396,7 +396,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int16)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeInt(state, v)
        }
@@ -404,7 +404,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint16)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeUint(state, v)
        }
@@ -412,7 +412,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int32)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeInt(state, v)
        }
@@ -420,7 +420,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint32)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeUint(state, v)
        }
@@ -428,7 +428,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := *(*int64)(p)
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeInt(state, v)
        }
@@ -436,7 +436,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := *(*uint64)(p)
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeUint(state, v)
        }
@@ -444,7 +444,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uintptr)(p))
-       if v != 0 || state.inArray {
+       if v != 0 || state.sendZero {
                state.update(i)
                encodeUint(state, v)
        }
@@ -468,7 +468,7 @@ func floatBits(f float64) uint64 {
 
 func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) {
        f := *(*float)(p)
-       if f != 0 || state.inArray {
+       if f != 0 || state.sendZero {
                v := floatBits(float64(f))
                state.update(i)
                encodeUint(state, v)
@@ -477,7 +477,7 @@ func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        f := *(*float32)(p)
-       if f != 0 || state.inArray {
+       if f != 0 || state.sendZero {
                v := floatBits(float64(f))
                state.update(i)
                encodeUint(state, v)
@@ -486,7 +486,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        f := *(*float64)(p)
-       if f != 0 || state.inArray {
+       if f != 0 || state.sendZero {
                state.update(i)
                v := floatBits(f)
                encodeUint(state, v)
@@ -496,7 +496,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
 // Complex numbers are just a pair of floating-point numbers, real part first.
 func encComplex(i *encInstr, state *encoderState, p unsafe.Pointer) {
        c := *(*complex)(p)
-       if c != 0+0i || state.inArray {
+       if c != 0+0i || state.sendZero {
                rpart := floatBits(float64(real(c)))
                ipart := floatBits(float64(imag(c)))
                state.update(i)
@@ -507,7 +507,7 @@ func encComplex(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        c := *(*complex64)(p)
-       if c != 0+0i || state.inArray {
+       if c != 0+0i || state.sendZero {
                rpart := floatBits(float64(real(c)))
                ipart := floatBits(float64(imag(c)))
                state.update(i)
@@ -518,7 +518,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
 
 func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
        c := *(*complex128)(p)
-       if c != 0+0i || state.inArray {
+       if c != 0+0i || state.sendZero {
                rpart := floatBits(real(c))
                ipart := floatBits(imag(c))
                state.update(i)
@@ -530,7 +530,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
 // Byte arrays are encoded as an unsigned count followed by the raw bytes.
 func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
        b := *(*[]byte)(p)
-       if len(b) > 0 || state.inArray {
+       if len(b) > 0 || state.sendZero {
                state.update(i)
                encodeUint(state, uint64(len(b)))
                state.b.Write(b)
@@ -540,7 +540,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
 // Strings are encoded as an unsigned count followed by the raw bytes.
 func encString(i *encInstr, state *encoderState, p unsafe.Pointer) {
        s := *(*string)(p)
-       if len(s) > 0 || state.inArray {
+       if len(s) > 0 || state.sendZero {
                state.update(i)
                encodeUint(state, uint64(len(s)))
                io.WriteString(state.b, s)
@@ -560,6 +560,26 @@ type encEngine struct {
        instr []encInstr
 }
 
+const singletonField = 0
+
+func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
+       state := new(encoderState)
+       state.b = b
+       state.fieldnum = singletonField
+       // There is no surrounding struct to frame the transmission, so we must
+       // generate data even if the item is zero.  To do this, set sendZero.
+       state.sendZero = true
+       instr := &engine.instr[singletonField]
+       p := unsafe.Pointer(basep) // offset will be zero
+       if instr.indir > 0 {
+               if p = encIndirect(p, instr.indir); p == nil {
+                       return nil
+               }
+       }
+       instr.op(instr, state, p)
+       return state.err
+}
+
 func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
        state := new(encoderState)
        state.b = b
@@ -584,7 +604,7 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndi
        state := new(encoderState)
        state.b = b
        state.fieldnum = -1
-       state.inArray = true
+       state.sendZero = true
        encodeUint(state, uint64(length))
        for i := 0; i < length && state.err == nil; i++ {
                elemp := p
@@ -607,22 +627,17 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
                v = reflect.Indirect(v)
        }
        if v == nil {
-               state.err = os.ErrorString("gob: encodeMap: nil element")
+               state.err = os.ErrorString("gob: encodeReflectValue: nil element")
                return
        }
        op(nil, state, unsafe.Pointer(v.Addr()))
 }
 
-func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
+func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
        state := new(encoderState)
        state.b = b
        state.fieldnum = -1
-       state.inArray = true
-       // Maps cannot be accessed by moving addresses around the way
-       // that slices etc. can.  We must recover a full reflection value for
-       // the iteration.
-       v := reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer((p))))
-       mv := reflect.Indirect(v).(*reflect.MapValue)
+       state.sendZero = true
        keys := mv.Keys()
        encodeUint(state, uint64(len(keys)))
        for _, key := range keys {
@@ -694,6 +709,10 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                                return nil, 0, err
                        }
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
+                               slice := (*reflect.SliceHeader)(p)
+                               if slice.Len == 0 {
+                                       return
+                               }
                                state.update(i)
                                state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
                        }
@@ -707,8 +726,16 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                                return nil, 0, err
                        }
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
+                               // Maps cannot be accessed by moving addresses around the way
+                               // that slices etc. can.  We must recover a full reflection value for
+                               // the iteration.
+                               v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p))))
+                               mv := reflect.Indirect(v).(*reflect.MapValue)
+                               if mv.Len() == 0 {
+                                       return
+                               }
                                state.update(i)
-                               state.err = encodeMap(state.b, typ, uintptr(p), keyOp, elemOp, keyIndir, elemIndir)
+                               state.err = encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir)
                        }
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
@@ -732,21 +759,27 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
 
 // The local Type was compiled from the actual value, so we know it's compatible.
 func compileEnc(rt reflect.Type) (*encEngine, os.Error) {
-       srt, ok := rt.(*reflect.StructType)
-       if !ok {
-               panic("can't happen: non-struct")
-       }
+       srt, isStruct := rt.(*reflect.StructType)
        engine := new(encEngine)
-       engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator
-       for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ {
-               f := srt.Field(fieldnum)
-               op, indir, err := encOpFor(f.Type)
+       if isStruct {
+               engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator
+               for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ {
+                       f := srt.Field(fieldnum)
+                       op, indir, err := encOpFor(f.Type)
+                       if err != nil {
+                               return nil, err
+                       }
+                       engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)}
+               }
+               engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0}
+       } else {
+               engine.instr = make([]encInstr, 1)
+               op, indir, err := encOpFor(rt)
                if err != nil {
                        return nil, err
                }
-               engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)}
+               engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero
        }
-       engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0}
        return engine, nil
 }
 
@@ -772,14 +805,14 @@ func encode(b *bytes.Buffer, e interface{}) os.Error {
        for i := 0; i < indir; i++ {
                v = reflect.Indirect(v)
        }
-       if _, ok := v.(*reflect.StructValue); !ok {
-               return os.ErrorString("gob: encode can't handle " + v.Type().String())
-       }
        typeLock.Lock()
        engine, err := getEncEngine(rt)
        typeLock.Unlock()
        if err != nil {
                return err
        }
-       return encodeStruct(engine, b, v.Addr())
+       if _, ok := v.(*reflect.StructValue); ok {
+               return encodeStruct(engine, b, v.Addr())
+       }
+       return encodeSingle(engine, b, v.Addr())
 }
index e24c18d2069f7c615d8fcb50233cb6f8f0860bd6..28cf6f6e0cf022aaccc3da347869c5bdf320eba7 100644 (file)
@@ -68,7 +68,7 @@ func (enc *Encoder) send() {
        }
 }
 
-func (enc *Encoder) sendType(origt reflect.Type) {
+func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
        // Drill down to the base type.
        rt, _ := indirect(origt)
 
@@ -147,11 +147,6 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
 
        enc.state.err = nil
        rt, _ := indirect(reflect.Typeof(e))
-       // Must be a struct
-       if _, ok := rt.(*reflect.StructType); !ok {
-               enc.badType(rt)
-               return enc.state.err
-       }
 
        // Sanity check only: encoder should never come in with data present.
        if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
@@ -163,10 +158,23 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
        // First, have we already sent this type?
        if _, alreadySent := enc.sent[rt]; !alreadySent {
                // No, so send it.
-               enc.sendType(rt)
+               sent := enc.sendType(rt)
                if enc.state.err != nil {
                        return enc.state.err
                }
+               // If the type info has still not been transmitted, it means we have
+               // a singleton basic type (int, []byte etc.) at top level.  We don't
+               // need to send the type info but we do need to update enc.sent.
+               if !sent {
+                       typeLock.Lock()
+                       info, err := getTypeInfo(rt)
+                       typeLock.Unlock()
+                       if err != nil {
+                               enc.setError(err)
+                               return err
+                       }
+                       enc.sent[rt] = info.id
+               }
        }
 
        // Identify the type of this top-level value.
index 4250b8a9d7765213aed3474155e9a1b44cc4fb4e..b578cd0f8717058f594e1cf5216ba2ec6f5b4ee1 100644 (file)
@@ -131,17 +131,10 @@ func TestBadData(t *testing.T) {
        corruptDataCheck("\x03now is the time for all good men", errBadType, t)
 }
 
-// Types not supported by the Encoder (only structs work at the top level).
-// Basic types work implicitly.
+// Types not supported by the Encoder.
 var unsupportedValues = []interface{}{
-       3,
-       "hi",
-       7.2,
-       []int{1, 2, 3},
-       [3]int{1, 2, 3},
        make(chan int),
        func(a int) bool { return true },
-       make(map[string]int),
        new(interface{}),
 }
 
@@ -275,3 +268,59 @@ func TestDefaultsInArray(t *testing.T) {
                t.Error(err)
        }
 }
+
+var testInt int
+var testFloat32 float32
+var testString string
+var testSlice []string
+var testMap map[string]int
+
+type SingleTest struct {
+       in  interface{}
+       out interface{}
+       err string
+}
+
+var singleTests = []SingleTest{
+       SingleTest{17, &testInt, ""},
+       SingleTest{float32(17.5), &testFloat32, ""},
+       SingleTest{"bike shed", &testString, ""},
+       SingleTest{[]string{"bike", "shed", "paint", "color"}, &testSlice, ""},
+       SingleTest{map[string]int{"seven": 7, "twelve": 12}, &testMap, ""},
+
+       // Decode errors
+       SingleTest{172, &testFloat32, "wrong type"},
+}
+
+func TestSingletons(t *testing.T) {
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       dec := NewDecoder(b)
+       for _, test := range singleTests {
+               b.Reset()
+               err := enc.Encode(test.in)
+               if err != nil {
+                       t.Errorf("error encoding %v: %s", test.in, err)
+                       continue
+               }
+               err = dec.Decode(test.out)
+               switch {
+               case err != nil && test.err == "":
+                       t.Errorf("error decoding %v: %s", test.in, err)
+                       continue
+               case err == nil && test.err != "":
+                       t.Errorf("expected error decoding %v: %s", test.in, test.err)
+                       continue
+               case err != nil && test.err != "":
+                       if strings.Index(err.String(), test.err) < 0 {
+                               t.Errorf("wrong error decoding %v: wanted %s, got %v", test.in, test.err, err)
+                       }
+                       continue
+               }
+               // Get rid of the pointer in the rhs
+               val := reflect.NewValue(test.out).(*reflect.PtrValue).Elem().Interface()
+               if !reflect.DeepEqual(test.in, val) {
+                       t.Errorf("decoding int: expected %v got %v", test.in, val)
+               }
+       }
+}