]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.17] encoding/xml: limit depth of nesting in unmarshal
authorRoland Shoemaker <roland@golang.org>
Tue, 29 Mar 2022 22:52:09 +0000 (15:52 -0700)
committerMichael Knyszek <mknyszek@google.com>
Tue, 12 Jul 2022 15:20:25 +0000 (15:20 +0000)
Prevent exhausting the stack limit when unmarshalling extremely deeply
nested structures into nested types.

Fixes #53715
Updates #53611
Fixes CVE-2022-30633

Change-Id: Ic6c5d41674c93cfc9a316135a408db9156d39c59
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1421319
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
(cherry picked from commit ebee00a55e28931b2cad0e76207a73712b000432)
Reviewed-on: https://go-review.googlesource.com/c/go/+/417069
Reviewed-by: Heschi Kreinick <heschi@google.com>
Run-TryBot: Michael Knyszek <mknyszek@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/encoding/xml/read.go
src/encoding/xml/read_test.go

index e9f9d2efa9cc2ba23a3addf3564c7f253d35e2bf..c77579880cbb09f00df8cbc584debafd536a246b 100644 (file)
@@ -148,7 +148,7 @@ func (d *Decoder) DecodeElement(v interface{}, start *StartElement) error {
        if val.Kind() != reflect.Ptr {
                return errors.New("non-pointer passed to Unmarshal")
        }
-       return d.unmarshal(val.Elem(), start)
+       return d.unmarshal(val.Elem(), start, 0)
 }
 
 // An UnmarshalError represents an error in the unmarshaling process.
@@ -304,8 +304,15 @@ var (
        textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
 )
 
+const maxUnmarshalDepth = 10000
+
+var errExeceededMaxUnmarshalDepth = errors.New("exceeded max depth")
+
 // Unmarshal a single XML element into val.
-func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
+func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) error {
+       if depth >= maxUnmarshalDepth {
+               return errExeceededMaxUnmarshalDepth
+       }
        // Find start element if we need it.
        if start == nil {
                for {
@@ -398,7 +405,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
                v.Set(reflect.Append(val, reflect.Zero(v.Type().Elem())))
 
                // Recur to read element into slice.
-               if err := d.unmarshal(v.Index(n), start); err != nil {
+               if err := d.unmarshal(v.Index(n), start, depth+1); err != nil {
                        v.SetLen(n)
                        return err
                }
@@ -521,13 +528,15 @@ Loop:
                case StartElement:
                        consumed := false
                        if sv.IsValid() {
-                               consumed, err = d.unmarshalPath(tinfo, sv, nil, &t)
+                               // unmarshalPath can call unmarshal, so we need to pass the depth through so that
+                               // we can continue to enforce the maximum recusion limit.
+                               consumed, err = d.unmarshalPath(tinfo, sv, nil, &t, depth)
                                if err != nil {
                                        return err
                                }
                                if !consumed && saveAny.IsValid() {
                                        consumed = true
-                                       if err := d.unmarshal(saveAny, &t); err != nil {
+                                       if err := d.unmarshal(saveAny, &t, depth+1); err != nil {
                                                return err
                                        }
                                }
@@ -672,7 +681,7 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
 // The consumed result tells whether XML elements have been consumed
 // from the Decoder until start's matching end element, or if it's
 // still untouched because start is uninteresting for sv's fields.
-func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) {
+func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement, depth int) (consumed bool, err error) {
        recurse := false
 Loop:
        for i := range tinfo.fields {
@@ -687,7 +696,7 @@ Loop:
                }
                if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
                        // It's a perfect match, unmarshal the field.
-                       return true, d.unmarshal(finfo.value(sv, initNilPointers), start)
+                       return true, d.unmarshal(finfo.value(sv, initNilPointers), start, depth+1)
                }
                if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
                        // It's a prefix for the field. Break and recurse
@@ -716,7 +725,9 @@ Loop:
                }
                switch t := tok.(type) {
                case StartElement:
-                       consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t)
+                       // the recursion depth of unmarshalPath is limited to the path length specified
+                       // by the struct field tag, so we don't increment the depth here.
+                       consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t, depth)
                        if err != nil {
                                return true, err
                        }
index 4ccab3d0106ebc4b4df62fc3010c9747ad303955..8c940aefb81079d616cd05baf75b8791bf26aacd 100644 (file)
@@ -6,6 +6,7 @@ package xml
 
 import (
        "bytes"
+       "errors"
        "io"
        "reflect"
        "runtime"
@@ -1097,3 +1098,16 @@ func TestCVE202230633(t *testing.T) {
        }
        Unmarshal(bytes.Repeat([]byte("<a>"), 17_000_000), &example)
 }
+
+func TestCVE202228131(t *testing.T) {
+       type nested struct {
+               Parent *nested `xml:",any"`
+       }
+       var n nested
+       err := Unmarshal(bytes.Repeat([]byte("<a>"), maxUnmarshalDepth+1), &n)
+       if err == nil {
+               t.Fatal("Unmarshal did not fail")
+       } else if !errors.Is(err, errExeceededMaxUnmarshalDepth) {
+               t.Fatalf("Unmarshal unexpected error: got %q, want %q", err, errExeceededMaxUnmarshalDepth)
+       }
+}