var maxIgnoreNestingDepth = 10000
// decIgnoreOpFor returns the decoding op for a field that has no destination.
-func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp {
- if depth > maxIgnoreNestingDepth {
+func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp {
+ // Track how deep we've recursed trying to skip nested ignored fields.
+ dec.ignoreDepth++
+ defer func() { dec.ignoreDepth-- }()
+ if dec.ignoreDepth > maxIgnoreNestingDepth {
error_(errors.New("invalid nesting depth"))
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
errorf("bad data: undefined type %s", wireId.string())
case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem
- elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
+ elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len)
}
case wire.MapT != nil:
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
- keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1)
- elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
+ keyOp := dec.decIgnoreOpFor(keyId, inProgress)
+ elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreMap(state, *keyOp, *elemOp)
}
case wire.SliceT != nil:
elemId := wire.SliceT.Elem
- elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
+ elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreSlice(state, *elemOp)
}
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine {
engine := new(decEngine)
engine.instr = make([]decInstr, 1) // one item
- op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0)
+ op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp))
ovfl := overflow(dec.typeString(remoteId))
engine.instr[0] = decInstr{*op, 0, nil, ovfl}
engine.numInstr = 1
localField, present := srt.FieldByName(wireField.Name)
// TODO(r): anonymous names
if !present || !isExported(wireField.Name) {
- op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0)
+ op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp))
engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl}
continue
}
defer func() { maxIgnoreNestingDepth = oldNestingDepth }()
b := new(bytes.Buffer)
enc := NewEncoder(b)
+
+ // Nested slice
typ := reflect.TypeFor[int]()
nested := reflect.ArrayOf(1, typ)
for i := 0; i < 100; i++ {
if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
}
+
+ // Nested struct
+ nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: typ}})
+ for i := 0; i < 100; i++ {
+ nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}})
+ }
+ badStruct = reflect.New(reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}}))
+ enc.Encode(badStruct.Interface())
+ dec = NewDecoder(b)
+ if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
+ t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
+ }
}