]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/json: error when encoding a pointer cycle
authorDaniel Martí <mvdan@mvdan.cc>
Mon, 29 Jul 2019 03:16:14 +0000 (20:16 -0700)
committerAndrew Bonventre <andybons@golang.org>
Mon, 11 Nov 2019 16:24:21 +0000 (16:24 +0000)
Otherwise we'd panic with a stack overflow.

Most programs are in control of the data being encoded and can ensure
there are no cycles, but sometimes it's not that simple. For example,
running a user's html template with script tags can easily result in
crashes if the user can find a pointer cycle.

Adding the checks via a map to every ptrEncoder.encode call slowed down
the benchmarks below by a noticeable 13%. Instead, only start doing the
relatively expensive pointer cycle checks if we're many levels of
pointers deep in an encode state.

A threshold of 1000 is small enough to capture pointer cycles before
they're a problem (the goroutine stack limit is currently 1GB, and I
needed close to a million levels to reach it). Yet it's large enough
that reasonable uses of the json encoder only see a tiny 1% slow-down
due to the added ptrLevel field and check.

name           old time/op    new time/op    delta
CodeEncoder-8    2.34ms ± 1%    2.37ms ± 0%  +1.05%  (p=0.000 n=10+10)
CodeMarshal-8    2.42ms ± 1%    2.44ms ± 0%  +1.10%  (p=0.000 n=10+10)

name           old speed      new speed      delta
CodeEncoder-8   829MB/s ± 1%   820MB/s ± 0%  -1.04%  (p=0.000 n=10+10)
CodeMarshal-8   803MB/s ± 1%   795MB/s ± 0%  -1.09%  (p=0.000 n=10+10)

name           old alloc/op   new alloc/op   delta
CodeEncoder-8    43.1kB ± 8%    42.5kB ±10%    ~     (p=0.989 n=10+10)
CodeMarshal-8    1.99MB ± 0%    1.99MB ± 0%    ~     (p=0.254 n=9+6)

name           old allocs/op  new allocs/op  delta
CodeEncoder-8      0.00           0.00         ~     (all equal)
CodeMarshal-8      1.00 ± 0%      1.00 ± 0%    ~     (all equal)

Finally, add a few tests to ensure that the code handles the edge cases
properly.

Fixes #10769.

Change-Id: I73d48e0cf6ea140127ea031f2dbae6e6a55e58b8
Reviewed-on: https://go-review.googlesource.com/c/go/+/187920
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Bjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Reviewed-by: Andrew Bonventre <andybons@golang.org>
src/encoding/json/encode.go
src/encoding/json/encode_test.go

index b81e5058660f73dbd4bfaf10c853e973562fcf86..39cdaebde7b9c94e37ffec3f3c59a16ac8138382 100644 (file)
@@ -153,7 +153,7 @@ import (
 //
 // JSON cannot represent cyclic data structures and Marshal does not
 // handle them. Passing cyclic structures to Marshal will result in
-// an infinite recursion.
+// an error.
 //
 func Marshal(v interface{}) ([]byte, error) {
        e := newEncodeState()
@@ -285,17 +285,31 @@ var hex = "0123456789abcdef"
 type encodeState struct {
        bytes.Buffer // accumulated output
        scratch      [64]byte
+
+       // Keep track of what pointers we've seen in the current recursive call
+       // path, to avoid cycles that could lead to a stack overflow. Only do
+       // the relatively expensive map operations if ptrLevel is larger than
+       // startDetectingCyclesAfter, so that we skip the work if we're within a
+       // reasonable amount of nested pointers deep.
+       ptrLevel uint
+       ptrSeen  map[interface{}]struct{}
 }
 
+const startDetectingCyclesAfter = 1000
+
 var encodeStatePool sync.Pool
 
 func newEncodeState() *encodeState {
        if v := encodeStatePool.Get(); v != nil {
                e := v.(*encodeState)
                e.Reset()
+               if len(e.ptrSeen) > 0 {
+                       panic("ptrEncoder.encode should have emptied ptrSeen via defers")
+               }
+               e.ptrLevel = 0
                return e
        }
-       return new(encodeState)
+       return &encodeState{ptrSeen: make(map[interface{}]struct{})}
 }
 
 // jsonError is an error wrapper type for internal use only.
@@ -887,7 +901,18 @@ func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
                e.WriteString("null")
                return
        }
+       if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
+               // We're a large number of nested ptrEncoder.encode calls deep;
+               // start checking if we've run into a pointer cycle.
+               ptr := v.Interface()
+               if _, ok := e.ptrSeen[ptr]; ok {
+                       e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
+               }
+               e.ptrSeen[ptr] = struct{}{}
+               defer delete(e.ptrSeen, ptr)
+       }
        pe.elemEnc(e, v.Elem(), opts)
+       e.ptrLevel--
 }
 
 func newPtrEncoder(t reflect.Type) encoderFunc {
index 40f16d86ff87522b83a6b0347d720adcc2143b39..5110c7de9b123387aaf5cee9bb2e71b046c94276 100644 (file)
@@ -138,10 +138,45 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
        }
 }
 
+type SamePointerNoCycle struct {
+       Ptr1, Ptr2 *SamePointerNoCycle
+}
+
+var samePointerNoCycle = &SamePointerNoCycle{}
+
+type PointerCycle struct {
+       Ptr *PointerCycle
+}
+
+var pointerCycle = &PointerCycle{}
+
+type PointerCycleIndirect struct {
+       Ptrs []interface{}
+}
+
+var pointerCycleIndirect = &PointerCycleIndirect{}
+
+func init() {
+       ptr := &SamePointerNoCycle{}
+       samePointerNoCycle.Ptr1 = ptr
+       samePointerNoCycle.Ptr2 = ptr
+
+       pointerCycle.Ptr = pointerCycle
+       pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect}
+}
+
+func TestSamePointerNoCycle(t *testing.T) {
+       if _, err := Marshal(samePointerNoCycle); err != nil {
+               t.Fatalf("unexpected error: %v", err)
+       }
+}
+
 var unsupportedValues = []interface{}{
        math.NaN(),
        math.Inf(-1),
        math.Inf(1),
+       pointerCycle,
+       pointerCycleIndirect,
 }
 
 func TestUnsupportedValues(t *testing.T) {