]> Cypherpunks repositories - gostls13.git/commitdiff
store ids rather than Types in the structs so they can be encoded.
authorRob Pike <r@golang.org>
Thu, 9 Jul 2009 21:33:43 +0000 (14:33 -0700)
committerRob Pike <r@golang.org>
Thu, 9 Jul 2009 21:33:43 +0000 (14:33 -0700)
change Type to gobType.
fix some bugs around recursive structures.
lots of cleanup.
add the first cut at a type encoder.

R=rsc
DELTA=400  (287 added, 11 deleted, 102 changed)
OCL=31401
CL=31406

src/pkg/gob/Makefile
src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go [new file with mode: 0644]
src/pkg/gob/encoder_test.go [new file with mode: 0644]
src/pkg/gob/type.go
src/pkg/gob/type_test.go

index e41eac3a95524bb020bb7518d28ea7564abb056b..42383ba05ccf22c9db075313eec9b427b38e1ec4 100644 (file)
@@ -36,10 +36,11 @@ O1=\
        type.$O\
 
 O2=\
+       decode.$O\
        encode.$O\
 
 O3=\
-       decode.$O\
+       encoder.$O\
 
 
 phases: a1 a2 a3
@@ -50,11 +51,11 @@ a1: $(O1)
        rm -f $(O1)
 
 a2: $(O2)
-       $(AR) grc _obj$D/gob.a encode.$O
+       $(AR) grc _obj$D/gob.a decode.$O encode.$O
        rm -f $(O2)
 
 a3: $(O3)
-       $(AR) grc _obj$D/gob.a decode.$O
+       $(AR) grc _obj$D/gob.a encoder.$O
        rm -f $(O3)
 
 
index e25a719fad80657a596ecd6178413ee38e0b432c..23ff885f0e040036758a38651d1a5cf54a7216f9 100644 (file)
@@ -13,7 +13,6 @@ import (
        "testing";
        "unsafe";
 )
-import "fmt" // TODO DELETE
 
 // Guarantee encoding format by comparing some encodings to hand-written values
 type EncodeT struct {
@@ -560,6 +559,30 @@ func TestEndToEnd(t *testing.T) {
        }
 }
 
+func TestNesting(t *testing.T) {
+       type RT struct {
+               a string;
+               next *RT
+       }
+       rt := new(RT);
+       rt.a = "level1";
+       rt.next = new(RT);
+       rt.next.a = "level2";
+       b := new(bytes.Buffer);
+       Encode(b, rt);
+       var drt RT;
+       Decode(b, &drt);
+       if drt.a != rt.a {
+               t.Errorf("nesting: encode expected %v got %v", *rt, drt);
+       }
+       if drt.next == nil {
+               t.Errorf("nesting: recursion failed");
+       }
+       if drt.next.a != rt.next.a {
+               t.Errorf("nesting: encode expected %v got %v", *rt.next, *drt.next);
+       }
+}
+
 // These three structures have the same data with different indirections
 type T0 struct {
        a int;
index 1b3e3104a80e50d680f84f83cf91f2f86ec12648..4735f6ba1ced40681e7cbd2baeebef283566f2df 100644 (file)
@@ -199,6 +199,16 @@ func decUint64(i *decInstr, state *DecState, p unsafe.Pointer) {
        *(*uint64)(p) = uint64(DecodeUint(state));
 }
 
+func decUintptr(i *decInstr, state *DecState, p unsafe.Pointer) {
+       if i.indir > 0 {
+               if *(*unsafe.Pointer)(p) == nil {
+                       *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uintptr));
+               }
+               p = *(*unsafe.Pointer)(p);
+       }
+       *(*uintptr)(p) = uintptr(DecodeUint(state));
+}
+
 // Floating-point numbers are transmitted as uint64s holding the bits
 // of the underlying representation.  They are sent byte-reversed, with
 // the exponent end coming out first, so integer floating point numbers
