]> Cypherpunks repositories - gostls13.git/commitdiff
a better encoder test, with a couple of fixes for bugs it uncovered.
authorRob Pike <r@golang.org>
Fri, 10 Jul 2009 20:44:37 +0000 (13:44 -0700)
committerRob Pike <r@golang.org>
Fri, 10 Jul 2009 20:44:37 +0000 (13:44 -0700)
R=rsc
DELTA=84  (65 added, 9 deleted, 10 changed)
OCL=31458
CL=31458

src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go
src/pkg/gob/type.go

index 775a881aa6eca9bd4bf08610cf815e4edd717f27..30ec819c77f9f5dafcf1229e1f0badd11554e85b 100644 (file)
@@ -12,17 +12,15 @@ import (
        "sync";
 )
 
-import "fmt"   // TODO DELETE
-
 type Encoder struct {
        sync.Mutex;     // each item must be sent atomically
-       sent    map[reflect.Type] uint; // which types we've already sent
+       sent    map[reflect.Type] TypeId;       // which types we've already sent
        state   *EncState;      // so we can encode integers, strings directly
 }
 
 func NewEncoder(w io.Writer) *Encoder {
        enc := new(Encoder);
-       enc.sent = make(map[reflect.Type] uint);
+       enc.sent = make(map[reflect.Type] TypeId);
        enc.state = new(EncState);
        enc.state.w = w;        // the rest isn't important; all we need is buffer and writer
        return enc;
@@ -32,15 +30,9 @@ func (enc *Encoder) badType(rt reflect.Type) {
        enc.state.err = os.ErrorString("can't encode type " + rt.String());
 }
 
-func (enc *Encoder) sendType(rt reflect.Type) {
+func (enc *Encoder) sendType(origt reflect.Type) {
        // Drill down to the base type.
-       for {
-               pt, ok := rt.(*reflect.PtrType);
-               if !ok {
-                       break
-               }
-               rt = pt.Elem();
-       }
+       rt, indir_ := indirect(origt);
 
        // We only send structs - everything else is basic or an error
        switch t := rt.(type) {
@@ -62,9 +54,8 @@ func (enc *Encoder) sendType(rt reflect.Type) {
                return; // basic, array, etc; not a type to be sent.
        }
 
-       // Have we already sent this type?
-       id, alreadySent := enc.sent[rt];
-       if alreadySent {
+       // Have we already sent this type?  This time we ask about the base type.
+       if id_, alreadySent := enc.sent[rt]; alreadySent {
                return
        }
 
@@ -76,7 +67,9 @@ func (enc *Encoder) sendType(rt reflect.Type) {
        // Type:
        Encode(enc.state.w, info.wire);
        // Remember we've sent this type.
-       enc.sent[rt] = id;
+       enc.sent[rt] = info.typeId;
+       // Remember we've sent the top-level, possibly indirect type too.
+       enc.sent[origt] = info.typeId;
        // Now send the inner types
        st := rt.(*reflect.StructType);
        for i := 0; i < st.NumField(); i++ {
@@ -92,9 +85,13 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
        defer enc.Unlock();
 
        // Make sure the type is known to the other side.
-       enc.sendType(rt);
-       if enc.state.err != nil {
-               return enc.state.err
+       // First, have we already sent this type?
+       if id_, alreadySent := enc.sent[rt]; !alreadySent {
+               // No, so send it.
+               enc.sendType(rt);
+               if enc.state.err != nil {
+                       return enc.state.err
+               }
        }
 
        // Identify the type of this top-level value.
index 71287ad15aefc586e927bb8a189ef7529199fc6c..c762a1876353383bd2a22a394cf92c7765861746 100644 (file)
@@ -6,6 +6,7 @@ package gob
 
 import (
        "bytes";
+"fmt";         // DELETE
        "gob";
        "os";
        "reflect";
@@ -34,4 +35,59 @@ func TestBasicEncoder(t *testing.T) {
        if enc.state.err != nil {
                t.Error("encoder fail:", enc.state.err)
        }
+
+       // Decode the result by hand to verify;
+       state := new(DecState);
+       state.r = b;
+       // The output should be:
+       // 1) -7: the type id of ET1
+       id1 := DecodeInt(state);
+       if id1 >= 0 {
+               t.Fatal("expected ET1 negative id; got", id1);
+       }
+       // 2) The wireType for ET1
+       wire1 := new(wireType);
+       err := Decode(b, wire1);
+       if err != nil {
+               t.Fatal("error decoding ET1 type:", err);
+       }
+       info := getTypeInfo(reflect.Typeof(ET1{}));
+       trueWire1 := &wireType{name:"ET1", s: info.typeId.gobType().(*structType)};
+       if !reflect.DeepEqual(wire1, trueWire1) {
+               t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1);
+       }
+       // 3) -8: the type id of ET2
+       id2 := DecodeInt(state);
+       if id2 >= 0 {
+               t.Fatal("expected ET2 negative id; got", id2);
+       }
+       // 4) The wireType for ET2
+       wire2 := new(wireType);
+       err = Decode(b, wire2);
+       if err != nil {
+               t.Fatal("error decoding ET2 type:", err);
+       }
+       info = getTypeInfo(reflect.Typeof(ET2{}));
+       trueWire2 := &wireType{name:"ET2", s: info.typeId.gobType().(*structType)};
+       if !reflect.DeepEqual(wire2, trueWire2) {
+               t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2);
+       }
+       // 5) The type id for the et1 value
+       newId1 := DecodeInt(state);
+       if newId1 != -id1 {
+               t.Fatal("expected Et1 id", -id1, "got", newId1);
+       }
+       // 6) The value of et1
+       newEt1 := new(ET1);
+       err = Decode(b, newEt1);
+       if err != nil {
+               t.Fatal("error decoding ET1 value:", err);
+       }
+       if !reflect.DeepEqual(et1, newEt1) {
+               t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1);
+       }
+       // 7) EOF
+       if b.Len() != 0 {
+               t.Error("not at eof;", b.Len(), "bytes left")
+       }
 }
index cb0ca02329d8cdcb09d697750dfe6610f4444fcf..66636a4d44b2a58a8980dd3b82400043795acd5d 100644 (file)
@@ -142,6 +142,9 @@ type structType struct {
 }
 
 func (s *structType) safeString(seen map[TypeId] bool) string {
+       if s == nil {
+               return "<nil>"
+       }
        if _, ok := seen[s._id]; ok {
                return s.name
        }