]> Cypherpunks repositories - gostls13.git/commitdiff
Rework gobs to fix bad bug related to sharing of id's between encoder and decoder...
authorRob Pike <r@golang.org>
Tue, 17 Nov 2009 07:32:30 +0000 (23:32 -0800)
committerRob Pike <r@golang.org>
Tue, 17 Nov 2009 07:32:30 +0000 (23:32 -0800)
Fix is to move all decoder state into the decoder object.

Fixes #215.

R=rsc
CC=golang-dev
https://golang.org/cl/155077

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/decoder.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go
src/pkg/gob/type.go

index a1491a4a10e7028ea970025f3ec1c90adec5e8a9..c5d070155c61378b95a8931eb6c6e0c447ccef24 100644 (file)
@@ -37,7 +37,6 @@ var encodeT = []EncodeT{
        EncodeT{1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
 }
 
-
 // Test basic encode/decode routines for unsigned integers
 func TestUintCodec(t *testing.T) {
        b := new(bytes.Buffer);
@@ -592,9 +591,15 @@ func TestEndToEnd(t *testing.T) {
                t: &T2{"this is T2"},
        };
        b := new(bytes.Buffer);
-       encode(b, t1);
+       err := NewEncoder(b).Encode(t1);
+       if err != nil {
+               t.Error("encode:", err)
+       }
        var _t1 T1;
-       decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1);
+       err = NewDecoder(b).Decode(&_t1);
+       if err != nil {
+               t.Fatal("decode:", err)
+       }
        if !reflect.DeepEqual(t1, &_t1) {
                t.Errorf("encode expected %v got %v", *t1, _t1)
        }
@@ -610,8 +615,9 @@ func TestOverflow(t *testing.T) {
        }
        var it inputT;
        var err os.Error;
-       id := getTypeInfoNoError(reflect.Typeof(it)).id;
        b := new(bytes.Buffer);
+       enc := NewEncoder(b);
+       dec := NewDecoder(b);
 
        // int8
        b.Reset();
@@ -623,8 +629,8 @@ func TestOverflow(t *testing.T) {
                mini    int8;
        }
        var o1 outi8;
-       encode(b, it);
-       err = decode(b, id, &o1);
+       enc.Encode(it);
+       err = dec.Decode(&o1);
        if err == nil || err.String() != `value for "maxi" out of range` {
                t.Error("wrong overflow error for int8:", err)
        }
@@ -632,8 +638,8 @@ func TestOverflow(t *testing.T) {
                mini: math.MinInt8 - 1,
        };
        b.Reset();
-       encode(b, it);
-       err = decode(b, id, &o1);
+       enc.Encode(it);
+       err = dec.Decode(&o1);
        if err == nil || err.String() != `value for "mini" out of range` {
                t.Error("wrong underflow error for int8:", err)
        }
@@ -648,8 +654,8 @@ func TestOverflow(t *testing.T) {
                mini    int16;
        }
        var o2 outi16;
-       encode(b, it);
-       err = decode(b, id, &o2);
+       enc.Encode(it);
+       err = dec.Decode(&o2);
        if err == nil || err.String() != `value for "maxi" out of range` {
                t.Error("wrong overflow error for int16:", err)
        }
@@ -657,8 +663,8 @@ func TestOverflow(t *testing.T) {
                mini: math.MinInt16 - 1,
        };
        b.Reset();
-       encode(b, it);
-       err = decode(b, id, &o2);
+       enc.Encode(it);
+       err = dec.Decode(&o2);
        if err == nil || err.String() != `value for "mini" out of range` {
                t.Error("wrong underflow error for int16:", err)
        }
@@ -673,8 +679,8 @@ func TestOverflow(t *testing.T) {
                mini    int32;
        }
        var o3 outi32;
-       encode(b, it);
-       err = decode(b, id, &o3);
+       enc.Encode(it);
+       err = dec.Decode(&o3);
        if err == nil || err.String() != `value for "maxi" out of range` {
                t.Error("wrong overflow error for int32:", err)
        }
@@ -682,8 +688,8 @@ func TestOverflow(t *testing.T) {
                mini: math.MinInt32 - 1,
        };
        b.Reset();
-       encode(b, it);
-       err = decode(b, id, &o3);
+       enc.Encode(it);
+       err = dec.Decode(&o3);
        if err == nil || err.String() != `value for "mini" out of range` {
                t.Error("wrong underflow error for int32:", err)
        }
@@ -697,8 +703,8 @@ func TestOverflow(t *testing.T) {
                maxu uint8;
        }
        var o4 outu8;
-       encode(b, it);
-       err = decode(b, id, &o4);
+       enc.Encode(it);
+       err = dec.Decode(&o4);
        if err == nil || err.String() != `value for "maxu" out of range` {
                t.Error("wrong overflow error for uint8:", err)
        }
@@ -712,8 +718,8 @@ func TestOverflow(t *testing.T) {
                maxu uint16;
        }
        var o5 outu16;
-       encode(b, it);
-       err = decode(b, id, &o5);
+       enc.Encode(it);
+       err = dec.Decode(&o5);
        if err == nil || err.String() != `value for "maxu" out of range` {
                t.Error("wrong overflow error for uint16:", err)
        }
@@ -727,8 +733,8 @@ func TestOverflow(t *testing.T) {
                maxu uint32;
        }
        var o6 outu32;
-       encode(b, it);
-       err = decode(b, id, &o6);
+       enc.Encode(it);
+       err = dec.Decode(&o6);
        if err == nil || err.String() != `value for "maxu" out of range` {
                t.Error("wrong overflow error for uint32:", err)
        }
@@ -743,8 +749,8 @@ func TestOverflow(t *testing.T) {
                minf    float32;
        }
        var o7 outf32;
-       encode(b, it);
-       err = decode(b, id, &o7);
+       enc.Encode(it);
+       err = dec.Decode(&o7);
        if err == nil || err.String() != `value for "maxf" out of range` {
                t.Error("wrong overflow error for float32:", err)
        }
@@ -761,9 +767,13 @@ func TestNesting(t *testing.T) {
        rt.next = new(RT);
        rt.next.a = "level2";
        b := new(bytes.Buffer);
-       encode(b, rt);
+       NewEncoder(b).Encode(rt);
        var drt RT;
-       decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt);
+       dec := NewDecoder(b);
+       err := dec.Decode(&drt);
+       if err != nil {
+               t.Errorf("decoder error:", err)
+       }
        if drt.a != rt.a {
                t.Errorf("nesting: encode expected %v got %v", *rt, drt)
        }
@@ -809,10 +819,11 @@ func TestAutoIndirection(t *testing.T) {
        **t1.d = new(int);
        ***t1.d = 17777;
        b := new(bytes.Buffer);
-       encode(b, t1);
+       enc := NewEncoder(b);
+       enc.Encode(t1);
+       dec := NewDecoder(b);
        var t0 T0;
-       t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id;
-       decode(b, t0Id, &t0);
+       dec.Decode(&t0);
        if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
                t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0)
        }
