]> Cypherpunks repositories - gostls13.git/commitdiff
gob: protect against pure recursive types.
authorRob Pike <r@golang.org>
Wed, 23 Feb 2011 17:49:35 +0000 (09:49 -0800)
committerRob Pike <r@golang.org>
Wed, 23 Feb 2011 17:49:35 +0000 (09:49 -0800)
There are further changes required for things like
recursive map types.  Recursive struct types work
but the mechanism needs generalization.  The
case handled in this CL is pathological since it
cannot be represented at all by gob, so it should
be handled separately. (Prior to this CL, encode
would recur forever.)

R=rsc
CC=golang-dev
https://golang.org/cl/4206041

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/encoder.go
src/pkg/gob/type.go

index c09736221ed05a0b0c277ca6df7ce88243ed58f7..480d3df075c0e9b809451abf6efae259987b815f 100644 (file)
@@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) {
        }
 }
 
+
+func TestBadRecursiveType(t *testing.T) {
+       type Rec ***Rec
+       var rec Rec
+       b := new(bytes.Buffer)
+       err := NewEncoder(b).Encode(&rec)
+       if err == nil {
+               t.Error("expected error; got none")
+       } else if strings.Index(err.String(), "recursive") < 0 {
+               t.Error("expected recursive type error; got", err)
+       }
+       // Can't test decode easily because we can't encode one, so we can't pass one to a Decoder.
+}
+
 type Bad0 struct {
-       ch chan int
-       c  float64
+       CH chan int
+       C  float64
 }
 
-var nilEncoder *Encoder
 
 func TestInvalidField(t *testing.T) {
        var bad0 Bad0
-       bad0.ch = make(chan int)
+       bad0.CH = make(chan int)
        b := new(bytes.Buffer)
+       var nilEncoder *Encoder
        err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
        if err == nil {
                t.Error("expected error; got none")
index d3f87144dac274e20501f3b3d39bef19a87e2cff..655a28bfe1f7dfdc0b2d532abe112782e519c407 100644 (file)
@@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
 }
 
 func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
-       defer catchError(&err)
        p = allocate(ut.base, p, ut.indir)
        state := newDecodeState(dec, &dec.buf)
        state.fieldnum = singletonField
@@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
 // This state cannot arise for decodeSingle, which is called directly
 // from the user's value, not from the innards of an engine.
 func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
-       defer catchError(&err)
        p = allocate(ut.base.(*reflect.StructType), p, indir)
        state := newDecodeState(dec, &dec.buf)
        state.fieldnum = -1
@@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
 }
 
 func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
-       defer catchError(&err)
        state := newDecodeState(dec, &dec.buf)
        state.fieldnum = -1
        for state.b.Len() > 0 {
@@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
 }
 
 func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
-       defer catchError(&err)
        state := newDecodeState(dec, &dec.buf)
        state.fieldnum = singletonField
        delta := int(state.decodeUint())
@@ -937,7 +933,6 @@ func isExported(name string) bool {
 }
 
 func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
-       defer catchError(&err)
        srt, ok := rt.(*reflect.StructType)
        if !ok {
                return dec.compileSingle(remoteId, rt)
@@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
        return
 }
 
-func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
+func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
+       defer catchError(&err)
        // If the value is nil, it means we should just ignore this item.
        if val == nil {
                return dec.decodeIgnoredValue(wireId)
index 1419a278445f4fcad4fb50e7881ced90a58d171f..92d036c11c3578846dbbaaf6e637d890ee3676ee 100644 (file)
@@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        // Remove any nested writers remaining due to previous errors.
        enc.w = enc.w[0:1]
 
-       enc.err = nil
-       ut := userType(value.Type())
+       ut, err := validUserType(value.Type())
+       if err != nil {
+               return err
+       }
 
+       enc.err = nil
        state := newEncoderState(enc, new(bytes.Buffer))
 
        enc.sendTypeDescriptor(enc.writer(), state, ut)
@@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        }
 
        // Encode the object.
-       err := enc.encode(state.b, value, ut)
+       err = enc.encode(state.b, value, ut)
        if err != nil {
                enc.setError(err)
        } else {
index c9c116abf830acd8782d9fb01453ff5e87a5cbd0..3ed4cce924ff2b137056ee8a855c81ad2afc9924 100644 (file)
@@ -27,28 +27,63 @@ var (
        userTypeCache = make(map[reflect.Type]*userTypeInfo)
 )
 
-// userType returns, and saves, the information associated with user-provided type rt
-func userType(rt reflect.Type) *userTypeInfo {
+// validType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, err will be non-nil.  To be used when the error handler
+// is not set up.
+func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
        userTypeLock.RLock()
-       ut := userTypeCache[rt]
+       ut = userTypeCache[rt]
        userTypeLock.RUnlock()
        if ut != nil {
-               return ut
+               return
        }
        // Now set the value under the write lock.
        userTypeLock.Lock()
        defer userTypeLock.Unlock()
        if ut = userTypeCache[rt]; ut != nil {
                // Lost the race; not a problem.
-               return ut
+               return
        }
        ut = new(userTypeInfo)
+       ut.base = rt
        ut.user = rt
-       ut.base, ut.indir = indirect(rt)
+       // A type that is just a cycle of pointers (such as type T *T) cannot
+       // be represented in gobs, which need some concrete data.  We use a
+       // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
+       // pp 539-540.  As we step through indirections, run another type at
+       // half speed. If they meet up, there's a cycle.
+       // TODO: still need to deal with self-referential non-structs such
+       // as type T map[string]T but that is a larger undertaking - and can
+       // be useful, not always erroneous.
+       slowpoke := ut.base // walks half as fast as ut.base
+       for {
+               pt, ok := ut.base.(*reflect.PtrType)
+               if !ok {
+                       break
+               }
+               ut.base = pt.Elem()
+               if ut.base == slowpoke { // ut.base lapped slowpoke
+                       // recursive pointer type.
+                       return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String())
+               }
+               if ut.indir%2 == 0 {
+                       slowpoke = slowpoke.(*reflect.PtrType).Elem()
+               }
+               ut.indir++
+       }
        userTypeCache[rt] = ut
-       return ut
+       return
 }
 
+// userType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, it calls error.
+func userType(rt reflect.Type) *userTypeInfo {
+       ut, err := validUserType(rt)
+       if err != nil {
+               error(err)
+       }
+       return ut
+}
 // A typeId represents a gob Type as an integer that can be passed on the wire.
 // Internally, typeIds are used as keys to a map to recover the underlying type info.
 type typeId int32
@@ -273,21 +308,6 @@ func newStructType(name string) *structType {
        return s
 }
 
-// Step through the indirections on a type to discover the base type.
-// Return the base type and the number of indirections.
-func indirect(t reflect.Type) (rt reflect.Type, count int) {
-       rt = t
-       for {
-               pt, ok := rt.(*reflect.PtrType)
-               if !ok {
-                       break
-               }
-               rt = pt.Elem()
-               count++
-       }
-       return
-}
-
 func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
        switch t := rt.(type) {
        // All basic types are easy: they are predefined.