]> Cypherpunks repositories - gostls13.git/commitdiff
- allow wire type and receive type to differ.
authorRob Pike <r@golang.org>
Fri, 17 Jul 2009 00:55:16 +0000 (17:55 -0700)
committerRob Pike <r@golang.org>
Fri, 17 Jul 2009 00:55:16 +0000 (17:55 -0700)
- still TODO: ignoring struct fields.

R=rsc
DELTA=309  (240 added, 2 deleted, 67 changed)
OCL=31750
CL=31750

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/decoder.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go
src/pkg/gob/type.go

index 44ebf3a113ad50bc880ed31f0c68e0e043492fe3..294506589d40462714900d7da9124031d7f00676 100644 (file)
@@ -523,7 +523,6 @@ func TestScalarDecInstructions(t *testing.T) {
        }
 }
 
-
 func TestEndToEnd(t *testing.T) {
        type T2 struct {
                t string
@@ -553,7 +552,7 @@ func TestEndToEnd(t *testing.T) {
        b := new(bytes.Buffer);
        encode(b, t1);
        var _t1 T1;
-       decode(b, &_t1);
+       decode(b, getTypeInfo(reflect.Typeof(_t1)).typeId, &_t1);
        if !reflect.DeepEqual(t1, &_t1) {
                t.Errorf("encode expected %v got %v", *t1, _t1);
        }
@@ -571,7 +570,7 @@ func TestNesting(t *testing.T) {
        b := new(bytes.Buffer);
        encode(b, rt);
        var drt RT;
-       decode(b, &drt);
+       decode(b, getTypeInfo(reflect.Typeof(drt)).typeId, &drt);
        if drt.a != rt.a {
                t.Errorf("nesting: encode expected %v got %v", *rt, drt);
        }
@@ -613,7 +612,8 @@ func TestAutoIndirection(t *testing.T) {
        b := new(bytes.Buffer);
        encode(b, t1);
        var t0 T0;
-       decode(b, &t0);
+       t0Id := getTypeInfo(reflect.Typeof(t0)).typeId;
+       decode(b, t0Id, &t0);
        if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
                t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0);
        }
@@ -627,7 +627,7 @@ func TestAutoIndirection(t *testing.T) {
        b.Reset();
        encode(b, t2);
        t0 = T0{};
-       decode(b, &t0);
+       decode(b, t0Id, &t0);
        if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
                t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0);
        }
@@ -637,7 +637,8 @@ func TestAutoIndirection(t *testing.T) {
        b.Reset();
        encode(b, t0);
        t1 = T1{};
-       decode(b, &t1);
+       t1Id := getTypeInfo(reflect.Typeof(t1)).typeId;
+       decode(b, t1Id, &t1);
        if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
                t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d);
        }
@@ -646,7 +647,8 @@ func TestAutoIndirection(t *testing.T) {
        b.Reset();
        encode(b, t0);
        t2 = T2{};
-       decode(b, &t2);
+       t2Id := getTypeInfo(reflect.Typeof(t2)).typeId;
+       decode(b, t2Id, &t2);
        if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
                t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d);
        }
@@ -658,8 +660,72 @@ func TestAutoIndirection(t *testing.T) {
        **t2.b = 0;
        *t2.c = 0;
        t2.d = 0;
-       decode(b, &t2);
+       decode(b, t2Id, &t2);
        if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
                t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d);
        }
 }