@@ -830,9 +841,9 @@ func TestAutoIndirection(t *testing.T) {
        **t2.a = new(int);
        ***t2.a = 17;
        b.Reset();
-       encode(b, t2);
+       enc.Encode(t2);
        t0 = T0{};
-       decode(b, t0Id, &t0);
+       dec.Decode(&t0);
        if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
                t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0)
        }
@@ -840,32 +851,30 @@ func TestAutoIndirection(t *testing.T) {
        // Now transfer t0 into t1
        t0 = T0{17, 177, 1777, 17777};
        b.Reset();
-       encode(b, t0);
+       enc.Encode(t0);
        t1 = T1{};
-       t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id;
-       decode(b, t1Id, &t1);
+       dec.Decode(&t1);
        if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
                t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d)
        }
 
        // Now transfer t0 into t2
        b.Reset();
-       encode(b, t0);
+       enc.Encode(t0);
        t2 = T2{};
-       t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id;
-       decode(b, t2Id, &t2);
+       dec.Decode(&t2);
        if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
                t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
        }
 
        // Now do t2 again but without pre-allocated pointers.
        b.Reset();
-       encode(b, t0);
+       enc.Encode(t0);
        ***t2.a = 0;
        **t2.b = 0;
        *t2.c = 0;
        t2.d = 0;