@@ -367,7 +377,6 @@ func decodeSlice(atyp *reflect.SliceType, state *DecState, p uintptr, elemOp dec
        return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, int(length), elemIndir);
 }
 
-var decEngineMap = make(map[reflect.Type] *decEngine)
 var decOpMap = map[reflect.Type] decOp {
         reflect.Typeof((*reflect.BoolType)(nil)): decBool,
         reflect.Typeof((*reflect.IntType)(nil)): decInt,
@@ -380,6 +389,7 @@ var decOpMap = map[reflect.Type] decOp {
         reflect.Typeof((*reflect.Uint16Type)(nil)): decUint16,
         reflect.Typeof((*reflect.Uint32Type)(nil)): decUint32,
         reflect.Typeof((*reflect.Uint64Type)(nil)): decUint64,
+        reflect.Typeof((*reflect.UintptrType)(nil)): decUintptr,
         reflect.Typeof((*reflect.FloatType)(nil)): decFloat,
         reflect.Typeof((*reflect.Float32Type)(nil)): decFloat32,
         reflect.Typeof((*reflect.Float64Type)(nil)): decFloat64,
@@ -415,8 +425,10 @@ func decOpFor(rt reflect.Type) (decOp, int) {
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
                        engine := getDecEngine(typ);
+                       info := getTypeInfo(typ);
                        op = func(i *decInstr, state *DecState, p unsafe.Pointer) {
-                               state.err = decodeStruct(engine, t, state.r, uintptr(p), i.indir)
+                               // indirect through info to delay evaluation for recursive structs
+                               state.err = decodeStruct(info.decoder, t, state.r, uintptr(p), i.indir)
                        };
                }
        }
@@ -426,7 +438,7 @@ func decOpFor(rt reflect.Type) (decOp, int) {
        return op, indir
 }
 
-func compileDec(rt reflect.Type, typ Type) *decEngine {
+func compileDec(rt reflect.Type, typ gobType) *decEngine {
        srt, ok1 := rt.(*reflect.StructType);
        styp, ok2 := typ.(*structType);
        if !ok1 || !ok2 {
@@ -436,8 +448,8 @@ func compileDec(rt reflect.Type, typ Type) *decEngine {
        engine.instr = make([]decInstr, len(styp.field));
        for fieldnum := 0; fieldnum < len(styp.field); fieldnum++ {
                field := styp.field[fieldnum];
-               // TODO(r): verify compatibility with corresponding field of data.
-               // For now, assume perfect correspondence between struct and gob.
+               // Assumes perfect correspondence between struct and gob,
+               // which is safe to assume since typ was compiled from rt.
                f := srt.Field(fieldnum);
                op, indir := decOpFor(f.Type);
                engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(f.Offset)};
@@ -446,14 +458,19 @@ func compileDec(rt reflect.Type, typ Type) *decEngine {
 }
 
 
+// typeLock must be held.
 func getDecEngine(rt reflect.Type) *decEngine {
-       engine, ok := decEngineMap[rt];
-       if !ok {
-               pkg, name := rt.Name();
-               engine = compileDec(rt, newType(name, rt));
-               decEngineMap[rt] = engine;
+       info := getTypeInfo(rt);
+       if info.decoder == nil {
+               if info.typeId.gobType() == nil {
+                       _pkg, name := rt.Name();
+                       info.typeId = newType(name, rt).id();
+               }
+               // mark this engine as underway before compiling to handle recursive types.
+               info.decoder = new(decEngine);
+               info.decoder = compileDec(rt, info.typeId.gobType());
        }
-       return engine;
+       return info.decoder;
 }
 
 func Decode(r io.Reader, e interface{}) os.Error {
index f32180c3a97064e666c70a8ec6e1118353ac35d4..dac8097518fc663d59f3051582aa6824dd076558 100644 (file)
@@ -183,6 +183,14 @@ func encUint64(i *encInstr, state *EncState, p unsafe.Pointer) {
        }
 }
 
+func encUintptr(i *encInstr, state *EncState, p unsafe.Pointer) {
+       v := uint64(*(*uintptr)(p));
+       if v != 0 {
+               state.update(i);
+               EncodeUint(state, v);
+       }
+}
+
 // Floating-point numbers are transmitted as uint64s holding the bits
 // of the underlying representation.  They are sent byte-reversed, with
 // the exponent end coming out first, so integer floating point numbers
@@ -285,21 +293,21 @@ func encodeArray(w io.Writer, p uintptr, op encOp, elemWid uintptr, length int,
        state.fieldnum = -1;
        EncodeUint(state, uint64(length));
        for i := 0; i < length && state.err == nil; i++ {
-               up := unsafe.Pointer(p);
+               elemp := p;
+               up := unsafe.Pointer(elemp);
                if elemIndir > 0 {
                        if up = encIndirect(up, elemIndir); up == nil {
                                state.err = os.ErrorString("encodeArray: nil element");
                                break
                        }
-                       p = uintptr(up);
+                       elemp = uintptr(up);
                }
-               op(nil, state, unsafe.Pointer(p));
+               op(nil, state, unsafe.Pointer(elemp));
                p += uintptr(elemWid);
        }
        return state.err
 }
 
-var encEngineMap = make(map[reflect.Type] *encEngine)
 var encOpMap = map[reflect.Type] encOp {
        reflect.Typeof((*reflect.BoolType)(nil)): encBool,
        reflect.Typeof((*reflect.IntType)(nil)): encInt,
@@ -312,6 +320,7 @@ var encOpMap = map[reflect.Type] encOp {
        reflect.Typeof((*reflect.Uint16Type)(nil)): encUint16,
        reflect.Typeof((*reflect.Uint32Type)(nil)): encUint32,
        reflect.Typeof((*reflect.Uint64Type)(nil)): encUint64,
+       reflect.Typeof((*reflect.UintptrType)(nil)): encUintptr,
        reflect.Typeof((*reflect.FloatType)(nil)): encFloat,
        reflect.Typeof((*reflect.Float32Type)(nil)): encFloat32,
        reflect.Typeof((*reflect.Float64Type)(nil)): encFloat64,
@@ -354,9 +363,11 @@ func encOpFor(rt reflect.Type) (encOp, int) {
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
                        engine := getEncEngine(typ);
+                       info := getTypeInfo(typ);
                        op = func(i *encInstr, state *EncState, p unsafe.Pointer) {
                                state.update(i);
-                               state.err = encodeStruct(engine, state.w, uintptr(p));
+                               // indirect through info to delay evaluation for recursive structs
+                               state.err = encodeStruct(info.encoder, state.w, uintptr(p));
                        };
                }
        }
@@ -366,10 +377,8 @@ func encOpFor(rt reflect.Type) (encOp, int) {
        return op, indir
 }
 
-// The local Type was compiled from the actual value, so we know
-// it's compatible.
-// TODO(r): worth checking?  typ is unused here.
-func compileEnc(rt reflect.Type, typ Type) *encEngine {
+// The local Type was compiled from the actual value, so we know it's compatible.
+func compileEnc(rt reflect.Type) *encEngine {
        srt, ok := rt.(*reflect.StructType);
        if !ok {
                panicln("TODO: can't handle non-structs");
@@ -385,15 +394,16 @@ func compileEnc(rt reflect.Type, typ Type) *encEngine {
        return engine;
 }
 
-// typeLock must be held.
+// typeLock must be held (or we're in initialization and guaranteed single-threaded).
+// The reflection type must have all its indirections processed out.
 func getEncEngine(rt reflect.Type) *encEngine {
-       engine, ok := encEngineMap[rt];
-       if !ok {
-               pkg, name := rt.Name();
-               engine = compileEnc(rt, newType(name, rt));
-               encEngineMap[rt] = engine;
+       info := getTypeInfo(rt);
+       if info.encoder == nil {
+               // mark this engine as underway before compiling to handle recursive types.
+               info.encoder = new(encEngine);
+               info.encoder = compileEnc(rt);
        }
-       return engine
+       return info.encoder;
 }
 
 func Encode(w io.Writer, e interface{}) os.Error {
diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go
new file mode 100644 (file)
index 0000000..775a881
--- /dev/null
@@ -0,0 +1,108 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+       "gob";
+       "io";
+       "os";
+       "reflect";
+       "sync";
+)
+
+import "fmt"   // TODO DELETE
+
+type Encoder struct {
+       sync.Mutex;     // each item must be sent atomically
+       sent    map[reflect.Type] uint; // which types we've already sent
+       state   *EncState;      // so we can encode integers, strings directly
+}
+
+func NewEncoder(w io.Writer) *Encoder {
+       enc := new(Encoder);
+       enc.sent = make(map[reflect.Type] uint);
+       enc.state = new(EncState);
+       enc.state.w = w;        // the rest isn't important; all we need is buffer and writer
+       return enc;
+}
+
+func (enc *Encoder) badType(rt reflect.Type) {
+       enc.state.err = os.ErrorString("can't encode type " + rt.String());
+}
+
+func (enc *Encoder) sendType(rt reflect.Type) {
+       // Drill down to the base type.
+       for {
+               pt, ok := rt.(*reflect.PtrType);
+               if !ok {
+                       break
+               }
+               rt = pt.Elem();
+       }
+
+       // We only send structs - everything else is basic or an error
+       switch t := rt.(type) {
+       case *reflect.StructType:
+               break;  // we handle these
+       case *reflect.ChanType:
+               enc.badType(rt);
+               return;
+       case *reflect.MapType:
+               enc.badType(rt);
+               return;
+       case *reflect.FuncType:
+               enc.badType(rt);
+               return;
+       case *reflect.InterfaceType:
+               enc.badType(rt);
+               return;
+       default:
+               return; // basic, array, etc; not a type to be sent.
+       }
+
+       // Have we already sent this type?
+       id, alreadySent := enc.sent[rt];
+       if alreadySent {
+               return
+       }
+
+       // Need to send it.
+       info := getTypeInfo(rt);
+       // Send the pair (-id, type)
+       // Id:
+       EncodeInt(enc.state, -int64(info.typeId));
+       // Type:
+       Encode(enc.state.w, info.wire);
+       // Remember we've sent this type.
+       enc.sent[rt] = id;
+       // Now send the inner types
+       st := rt.(*reflect.StructType);
+       for i := 0; i < st.NumField(); i++ {
+               enc.sendType(st.Field(i).Type);
+       }
+}
+
+func (enc *Encoder) Encode(e interface{}) os.Error {
+       rt, indir := indirect(reflect.Typeof(e));
+
+       // Make sure we're single-threaded through here.
+       enc.Lock();
+       defer enc.Unlock();
+
+       // Make sure the type is known to the other side.
+       enc.sendType(rt);
+       if enc.state.err != nil {
+               return enc.state.err
+       }
+
+       // Identify the type of this top-level value.
+       EncodeInt(enc.state, int64(enc.sent[rt]));
+
+       // Finally, send the data
+       Encode(enc.state.w, e);
+
+       // Release and return.
+       return enc.state.err
+}
diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go
new file mode 100644 (file)
index 0000000..71287ad
--- /dev/null
@@ -0,0 +1,37 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+       "bytes";
+       "gob";
+       "os";
+       "reflect";
+       "strings";
+       "testing";
+       "unsafe";
+)
+
+type ET2 struct {
+       x string;
+}
+
+type ET1 struct {
+       a int;
+       et2 *ET2;
+       next *ET1;
+}
+
+func TestBasicEncoder(t *testing.T) {
+       b := new(bytes.Buffer);
+       enc := NewEncoder(b);
+       et1 := new(ET1);
+       et1.a = 7;
+       et1.et2 = new(ET2);
+       enc.Encode(et1);
+       if enc.state.err != nil {
+               t.Error("encoder fail:", enc.state.err)
+       }
+}
index ed221b9b3640796426eb0a7e4670d234dbf94923..cb0ca02329d8cdcb09d697750dfe6610f4444fcf 100644 (file)
@@ -13,28 +13,51 @@ import (
        "unicode";
 )
 
-var id uint32  // incremented for each new type we build
+// Types are identified by an integer TypeId.  These can be passed on the wire.
+// Internally, they are used as keys to a map to recover the underlying type info.
+type TypeId uint32
+
+var id TypeId  // incremented for each new type we build
 var typeLock   sync.Mutex      // set while building a type
 
-type Type interface {
-       id()    uint32;
-       setId(id uint32);
+type gobType interface {
+       id()    TypeId;
+       setId(id TypeId);
        String()        string;
-       safeString(seen map[uint32] bool)       string;
+       safeString(seen map[TypeId] bool)       string;
+}
+
+var types = make(map[reflect.Type] gobType)
+var idToType = make(map[TypeId] gobType)
+
+func setTypeId(typ gobType) {
+       id++;
+       typ.setId(id);
+       idToType[id] = typ;
+}
+
+func (t TypeId) gobType() gobType {
+       if t == 0 {
+               return nil
+       }
+       return idToType[t]
+}
+
+func (t TypeId) String() string {
+       return t.gobType().String()
 }
-var types = make(map[reflect.Type] Type)
 
 // Common elements of all types.
 type commonType struct {
        name    string;
-       _id     uint32;
+       _id     TypeId;
 }
 
-func (t *commonType) id() uint32 {
+func (t *commonType) id() TypeId {
        return t._id
 }
 
-func (t *commonType) setId(id uint32) {
+func (t *commonType) setId(id TypeId) {
        t._id = id
 }
 
@@ -51,31 +74,32 @@ func (t *commonType) Name() string {
 }
 
 // Basic type identifiers, predefined.
-var tBool Type
-var tInt Type
-var tUint Type
-var tFloat Type
-var tString Type
-var tBytes Type
+var tBool TypeId
+var tInt TypeId
+var tUint TypeId
+var tFloat TypeId
+var tString TypeId
+var tBytes TypeId
 
 // Array type
 type arrayType struct {
        commonType;
-       Elem    Type;
+       Elem    TypeId;
        Len     int;
 }
 
-func newArrayType(name string, elem Type, length int) *arrayType {
-       a := &arrayType{ commonType{ name: name }, elem, length };
+func newArrayType(name string, elem gobType, length int) *arrayType {
+       a := &arrayType{ commonType{ name: name }, elem.id(), length };
+       setTypeId(a);
        return a;
 }
 
-func (a *arrayType) safeString(seen map[uint32] bool) string {
+func (a *arrayType) safeString(seen map[TypeId] bool) string {
        if _, ok := seen[a._id]; ok {
                return a.name
        }
        seen[a._id] = true;
-       return fmt.Sprintf("[%d]%s", a.Len, a.Elem.safeString(seen));
+       return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen));
 }
 
 func (a *arrayType) String() string {
@@ -85,30 +109,31 @@ func (a *arrayType) String() string {
 // Slice type
 type sliceType struct {
        commonType;
-       Elem    Type;
+       Elem    TypeId;
 }
 
-func newSliceType(name string, elem Type) *sliceType {
-       s := &sliceType{ commonType{ name: name }, elem };
+func newSliceType(name string, elem gobType) *sliceType {
+       s := &sliceType{ commonType{ name: name }, elem.id() };
+       setTypeId(s);
        return s;
 }
 
-func (s *sliceType) safeString(seen map[uint32] bool) string {
+func (s *sliceType) safeString(seen map[TypeId] bool) string {
        if _, ok := seen[s._id]; ok {
                return s.name
        }
        seen[s._id] = true;
-       return fmt.Sprintf("[]%s", s.Elem.safeString(seen));
+       return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen));
 }
 
 func (s *sliceType) String() string {
-       return s.safeString(make(map[uint32] bool))
+       return s.safeString(make(map[TypeId] bool))
 }
 
 // Struct type
 type fieldType struct {
        name    string;
-       typ     Type;
+       typeId  TypeId;
 }
 
 type structType struct {
@@ -116,30 +141,31 @@ type structType struct {
        field   []*fieldType;
 }
 
-func (s *structType) safeString(seen map[uint32] bool) string {
+func (s *structType) safeString(seen map[TypeId] bool) string {
        if _, ok := seen[s._id]; ok {
                return s.name
        }
        seen[s._id] = true;
        str := s.name + " = struct { ";
        for _, f := range s.field {
-               str += fmt.Sprintf("%s %s; ", f.name, f.typ.safeString(seen));
+               str += fmt.Sprintf("%s %s; ", f.name, f.typeId.gobType().safeString(seen));
        }
        str += "}";
        return str;
 }
 
 func (s *structType) String() string {
-       return s.safeString(make(map[uint32] bool))
+       return s.safeString(make(map[TypeId] bool))
 }
 
 func newStructType(name string) *structType {
        s := &structType{ commonType{ name: name }, nil };
+       setTypeId(s);
        return s;
 }
 
 // Construction
-func newType(name string, rt reflect.Type) Type
+func newType(name string, rt reflect.Type) gobType
 
 // Step through the indirections on a type to discover the base type.
 // Return the number of indirections.
@@ -156,35 +182,45 @@ func indirect(t reflect.Type) (rt reflect.Type, count int) {
        return;
 }
 
-func newTypeObject(name string, rt reflect.Type) Type {
+func newTypeObject(name string, rt reflect.Type) gobType {
        switch t := rt.(type) {
        // All basic types are easy: they are predefined.
        case *reflect.BoolType:
-               return tBool
+               return tBool.gobType()
 
        case *reflect.IntType:
-               return tInt
+               return tInt.gobType()
+       case *reflect.Int8Type:
+               return tInt.gobType()
+       case *reflect.Int16Type:
+               return tInt.gobType()
        case *reflect.Int32Type:
-               return tInt
+               return tInt.gobType()
        case *reflect.Int64Type:
-               return tInt
+               return tInt.gobType()
 
        case *reflect.UintType:
-               return tUint
+               return tUint.gobType()
+       case *reflect.Uint8Type:
+               return tUint.gobType()
+       case *reflect.Uint16Type:
+               return tUint.gobType()
        case *reflect.Uint32Type:
-               return tUint
+               return tUint.gobType()
        case *reflect.Uint64Type:
-               return tUint
+               return tUint.gobType()
+       case *reflect.UintptrType:
+               return tUint.gobType()
 
        case *reflect.FloatType:
-               return tFloat
+               return tFloat.gobType()
        case *reflect.Float32Type:
-               return tFloat
+               return tFloat.gobType()
        case *reflect.Float64Type:
-               return tFloat
+               return tFloat.gobType()
 
        case *reflect.StringType:
-               return tString
+               return tString.gobType()
 
        case *reflect.ArrayType:
                return newArrayType(name, newType("", t.Elem()), t.Len());
@@ -192,7 +228,7 @@ func newTypeObject(name string, rt reflect.Type) Type {
        case *reflect.SliceType:
                // []byte == []uint8 is a special case
                if _, ok := t.Elem().(*reflect.Uint8Type); ok {
-                       return tBytes
+                       return tBytes.gobType()
                }
                return newSliceType(name, newType("", t.Elem()));
 
@@ -201,6 +237,7 @@ func newTypeObject(name string, rt reflect.Type) Type {
                // structures can be constructed safely.
                strType := newStructType(name);
                types[rt] = strType;
+               idToType[strType.id()] = strType;
                field := make([]*fieldType, t.NumField());
                for i := 0; i < t.NumField(); i++ {
                        f := t.Field(i);
@@ -209,7 +246,7 @@ func newTypeObject(name string, rt reflect.Type) Type {
                        if tname == "" {
                                tname = f.Type.String();
                        }
-                       field[i] =  &fieldType{ f.Name, newType(tname, f.Type) };
+                       field[i] =  &fieldType{ f.Name, newType(tname, f.Type).id() };
                }
                strType.field = field;
                return strType;
@@ -220,7 +257,7 @@ func newTypeObject(name string, rt reflect.Type) Type {
        return nil
 }
 
-func newType(name string, rt reflect.Type) Type {
+func newType(name string, rt reflect.Type) gobType {
        // Flatten the data structure by collapsing out pointers
        for {
                pt, ok := rt.(*reflect.PtrType);
@@ -234,34 +271,72 @@ func newType(name string, rt reflect.Type) Type {
                return typ
        }
        typ = newTypeObject(name, rt);
-       id++;
-       typ.setId(id);
        types[rt] = typ;
        return typ
 }
 
-// GetType returns the Gob type describing the interface value.
-func GetType(name string, e interface{}) Type {
-       rt := reflect.Typeof(e);
+// getType returns the Gob type describing the given reflect.Type.
+// typeLock must be held.
+func getType(name string, rt reflect.Type) gobType {
        // Set lock; all code running under here is synchronized.
-       typeLock.Lock();
        t := newType(name, rt);
-       typeLock.Unlock();
        return t;
 }
 
 // used for building the basic types; called only from init()
-func bootstrapType(name string, e interface{}) Type {
+func bootstrapType(name string, e interface{}) TypeId {
        rt := reflect.Typeof(e);
        _, present := types[rt];
        if present {
                panicln("bootstrap type already present:", name);
        }
        typ := &commonType{ name: name };
-       id++;
-       typ.setId(id);
        types[rt] = typ;
-       return typ
+       setTypeId(typ);
+       return id
+}
+
+// Representation of the information we send and receive about this type.
+// Each value we send is preceded by its type definition: an encoded int.
+// However, the very first time we send the value, we first send the pair
+// (-id, wireType).
+// For bootstrapping purposes, we assume that the recipient knows how
+// to decode a wireType; it is exactly the wireType struct here, interpreted
+// using the gob rules for sending a structure, except that we assume the
+// ids for wireType and structType are known.  The relevant pieces
+// are built in encode.go's init() function.
+
+type wireType struct {
+       name    string;
+       s       *structType;
+}
+
+type decEngine struct  // defined in decode.go
+type encEngine struct  // defined in encode.go
+type typeInfo struct {
+       typeId  TypeId;
+       decoder *decEngine;
+       encoder *encEngine;
+       wire    *wireType;
+}
+
+var typeInfoMap = make(map[reflect.Type] *typeInfo)    // protected by typeLock
+
+// The reflection type must have all its indirections processed out.
+func getTypeInfo(rt reflect.Type) *typeInfo {
+       if pt, ok := rt.(*reflect.PtrType); ok {
+               panicln("pointer type in getTypeInfo:", rt.String())
+       }
+       info, ok := typeInfoMap[rt];
+       if !ok {
+               info = new(typeInfo);
+               path, name := rt.Name();
+               info.typeId = getType(name, rt).id();
+               // assume it's a struct type
+               info.wire = &wireType{name, info.typeId.gobType().(*structType)};
+               typeInfoMap[rt] = info;
+       }
+       return info;
 }
 
 func init() {
index e62bd6415a73b89d6ffbe4cd1cc71765862423b7..d190a3045e278f5d1d30d812a56d05d39e50a9fb 100644 (file)
@@ -7,11 +7,12 @@ package gob
 import (
        "gob";
        "os";
+       "reflect";
        "testing";
 )
 
 type typeT struct {
-       typ     Type;
+       typeId  TypeId;
        str     string;
 }
 var basicTypes = []typeT {
@@ -23,13 +24,19 @@ var basicTypes = []typeT {
        typeT { tString, "string" },
 }
 
+func getTypeUnlocked(name string, rt reflect.Type) gobType {
+       typeLock.Lock();
+       defer typeLock.Unlock();
+       return getType(name, rt);
+}
+
 // Sanity checks
 func TestBasic(t *testing.T) {
        for _, tt := range basicTypes {
-               if tt.typ.String() != tt.str {
-                       t.Errorf("checkType: expected %q got %s", tt.str, tt.typ.String())
+               if tt.typeId.String() != tt.str {
+                       t.Errorf("checkType: expected %q got %s", tt.str, tt.typeId.String())
                }
-               if tt.typ.id() == 0 {
+               if tt.typeId == 0 {
                        t.Errorf("id for %q is zero", tt.str)
                }
        }
@@ -37,35 +44,35 @@ func TestBasic(t *testing.T) {
 
 // Reregister some basic types to check registration is idempotent.
 func TestReregistration(t *testing.T) {
-       newtyp := GetType("int", 0);
-       if newtyp != tInt {
+       newtyp := getTypeUnlocked("int", reflect.Typeof(int(0)));
+       if newtyp != tInt.gobType() {
                t.Errorf("reregistration of %s got new type", newtyp.String())
        }
-       newtyp = GetType("uint", uint(0));
-       if newtyp != tUint {
+       newtyp = getTypeUnlocked("uint", reflect.Typeof(uint(0)));
+       if newtyp != tUint.gobType() {
                t.Errorf("reregistration of %s got new type", newtyp.String())
        }
-       newtyp = GetType("string", "hello");
-       if newtyp != tString {
+       newtyp = getTypeUnlocked("string", reflect.Typeof("hello"));
+       if newtyp != tString.gobType() {
                t.Errorf("reregistration of %s got new type", newtyp.String())
        }
 }
 
 func TestArrayType(t *testing.T) {
        var a3 [3]int;
-       a3int := GetType("foo", a3);
+       a3int := getTypeUnlocked("foo", reflect.Typeof(a3));
        var newa3 [3]int;
-       newa3int := GetType("bar", a3);
+       newa3int := getTypeUnlocked("bar", reflect.Typeof(a3));
        if a3int != newa3int {
                t.Errorf("second registration of [3]int creates new type");
        }
        var a4 [4]int;
-       a4int := GetType("goo", a4);
+       a4int := getTypeUnlocked("goo", reflect.Typeof(a4));
        if a3int == a4int {
                t.Errorf("registration of [3]int creates same type as [4]int");
        }
        var b3 [3]bool;
-       a3bool := GetType("", b3);
+       a3bool := getTypeUnlocked("", reflect.Typeof(b3));
        if a3int == a3bool {
                t.Errorf("registration of [3]bool creates same type as [3]int");
        }
@@ -78,14 +85,14 @@ func TestArrayType(t *testing.T) {
 
 func TestSliceType(t *testing.T) {
        var s []int;
-       sint := GetType("slice", s);
+       sint := getTypeUnlocked("slice", reflect.Typeof(s));
        var news []int;
-       newsint := GetType("slice1", news);
+       newsint := getTypeUnlocked("slice1", reflect.Typeof(news));
        if sint != newsint {
                t.Errorf("second registration of []int creates new type");
        }
        var b []bool;
-       sbool := GetType("", b);
+       sbool := getTypeUnlocked("", reflect.Typeof(b));
        if sbool == sint {
                t.Errorf("registration of []bool creates same type as []int");
        }
@@ -114,7 +121,7 @@ type Foo struct {
 }
 
 func TestStructType(t *testing.T) {
-       sstruct := GetType("Foo", Foo{});
+       sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{}));
        str := sstruct.String();
        // If we can print it correctly, we built it correctly.
        expected := "Foo = struct { a int; b int; c string; d bytes; e float; f float; g Bar = struct { x string; }; h Bar; i Foo; }";