From: Rob Pike Date: Wed, 23 Feb 2011 17:49:35 +0000 (-0800) Subject: gob: protect against pure recursive types. X-Git-Tag: weekly.2011-02-24~24 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=c9b90c9d70da05269a0de0e528c13b6b76299846;p=gostls13.git gob: protect against pure recursive types. There are further changes required for things like recursive map types. Recursive struct types work but the mechanism needs generalization. The case handled in this CL is pathological since it cannot be represented at all by gob, so it should be handled separately. (Prior to this CL, encode would recur forever.) R=rsc CC=golang-dev https://golang.org/cl/4206041 --- diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index c09736221e..480d3df075 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) { } } + +func TestBadRecursiveType(t *testing.T) { + type Rec ***Rec + var rec Rec + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(&rec) + if err == nil { + t.Error("expected error; got none") + } else if strings.Index(err.String(), "recursive") < 0 { + t.Error("expected recursive type error; got", err) + } + // Can't test decode easily because we can't encode one, so we can't pass one to a Decoder. +} + type Bad0 struct { - ch chan int - c float64 + CH chan int + C float64 } -var nilEncoder *Encoder func TestInvalidField(t *testing.T) { var bad0 Bad0 - bad0.ch = make(chan int) + bad0.CH = make(chan int) b := new(bytes.Buffer) + var nilEncoder *Encoder err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) if err == nil { t.Error("expected error; got none") diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index d3f87144da..655a28bfe1 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { } func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { - defer catchError(&err) p = allocate(ut.base, p, ut.indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField @@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) // This state cannot arise for decodeSingle, which is called directly // from the user's value, not from the innards of an engine. func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) { - defer catchError(&err) p = allocate(ut.base.(*reflect.StructType), p, indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 @@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, } func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { - defer catchError(&err) state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 for state.b.Len() > 0 { @@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { } func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { - defer catchError(&err) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField delta := int(state.decodeUint()) @@ -937,7 +933,6 @@ func isExported(name string) bool { } func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { - defer catchError(&err) srt, ok := rt.(*reflect.StructType) if !ok { return dec.compileSingle(remoteId, rt) @@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er return } -func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error { +func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) { + defer catchError(&err) // If the value is nil, it means we should just ignore this item. if val == nil { return dec.decodeIgnoredValue(wireId) diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 1419a27844..92d036c11c 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { // Remove any nested writers remaining due to previous errors. enc.w = enc.w[0:1] - enc.err = nil - ut := userType(value.Type()) + ut, err := validUserType(value.Type()) + if err != nil { + return err + } + enc.err = nil state := newEncoderState(enc, new(bytes.Buffer)) enc.sendTypeDescriptor(enc.writer(), state, ut) @@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { } // Encode the object. - err := enc.encode(state.b, value, ut) + err = enc.encode(state.b, value, ut) if err != nil { enc.setError(err) } else { diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index c9c116abf8..3ed4cce924 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -27,28 +27,63 @@ var ( userTypeCache = make(map[reflect.Type]*userTypeInfo) ) -// userType returns, and saves, the information associated with user-provided type rt -func userType(rt reflect.Type) *userTypeInfo { +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { userTypeLock.RLock() - ut := userTypeCache[rt] + ut = userTypeCache[rt] userTypeLock.RUnlock() if ut != nil { - return ut + return } // Now set the value under the write lock. userTypeLock.Lock() defer userTypeLock.Unlock() if ut = userTypeCache[rt]; ut != nil { // Lost the race; not a problem. - return ut + return } ut = new(userTypeInfo) + ut.base = rt ut.user = rt - ut.base, ut.indir = indirect(rt) + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + // TODO: still need to deal with self-referential non-structs such + // as type T map[string]T but that is a larger undertaking - and can + // be useful, not always erroneous. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt, ok := ut.base.(*reflect.PtrType) + if !ok { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.(*reflect.PtrType).Elem() + } + ut.indir++ + } userTypeCache[rt] = ut - return ut + return } +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error(err) + } + return ut +} // A typeId represents a gob Type as an integer that can be passed on the wire. // Internally, typeIds are used as keys to a map to recover the underlying type info. type typeId int32 @@ -273,21 +308,6 @@ func newStructType(name string) *structType { return s } -// Step through the indirections on a type to discover the base type. -// Return the base type and the number of indirections. -func indirect(t reflect.Type) (rt reflect.Type, count int) { - rt = t - for { - pt, ok := rt.(*reflect.PtrType) - if !ok { - break - } - rt = pt.Elem() - count++ - } - return -} - func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { switch t := rt.(type) { // All basic types are easy: they are predefined.