]> Cypherpunks repositories - gostls13.git/commitdiff
handle unsupported types safely.
authorRob Pike <r@golang.org>
Thu, 30 Jul 2009 00:24:25 +0000 (17:24 -0700)
committerRob Pike <r@golang.org>
Thu, 30 Jul 2009 00:24:25 +0000 (17:24 -0700)
R=rsc
DELTA=154  (71 added, 6 deleted, 77 changed)
OCL=32483
CL=32492

src/pkg/gob/codec_test.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go
src/pkg/gob/type.go
src/pkg/gob/type_test.go

index 8263a9286caf4d34307594ee667bb1def91f1f94..9ad3b4727867fc48710750b1eaebae55a1ee2bc4 100644 (file)
@@ -564,7 +564,7 @@ func TestEndToEnd(t *testing.T) {
        b := new(bytes.Buffer);
        encode(b, t1);
        var _t1 T1;
-       decode(b, getTypeInfo(reflect.Typeof(_t1)).id, &_t1);
+       decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1);
        if !reflect.DeepEqual(t1, &_t1) {
                t.Errorf("encode expected %v got %v", *t1, _t1);
        }
@@ -580,7 +580,7 @@ func TestOverflow(t *testing.T) {
        }
        var it inputT;
        var err os.Error;
-       id := getTypeInfo(reflect.Typeof(it)).id;
+       id := getTypeInfoNoError(reflect.Typeof(it)).id;
        b := new(bytes.Buffer);
 
        // int8
@@ -733,7 +733,7 @@ func TestNesting(t *testing.T) {
        b := new(bytes.Buffer);
        encode(b, rt);
        var drt RT;
-       decode(b, getTypeInfo(reflect.Typeof(drt)).id, &drt);
+       decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt);
        if drt.a != rt.a {
                t.Errorf("nesting: encode expected %v got %v", *rt, drt);
        }
@@ -775,7 +775,7 @@ func TestAutoIndirection(t *testing.T) {
        b := new(bytes.Buffer);
        encode(b, t1);
        var t0 T0;
-       t0Id := getTypeInfo(reflect.Typeof(t0)).id;
+       t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id;
        decode(b, t0Id, &t0);
        if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
                t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0);
@@ -800,7 +800,7 @@ func TestAutoIndirection(t *testing.T) {
        b.Reset();
        encode(b, t0);
        t1 = T1{};
-       t1Id := getTypeInfo(reflect.Typeof(t1)).id;
+       t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id;
        decode(b, t1Id, &t1);
        if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
                t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d);
@@ -810,7 +810,7 @@ func TestAutoIndirection(t *testing.T) {
        b.Reset();
        encode(b, t0);
        t2 = T2{};
-       t2Id := getTypeInfo(reflect.Typeof(t2)).id;
+       t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id;
        decode(b, t2Id, &t2);
        if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
                t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d);
@@ -848,7 +848,7 @@ func TestReorderedFields(t *testing.T) {
        rt0.c = 3.14159;
        b := new(bytes.Buffer);
        encode(b, rt0);
-       rt0Id := getTypeInfo(reflect.Typeof(rt0)).id;
+       rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id;
        var rt1 RT1;
        // Wire type is RT0, local type is RT1.
        decode(b, rt0Id, &rt1);
@@ -886,7 +886,7 @@ func TestIgnoredFields(t *testing.T) {
 
        b := new(bytes.Buffer);
        encode(b, it0);
-       rt0Id := getTypeInfo(reflect.Typeof(it0)).id;
+       rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id;
        var rt1 RT1;
        // Wire type is IT0, local type is RT1.
        err := decode(b, rt0Id, &rt1);
@@ -897,3 +897,20 @@ func TestIgnoredFields(t *testing.T) {
                t.Errorf("rt1->rt0: expected %v; got %v", it0, rt1);
        }
 }
+
+type Bad0 struct {
+       inter interface{};
+       c float;
+}
+
+func TestInvalidField(t *testing.T) {
+       var bad0 Bad0;
+       bad0.inter = 17;
+       b := new(bytes.Buffer);
+       err := encode(b, &bad0);
+       if err == nil {
+               t.Error("expected error; got none")
+       } else if strings.Index(err.String(), "interface") < 0 {
+               t.Error("expected type error; got", err)
+       }
+}
index be3599770f3fceaac7af002eca4a73c027bd1aa8..1af79b1fdbb97e6b202c475fde93f77d90fc762c 100644 (file)
@@ -335,11 +335,11 @@ var encOpMap = map[reflect.Type] encOp {
        valueKind("x"): encString,
 }
 
-func getEncEngine(rt reflect.Type) *encEngine
+func getEncEngine(rt reflect.Type) (*encEngine, os.Error)
 
 // Return the encoding op for the base type under rt and
 // the indirection count to reach it.
-func encOpFor(rt reflect.Type) (encOp, int) {
+func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
        typ, indir := indirect(rt);
        op, ok := encOpMap[reflect.Typeof(typ)];
        if !ok {
@@ -352,7 +352,10 @@ func encOpFor(rt reflect.Type) (encOp, int) {
                                break;
                        }
                        // Slices have a header; we decode it to find the underlying array.
-                       elemOp, indir := encOpFor(t.Elem());
+                       elemOp, indir, err := encOpFor(t.Elem());
+                       if err != nil {
+                               return nil, 0, err
+                       }
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                slice := (*reflect.SliceHeader)(p);
                                if slice.Len == 0 {
@@ -363,15 +366,21 @@ func encOpFor(rt reflect.Type) (encOp, int) {
                        };
                case *reflect.ArrayType:
                        // True arrays have size in the type.
-                       elemOp, indir := encOpFor(t.Elem());
+                       elemOp, indir, err := encOpFor(t.Elem());
+                       if err != nil {
+                               return nil, 0, err
+                       }
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                state.update(i);
                                state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir);
                        };
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
-                       engine := getEncEngine(typ);
-                       info := getTypeInfo(typ);
+                       engine, err := getEncEngine(typ);
+                       if err != nil {
+                               return nil, 0, err
+                       }
+                       info := getTypeInfoNoError(typ);
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                state.update(i);
                                // indirect through info to delay evaluation for recursive structs
@@ -380,13 +389,13 @@ func encOpFor(rt reflect.Type) (encOp, int) {
                }
        }
        if op == nil {
-               panicln("can't happen: encode type", rt.String());
+               return op, indir, os.ErrorString("gob enc: can't happen: encode type" + rt.String());
        }
-       return op, indir
+       return op, indir, nil
 }
 
 // The local Type was compiled from the actual value, so we know it's compatible.
