]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: detect cyclic maps and slices
authorlujjjh <lujjjh@gmail.com>
Thu, 17 Sep 2020 14:39:13 +0000 (14:39 +0000)
committerDaniel Martí <mvdan@mvdan.cc>
Thu, 24 Sep 2020 18:18:20 +0000 (18:18 +0000)
Now reports an error if cyclic maps and slices are to be encoded
instead of an infinite recursion. This case wasn't handled in CL 187920.

Fixes #40745.

Change-Id: Ia34b014ecbb71fd2663bb065ba5355a307dbcc15
GitHub-Last-Rev: 6f874944f4065b5237babbb0fdce14c1c74a3c97
GitHub-Pull-Request: golang/go#40756
Reviewed-on: https://go-review.googlesource.com/c/go/+/248358
Reviewed-by: Daniel Martí <mvdan@mvdan.cc>
Trust: Daniel Martí <mvdan@mvdan.cc>
Trust: Bryan C. Mills <bcmills@google.com>
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Go Bot <gobot@golang.org>

src/encoding/json/encode.go
src/encoding/json/encode_test.go

index c2d191442c83cd864f1fa2921c4539fa7d3a9291..ea5eca51ef89d8b753bc3b33fb2bde733c899696 100644 (file)
@@ -779,6 +779,16 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
                e.WriteString("null")
                return
        }
+       if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
+               // We're a large number of nested ptrEncoder.encode calls deep;
+               // start checking if we've run into a pointer cycle.
+               ptr := v.Pointer()
+               if _, ok := e.ptrSeen[ptr]; ok {
+                       e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
+               }
+               e.ptrSeen[ptr] = struct{}{}
+               defer delete(e.ptrSeen, ptr)
+       }
        e.WriteByte('{')
 
        // Extract and sort the keys.
@@ -801,6 +811,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
                me.elemEnc(e, v.MapIndex(kv.v), opts)
        }
        e.WriteByte('}')
+       e.ptrLevel--
 }
 
 func newMapEncoder(t reflect.Type) encoderFunc {
@@ -857,7 +868,23 @@ func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
                e.WriteString("null")
                return
        }
+       if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
+               // We're a large number of nested ptrEncoder.encode calls deep;
+               // start checking if we've run into a pointer cycle.
+               // Here we use a struct to memorize the pointer to the first element of the slice
+               // and its length.
+               ptr := struct {
+                       ptr uintptr
+                       len int
+               }{v.Pointer(), v.Len()}
+               if _, ok := e.ptrSeen[ptr]; ok {
+                       e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
+               }
+               e.ptrSeen[ptr] = struct{}{}
+               defer delete(e.ptrSeen, ptr)
+       }
        se.arrayEnc(e, v, opts)
+       e.ptrLevel--
 }
 
 func newSliceEncoder(t reflect.Type) encoderFunc {
index 7290eca06f070769bf46225123ebde9c728294d6..42bb09d5cded0babe31a089369f325ff6b3aaf7e 100644 (file)
@@ -183,7 +183,15 @@ type PointerCycleIndirect struct {
        Ptrs []interface{}
 }
 
-var pointerCycleIndirect = &PointerCycleIndirect{}
+type RecursiveSlice []RecursiveSlice
+
+var (
+       pointerCycleIndirect = &PointerCycleIndirect{}
+       mapCycle             = make(map[string]interface{})
+       sliceCycle           = []interface{}{nil}
+       sliceNoCycle         = []interface{}{nil, nil}
+       recursiveSliceCycle  = []RecursiveSlice{nil}
+)
 
 func init() {
        ptr := &SamePointerNoCycle{}
@@ -192,6 +200,14 @@ func init() {
 
        pointerCycle.Ptr = pointerCycle
        pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect}
+
+       mapCycle["x"] = mapCycle
+       sliceCycle[0] = sliceCycle
+       sliceNoCycle[1] = sliceNoCycle[:1]
+       for i := startDetectingCyclesAfter; i > 0; i-- {
+               sliceNoCycle = []interface{}{sliceNoCycle}
+       }
+       recursiveSliceCycle[0] = recursiveSliceCycle
 }
 
 func TestSamePointerNoCycle(t *testing.T) {
@@ -200,12 +216,21 @@ func TestSamePointerNoCycle(t *testing.T) {
        }
 }
 
+func TestSliceNoCycle(t *testing.T) {
+       if _, err := Marshal(sliceNoCycle); err != nil {
+               t.Fatalf("unexpected error: %v", err)
+       }
+}
+
 var unsupportedValues = []interface{}{
        math.NaN(),
        math.Inf(-1),
        math.Inf(1),
        pointerCycle,
        pointerCycleIndirect,
+       mapCycle,
+       sliceCycle,
+       recursiveSliceCycle,
 }
 
 func TestUnsupportedValues(t *testing.T) {