From 730f0e1516783b0a2e15195476afacd279a18a4d72bed42b22a5811d2402bb55 Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Wed, 4 Dec 2024 19:14:46 +0300 Subject: [PATCH] Prevent uncaught recursion limits during decoding --- pyac/pyac.py | 39 ++++++++++++++++++------------- pyac/tests/test_recursion.py | 45 ++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 16 deletions(-) create mode 100644 pyac/tests/test_recursion.py diff --git a/pyac/pyac.py b/pyac/pyac.py index 6763243..e847a5e 100755 --- a/pyac/pyac.py +++ b/pyac/pyac.py @@ -232,7 +232,7 @@ def dumps(v): def _int(v): - s, tail = loads(v) + s, tail = _loads(v, _allowContainers=False) if not isinstance(s, bytes): raise DecodeError("non-BIN in INT") if s == b"": @@ -264,16 +264,7 @@ def tai2utc(secs, leapsecUTCAllow=False): return secs - diff -def loads(v, sets=False, leapsecUTCAllow=False): - """Decode YAC-encoded data. - - :param bool sets: transform maps with NIL-only values to set()s - :param bool leapsecUTCAllow: allow TAI64 values equal to leap seconds, - to be decoded as datetime UTC value. Raw() - value will be returned instead - :returns: decoded data and the undecoded tail - :rtype: (any, bytes) - """ +def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True): if len(v) == 0: raise NotEnoughData(1) if v[0] == TagEOC: @@ -358,7 +349,7 @@ def loads(v, sets=False, leapsecUTCAllow=False): if v[0] == TagNInt: i, v = _int(v[1:]) return (-1 - i), v - if v[0] == TagList: + if (v[0] == TagList) and _allowContainers: ret = [] v = v[1:] while True: @@ -367,13 +358,13 @@ def loads(v, sets=False, leapsecUTCAllow=False): break ret.append(i) return ret, v - if v[0] == TagMap: + if (v[0] == TagMap) and _allowContainers: ret = {} v = v[1:] kPrev = "" allNILs = True while True: - k, v = loads(v) + k, v = _loads(v, _allowContainers=False) if k == _EOC: break if not isinstance(k, str): @@ -392,14 +383,14 @@ def loads(v, sets=False, leapsecUTCAllow=False): if sets and allNILs: ret = set(ret.keys()) return ret, v - if v[0] == TagBlob: + if (v[0] == TagBlob) and _allowContainers: if len(v) < 1+8: raise NotEnoughData(1+8-len(v)) l = 1 + int.from_bytes(v[1:1+8], "big") v = v[1+8:] raws = [] while True: - i, v = loads(v) + i, v = _loads(v, _allowContainers=False) if i is None: if len(v) < l: raise NotEnoughData(l-len(v)+1) @@ -416,6 +407,22 @@ def loads(v, sets=False, leapsecUTCAllow=False): raise DecodeError("unknown tag") +def loads(v, **kwargs): + """Decode YAC-encoded data. + + :param bool sets: transform maps with NIL-only values to set()s + :param bool leapsecUTCAllow: allow TAI64 values equal to leap seconds, + to be decoded as datetime UTC value. Raw() + value will be returned instead + :returns: decoded data and the undecoded tail + :rtype: (any, bytes) + """ + try: + return _loads(v, **kwargs) + except RecursionError as err: + raise DecodeError("deep recursion") from err + + if __name__ == "__main__": from argparse import ArgumentParser from argparse import FileType diff --git a/pyac/tests/test_recursion.py b/pyac/tests/test_recursion.py new file mode 100644 index 0000000..902ad3c --- /dev/null +++ b/pyac/tests/test_recursion.py @@ -0,0 +1,45 @@ +# pyac -- Python YAC implementation +# Copyright (C) 2024-2025 Antont Rudenko +# 2024-2025 Sergey Matveev +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, version 3 of the License. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this program. If not, see . + +from unittest import TestCase + +from pyac import _byte +from pyac import DecodeError +from pyac import loads +from pyac import TagBlob +from pyac import TagList +from pyac import TagPInt + + +class TestTooDeepInt(TestCase): + def runTest(self) -> None: + with self.assertRaises(DecodeError) as err: + loads(_byte(TagPInt) + _byte(TagList) * 1000) + self.assertEqual(str(err.exception), "unknown tag") + + +class TestTooDeepBlob(TestCase): + def runTest(self) -> None: + with self.assertRaises(DecodeError) as err: + loads(_byte(TagBlob) + (8 * b"\x01") + _byte(TagList) * 1000) + self.assertEqual(str(err.exception), "unknown tag") + + +class TestTooDeepList(TestCase): + def runTest(self) -> None: + with self.assertRaises(DecodeError) as err: + loads(_byte(TagList) * 1000) + self.assertEqual(str(err.exception), "deep recursion") -- 2.50.0