+
+type RT0 struct {
+       a int;
+       b string;
+       c float;
+}
+type RT1 struct {
+       c float;
+       b string;
+       a int;
+       notSet string;
+}
+
+func TestReorderedFields(t *testing.T) {
+       var rt0 RT0;
+       rt0.a = 17;
+       rt0.b = "hello";
+       rt0.c = 3.14159;
+       b := new(bytes.Buffer);
+       encode(b, rt0);
+       rt0Id := getTypeInfo(reflect.Typeof(rt0)).typeId;
+       var rt1 RT1;
+       // Wire type is RT0, local type is RT1.
+       decode(b, rt0Id, &rt1);
+       if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c {
+               t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1);
+       }
+}
+
+// Like an RT0 but with fields we'll ignore on the decode side.
+type IT0 struct {
+       a int64;
+       b string;
+       ignore_d []int;
+       ignore_e [3]float;
+       ignore_f bool;
+       ignore_g string;
+       ignore_h []byte;
+       c float;
+}
+
+func TestIgnoredFields(t *testing.T) {
+       var it0 IT0;
+       it0.a = 17;
+       it0.b = "hello";
+       it0.c = 3.14159;
+       it0.ignore_d = []int{ 1, 2, 3 };
+       it0.ignore_e[0]  = 1.0;
+       it0.ignore_e[1]  = 2.0;
+       it0.ignore_e[2]  = 3.0;
+       it0.ignore_f = true;
+       it0.ignore_g = "pay no attention";
+       it0.ignore_h = strings.Bytes("to the curtain");
+
+       b := new(bytes.Buffer);
+       encode(b, it0);
+       rt0Id := getTypeInfo(reflect.Typeof(it0)).typeId;
+       var rt1 RT1;
+       // Wire type is IT0, local type is RT1.
+       decode(b, rt0Id, &rt1);
+       if int(it0.a) != rt1.a || it0.b != rt1.b || it0.c != rt1.c {
+               t.Errorf("rt1->rt0: expected %v; got %v", it0, rt1);
+       }
+}
index ec2ed66a5ac5ee7d94509b772307f72c74c69572..2fbcf68b3e4a4192d03d6a6e54b897ccc6a14d1b 100644 (file)
@@ -83,7 +83,7 @@ type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer);
 // The 'instructions' of the decoding machine
 type decInstr struct {
        op      decOp;
-       field           int;    // field number
+       field           int;    // field number of the wire type
        indir   int;    // how many pointer indirections to reach the value in the struct
        offset  uintptr;        // offset in the structure of the field to encode
 }
@@ -107,6 +107,10 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
        return p
 }
 
+func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) {
+       decodeUint(state);
+}
+
 func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
@@ -298,12 +302,18 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) {
        *(*string)(p) = string(b);
 }
 
+func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
+       b := make([]byte, decodeUint(state));
+       state.b.Read(b);
+}
+
 // Execution engine
 
 // The encoder engine is an array of instructions indexed by field number of the incoming
 // data.  It is executed with random access according to field number.
 type decEngine struct {
-       instr   []decInstr
+       instr   []decInstr;
+       numInstr        int;    // the number of active instructions
 }
 
 func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error {
@@ -332,7 +342,7 @@ func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer,
                }
                fieldnum := state.fieldnum + delta;
                if fieldnum >= len(engine.instr) {
-                       panicln("TODO(r): need to handle unknown data");
+                       panicln("TODO(r): field number out of range", fieldnum, len(engine.instr));
                }
                instr := &engine.instr[fieldnum];
                p := unsafe.Pointer(basep+instr.offset);
@@ -375,6 +385,21 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp
        return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir);
 }
 
+func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error {
+       instr := &decInstr{elemOp, 0, 0, 0};
+       for i := 0; i < length && state.err == nil; i++ {
+               elemOp(instr, state, nil);
+       }
+       return state.err
+}
+
+func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error {
+       if n := decodeUint(state); n != uint64(length) {
+               return os.ErrorString("gob: length mismatch in ignoreArray");
+       }
+       return ignoreArrayHelper(state, elemOp, length);
+}
+
 func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int) os.Error {
        length := uintptr(decodeUint(state));
        if indir > 0 {
@@ -395,30 +420,43 @@ func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp
        return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, int(length), elemIndir);
 }
 