-       decode(b, t2Id, &t2);
+       dec.Decode(&t2);
        if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
                t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
        }
@@ -889,11 +898,14 @@ func TestReorderedFields(t *testing.T) {
        rt0.b = "hello";
        rt0.c = 3.14159;
        b := new(bytes.Buffer);
-       encode(b, rt0);
-       rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id;
+       NewEncoder(b).Encode(rt0);
+       dec := NewDecoder(b);
        var rt1 RT1;
        // Wire type is RT0, local type is RT1.
-       decode(b, rt0Id, &rt1);
+       err := dec.Decode(&rt1);
+       if err != nil {
+               t.Error("decode error:", err)
+       }
        if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c {
                t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1)
        }
@@ -927,11 +939,11 @@ func TestIgnoredFields(t *testing.T) {
        it0.ignore_i = &RT1{3.1, "hi", 7, "hello"};
 
        b := new(bytes.Buffer);
-       encode(b, it0);
-       rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id;
+       NewEncoder(b).Encode(it0);
+       dec := NewDecoder(b);
        var rt1 RT1;
        // Wire type is IT0, local type is RT1.
-       err := decode(b, rt0Id, &rt1);
+       err := dec.Decode(&rt1);
        if err != nil {
                t.Error("error: ", err)
        }
index a9cdbe684d27f5ca30677323c3aaec3791cae2a7..b8400480cfe2f20d8e8c5eea9211e418f8dcc25e 100644 (file)
@@ -348,7 +348,7 @@ func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
 // Execution engine
 
 // The encoder engine is an array of instructions indexed by field number of the incoming
-// data.  It is executed with random access according to field number.
+// decoder.  It is executed with random access according to field number.
 type decEngine struct {
        instr           []decInstr;
        numInstr        int;    // the number of active instructions
@@ -515,7 +515,7 @@ var decIgnoreOpMap = map[typeId]decOp{
 
 // Return the decoding op for the base type under rt and
 // the indirection count to reach it.
-func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
+func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
        typ, indir := indirect(rt);
        op, ok := decOpMap[reflect.Typeof(typ)];
        if !ok {
@@ -528,7 +528,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
                                break;
                        }
                        elemId := wireId.gobType().(*sliceType).Elem;
-                       elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
+                       elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
                        if err != nil {
                                return nil, 0, err
                        }
@@ -540,7 +540,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
                case *reflect.ArrayType:
                        name = "element of " + name;
                        elemId := wireId.gobType().(*arrayType).Elem;
-                       elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
+                       elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
                        if err != nil {
                                return nil, 0, err
                        }
@@ -551,7 +551,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
 
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
-                       enginePtr, err := getDecEnginePtr(wireId, typ);
+                       enginePtr, err := dec.getDecEnginePtr(wireId, typ);
                        if err != nil {
                                return nil, 0, err
                        }
@@ -568,14 +568,14 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
 }
 
 // Return the decoding op for a field that has no destination.
-func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
+func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
        op, ok := decIgnoreOpMap[wireId];
        if !ok {
                // Special cases
                switch t := wireId.gobType().(type) {
                case *sliceType:
                        elemId := wireId.gobType().(*sliceType).Elem;
-                       elemOp, err := decIgnoreOpFor(elemId);
+                       elemOp, err := dec.decIgnoreOpFor(elemId);
                        if err != nil {
                                return nil, err
                        }
@@ -585,7 +585,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
 
                case *arrayType:
                        elemId := wireId.gobType().(*arrayType).Elem;
-                       elemOp, err := decIgnoreOpFor(elemId);
+                       elemOp, err := dec.decIgnoreOpFor(elemId);
                        if err != nil {
                                return nil, err
                        }
@@ -595,7 +595,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
 
                case *structType:
                        // Generate a closure that calls out to the engine for the nested type.
-                       enginePtr, err := getIgnoreEnginePtr(wireId);
+                       enginePtr, err := dec.getIgnoreEnginePtr(wireId);
                        if err != nil {
                                return nil, err
                        }
@@ -676,11 +676,18 @@ func compatibleType(fr reflect.Type, fw typeId) bool {
        return true;
 }
 
-func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
+func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
        srt, ok1 := rt.(*reflect.StructType);
-       wireStruct, ok2 := wireId.gobType().(*structType);
-       if !ok1 || !ok2 {
-               return nil, errNotStruct
+       var wireStruct *structType;
+       // Builtin types can come from global pool; the rest must be defined by the decoder
+       if t, ok := builtinIdToType[remoteId]; ok {
+               wireStruct = t.(*structType)
+       } else {
+               w, ok2 := dec.wireType[remoteId];
+               if !ok1 || !ok2 {
+                       return nil, errNotStruct
+               }
+               wireStruct = w.s;
        }
        engine = new(decEngine);
        engine.instr = make([]decInstr, len(wireStruct.field));
@@ -692,7 +699,7 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
                ovfl := overflow(wireField.name);
                // TODO(r): anonymous names
                if !present {
-                       op, err := decIgnoreOpFor(wireField.id);
+                       op, err := dec.decIgnoreOpFor(wireField.id);
                        if err != nil {
                                return nil, err
                        }
@@ -700,10 +707,10 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
                        continue;
                }
                if !compatibleType(localField.Type, wireField.id) {
-                       details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + wireId.Name();
+                       details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + remoteId.Name();
                        return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details);
                }
-               op, indir, err := decOpFor(wireField.id, localField.Type, localField.Name);
+               op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name);
                if err != nil {
                        return nil, err
                }
@@ -713,23 +720,19 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
        return;
 }
 
-var decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
-var ignorerCache = make(map[typeId]**decEngine)
-
-// typeLock must be held.
-func getDecEnginePtr(wireId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
-       decoderMap, ok := decoderCache[rt];
+func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
+       decoderMap, ok := dec.decoderCache[rt];
        if !ok {
                decoderMap = make(map[typeId]**decEngine);
-               decoderCache[rt] = decoderMap;
+               dec.decoderCache[rt] = decoderMap;
        }
-       if enginePtr, ok = decoderMap[wireId]; !ok {
+       if enginePtr, ok = decoderMap[remoteId]; !ok {
                // To handle recursive types, mark this engine as underway before compiling.
                enginePtr = new(*decEngine);
-               decoderMap[wireId] = enginePtr;
-               *enginePtr, err = compileDec(wireId, rt);
+               decoderMap[remoteId] = enginePtr;
+               *enginePtr, err = dec.compileDec(remoteId, rt);
                if err != nil {
-                       decoderMap[wireId] = nil, false
+                       decoderMap[remoteId] = nil, false
                }
        }
        return;
@@ -740,35 +743,28 @@ type emptyStruct struct{}
 
 var emptyStructType = reflect.Typeof(emptyStruct{})
 
-// typeLock must be held.
-func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
+func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
        var ok bool;
-       if enginePtr, ok = ignorerCache[wireId]; !ok {
+       if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
                // To handle recursive types, mark this engine as underway before compiling.
                enginePtr = new(*decEngine);
-               ignorerCache[wireId] = enginePtr;
-               *enginePtr, err = compileDec(wireId, emptyStructType);
+               dec.ignorerCache[wireId] = enginePtr;
+               *enginePtr, err = dec.compileDec(wireId, emptyStructType);
                if err != nil {
-                       ignorerCache[wireId] = nil, false
+                       dec.ignorerCache[wireId] = nil, false
                }
        }
        return;
 }
 
-func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
+func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
        // Dereference down to the underlying struct type.
        rt, indir := indirect(reflect.Typeof(e));
        st, ok := rt.(*reflect.StructType);
        if !ok {
                return os.ErrorString("gob: decode can't handle " + rt.String())
        }
-       typeLock.Lock();
-       if _, ok := idToType[wireId]; !ok {
-               typeLock.Unlock();
-               return errBadType;
-       }
-       enginePtr, err := getDecEnginePtr(wireId, rt);
-       typeLock.Unlock();
+       enginePtr, err := dec.getDecEnginePtr(wireId, rt);
        if err != nil {
                return err
        }
@@ -777,7 +773,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
                name := rt.Name();
                return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name);
        }
-       return decodeStruct(engine, st, b, uintptr(reflect.NewValue(e).Addr()), indir);
+       return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir);
 }
 
 func init() {
index dde5d823f983e1bba9c3df347fd5439821f59162..1713a3e59fa8ddd7dfdfd6b4f18e07728986c1a5 100644 (file)
@@ -8,17 +8,20 @@ import (
        "bytes";
        "io";
        "os";
+       "reflect";
        "sync";
 )
 
 // A Decoder manages the receipt of type and data information read from the
 // remote side of a connection.
 type Decoder struct {
-       mutex           sync.Mutex;             // each item must be received atomically
-       r               io.Reader;              // source of the data
-       seen            map[typeId]*wireType;   // which types we've already seen described
-       state           *decodeState;           // reads data from in-memory buffer
-       countState      *decodeState;           // reads counts from wire
+       mutex           sync.Mutex;                                     // each item must be received atomically
+       r               io.Reader;                                      // source of the data
+       wireType        map[typeId]*wireType;                           // map from remote ID to local description
+       decoderCache    map[reflect.Type]map[typeId]**decEngine;        // cache of compiled engines
+       ignorerCache    map[typeId]**decEngine;                         // ditto for ignored objects
+       state           *decodeState;                                   // reads data from in-memory buffer
+       countState      *decodeState;                                   // reads counts from wire
        buf             []byte;
        oneByte         []byte;
 }
@@ -27,8 +30,10 @@ type Decoder struct {
 func NewDecoder(r io.Reader) *Decoder {
        dec := new(Decoder);
        dec.r = r;
-       dec.seen = make(map[typeId]*wireType);
+       dec.wireType = make(map[typeId]*wireType);
        dec.state = newDecodeState(nil);        // buffer set in Decode(); rest is unimportant
+       dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine);
+       dec.ignorerCache = make(map[typeId]**decEngine);
        dec.oneByte = make([]byte, 1);
 
        return dec;
@@ -36,16 +41,16 @@ func NewDecoder(r io.Reader) *Decoder {
 
 func (dec *Decoder) recvType(id typeId) {
        // Have we already seen this type?  That's an error
-       if _, alreadySeen := dec.seen[id]; alreadySeen {
+       if _, alreadySeen := dec.wireType[id]; alreadySeen {
                dec.state.err = os.ErrorString("gob: duplicate type received");
                return;
        }
 
        // Type:
        wire := new(wireType);
-       decode(dec.state.b, tWireType, wire);
+       dec.state.err = dec.decode(tWireType, wire);
        // Remember we've seen this type.
-       dec.seen[id] = wire;
+       dec.wireType[id] = wire;
 }
 
 // Decode reads the next value from the connection and stores
@@ -97,7 +102,13 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
                }
 
                // No, it's a value.
-               dec.state.err = decode(dec.state.b, id, e);
+               // Make sure the type has been defined already.
+               _, ok := dec.wireType[id];
+               if !ok {
+                       dec.state.err = errBadType;
+                       break;
+               }
+               dec.state.err = dec.decode(id, e);
                break;
        }
        return dec.state.err;
index a59309e3355fc730b7c24a702eef4288b5040851..28ecaec93f73e463456fd4cb2287ca6840b643d2 100644 (file)
@@ -235,19 +235,15 @@ func (enc *Encoder) send() {
        enc.w.Write(enc.buf[0:total]);
 }
 
-func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
+func (enc *Encoder) sendType(origt reflect.Type) {
        // Drill down to the base type.
        rt, _ := indirect(origt);
 
        // We only send structs - everything else is basic or an error
-       switch rt.(type) {
+       switch rt := rt.(type) {
        default:
-               // Basic types do not need to be described, but if this is a top-level
-               // type, it's a user error, at least for now.
-               if topLevel {
-                       enc.badType(rt)
-               }
-               return;
+               // Basic types do not need to be described.
+               return
        case *reflect.StructType:
                // Structs do need to be described.
                break
@@ -255,10 +251,9 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
                // Probably a bad field in a struct.
                enc.badType(rt);
                return;
-       case *reflect.ArrayType, *reflect.SliceType:
-               // Array and slice types are not sent, only their element types.
-               // If we see one here it's user error; probably a bad top-level value.
-               enc.badType(rt);
+       // Array and slice types are not sent, only their element types.
+       case reflect.ArrayOrSliceType:
+               enc.sendType(rt.Elem());
                return;
        }
 
@@ -289,7 +284,7 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
        // Now send the inner types
        st := rt.(*reflect.StructType);
        for i := 0; i < st.NumField(); i++ {
-               enc.sendType(st.Field(i).Type, false)
+               enc.sendType(st.Field(i).Type)
        }
        return;
 }
@@ -301,6 +296,12 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
                panicln("Encoder: buffer not empty")
        }
        rt, _ := indirect(reflect.Typeof(e));
+       // Must be a struct
+       if _, ok := rt.(*reflect.StructType); !ok {
+               enc.badType(rt);
+               return enc.state.err;
+       }
+
 
        // Make sure we're single-threaded through here.
        enc.mutex.Lock();
@@ -310,7 +311,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
        // First, have we already sent this type?
        if _, alreadySent := enc.sent[rt]; !alreadySent {
                // No, so send it.
-               enc.sendType(rt, true);
+               enc.sendType(rt);
                if enc.state.err != nil {
                        enc.state.b.Reset();
                        enc.countState.b.Reset();
index e850bceae53adf97144b00e0b52465678e062b65..43d3e72ed95cd1601843315ccdc6603b1f0604c5 100644 (file)
@@ -32,122 +32,10 @@ type ET3 struct {
 // Like ET1 but with a different type for a field
 type ET4 struct {
        a       int;
-       et2     *ET1;
+       et2     float;
        next    int;
 }
 
-func TestBasicEncoder(t *testing.T) {
-       b := new(bytes.Buffer);
-       enc := NewEncoder(b);
-       et1 := new(ET1);
-       et1.a = 7;
-       et1.et2 = new(ET2);
-       enc.Encode(et1);
-       if enc.state.err != nil {
-               t.Error("encoder fail:", enc.state.err)
-       }
-
-       // Decode the result by hand to verify;
-       state := newDecodeState(b);
-       // The output should be:
-       // 0) The length, 38.
-       length := decodeUint(state);
-       if length != 38 {
-               t.Fatal("0. expected length 38; got", length)
-       }
-       // 1) -7: the type id of ET1
-       id1 := decodeInt(state);
-       if id1 >= 0 {
-               t.Fatal("expected ET1 negative id; got", id1)
-       }
-       // 2) The wireType for ET1
-       wire1 := new(wireType);
-       err := decode(b, tWireType, wire1);
-       if err != nil {
-               t.Fatal("error decoding ET1 type:", err)
-       }
-       info := getTypeInfoNoError(reflect.Typeof(ET1{}));
-       trueWire1 := &wireType{s: info.id.gobType().(*structType)};
-       if !reflect.DeepEqual(wire1, trueWire1) {
-               t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1)
-       }
-       // 3) The length, 21.
-       length = decodeUint(state);
-       if length != 21 {
-               t.Fatal("3. expected length 21; got", length)
-       }
-       // 4) -8: the type id of ET2
-       id2 := decodeInt(state);
-       if id2 >= 0 {
-               t.Fatal("expected ET2 negative id; got", id2)
-       }
-       // 5) The wireType for ET2
-       wire2 := new(wireType);
-       err = decode(b, tWireType, wire2);
-       if err != nil {
-               t.Fatal("error decoding ET2 type:", err)
-       }
-       info = getTypeInfoNoError(reflect.Typeof(ET2{}));
-       trueWire2 := &wireType{s: info.id.gobType().(*structType)};
-       if !reflect.DeepEqual(wire2, trueWire2) {
-               t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2)
-       }
-       // 6) The length, 6.
-       length = decodeUint(state);
-       if length != 6 {
-               t.Fatal("6. expected length 6; got", length)
-       }
-       // 7) The type id for the et1 value
-       newId1 := decodeInt(state);
-       if newId1 != -id1 {
-               t.Fatal("expected Et1 id", -id1, "got", newId1)
-       }
-       // 8) The value of et1
-       newEt1 := new(ET1);
-       et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id;
-       err = decode(b, et1Id, newEt1);
-       if err != nil {
-               t.Fatal("error decoding ET1 value:", err)
-       }
-       if !reflect.DeepEqual(et1, newEt1) {
-               t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
-       }
-       // 9) EOF
-       if b.Len() != 0 {
-               t.Error("not at eof;", b.Len(), "bytes left")
-       }
-
-       // Now do it again. This time we should see only the type id and value.
-       b.Reset();
-       enc.Encode(et1);
-       if enc.state.err != nil {
-               t.Error("2nd round: encoder fail:", enc.state.err)
-       }
-       // The length.
-       length = decodeUint(state);
-       if length != 6 {
-               t.Fatal("6. expected length 6; got", length)
-       }
-       // 5a) The type id for the et1 value
-       newId1 = decodeInt(state);
-       if newId1 != -id1 {
-               t.Fatal("2nd round: expected Et1 id", -id1, "got", newId1)
-       }
-       // 6a) The value of et1
-       newEt1 = new(ET1);
-       err = decode(b, et1Id, newEt1);
-       if err != nil {
-               t.Fatal("2nd round: error decoding ET1 value:", err)
-       }
-       if !reflect.DeepEqual(et1, newEt1) {
-               t.Fatalf("2nd round: invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
-       }
-       // 7a) EOF
-       if b.Len() != 0 {
-               t.Error("2nd round: not at eof;", b.Len(), "bytes left")
-       }
-}
-
 func TestEncoderDecoder(t *testing.T) {
        b := new(bytes.Buffer);
        enc := NewEncoder(b);
@@ -215,7 +103,7 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) {
                t.Error("expected error for", msg)
        }
        if !shouldFail && (dec.state.err != nil) {
-               t.Error("unexpected error for", msg)
+               t.Error("unexpected error for", msg, dec.state.err)
        }
 }
 
index 92db7cef382bd700c65f90f1fb408b4be9a33b94..ffff0541df2dcf9d818777405c2a3a05df0971d6 100644 (file)
@@ -48,6 +48,7 @@ type gobType interface {
 
 var types = make(map[reflect.Type]gobType)
 var idToType = make(map[typeId]gobType)
+var builtinIdToType map[typeId]gobType // set in init() after builtins are established
 
 func setTypeId(typ gobType) {
        nextId++;
@@ -104,6 +105,10 @@ func init() {
        checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id);
        checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id);
        checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id);
+       builtinIdToType = make(map[typeId]gobType);
+       for k, v := range idToType {
+               builtinIdToType[k] = v
+       }
 }
 
 // Array type