]> Cypherpunks repositories - gostls13.git/commitdiff
fix bugs in gob.
authorRob Pike <r@golang.org>
Mon, 12 Oct 2009 00:37:22 +0000 (17:37 -0700)
committerRob Pike <r@golang.org>
Mon, 12 Oct 2009 00:37:22 +0000 (17:37 -0700)
1) didn't handle attempts to encode non-structs properly.
2) if there were multiple indirections involving allocation, didn't allocate the
intermediate cells.
tests added.

R=rsc
DELTA=82  (65 added, 5 deleted, 12 changed)
OCL=35582
CL=35582

src/pkg/gob/decode.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go

index dddc1ec0540d18aa9bf4abe6449769b74283ed82..415b4b6779d16d494c6bc0c5d6cd2633554418fd 100644 (file)
@@ -355,13 +355,18 @@ type decEngine struct {
 }
 
 func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error {
-       if indir > 0 {
+       for ; indir > 0; indir-- {
                up := unsafe.Pointer(p);
                if *(*unsafe.Pointer)(up) == nil {
-                       // Allocate the structure by making a slice of bytes and recording the
+                       // Allocate object by making a slice of bytes and recording the
                        // address of the beginning of the array. TODO(rsc).
-                       b := make([]byte, rtyp.Size());
-                       *(*unsafe.Pointer)(up) = unsafe.Pointer(&b[0]);
+                       if indir > 1 {  // allocate a pointer
+                               b := make([]byte, unsafe.Sizeof((*int)(nil)));
+                               *(*unsafe.Pointer)(up) = unsafe.Pointer(&b[0]);
+                       } else {        // allocate a struct
+                               b := make([]byte, rtyp.Size());
+                               *(*unsafe.Pointer)(up) = unsafe.Pointer(&b[0]);
+                       }
                }
                p = *(*uintptr)(up);
        }
@@ -753,15 +758,10 @@ func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
 }
 
 func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
-       // Dereference down to the underlying object.
+       // Dereference down to the underlying struct type.
        rt, indir := indirect(reflect.Typeof(e));
-       v := reflect.NewValue(e);
-       for i := 0; i < indir; i++ {
-               v = reflect.Indirect(v);
-       }
-       var st *reflect.StructValue;
-       var ok bool;
-       if st, ok = v.(*reflect.StructValue); !ok {
+       st, ok := rt.(*reflect.StructType);
+       if !ok {
                return os.ErrorString("gob: decode can't handle " + rt.String());
        }
        typeLock.Lock();
@@ -779,7 +779,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
                name := rt.Name();
                return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name);
        }
-       return decodeStruct(engine, rt.(*reflect.StructType), b, uintptr(v.Addr()), 0);
+       return decodeStruct(engine, st, b, uintptr(reflect.NewValue(e).Addr()), indir);
 }
 
 func init() {
index 207785982a3f35f360523a820476dc473965264c..c567b507e47e2345eafea524b1a821728c45ce1b 100644 (file)
@@ -235,14 +235,18 @@ func (enc *Encoder) send() {
        enc.w.Write(enc.buf[0:total]);
 }
 
-func (enc *Encoder) sendType(origt reflect.Type) {
+func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
        // Drill down to the base type.
        rt, _ := indirect(origt);
 
        // We only send structs - everything else is basic or an error
        switch rt.(type) {
        default:
-               // Basic types do not need to be described.
+               // Basic types do not need to be described, but if this is a top-level
+               // type, it's a user error, at least for now.
+               if topLevel {
+                       enc.badType(rt);
+               }
                return;
        case *reflect.StructType:
                // Structs do need to be described.
@@ -285,7 +289,7 @@ func (enc *Encoder) sendType(origt reflect.Type) {
        // Now send the inner types
        st := rt.(*reflect.StructType);
        for i := 0; i < st.NumField(); i++ {
-               enc.sendType(st.Field(i).Type);
+               enc.sendType(st.Field(i).Type, false);
        }
        return;
 }
@@ -306,7 +310,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
        // First, have we already sent this type?
        if _, alreadySent := enc.sent[rt]; !alreadySent {
                // No, so send it.
-               enc.sendType(rt);
+               enc.sendType(rt, true);
                if enc.state.err != nil {
                        enc.state.b.Reset();
                        enc.countState.b.Reset();
index e34d961bae622e19e0aab1a41c63378ea411e7ba..9efd00a602fb9daf0e7f4528c72a2c9869e653c5 100644 (file)
@@ -245,6 +245,9 @@ func TestBadData(t *testing.T) {
 // Types not supported by the Encoder (only structs work at the top level).
 // Basic types work implicitly.
 var unsupportedValues = []interface{} {
+       3,
+       "hi",
+       7.2,
        []int{ 1, 2, 3 },
        [3]int{ 1, 2, 3 },
        make(chan int),
@@ -263,3 +266,56 @@ func TestUnsupported(t *testing.T) {
                }
        }
 }
+
+func encAndDec(in, out interface{}) os.Error {
+       b := new(bytes.Buffer);
+       enc := NewEncoder(b);
+       enc.Encode(in);
+       if enc.state.err != nil {
+               return enc.state.err
+       }
+       dec := NewDecoder(b);
+       dec.Decode(out);
+       if dec.state.err != nil {
+               return dec.state.err
+       }
+       return nil;
+}
+
+func TestTypeToPtrType(t *testing.T) {
+       // Encode a T, decode a *T
+       type Type0 struct { a int }
+       t0 := Type0{7};
+       t0p := (*Type0)(nil);
+       if err := encAndDec(t0, t0p); err != nil {
+               t.Error(err)
+       }
+}
+
+func TestPtrTypeToType(t *testing.T) {
+       // Encode a *T, decode a T
+       type Type1 struct { a uint }
+       t1p := &Type1{17};
+       var t1 Type1;
+       if err := encAndDec(t1, t1p); err != nil {
+               t.Error(err)
+       }
+}
+
+func TestTypeToPtrPtrPtrPtrType(t *testing.T) {
+       // Encode a *T, decode a T
+       type Type2 struct { a ****float }
+       t2 := Type2{};
+       t2.a = new(***float);
+       *t2.a = new(**float);
+       **t2.a = new(*float);
+       ***t2.a = new(float);
+       ****t2.a = 27.4;
+       t2pppp := new(***Type2);
+       if err := encAndDec(t2, t2pppp); err != nil {
+               t.Error(err)
+       }
+       if ****(****t2pppp).a != ****t2.a {
+               t.Errorf("wrong value after decode: %g not %g", ****(****t2pppp).a, ****t2.a);
+       }
+}