+func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
+       return ignoreArrayHelper(state, elemOp, int(decodeUint(state)));
+}
+
 var decOpMap = map[reflect.Type] decOp {
-        reflect.Typeof((*reflect.BoolType)(nil)): decBool,
-        reflect.Typeof((*reflect.IntType)(nil)): decInt,
-        reflect.Typeof((*reflect.Int8Type)(nil)): decInt8,
-        reflect.Typeof((*reflect.Int16Type)(nil)): decInt16,
-        reflect.Typeof((*reflect.Int32Type)(nil)): decInt32,
-        reflect.Typeof((*reflect.Int64Type)(nil)): decInt64,
-        reflect.Typeof((*reflect.UintType)(nil)): decUint,
-        reflect.Typeof((*reflect.Uint8Type)(nil)): decUint8,
-        reflect.Typeof((*reflect.Uint16Type)(nil)): decUint16,
-        reflect.Typeof((*reflect.Uint32Type)(nil)): decUint32,
-        reflect.Typeof((*reflect.Uint64Type)(nil)): decUint64,
-        reflect.Typeof((*reflect.UintptrType)(nil)): decUintptr,
-        reflect.Typeof((*reflect.FloatType)(nil)): decFloat,
-        reflect.Typeof((*reflect.Float32Type)(nil)): decFloat32,
-        reflect.Typeof((*reflect.Float64Type)(nil)): decFloat64,
-        reflect.Typeof((*reflect.StringType)(nil)): decString,
-}
-
-func getDecEngine(rt reflect.Type) *decEngine
+       reflect.Typeof((*reflect.BoolType)(nil)): decBool,
+       reflect.Typeof((*reflect.IntType)(nil)): decInt,
+       reflect.Typeof((*reflect.Int8Type)(nil)): decInt8,
+       reflect.Typeof((*reflect.Int16Type)(nil)): decInt16,
+       reflect.Typeof((*reflect.Int32Type)(nil)): decInt32,
+       reflect.Typeof((*reflect.Int64Type)(nil)): decInt64,
+       reflect.Typeof((*reflect.UintType)(nil)): decUint,
+       reflect.Typeof((*reflect.Uint8Type)(nil)): decUint8,
+       reflect.Typeof((*reflect.Uint16Type)(nil)): decUint16,
+       reflect.Typeof((*reflect.Uint32Type)(nil)): decUint32,
+       reflect.Typeof((*reflect.Uint64Type)(nil)): decUint64,
+       reflect.Typeof((*reflect.UintptrType)(nil)): decUintptr,
+       reflect.Typeof((*reflect.FloatType)(nil)): decFloat,
+       reflect.Typeof((*reflect.Float32Type)(nil)): decFloat32,
+       reflect.Typeof((*reflect.Float64Type)(nil)): decFloat64,
+       reflect.Typeof((*reflect.StringType)(nil)): decString,
+}
+
+var decIgnoreOpMap = map[TypeId] decOp {
+       tBool: ignoreUint,
+       tInt: ignoreUint,
+       tUint: ignoreUint,
+       tFloat: ignoreUint,
+       tBytes: ignoreUint8Array,
+       tString: ignoreUint8Array,
+}
+
+func getDecEnginePtr(wireId TypeId, rt reflect.Type) **decEngine
 
 // Return the decoding op for the base type under rt and
 // the indirection count to reach it.
-func decOpFor(rt reflect.Type) (decOp, int) {
+func decOpFor(wireId TypeId, rt reflect.Type) (decOp, int) {
        typ, indir := indirect(rt);
        op, ok := decOpMap[reflect.Typeof(typ)];
        if !ok {
@@ -429,24 +467,25 @@ func decOpFor(rt reflect.Type) (decOp, int) {
                                op = decUint8Array;
                                break;
                        }
-                       elemOp, elemIndir := decOpFor(t.Elem());
+                       elemId := wireId.gobType().(*sliceType).Elem;
+                       elemOp, elemIndir := decOpFor(elemId, t.Elem());
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir);
                        };
 
                case *reflect.ArrayType:
-                       elemOp, elemIndir := decOpFor(t.Elem());
+                       elemId := wireId.gobType().(*arrayType).Elem;
+                       elemOp, elemIndir := decOpFor(elemId, t.Elem());
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir);
                        };
 
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
-                       engine := getDecEngine(typ);
-                       info := getTypeInfo(typ);
+                       enginePtr := getDecEnginePtr(wireId, typ);
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                // indirect through info to delay evaluation for recursive structs
-                               state.err = decodeStruct(info.decoder, t, state.b, uintptr(p), i.indir)
+                               state.err = decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
                        };
                }
        }
@@ -456,53 +495,173 @@ func decOpFor(rt reflect.Type) (decOp, int) {
        return op, indir
 }
 
