From 428509402b03c608e625a4844ab0cce75e4bead2 Mon Sep 17 00:00:00 2001 From: lujjjh Date: Thu, 17 Sep 2020 14:39:13 +0000 Subject: [PATCH] encoding/json: detect cyclic maps and slices MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Now reports an error if cyclic maps and slices are to be encoded instead of an infinite recursion. This case wasn't handled in CL 187920. Fixes #40745. Change-Id: Ia34b014ecbb71fd2663bb065ba5355a307dbcc15 GitHub-Last-Rev: 6f874944f4065b5237babbb0fdce14c1c74a3c97 GitHub-Pull-Request: golang/go#40756 Reviewed-on: https://go-review.googlesource.com/c/go/+/248358 Reviewed-by: Daniel Martí Trust: Daniel Martí Trust: Bryan C. Mills Run-TryBot: Daniel Martí TryBot-Result: Go Bot --- src/encoding/json/encode.go | 27 +++++++++++++++++++++++++++ src/encoding/json/encode_test.go | 27 ++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go index c2d191442c..ea5eca51ef 100644 --- a/src/encoding/json/encode.go +++ b/src/encoding/json/encode.go @@ -779,6 +779,16 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { e.WriteString("null") return } + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + ptr := v.Pointer() + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } e.WriteByte('{') // Extract and sort the keys. @@ -801,6 +811,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { me.elemEnc(e, v.MapIndex(kv.v), opts) } e.WriteByte('}') + e.ptrLevel-- } func newMapEncoder(t reflect.Type) encoderFunc { @@ -857,7 +868,23 @@ func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { e.WriteString("null") return } + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + // Here we use a struct to memorize the pointer to the first element of the slice + // and its length. + ptr := struct { + ptr uintptr + len int + }{v.Pointer(), v.Len()} + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } se.arrayEnc(e, v, opts) + e.ptrLevel-- } func newSliceEncoder(t reflect.Type) encoderFunc { diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go index 7290eca06f..42bb09d5cd 100644 --- a/src/encoding/json/encode_test.go +++ b/src/encoding/json/encode_test.go @@ -183,7 +183,15 @@ type PointerCycleIndirect struct { Ptrs []interface{} } -var pointerCycleIndirect = &PointerCycleIndirect{} +type RecursiveSlice []RecursiveSlice + +var ( + pointerCycleIndirect = &PointerCycleIndirect{} + mapCycle = make(map[string]interface{}) + sliceCycle = []interface{}{nil} + sliceNoCycle = []interface{}{nil, nil} + recursiveSliceCycle = []RecursiveSlice{nil} +) func init() { ptr := &SamePointerNoCycle{} @@ -192,6 +200,14 @@ func init() { pointerCycle.Ptr = pointerCycle pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect} + + mapCycle["x"] = mapCycle + sliceCycle[0] = sliceCycle + sliceNoCycle[1] = sliceNoCycle[:1] + for i := startDetectingCyclesAfter; i > 0; i-- { + sliceNoCycle = []interface{}{sliceNoCycle} + } + recursiveSliceCycle[0] = recursiveSliceCycle } func TestSamePointerNoCycle(t *testing.T) { @@ -200,12 +216,21 @@ func TestSamePointerNoCycle(t *testing.T) { } } +func TestSliceNoCycle(t *testing.T) { + if _, err := Marshal(sliceNoCycle); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + var unsupportedValues = []interface{}{ math.NaN(), math.Inf(-1), math.Inf(1), pointerCycle, pointerCycleIndirect, + mapCycle, + sliceCycle, + recursiveSliceCycle, } func TestUnsupportedValues(t *testing.T) { -- 2.51.0