From c54b5d032f593c5379e2286da11fb78fe33fd7a3 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Fri, 25 Feb 2011 09:45:06 -0800 Subject: [PATCH] gob: make recursive map and slice types work. Before this fix, types such as type T map[string]T caused infinite recursion in the gob implementation. Now they just work. Fixes #1518. R=rsc CC=golang-dev https://golang.org/cl/4230045 --- src/pkg/gob/codec_test.go | 10 ++-- src/pkg/gob/decode.go | 63 ++++++++++++++---------- src/pkg/gob/encode.go | 43 +++++++++------- src/pkg/gob/encoder_test.go | 18 +++++++ src/pkg/gob/type.go | 98 +++++++++++++++++++++++++++---------- 5 files changed, 156 insertions(+), 76 deletions(-) diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index 480d3df075..c822d6863a 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -342,7 +342,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a int } - instr := &decInstr{decOpMap[reflect.Int], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Int], 6, 0, 0, ovfl} state := newDecodeStateFromData(signedResult) execDec("int", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -355,7 +355,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a uint } - instr := &decInstr{decOpMap[reflect.Uint], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Uint], 6, 0, 0, ovfl} state := newDecodeStateFromData(unsignedResult) execDec("uint", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -446,7 +446,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a uintptr } - instr := &decInstr{decOpMap[reflect.Uintptr], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Uintptr], 6, 0, 0, ovfl} state := newDecodeStateFromData(unsignedResult) execDec("uintptr", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -511,7 +511,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a complex64 } - instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Complex64], 6, 0, 0, ovfl} state := newDecodeStateFromData(complexResult) execDec("complex", instr, state, t, unsafe.Pointer(&data)) if data.a != 17+19i { @@ -524,7 +524,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a complex128 } - instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Complex128], 6, 0, 0, ovfl} state := newDecodeStateFromData(complexResult) execDec("complex", instr, state, t, unsafe.Pointer(&data)) if data.a != 17+19i { diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 655a28bfe1..ad03d176b8 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -671,7 +671,7 @@ func (dec *Decoder) ignoreInterface(state *decodeState) { } // Index by Go types. -var decOpMap = []decOp{ +var decOpTable = [...]decOp{ reflect.Bool: decBool, reflect.Int8: decInt8, reflect.Int16: decInt16, @@ -701,37 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{ // Return the decoding op for the base type under rt and // the indirection count to reach it. -func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) { +func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { ut := userType(rt) + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } typ := ut.base indir := ut.indir var op decOp k := typ.Kind() - if int(k) < len(decOpMap) { - op = decOpMap[k] + if int(k) < len(decOpTable) { + op = decOpTable[k] } if op == nil { + inProgress[rt] = &op // Special cases switch t := typ.(type) { case *reflect.ArrayType: name = "element of " + name elemId := dec.wireType[wireId].ArrayT.Elem - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } case *reflect.MapType: name = "element of " + name keyId := dec.wireType[wireId].MapT.Key elemId := dec.wireType[wireId].MapT.Elem - keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name) - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { up := unsafe.Pointer(p) - state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) + state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl) } case *reflect.SliceType: @@ -746,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } else { elemId = dec.wireType[wireId].SliceT.Elem } - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: @@ -774,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp if op == nil { errorf("gob: decode can't handle type %s", rt.String()) } - return op, indir + return &op, indir } // Return the decoding op for a field that has no destination. @@ -838,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { // Are these two gob Types compatible? // Answers the question for basic types, arrays, and slices. // Structs are considered ok; fields will be checked later. -func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { +func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool { + if rhs, ok := inProgress[fr]; ok { + return rhs == fw + } + inProgress[fr] = fw fr = userType(fr).base switch t := fr.(type) { default: - // map, chan, etc: cannot handle. + // chan, etc: cannot handle. return false case *reflect.BoolType: return fw == tBool @@ -864,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return false } array := wire.ArrayT - return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) case *reflect.MapType: wire, ok := dec.wireType[fw] if !ok || wire.MapT == nil { return false } MapType := wire.MapT - return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem) + return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress) case *reflect.SliceType: // Is it an array of bytes? if t.Elem().Kind() == reflect.Uint8 { @@ -885,7 +895,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { sw = dec.wireType[fw].SliceT } elem := userType(t.Elem()).base - return sw != nil && dec.compatibleType(elem, sw.Elem) + return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) case *reflect.StructType: return true } @@ -906,12 +916,12 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item name := rt.String() // best we can do - if !dec.compatibleType(rt, remoteId) { + if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) { return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId)) } - op, indir := dec.decOpFor(remoteId, rt, name) + op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp)) ovfl := os.ErrorString(`value for "` + name + `" out of range`) - engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} + engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl} engine.numInstr = 1 return } @@ -954,6 +964,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.Field)) + seen := make(map[reflect.Type]*decOp) // Loop over the fields of the wire type. for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { wireField := wireStruct.Field[fieldnum] @@ -969,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl} continue } - if !dec.compatibleType(localField.Type, wireField.Id) { + if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } - op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name) - engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl} + op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) + engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl} engine.numInstr++ } return @@ -1070,8 +1081,8 @@ func init() { default: panic("gob: unknown size of int/uint") } - decOpMap[reflect.Int] = iop - decOpMap[reflect.Uint] = uop + decOpTable[reflect.Int] = iop + decOpTable[reflect.Uint] = uop // Finally uintptr switch reflect.Typeof(uintptr(0)).Bits() { @@ -1082,5 +1093,5 @@ func init() { default: panic("gob: unknown size of uintptr") } - decOpMap[reflect.Uintptr] = uop + decOpTable[reflect.Uintptr] = uop } diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index c5570409b4..5f4fc6f34b 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -414,7 +414,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) } } -var encOpMap = []encOp{ +var encOpTable = [...]encOp{ reflect.Bool: encBool, reflect.Int: encInt, reflect.Int8: encInt8, @@ -434,18 +434,24 @@ var encOpMap = []encOp{ reflect.String: encString, } -// Return the encoding op for the base type under rt and +// Return (a pointer to) the encoding op for the base type under rt and // the indirection count to reach it. -func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { +func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { ut := userType(rt) + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } typ := ut.base indir := ut.indir - var op encOp k := typ.Kind() - if int(k) < len(encOpMap) { - op = encOpMap[k] + var op encOp + if int(k) < len(encOpTable) { + op = encOpTable[k] } if op == nil { + inProgress[rt] = &op // Special cases switch t := typ.(type) { case *reflect.SliceType: @@ -454,25 +460,25 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { break } // Slices have a header; we decode it to find the underlying array. - elemOp, indir := enc.encOpFor(t.Elem()) + elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { slice := (*reflect.SliceHeader)(p) if !state.sendZero && slice.Len == 0 { return } state.update(i) - state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) + state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len)) } case *reflect.ArrayType: // True arrays have size in the type. - elemOp, indir := enc.encOpFor(t.Elem()) + elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) - state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) + state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len()) } case *reflect.MapType: - keyOp, keyIndir := enc.encOpFor(t.Key()) - elemOp, elemIndir := enc.encOpFor(t.Elem()) + keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress) + elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for @@ -483,7 +489,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { return } state.update(i) - state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) + state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. @@ -511,21 +517,22 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { if op == nil { errorf("gob enc: can't happen: encode type %s", rt.String()) } - return op, indir + return &op, indir } // The local Type was compiled from the actual value, so we know it's compatible. func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { srt, isStruct := rt.(*reflect.StructType) engine := new(encEngine) + seen := make(map[reflect.Type]*encOp) if isStruct { for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ { f := srt.Field(fieldNum) if !isExported(f.Name) { continue } - op, indir := enc.encOpFor(f.Type) - engine.instr = append(engine.instr, encInstr{op, fieldNum, indir, uintptr(f.Offset)}) + op, indir := enc.encOpFor(f.Type, seen) + engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)}) } if srt.NumField() > 0 && len(engine.instr) == 0 { errorf("type %s has no exported fields", rt) @@ -533,8 +540,8 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) } else { engine.instr = make([]encInstr, 1) - op, indir := enc.encOpFor(rt) - engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero + op, indir := enc.encOpFor(rt, seen) + engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero } return engine } diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 3e06db7272..a0c713b81d 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -249,6 +249,24 @@ func TestArray(t *testing.T) { } } +func TestRecursiveMapType(t *testing.T) { + type recursiveMap map[string]recursiveMap + r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil} + r2 := make(recursiveMap) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + +func TestRecursiveSliceType(t *testing.T) { + type recursiveSlice []recursiveSlice + r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil} + r2 := make(recursiveSlice, 0) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + // Regression test for bug: must send zero values inside arrays func TestDefaultsInArray(t *testing.T) { type Type7 struct { diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 3ed4cce924..6e3f148b4e 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -52,9 +52,6 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, // pp 539-540. As we step through indirections, run another type at // half speed. If they meet up, there's a cycle. - // TODO: still need to deal with self-referential non-structs such - // as type T map[string]T but that is a larger undertaking - and can - // be useful, not always erroneous. slowpoke := ut.base // walks half as fast as ut.base for { pt, ok := ut.base.(*reflect.PtrType) @@ -210,12 +207,18 @@ type arrayType struct { Len int } -func newArrayType(name string, elem gobType, length int) *arrayType { - a := &arrayType{CommonType{Name: name}, elem.id(), length} - setTypeId(a) +func newArrayType(name string) *arrayType { + a := &arrayType{CommonType{Name: name}, 0, 0} return a } +func (a *arrayType) init(elem gobType, len int) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(a) + a.Elem = elem.id() + a.Len = len +} + func (a *arrayType) safeString(seen map[typeId]bool) string { if seen[a.Id] { return a.Name @@ -233,12 +236,18 @@ type mapType struct { Elem typeId } -func newMapType(name string, key, elem gobType) *mapType { - m := &mapType{CommonType{Name: name}, key.id(), elem.id()} - setTypeId(m) +func newMapType(name string) *mapType { + m := &mapType{CommonType{Name: name}, 0, 0} return m } +func (m *mapType) init(key, elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(m) + m.Key = key.id() + m.Elem = elem.id() +} + func (m *mapType) safeString(seen map[typeId]bool) string { if seen[m.Id] { return m.Name @@ -257,12 +266,17 @@ type sliceType struct { Elem typeId } -func newSliceType(name string, elem gobType) *sliceType { - s := &sliceType{CommonType{Name: name}, elem.id()} - setTypeId(s) +func newSliceType(name string) *sliceType { + s := &sliceType{CommonType{Name: name}, 0} return s } +func (s *sliceType) init(elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(s) + s.Elem = elem.id() +} + func (s *sliceType) safeString(seen map[typeId]bool) string { if seen[s.Id] { return s.Name @@ -304,11 +318,26 @@ func (s *structType) string() string { return s.safeString(make(map[typeId]bool) func newStructType(name string) *structType { s := &structType{CommonType{Name: name}, nil} + // For historical reasons we set the id here rather than init. + // Se the comment in newTypeObject for details. setTypeId(s) return s } +func (s *structType) init(field []*fieldType) { + s.Field = field +} + func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { + var err os.Error + var type0, type1 gobType + defer func() { + if err != nil { + types[rt] = nil, false + } + }() + // Install the top-level type before the subtypes (e.g. struct before + // fields) so recursive types can be constructed safely. switch t := rt.(type) { // All basic types are easy: they are predefined. case *reflect.BoolType: @@ -333,40 +362,55 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { return tInterface.gobType(), nil case *reflect.ArrayType: - gt, err := getType("", t.Elem()) + at := newArrayType(name) + types[rt] = at + type0, err = getType("", t.Elem()) if err != nil { return nil, err } - return newArrayType(name, gt, t.Len()), nil + // Historical aside: + // For arrays, maps, and slices, we set the type id after the elements + // are constructed. This is to retain the order of type id allocation after + // a fix made to handle recursive types, which changed the order in + // which types are built. Delaying the setting in this way preserves + // type ids while allowing recursive types to be described. Structs, + // done below, were already handling recursion correctly so they + // assign the top-level id before those of the field. + at.init(type0, t.Len()) + return at, nil case *reflect.MapType: - kt, err := getType("", t.Key()) + mt := newMapType(name) + types[rt] = mt + type0, err = getType("", t.Key()) if err != nil { return nil, err } - vt, err := getType("", t.Elem()) + type1, err = getType("", t.Elem()) if err != nil { return nil, err } - return newMapType(name, kt, vt), nil + mt.init(type0, type1) + return mt, nil case *reflect.SliceType: // []byte == []uint8 is a special case if t.Elem().Kind() == reflect.Uint8 { return tBytes.gobType(), nil } - gt, err := getType(t.Elem().Name(), t.Elem()) + st := newSliceType(name) + types[rt] = st + type0, err = getType(t.Elem().Name(), t.Elem()) if err != nil { return nil, err } - return newSliceType(name, gt), nil + st.init(type0) + return st, nil case *reflect.StructType: - // Install the struct type itself before the fields so recursive - // structures can be constructed safely. - strType := newStructType(name) - types[rt] = strType - idToType[strType.id()] = strType + st := newStructType(name) + types[rt] = st + idToType[st.id()] = st field := make([]*fieldType, t.NumField()) for i := 0; i < t.NumField(); i++ { f := t.Field(i) @@ -382,8 +426,8 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { } field[i] = &fieldType{f.Name, gt.id()} } - strType.Field = field - return strType, nil + st.init(field) + return st, nil default: return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()) @@ -435,7 +479,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { // For bootstrapping purposes, we assume that the recipient knows how // to decode a wireType; it is exactly the wireType struct here, interpreted // using the gob rules for sending a structure, except that we assume the -// ids for wireType and structType are known. The relevant pieces +// ids for wireType and structType etc. are known. The relevant pieces // are built in encode.go's init() function. // To maintain binary compatibility, if you extend this type, always put // the new fields last. -- 2.48.1