-func compileDec(rt reflect.Type, typ gobType) *decEngine {
+// Return the decoding op for a field that has no destination.
+func decIgnoreOpFor(wireId TypeId) decOp {
+       op, ok := decIgnoreOpMap[wireId];
+       if !ok {
+               // Special cases
+               switch t := wireId.gobType().(type) {
+               case *sliceType:
+                       elemId := wireId.gobType().(*sliceType).Elem;
+                       elemOp := decIgnoreOpFor(elemId);
+                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                               state.err = ignoreSlice(state, elemOp);
+                       };
+
+               case *arrayType:
+                       elemId := wireId.gobType().(*arrayType).Elem;
+                       elemOp := decIgnoreOpFor(elemId);
+                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                               state.err = ignoreArray(state, elemOp, t.Len);
+                       };
+
+               case *structType:
+                       // TODO: write an ignore engine for structs
+               }
+       }
+       if op == nil {
+               panicln("decode can't handle type", wireId.gobType().String());
+       }
+       return op;
+}
+
+// Are these two gob Types compatible?
+// Answers the question for basic types, arrays, and slices.  Defers for structs.
+func compatibleType(fr reflect.Type, fw TypeId) bool {
+       for {
+               if pt, ok := fr.(*reflect.PtrType); ok {
+                       fr = pt.Elem();
+                       continue;
+               }
+               break;
+       }
+       switch t := fr.(type) {
+       default:
+               // interface, map, chan, etc: cannot handle.
+               return false;
+       case *reflect.BoolType:
+               return fw == tBool;
+       case *reflect.IntType:
+               return fw == tInt;
+       case *reflect.Int8Type:
+               return fw == tInt;
+       case *reflect.Int16Type:
+               return fw == tInt;
+       case *reflect.Int32Type:
+               return fw == tInt;
+       case *reflect.Int64Type:
+               return fw == tInt;
+       case *reflect.UintType:
+               return fw == tUint;
+       case *reflect.Uint8Type:
+               return fw == tUint;
+       case *reflect.Uint16Type:
+               return fw == tUint;
+       case *reflect.Uint32Type:
+               return fw == tUint;
+       case *reflect.Uint64Type:
+               return fw == tUint;
+       case *reflect.UintptrType:
+               return fw == tUint;
+       case *reflect.FloatType:
+               return fw == tFloat;
+       case *reflect.Float32Type:
+               return fw == tFloat;
+       case *reflect.Float64Type:
+               return fw == tFloat;
+       case *reflect.StringType:
+               return fw == tString;
+       case *reflect.StructType:
+               return true;    // defer for now
+       case *reflect.ArrayType:
+               aw, ok := fw.gobType().(*arrayType);
+               return ok && t.Len() == aw.Len && compatibleType(t.Elem(), aw.Elem);
+       case *reflect.SliceType:
+               // Is it an array of bytes?
+               et := t.Elem();
+               if _, ok := et.(*reflect.Uint8Type); ok {
+                       return fw == tBytes
+               }
+               sw, ok := fw.gobType().(*sliceType);
+               return ok && compatibleType(t.Elem(), sw.Elem);
+       }
+       return true;
+}
+
+func compileDec(wireId TypeId, rt reflect.Type) *decEngine {
        srt, ok1 := rt.(*reflect.StructType);
-       styp, ok2 := typ.(*structType);
+       wireStruct, ok2 := wireId.gobType().(*structType);
        if !ok1 || !ok2 {
-               panicln("TODO: can't handle non-structs");
+               panicln("gob: TODO: can't handle non-structs");
        }
        engine := new(decEngine);
-       engine.instr = make([]decInstr, len(styp.field));
-       for fieldnum := 0; fieldnum < len(styp.field); fieldnum++ {
-               field := styp.field[fieldnum];
-               // Assumes perfect correspondence between struct and gob,
-               // which is safe to assume since typ was compiled from rt.
-               f := srt.Field(fieldnum);
-               op, indir := decOpFor(f.Type);
-               engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(f.Offset)};
+       engine.instr = make([]decInstr, len(wireStruct.field));
+       // Loop over the fields of the wire type.
+       for fieldnum := 0; fieldnum < len(wireStruct.field); fieldnum++ {
+               wireField := wireStruct.field[fieldnum];
+               // Find the field of the local type with the same name.
+               // TODO: put this as a method in reflect
+               var localField reflect.StructField;
+               for lfn := 0; lfn < srt.NumField(); lfn++ {
+                       if srt.Field(lfn).Name == wireField.name {
+                               localField = srt.Field(lfn);
+                               break;
+                       }
+               }
+               // TODO(r): anonymous names
+               if localField.Anonymous || localField.Name == "" {
+                       println("no matching field", wireField.name, "in type", wireId.String());
+                       op := decIgnoreOpFor(wireField.typeId);
+                       engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0};
+                       continue;
+               }
+               if !compatibleType(localField.Type, wireField.typeId) {
+                       panicln("TODO: wrong type for field", wireField.name, "in type", wireId.String());
+               }
+               op, indir := decOpFor(wireField.typeId, localField.Type);
+               engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset)};
+               engine.numInstr++;
        }
        return engine;
 }
 
 
 // typeLock must be held.
