]> Cypherpunks repositories - gostls13.git/commitdiff
encode and decode for nested structures.
authorRob Pike <r@golang.org>
Thu, 2 Jul 2009 20:43:47 +0000 (13:43 -0700)
committerRob Pike <r@golang.org>
Thu, 2 Jul 2009 20:43:47 +0000 (13:43 -0700)
fix a bug in delta encoding: only update the delta-base if something is marshaled.

R=rsc
DELTA=154  (94 added, 56 deleted, 4 changed)
OCL=31069
CL=31071

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/encode.go

index 4b5169eb07c284a4c998e49c100aca0232cc0732..339774cd0178d25cf43e1380ea8269293d4d221b 100644 (file)
@@ -569,10 +569,14 @@ func TestScalarDecInstructions(t *testing.T) {
 
 
 func TestEncode(t *testing.T) {
+       type T2 struct {
+               t string
+       }
        type T1 struct {
                a, b,c int;
                s string;
                y []byte;
+               t *T2;
        }
        t1 := &T1{
                a: 17,
@@ -580,6 +584,7 @@ func TestEncode(t *testing.T) {
                c: -5,
                s: "Now is the time",
                y: strings.Bytes("hello, sailor"),
+               t: &T2{"this is T2"},
        };
        b := new(bytes.Buffer);
        Encode(b, t1);
index 79440b2401e270b4c595170cc226f62102f66ec6..7a4918a2f46329d4596e635004b88dff58077d3d 100644 (file)
@@ -279,6 +279,31 @@ type decEngine struct {
        instr   []decInstr
 }
 
+func (engine *decEngine) decodeStruct(r io.Reader, p uintptr) os.Error {
+       state := new(DecState);
+       state.r = r;
+       state.base = p;
+       state.fieldnum = -1;
+       for state.err == nil {
+               delta := int(DecodeUint(state));
+               if state.err != nil || delta == 0 {     // struct terminator is zero delta fieldnum
+                       break
+               }
+               fieldnum := state.fieldnum + delta;
+               if fieldnum >= len(engine.instr) {
+                       panicln("TODO(r): need to handle unknown data");
+               }
+               instr := &engine.instr[fieldnum];
+               p := unsafe.Pointer(state.base+instr.offset);
+               if instr.indir > 1 {
+                       p = decIndirect(p, instr.indir);
+               }
+               instr.op(instr, state, p);
+               state.fieldnum = fieldnum;
+       }
+       return state.err
+}
+
 var decEngineMap = make(map[reflect.Type] *decEngine)
 var decOpMap = map[int] decOp {
         reflect.BoolKind: decBool,
@@ -298,6 +323,8 @@ var decOpMap = map[int] decOp {
         reflect.StringKind: decString,
 }
 
+func getDecEngine(rt reflect.Type) *decEngine
+
 func decOpFor(typ reflect.Type) decOp {
        op, ok := decOpMap[typ.Kind()];
        if !ok {
@@ -309,6 +336,13 @@ func decOpFor(typ reflect.Type) decOp {
                                op = decUint8Array
                        }
                }
+               if typ.Kind() == reflect.StructKind {
+                       // Generate a closure that calls out to the engine for the nested type.
+                       engine := getDecEngine(typ);
+                       op = func(i *decInstr, state *DecState, p unsafe.Pointer) {
+                               state.err = engine.decodeStruct(state.r, uintptr(p))
+                       };
+               }
        }
        if op == nil {
                panicln("decode can't handle type", typ.String());
@@ -355,35 +389,6 @@ func getDecEngine(rt reflect.Type) *decEngine {
        return engine;
 }
 
-func (engine *decEngine) decode(r io.Reader, v reflect.Value) os.Error {
-       sv, ok := v.(reflect.StructValue);
-       if !ok {
-               panicln("decoder can't handle non-struct values yet");
-       }
-       state := new(DecState);
-       state.r = r;
-       state.base = uintptr(sv.Addr());
-       state.fieldnum = -1;
-       for state.err == nil {
-               delta := int(DecodeUint(state));
-               if state.err != nil || delta == 0 {     // struct terminator is zero delta fieldnum
-                       break
-               }
-               fieldnum := state.fieldnum + delta;
-               if fieldnum >= len(engine.instr) {
-                       panicln("TODO(r): need to handle unknown data");
-               }
-               instr := &engine.instr[fieldnum];
-               p := unsafe.Pointer(state.base+instr.offset);
-               if instr.indir > 1 {
-                       p = decIndirect(p, instr.indir);
-               }
-               instr.op(instr, state, p);
-               state.fieldnum = fieldnum;
-       }
-       return state.err
-}
-
 func Decode(r io.Reader, e interface{}) os.Error {
        // Dereference down to the underlying object.
        rt := reflect.Typeof(e);
@@ -396,8 +401,11 @@ func Decode(r io.Reader, e interface{}) os.Error {
                rt = pt.Sub();
                v = reflect.Indirect(v);
        }
+       if v.Kind() != reflect.StructKind {
+               return os.ErrorString("decode can't handle " + v.Type().String())
+       }
        typeLock.Lock();
        engine := getDecEngine(rt);
        typeLock.Unlock();
-       return engine.decode(r, v);
+       return engine.decodeStruct(r, uintptr(v.Addr()));
 }
index e046d6c83d693244baefaa9782b3855b46d3bf6e..a2ff8cbab23d5045e99a341b645067cf9ed008b2 100644 (file)
@@ -92,6 +92,7 @@ func encBool(i *encInstr, state *EncState, p unsafe.Pointer) {
        if b {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, 1);
+               state.fieldnum = i.field;
        }
 }
 
@@ -100,6 +101,7 @@ func encInt(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeInt(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -108,6 +110,7 @@ func encUint(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -116,6 +119,7 @@ func encInt8(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeInt(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -124,6 +128,7 @@ func encUint8(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -132,6 +137,7 @@ func encInt16(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeInt(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -140,6 +146,7 @@ func encUint16(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -148,6 +155,7 @@ func encInt32(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeInt(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -156,6 +164,7 @@ func encUint32(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -164,6 +173,7 @@ func encInt64(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeInt(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -172,6 +182,7 @@ func encUint64(i *encInstr, state *EncState, p unsafe.Pointer) {
        if v != 0 {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -197,6 +208,7 @@ func encFloat(i *encInstr, state *EncState, p unsafe.Pointer) {
                v := floatBits(float64(f));
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -206,6 +218,7 @@ func encFloat32(i *encInstr, state *EncState, p unsafe.Pointer) {
                v := floatBits(float64(f));
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -215,6 +228,7 @@ func encFloat64(i *encInstr, state *EncState, p unsafe.Pointer) {
                v := floatBits(f);
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, v);
+               state.fieldnum = i.field;
        }
 }
 
@@ -225,6 +239,7 @@ func encUint8Array(i *encInstr, state *EncState, p unsafe.Pointer) {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, uint64(len(b)));
                state.w.Write(b);
+               state.fieldnum = i.field;
        }
 }
 
@@ -235,6 +250,7 @@ func encString(i *encInstr, state *EncState, p unsafe.Pointer) {
                EncodeUint(state, uint64(i.field - state.fieldnum));
                EncodeUint(state, uint64(len(s)));
                io.WriteString(state.w, s);
+               state.fieldnum = i.field;
        }
 }
 
@@ -251,6 +267,28 @@ type encEngine struct {
        instr   []encInstr
 }
 
+func (engine *encEngine) encodeStruct(w io.Writer, p uintptr) os.Error {
+       state := new(EncState);
+       state.w = w;
+       state.base = p;
+       state.fieldnum = -1;
+       for i := 0; i < len(engine.instr); i++ {
+               instr := &engine.instr[i];
+               p := unsafe.Pointer(state.base+instr.offset);
+               if instr.indir > 0 {
+                       if p = encIndirect(p, instr.indir); p == nil {
+                               state.fieldnum = i;
+                               continue
+                       }
+               }
+               instr.op(instr, state, p);
+               if state.err != nil {
+                       break
+               }
+       }
+       return state.err
+}
+
 var encEngineMap = make(map[reflect.Type] *encEngine)
 var encOpMap = map[int] encOp {
         reflect.BoolKind: encBool,
@@ -270,6 +308,8 @@ var encOpMap = map[int] encOp {
         reflect.StringKind: encString,
 }
 
+func getEncEngine(rt reflect.Type) *encEngine
+
 func encOpFor(typ reflect.Type) encOp {
        op, ok := encOpMap[typ.Kind()];
        if !ok {
@@ -281,6 +321,15 @@ func encOpFor(typ reflect.Type) encOp {
                                op = encUint8Array
                        }
                }
+               if typ.Kind() == reflect.StructKind {
+                       // Generate a closure that calls out to the engine for the nested type.
+                       engine := getEncEngine(typ);
+                       op = func(i *encInstr, state *EncState, p unsafe.Pointer) {
+                               EncodeUint(state, uint64(i.field - state.fieldnum));
+                               state.err = engine.encodeStruct(state.w, uintptr(p));
+                               state.fieldnum = i.field;
+                       };
+               }
        }
        if op == nil {
                panicln("encode can't handle type", typ.String());
@@ -327,33 +376,6 @@ func getEncEngine(rt reflect.Type) *encEngine {
        return engine
 }
 
-func (engine *encEngine) encode(w io.Writer, v reflect.Value) os.Error {
-       sv, ok := v.(reflect.StructValue);
-       if !ok {
-               panicln("encoder can't handle non-struct values yet");
-       }
-       state := new(EncState);
-       state.w = w;
-       state.base = uintptr(sv.Addr());
-       state.fieldnum = -1;
-       for i := 0; i < len(engine.instr); i++ {
-               instr := &engine.instr[i];
-               p := unsafe.Pointer(state.base+instr.offset);
-               if instr.indir > 0 {
-                       if p = encIndirect(p, instr.indir); p == nil {
-                               state.fieldnum = i;
-                               continue
-                       }
-               }
-               instr.op(instr, state, p);
-               if state.err != nil {
-                       break
-               }
-               state.fieldnum = i;
-       }
-       return state.err
-}
-
 func Encode(w io.Writer, e interface{}) os.Error {
        // Dereference down to the underlying object.
        rt := reflect.Typeof(e);
@@ -366,8 +388,11 @@ func Encode(w io.Writer, e interface{}) os.Error {
                rt = pt.Sub();
                v = reflect.Indirect(v);
        }
+       if v.Kind() != reflect.StructKind {
+               return os.ErrorString("decode can't handle " + v.Type().String())
+       }
        typeLock.Lock();
        engine := getEncEngine(rt);
        typeLock.Unlock();
-       return engine.encode(w, v);
+       return engine.encodeStruct(w, uintptr(v.(reflect.StructValue).Addr()));
 }