-func compileEnc(rt reflect.Type) *encEngine {
+func compileEnc(rt reflect.Type) (*encEngine, os.Error) {
        srt, ok := rt.(*reflect.StructType);
        if !ok {
                panicln("can't happen: non-struct");
@@ -395,23 +404,29 @@ func compileEnc(rt reflect.Type) *encEngine {
        engine.instr = make([]encInstr, srt.NumField()+1);      // +1 for terminator
        for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ {
                f := srt.Field(fieldnum);
-               op, indir := encOpFor(f.Type);
+               op, indir, err := encOpFor(f.Type);
+               if err != nil {
+                       return nil, err
+               }
                engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)};
        }
        engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0};
-       return engine;
+       return engine, nil;
 }
 
 // typeLock must be held (or we're in initialization and guaranteed single-threaded).
 // The reflection type must have all its indirections processed out.
-func getEncEngine(rt reflect.Type) *encEngine {
-       info := getTypeInfo(rt);
+func getEncEngine(rt reflect.Type) (*encEngine, os.Error) {
+       info, err := getTypeInfo(rt);
+       if err != nil {
+               return nil, err
+       }
        if info.encoder == nil {
                // mark this engine as underway before compiling to handle recursive types.
                info.encoder = new(encEngine);
-               info.encoder = compileEnc(rt);
+               info.encoder, err = compileEnc(rt);
        }
-       return info.encoder;
+       return info.encoder, err;
 }
 
 func encode(b *bytes.Buffer, e interface{}) os.Error {
@@ -425,7 +440,10 @@ func encode(b *bytes.Buffer, e interface{}) os.Error {
                return os.ErrorString("gob: encode can't handle " + v.Type().String())
        }
        typeLock.Lock();
-       engine := getEncEngine(rt);
+       engine, err := getEncEngine(rt);
        typeLock.Unlock();
+       if err != nil {
+               return err
+       }
        return encodeStruct(engine, b, v.Addr());
 }
index 3d8f3928cbd85d29c07b3c6ebbd71efba661bc6e..e3ee8b3cb12cf8013230d13b24ab402a389ac543 100644 (file)
@@ -73,7 +73,7 @@
 
        Maps are not supported yet, but they will be.  Interfaces, functions, and channels
        cannot be sent in a gob.  Attempting to encode a value that contains one will
-       fail.  (TODO(r): fix this - it panics now.)
+       fail.
 
        The rest of this comment documents the encoding, details that are not important
        for most users.  Details are presented bottom-up.
@@ -267,8 +267,12 @@ func (enc *Encoder) sendType(origt reflect.Type) {
 
        // Need to send it.
        typeLock.Lock();
-       info := getTypeInfo(rt);
+       info, err := getTypeInfo(rt);
        typeLock.Unlock();
+       if err != nil {
+               enc.state.err = err;
+               return;
+       }
        // Send the pair (-id, type)
        // Id:
        encodeInt(enc.state, -int64(info.id));
@@ -285,6 +289,7 @@ func (enc *Encoder) sendType(origt reflect.Type) {
        for i := 0; i < st.NumField(); i++ {
                enc.sendType(st.Field(i).Type);
        }
+       return
 }
 
 // Encode transmits the data item represented by the empty interface value,
index e1c64e4dd70e034f08f761808ffe2037baff4667..6033fbb3fef6a77b5d20f66b84974b8f2a4c6659 100644 (file)
@@ -69,7 +69,7 @@ func TestBasicEncoder(t *testing.T) {
        if err != nil {
                t.Fatal("error decoding ET1 type:", err);
        }
-       info := getTypeInfo(reflect.Typeof(ET1{}));
+       info := getTypeInfoNoError(reflect.Typeof(ET1{}));
        trueWire1 := &wireType{s: info.id.gobType().(*structType)};
        if !reflect.DeepEqual(wire1, trueWire1) {
                t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1);
@@ -90,7 +90,7 @@ func TestBasicEncoder(t *testing.T) {
        if err != nil {
                t.Fatal("error decoding ET2 type:", err);
        }
-       info = getTypeInfo(reflect.Typeof(ET2{}));
+       info = getTypeInfoNoError(reflect.Typeof(ET2{}));
        trueWire2 := &wireType{s: info.id.gobType().(*structType)};
        if !reflect.DeepEqual(wire2, trueWire2) {
                t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2);
@@ -107,7 +107,7 @@ func TestBasicEncoder(t *testing.T) {
        }
        // 8) The value of et1
        newEt1 := new(ET1);
-       et1Id := getTypeInfo(reflect.Typeof(*newEt1)).id;
+       et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id;
        err = decode(b, et1Id, newEt1);
        if err != nil {
                t.Fatal("error decoding ET1 value:", err);
index f54746c323682ae81b1f85bdf4e1b1b125a4a572..585d4f7472dd227105ca28a2b7691f8aa3dd1f92 100644 (file)
@@ -202,7 +202,7 @@ func newStructType(name string) *structType {
 }
 
 // Construction
-func newType(name string, rt reflect.Type) gobType
+func getType(name string, rt reflect.Type) (gobType, os.Error)
 
 // Step through the indirections on a type to discover the base type.
 // Return the number of indirections.
@@ -219,55 +219,63 @@ func indirect(t reflect.Type) (rt reflect.Type, count int) {
        return;
 }
 
-func newTypeObject(name string, rt reflect.Type) gobType {
+func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
        switch t := rt.(type) {
        // All basic types are easy: they are predefined.
        case *reflect.BoolType:
-               return tBool.gobType()
+               return tBool.gobType(), nil
 
        case *reflect.IntType:
-               return tInt.gobType()
+               return tInt.gobType(), nil
        case *reflect.Int8Type:
-               return tInt.gobType()
+               return tInt.gobType(), nil
        case *reflect.Int16Type:
-               return tInt.gobType()
+               return tInt.gobType(), nil
        case *reflect.Int32Type:
-               return tInt.gobType()
+               return tInt.gobType(), nil
        case *reflect.Int64Type:
-               return tInt.gobType()
+               return tInt.gobType(), nil
 
        case *reflect.UintType:
-               return tUint.gobType()
+               return tUint.gobType(), nil
        case *reflect.Uint8Type:
-               return tUint.gobType()
+               return tUint.gobType(), nil
        case *reflect.Uint16Type:
-               return tUint.gobType()
+               return tUint.gobType(), nil
        case *reflect.Uint32Type:
-               return tUint.gobType()
+               return tUint.gobType(), nil
        case *reflect.Uint64Type:
-               return tUint.gobType()
+               return tUint.gobType(), nil
        case *reflect.UintptrType:
-               return tUint.gobType()
+               return tUint.gobType(), nil
 
        case *reflect.FloatType:
-               return tFloat.gobType()
+               return tFloat.gobType(), nil
        case *reflect.Float32Type:
-               return tFloat.gobType()
+               return tFloat.gobType(), nil
        case *reflect.Float64Type:
-               return tFloat.gobType()
+               return tFloat.gobType(), nil
 
        case *reflect.StringType:
-               return tString.gobType()
+               return tString.gobType(), nil
 
        case *reflect.ArrayType:
-               return newArrayType(name, newType("", t.Elem()), t.Len());
+               gt, err := getType("", t.Elem());
+               if err != nil {
+                       return nil, err
+               }
+               return newArrayType(name, gt, t.Len()), nil;
 
        case *reflect.SliceType:
                // []byte == []uint8 is a special case
                if _, ok := t.Elem().(*reflect.Uint8Type); ok {
-                       return tBytes.gobType()
+                       return tBytes.gobType(), nil
+               }
+               gt, err := getType(t.Elem().Name(), t.Elem());
+               if err != nil {
+                       return nil, err
                }
-               return newSliceType(name, newType(t.Elem().Name(), t.Elem()));
+               return newSliceType(name, gt), nil;
 
        case *reflect.StructType:
                // Install the struct type itself before the fields so recursive
@@ -283,18 +291,24 @@ func newTypeObject(name string, rt reflect.Type) gobType {
                        if tname == "" {
                                tname = f.Type.String();
                        }
-                       field[i] =  &fieldType{ f.Name, newType(tname, f.Type).id() };
+                       gt, err := getType(tname, f.Type);
+                       if err != nil {
+                               return nil, err
+                       }
+                       field[i] =  &fieldType{ f.Name, gt.id() };
                }
                strType.field = field;
-               return strType;
+               return strType, nil;
 
        default:
-               panicln("gob NewTypeObject can't handle type", rt.String());    // TODO(r): panic?
+               return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String());
        }
-       return nil
+       return nil, nil
 }
 
-func newType(name string, rt reflect.Type) gobType {
+// getType returns the Gob type describing the given reflect.Type.
+// typeLock must be held.
+func getType(name string, rt reflect.Type) (gobType, os.Error) {
        // Flatten the data structure by collapsing out pointers
        for {
                pt, ok := rt.(*reflect.PtrType);
@@ -305,19 +319,13 @@ func newType(name string, rt reflect.Type) gobType {
        }
        typ, present := types[rt];
        if present {
-               return typ
+               return typ, nil
        }
-       typ = newTypeObject(name, rt);
-       types[rt] = typ;
-       return typ
-}
-
-// getType returns the Gob type describing the given reflect.Type.
-// typeLock must be held.
-func getType(name string, rt reflect.Type) gobType {
-       // Set lock; all code running under here is synchronized.
-       t := newType(name, rt);
-       return t;
+       typ, err := newTypeObject(name, rt);
+       if err == nil {
+               types[rt] = typ
+       }
+       return typ, err
 }
 
 func checkId(want, got typeId) {
@@ -371,7 +379,7 @@ var typeInfoMap = make(map[reflect.Type] *typeInfo) // protected by typeLock
 
 // The reflection type must have all its indirections processed out.
 // typeLock must be held.
-func getTypeInfo(rt reflect.Type) *typeInfo {
+func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
        if pt, ok := rt.(*reflect.PtrType); ok {
                panicln("pointer type in getTypeInfo:", rt.String())
        }
@@ -379,12 +387,25 @@ func getTypeInfo(rt reflect.Type) *typeInfo {
        if !ok {
                info = new(typeInfo);
                name := rt.Name();
-               info.id = getType(name, rt).id();
+               gt, err := getType(name, rt);
+               if err != nil {
+                       return nil, err
+               }
+               info.id = gt.id();
                // assume it's a struct type
                info.wire = &wireType{info.id.gobType().(*structType)};
                typeInfoMap[rt] = info;
        }
-       return info;
+       return info, nil;
+}
+
+// Called only when a panic is acceptable and unexpected.
+func getTypeInfoNoError(rt reflect.Type) *typeInfo {
+       t, err := getTypeInfo(rt);
+       if err != nil {
+               panicln("getTypeInfo:", err.String());
+       }
+       return t
 }
 
 func init() {
@@ -396,9 +417,9 @@ func init() {
        // The string for tBytes is "bytes" not "[]byte" to signify its specialness.
        tBytes = bootstrapType("bytes", make([]byte, 0), 5);
        tString= bootstrapType("string", "", 6);
-       tWireType = getTypeInfo(reflect.Typeof(wireType{})).id;
+       tWireType = getTypeInfoNoError(reflect.Typeof(wireType{})).id;
        checkId(7, tWireType);
-       checkId(8, getTypeInfo(reflect.Typeof(structType{})).id);
-       checkId(9, getTypeInfo(reflect.Typeof(commonType{})).id);
-       checkId(10, getTypeInfo(reflect.Typeof(fieldType{})).id);
+       checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id);
+       checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id);
+       checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id);
 }
index 2f11ba3fea3518b0d8b48525d0fa7aab412a8400..b86fffa21aab24e8c24cfe8e6d2e33094091517c 100644 (file)
@@ -27,7 +27,11 @@ var basicTypes = []typeT {
 func getTypeUnlocked(name string, rt reflect.Type) gobType {
        typeLock.Lock();
        defer typeLock.Unlock();
-       return getType(name, rt);
+       t, err := getType(name, rt);
+       if err != nil {
+               panicln("getTypeUnlocked:", err.String())
+       }
+       return t;
 }
 
 // Sanity checks