]> Cypherpunks repositories - keks.git/commitdiff
Prevent uncaught recursion limits during decoding
authorSergey Matveev <stargrave@stargrave.org>
Wed, 4 Dec 2024 16:14:46 +0000 (19:14 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Wed, 4 Dec 2024 16:15:27 +0000 (19:15 +0300)
pyac/pyac.py
pyac/tests/test_recursion.py [new file with mode: 0644]

index 67632438c8a2226672a5f74da7710d305d9bfb597d471838114f1476c00f12d2..e847a5e44550de4d9ce69ae6b1700e2e6599115dfd570f7b54ea282b2135c5b2 100755 (executable)
@@ -232,7 +232,7 @@ def dumps(v):
 
 
 def _int(v):
-    s, tail = loads(v)
+    s, tail = _loads(v, _allowContainers=False)
     if not isinstance(s, bytes):
         raise DecodeError("non-BIN in INT")
     if s == b"":
@@ -264,16 +264,7 @@ def tai2utc(secs, leapsecUTCAllow=False):
     return secs - diff
 
 
-def loads(v, sets=False, leapsecUTCAllow=False):
-    """Decode YAC-encoded data.
-
-    :param bool sets: transform maps with NIL-only values to set()s
-    :param bool leapsecUTCAllow: allow TAI64 values equal to leap seconds,
-                                 to be decoded as datetime UTC value. Raw()
-                                 value will be returned instead
-    :returns: decoded data and the undecoded tail
-    :rtype: (any, bytes)
-    """
+def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
     if len(v) == 0:
         raise NotEnoughData(1)
     if v[0] == TagEOC:
@@ -358,7 +349,7 @@ def loads(v, sets=False, leapsecUTCAllow=False):
     if v[0] == TagNInt:
         i, v = _int(v[1:])
         return (-1 - i), v
-    if v[0] == TagList:
+    if (v[0] == TagList) and _allowContainers:
         ret = []
         v = v[1:]
         while True:
@@ -367,13 +358,13 @@ def loads(v, sets=False, leapsecUTCAllow=False):
                 break
             ret.append(i)
         return ret, v
-    if v[0] == TagMap:
+    if (v[0] == TagMap) and _allowContainers:
         ret = {}
         v = v[1:]
         kPrev = ""
         allNILs = True
         while True:
-            k, v = loads(v)
+            k, v = _loads(v, _allowContainers=False)
             if k == _EOC:
                 break
             if not isinstance(k, str):
@@ -392,14 +383,14 @@ def loads(v, sets=False, leapsecUTCAllow=False):
         if sets and allNILs:
             ret = set(ret.keys())
         return ret, v
-    if v[0] == TagBlob:
+    if (v[0] == TagBlob) and _allowContainers:
         if len(v) < 1+8:
             raise NotEnoughData(1+8-len(v))
         l = 1 + int.from_bytes(v[1:1+8], "big")
         v = v[1+8:]
         raws = []
         while True:
-            i, v = loads(v)
+            i, v = _loads(v, _allowContainers=False)
             if i is None:
                 if len(v) < l:
                     raise NotEnoughData(l-len(v)+1)
@@ -416,6 +407,22 @@ def loads(v, sets=False, leapsecUTCAllow=False):
     raise DecodeError("unknown tag")
 
 
+def loads(v, **kwargs):
+    """Decode YAC-encoded data.
+
+    :param bool sets: transform maps with NIL-only values to set()s
+    :param bool leapsecUTCAllow: allow TAI64 values equal to leap seconds,
+                                 to be decoded as datetime UTC value. Raw()
+                                 value will be returned instead
+    :returns: decoded data and the undecoded tail
+    :rtype: (any, bytes)
+    """
+    try:
+        return _loads(v, **kwargs)
+    except RecursionError as err:
+        raise DecodeError("deep recursion") from err
+
+
 if __name__ == "__main__":
     from argparse import ArgumentParser
     from argparse import FileType
diff --git a/pyac/tests/test_recursion.py b/pyac/tests/test_recursion.py
new file mode 100644 (file)
index 0000000..902ad3c
--- /dev/null
@@ -0,0 +1,45 @@
+# pyac -- Python YAC implementation
+# Copyright (C) 2024-2025 Antont Rudenko <rudenko.ad@phystech.edu>
+#               2024-2025 Sergey Matveev <stargrave@stargrave.org>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as
+# published by the Free Software Foundation, version 3 of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from unittest import TestCase
+
+from pyac import _byte
+from pyac import DecodeError
+from pyac import loads
+from pyac import TagBlob
+from pyac import TagList
+from pyac import TagPInt
+
+
+class TestTooDeepInt(TestCase):
+    def runTest(self) -> None:
+        with self.assertRaises(DecodeError) as err:
+            loads(_byte(TagPInt) + _byte(TagList) * 1000)
+        self.assertEqual(str(err.exception), "unknown tag")
+
+
+class TestTooDeepBlob(TestCase):
+    def runTest(self) -> None:
+        with self.assertRaises(DecodeError) as err:
+            loads(_byte(TagBlob) + (8 * b"\x01") + _byte(TagList) * 1000)
+        self.assertEqual(str(err.exception), "unknown tag")
+
+
+class TestTooDeepList(TestCase):
+    def runTest(self) -> None:
+        with self.assertRaises(DecodeError) as err:
+            loads(_byte(TagList) * 1000)
+        self.assertEqual(str(err.exception), "deep recursion")