-func getDecEngine(rt reflect.Type) *decEngine {
-       info := getTypeInfo(rt);
-       if info.decoder == nil {
+func getDecEnginePtr(wireId TypeId, rt reflect.Type) **decEngine {
+       info := getTypeInfo(rt);        // TODO: eliminate this; creates a gobType you don't need.
+       var enginePtr **decEngine;
+       var ok bool;
+       if enginePtr, ok = info.decoderPtr[wireId]; !ok {
                if info.typeId.gobType() == nil {
                        _pkg, name := rt.Name();
                        info.typeId = newType(name, rt).id();
                }
                // mark this engine as underway before compiling to handle recursive types.
-               info.decoder = new(decEngine);
-               info.decoder = compileDec(rt, info.typeId.gobType());
+               enginePtr = new(*decEngine);
+               info.decoderPtr[wireId] = enginePtr;
+               *enginePtr = compileDec(wireId, rt);
        }
-       return info.decoder;
+       return enginePtr
 }
 
-func decode(b *bytes.Buffer, e interface{}) os.Error {
+func decode(b *bytes.Buffer, wireId TypeId, e interface{}) os.Error {
        // Dereference down to the underlying object.
        rt, indir := indirect(reflect.Typeof(e));
        v := reflect.NewValue(e);
        for i := 0; i < indir; i++ {
                v = reflect.Indirect(v);
        }
-       if _, ok := v.(*reflect.StructValue); !ok {
+       var st *reflect.StructValue;
+       var ok bool;
+       if st, ok = v.(*reflect.StructValue); !ok {
                return os.ErrorString("gob: decode can't handle " + rt.String())
        }
        typeLock.Lock();
-       engine := getDecEngine(rt);
+       engine := *getDecEnginePtr(wireId, rt);
        typeLock.Unlock();
+       if engine.numInstr == 0 && st.NumField() > 0 {
+               path, name := rt.Name();
+               return os.ErrorString("no fields matched compiling decoder for " + name)
+       }
        return decodeStruct(engine, rt.(*reflect.StructType), b, uintptr(v.Addr()), 0);
 }
index 9257f7c23a28c18f18575378e3ed41bd61e28ad1..9c5a755542810020374434aed62d5f8791157974 100644 (file)
@@ -42,7 +42,7 @@ func (dec *Decoder) recvType(id TypeId) {
 
        // Type:
        wire := new(wireType);
-       decode(dec.state.b, wire);
+       decode(dec.state.b, tWireType, wire);
        // Remember we've seen this type.
        dec.seen[id] = wire;
 }
@@ -86,7 +86,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
                        return dec.state.err
                }
 
-               // Is it a type?
+               // Is it a new type?
                if id < 0 {     // 0 is the error state, handled above
                        // If the id is negative, we have a type.
                        dec.recvType(-id);
@@ -97,7 +97,9 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
                }
 
                // No, it's a value.
+               typeLock.Lock();
                info := getTypeInfo(rt);
+               typeLock.Unlock();
 
                // Check type compatibility.
                // TODO(r): need to make the decoder work correctly if the wire type is compatible
@@ -108,7 +110,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
                }
 
                // Receive a value.
-               decode(dec.state.b, e);
+               decode(dec.state.b, id, e);
 
                return dec.state.err
        }
index b3a420a86a2ebf6b080f7de16731cd6b3489f044..f75eccd958f54c008ca955e782951f8518cb350d 100644 (file)
@@ -86,7 +86,9 @@ func (enc *Encoder) sendType(origt reflect.Type) {
        }
 
        // Need to send it.
+       typeLock.Lock();
        info := getTypeInfo(rt);
+       typeLock.Unlock();
        // Send the pair (-id, type)
        // Id:
        encodeInt(enc.state, -int64(info.typeId));
index 6a4a53734171a14105a29243a2dd52d3a9dc6be3..5acb310c75c36962cf36b43150bb0675220ce2f7 100644 (file)
@@ -72,7 +72,7 @@ func TestBasicEncoder(t *testing.T) {
        }
        // 2) The wireType for ET1
        wire1 := new(wireType);
-       err := decode(b, wire1);
+       err := decode(b, tWireType, wire1);
        if err != nil {
                t.Fatal("error decoding ET1 type:", err);
        }
@@ -93,7 +93,7 @@ func TestBasicEncoder(t *testing.T) {
        }
        // 5) The wireType for ET2
        wire2 := new(wireType);
-       err = decode(b, wire2);
+       err = decode(b, tWireType, wire2);
        if err != nil {
                t.Fatal("error decoding ET2 type:", err);
        }
@@ -114,7 +114,8 @@ func TestBasicEncoder(t *testing.T) {
        }
        // 8) The value of et1
        newEt1 := new(ET1);
-       err = decode(b, newEt1);
+       et1Id := getTypeInfo(reflect.Typeof(*newEt1)).typeId;
+       err = decode(b, et1Id, newEt1);
        if err != nil {
                t.Fatal("error decoding ET1 value:", err);
        }
@@ -144,7 +145,7 @@ func TestBasicEncoder(t *testing.T) {
        }
        // 6a) The value of et1
        newEt1 = new(ET1);
-       err = decode(b, newEt1);
+       err = decode(b, et1Id, newEt1);
        if err != nil {
                t.Fatal("2nd round: error decoding ET1 value:", err);
        }
index 00ff82494dafd95d1cb1cf6d28edf1cde5b35675..6f84e7bcf87aa09ad27d2596f1d28ba6e071beb3 100644 (file)
@@ -81,6 +81,9 @@ var tFloat TypeId
 var tString TypeId
 var tBytes TypeId
 
+// Predefined because it's needed by the Decoder
+var tWireType TypeId
+
 // Array type
 type arrayType struct {
        commonType;
@@ -322,7 +325,9 @@ type decEngine struct       // defined in decode.go
 type encEngine struct  // defined in encode.go
 type typeInfo struct {
        typeId  TypeId;
-       decoder *decEngine;
+       // Decoder engine to convert TypeId.Type() to this type.  Stored as a pointer to a
+       // pointer to aid construction of recursive types.  Protected by typeLock.
+       decoderPtr      map[TypeId] **decEngine;
        encoder *encEngine;
        wire    *wireType;
 }
@@ -330,6 +335,7 @@ type typeInfo struct {
 var typeInfoMap = make(map[reflect.Type] *typeInfo)    // protected by typeLock
 
 // The reflection type must have all its indirections processed out.
+// typeLock must be held.
 func getTypeInfo(rt reflect.Type) *typeInfo {
        if pt, ok := rt.(*reflect.PtrType); ok {
                panicln("pointer type in getTypeInfo:", rt.String())
@@ -339,6 +345,7 @@ func getTypeInfo(rt reflect.Type) *typeInfo {
                info = new(typeInfo);
                path, name := rt.Name();
                info.typeId = getType(name, rt).id();
+               info.decoderPtr = make(map[TypeId] **decEngine);
                // assume it's a struct type
                info.wire = &wireType{info.typeId.gobType().(*structType)};
                typeInfoMap[rt] = info;
@@ -347,11 +354,12 @@ func getTypeInfo(rt reflect.Type) *typeInfo {
 }
 
 func init() {
-       tBool= bootstrapType("bool", false);
+       tBool = bootstrapType("bool", false);
        tInt = bootstrapType("int", int(0));
        tUint = bootstrapType("uint", uint(0));
        tFloat = bootstrapType("float", float64(0));
        // The string for tBytes is "bytes" not "[]byte" to signify its specialness.
        tBytes = bootstrapType("bytes", make([]byte, 0));
        tString= bootstrapType("string", "");
+       tWireType = getTypeInfo(reflect.Typeof(wireType{})).typeId;
 }