From 08c84420bc40d1cd5eb71b85cbe3a36f707bdb3f Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Fri, 3 May 2024 09:21:39 -0400 Subject: [PATCH] encoding/gob: cover missed cases when checking ignore depth This change makes sure that we are properly checking the ignored field recursion depth in decIgnoreOpFor consistently. This prevents stack exhaustion when attempting to decode a message that contains an extremely deeply nested struct which is ignored. Thanks to Md Sakib Anwar of The Ohio State University (anwar.40@osu.edu) for reporting this issue. Fixes #69139 Fixes CVE-2024-34156 Change-Id: Iacce06be95a5892b3064f1c40fcba2e2567862d6 Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/1440 Reviewed-by: Russ Cox Reviewed-by: Damien Neil Reviewed-on: https://go-review.googlesource.com/c/go/+/611239 LUCI-TryBot-Result: Go LUCI Reviewed-by: Dmitri Shuralyov Reviewed-by: Roland Shoemaker Auto-Submit: Dmitri Shuralyov --- src/encoding/gob/decode.go | 19 +++++++++++-------- src/encoding/gob/decoder.go | 2 ++ src/encoding/gob/gobencdec_test.go | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/encoding/gob/decode.go b/src/encoding/gob/decode.go index d178b2b2fb..26b5f6d62b 100644 --- a/src/encoding/gob/decode.go +++ b/src/encoding/gob/decode.go @@ -911,8 +911,11 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg var maxIgnoreNestingDepth = 10000 // decIgnoreOpFor returns the decoding op for a field that has no destination. -func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp { - if depth > maxIgnoreNestingDepth { +func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp { + // Track how deep we've recursed trying to skip nested ignored fields. + dec.ignoreDepth++ + defer func() { dec.ignoreDepth-- }() + if dec.ignoreDepth > maxIgnoreNestingDepth { error_(errors.New("invalid nesting depth")) } // If this type is already in progress, it's a recursive type (e.g. map[string]*T). @@ -938,7 +941,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, errorf("bad data: undefined type %s", wireId.string()) case wire.ArrayT != nil: elemId := wire.ArrayT.Elem - elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) + elemOp := dec.decIgnoreOpFor(elemId, inProgress) op = func(i *decInstr, state *decoderState, value reflect.Value) { state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len) } @@ -946,15 +949,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, case wire.MapT != nil: keyId := dec.wireType[wireId].MapT.Key elemId := dec.wireType[wireId].MapT.Elem - keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1) - elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) + keyOp := dec.decIgnoreOpFor(keyId, inProgress) + elemOp := dec.decIgnoreOpFor(elemId, inProgress) op = func(i *decInstr, state *decoderState, value reflect.Value) { state.dec.ignoreMap(state, *keyOp, *elemOp) } case wire.SliceT != nil: elemId := wire.SliceT.Elem - elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) + elemOp := dec.decIgnoreOpFor(elemId, inProgress) op = func(i *decInstr, state *decoderState, value reflect.Value) { state.dec.ignoreSlice(state, *elemOp) } @@ -1115,7 +1118,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *de func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine { engine := new(decEngine) engine.instr = make([]decInstr, 1) // one item - op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0) + op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp)) ovfl := overflow(dec.typeString(remoteId)) engine.instr[0] = decInstr{*op, 0, nil, ovfl} engine.numInstr = 1 @@ -1160,7 +1163,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn localField, present := srt.FieldByName(wireField.Name) // TODO(r): anonymous names if !present || !isExported(wireField.Name) { - op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0) + op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp)) engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl} continue } diff --git a/src/encoding/gob/decoder.go b/src/encoding/gob/decoder.go index c4b6088013..eae307838e 100644 --- a/src/encoding/gob/decoder.go +++ b/src/encoding/gob/decoder.go @@ -35,6 +35,8 @@ type Decoder struct { freeList *decoderState // list of free decoderStates; avoids reallocation countBuf []byte // used for decoding integers while parsing messages err error + // ignoreDepth tracks the depth of recursively parsed ignored fields + ignoreDepth int } // NewDecoder returns a new decoder that reads from the [io.Reader]. diff --git a/src/encoding/gob/gobencdec_test.go b/src/encoding/gob/gobencdec_test.go index ae806fc39a..d30e622aa2 100644 --- a/src/encoding/gob/gobencdec_test.go +++ b/src/encoding/gob/gobencdec_test.go @@ -806,6 +806,8 @@ func TestIgnoreDepthLimit(t *testing.T) { defer func() { maxIgnoreNestingDepth = oldNestingDepth }() b := new(bytes.Buffer) enc := NewEncoder(b) + + // Nested slice typ := reflect.TypeFor[int]() nested := reflect.ArrayOf(1, typ) for i := 0; i < 100; i++ { @@ -819,4 +821,16 @@ func TestIgnoreDepthLimit(t *testing.T) { if err := dec.Decode(&output); err == nil || err.Error() != expectedErr { t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err) } + + // Nested struct + nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: typ}}) + for i := 0; i < 100; i++ { + nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}}) + } + badStruct = reflect.New(reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}})) + enc.Encode(badStruct.Interface()) + dec = NewDecoder(b) + if err := dec.Decode(&output); err == nil || err.Error() != expectedErr { + t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err) + } } -- 2.48.1