From 42f591df728ed0debb131e9a46ebbc59b395cbe8753a18629acc42b82a01abdb Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Sat, 30 Nov 2024 23:02:05 +0300 Subject: [PATCH] Unify returned NotEnoughData sizes NotEnoughData.n is decided to show how many *more* bytes we require, not how many bytes at least at all is needed. --- pyac/pyac.py | 14 ++++++------- pyac/tests/test_blob.py | 4 ++-- pyac/tests/test_float.py | 10 ++++----- pyac/tests/test_int.py | 6 ++++-- pyac/tests/test_str.py | 44 +++++++++++++++++++++++++++++++++++----- pyac/tests/test_tai.py | 15 ++++++++------ pyac/tests/test_uuid.py | 2 +- 7 files changed, 67 insertions(+), 28 deletions(-) diff --git a/pyac/pyac.py b/pyac/pyac.py index be55f31..59c038f 100755 --- a/pyac/pyac.py +++ b/pyac/pyac.py @@ -286,17 +286,17 @@ def loads(v, sets=False, leapsecUTCAllow=False): return True, v[1:] if v[0] == TagUUID: if len(v) < 1+16: - raise NotEnoughData(1+16) + raise NotEnoughData(1+16-len(v)) return UUID(bytes=v[1:1+16]), v[1+16:] if v[0] in _floats: l = _floats[v[0]] if len(v) < 1+l: - raise NotEnoughData(1+l) + raise NotEnoughData(1+l-len(v)) return Raw(v[0], v[1:1+l]), v[1+l:] if v[0] in _tais: l = _tais[v[0]] if len(v) < 1+l: - raise NotEnoughData(1+l) + raise NotEnoughData(1+l - len(v)) secs = int.from_bytes(v[1:1+8], "big") if secs >= (1 << 63): raise DecodeError("reserved TAI64 value is in use") @@ -337,10 +337,10 @@ def loads(v, sets=False, leapsecUTCAllow=False): l += ((1 << 8)-1) + ((1 << 16)-1) if llen > 0: if len(v) < 1+llen: - raise NotEnoughData(1+llen) + raise NotEnoughData(1+llen-len(v)) l += int.from_bytes(v[1:1+llen], "big") if len(v) < 1+llen+l: - raise NotEnoughData(1+llen+l) + raise NotEnoughData(1+llen+l-len(v)) s = v[1+llen:1+llen+l] if (v[0] & TagUTF8) > 0: try: @@ -391,7 +391,7 @@ def loads(v, sets=False, leapsecUTCAllow=False): return ret, v if v[0] == TagBlob: if len(v) < 1+8: - raise NotEnoughData(1+8) + raise NotEnoughData(1+8-len(v)) l = 1 + int.from_bytes(v[1:1+8], "big") v = v[1+8:] raws = [] @@ -399,7 +399,7 @@ def loads(v, sets=False, leapsecUTCAllow=False): i, v = loads(v) if i is None: if len(v) < l: - raise NotEnoughData(l) + raise NotEnoughData(l-len(v)+1) raws.append(v[:l]) v = v[l:] elif isinstance(i, bytes): diff --git a/pyac/tests/test_blob.py b/pyac/tests/test_blob.py index e50c91a..5c6cc34 100644 --- a/pyac/tests/test_blob.py +++ b/pyac/tests/test_blob.py @@ -75,13 +75,13 @@ class TestBlob(TestCase): encoded = b"\x0b\x00\x00\x00\x00\x00\x00\x00\x03\x01test\x01da" with self.assertRaises(NotEnoughData) as err: loads(encoded) - self.assertEqual(err.exception.n, 4) + self.assertEqual(err.exception.n, 3) def test_throws_when_not_enough_data_for_length(self) -> None: encoded = b"\x0b\x00\x00\x00\x00" with self.assertRaises(NotEnoughData) as err: loads(encoded) - self.assertEqual(err.exception.n, 9) + self.assertEqual(err.exception.n, 8-4) def test_throws_when_wrong_terminator_length(self) -> None: encoded = b"\x0b\x00\x00\x00\x00\x00\x00\x00\x03\x01test\x01data\x8Aterminator" diff --git a/pyac/tests/test_float.py b/pyac/tests/test_float.py index 96f440c..1a279bb 100644 --- a/pyac/tests/test_float.py +++ b/pyac/tests/test_float.py @@ -49,24 +49,24 @@ class TestFloat(TestCase): def test_not_enough_data_16(self) -> None: with self.assertRaises(NotEnoughData) as err: loads(b"\x10" + b"\x11" * (2-1)) - self.assertEqual(err.exception.n, 1+2) + self.assertEqual(err.exception.n, 1) def test_not_enough_data_32(self) -> None: with self.assertRaises(NotEnoughData) as err: loads(b"\x11" + b"\x11" * (4-1)) - self.assertEqual(err.exception.n, 1+4) + self.assertEqual(err.exception.n, 1) def test_not_enough_data_64(self) -> None: with self.assertRaises(NotEnoughData) as err: loads(b"\x12" + b"\x11" * (8-1)) - self.assertEqual(err.exception.n, 1+8) + self.assertEqual(err.exception.n, 1) def test_not_enough_data_128(self) -> None: with self.assertRaises(NotEnoughData) as err: loads(b"\x13" + b"\x11" * (16-1)) - self.assertEqual(err.exception.n, 1+16) + self.assertEqual(err.exception.n, 1) def test_not_enough_data_256(self) -> None: with self.assertRaises(NotEnoughData) as err: loads(b"\x14" + b"\x11" * (32-1)) - self.assertEqual(err.exception.n, 1+32) + self.assertEqual(err.exception.n, 1) diff --git a/pyac/tests/test_int.py b/pyac/tests/test_int.py index 8aa6583..16d3629 100644 --- a/pyac/tests/test_int.py +++ b/pyac/tests/test_int.py @@ -84,9 +84,11 @@ class TestInt(TestCase): self.assertSequenceEqual(tail, junk) def test_decode_not_enough_data(self) -> None: - encoded: bytes = b"\x0c\x81" with self.assertRaises(NotEnoughData) as err: - loads(encoded) + loads(b"\x0c\x83\x01\x02") + self.assertEqual(err.exception.n, 1) + with self.assertRaises(NotEnoughData) as err: + loads(b"\x0c\x83\x01") self.assertEqual(err.exception.n, 2) def test_throws_when_unminimal_int(self) -> None: diff --git a/pyac/tests/test_str.py b/pyac/tests/test_str.py index 25d6287..f4cf622 100644 --- a/pyac/tests/test_str.py +++ b/pyac/tests/test_str.py @@ -105,12 +105,32 @@ class TestStr(TestCase): result, tail = loads(encoded) self.assertEqual(str(err.exception), "invalid UTF-8") - def test_throws_when_not_enough_data_for_length(self) -> None: + def test_throws_when_not_enough_data(self) -> None: + encoded = b"\xc5he" + with self.assertRaises(NotEnoughData) as err: + loads(encoded) + self.assertEqual(err.exception.n, 5-2) + + def test_throws_when_not_enough_data_for_length_8(self) -> None: + long_string = "a" * 100 + encoded = dumps(long_string)[:1] + with self.assertRaises(NotEnoughData) as err: + loads(encoded) + self.assertEqual(err.exception.n, 1) + + def test_throws_when_not_enough_data_for_length_16(self) -> None: long_string = "a" * 318 encoded = dumps(long_string)[:2] with self.assertRaises(NotEnoughData) as err: loads(encoded) - self.assertEqual(err.exception.n, 3) + self.assertEqual(err.exception.n, 1) + + def test_throws_when_not_enough_data_for_length_64(self) -> None: + long_string = "a" * 65853 + encoded = dumps(long_string)[:2] + with self.assertRaises(NotEnoughData) as err: + loads(encoded) + self.assertEqual(err.exception.n, 7) @given(unicode_allowed) def test_symmetric(self, s: str): @@ -166,14 +186,28 @@ class TestBin(TestCase): encoded = b"\x85he" with self.assertRaises(NotEnoughData) as err: loads(encoded) - self.assertEqual(err.exception.n, 6) + self.assertEqual(err.exception.n, 5-2) + + def test_throws_when_not_enough_data_for_length_8(self) -> None: + long_string = b"a" * 100 + encoded = dumps(long_string)[:1] + with self.assertRaises(NotEnoughData) as err: + loads(encoded) + self.assertEqual(err.exception.n, 1) - def test_throws_when_not_enough_data_for_length(self) -> None: + def test_throws_when_not_enough_data_for_length_16(self) -> None: long_string = b"a" * 318 encoded = dumps(long_string)[:2] with self.assertRaises(NotEnoughData) as err: loads(encoded) - self.assertEqual(err.exception.n, 3) + self.assertEqual(err.exception.n, 1) + + def test_throws_when_not_enough_data_for_length_64(self) -> None: + long_string = b"a" * 65853 + encoded = dumps(long_string)[:2] + with self.assertRaises(NotEnoughData) as err: + loads(encoded) + self.assertEqual(err.exception.n, 7) @given(binary()) def test_symmetric(self, s: bytes): diff --git a/pyac/tests/test_tai.py b/pyac/tests/test_tai.py index cc30cf8..d40e2d0 100644 --- a/pyac/tests/test_tai.py +++ b/pyac/tests/test_tai.py @@ -60,8 +60,9 @@ class TestTAI64(TestCase): self.assertSequenceEqual(tail, junk) def test_throws_when_not_enough_data(self) -> None: - with self.assertRaises(NotEnoughData): - loads(b"\x18" + b"\x00" * 7) + with self.assertRaises(NotEnoughData) as err: + loads(b"\x18" + b"\x00" * (8-1)) + self.assertEqual(err.exception.n, 1) def test_large_number_of_secs(self) -> None: decoded, tail = loads(b"\x18\x70\x00\x00\x00\x65\x19\x5f\x65") @@ -85,8 +86,9 @@ class TestTAI64N(TestCase): self.assertSequenceEqual(tail, junk) def test_throws_when_not_enough_data(self) -> None: - with self.assertRaises(NotEnoughData): - loads(b"\x19" + b"\x00" * 11) + with self.assertRaises(NotEnoughData) as err: + loads(b"\x19" + b"\x00" * (12-2)) + self.assertEqual(err.exception.n, 2) def test_nanoseconds_not_convertible_to_microseconds(self) -> None: decoded, tail = loads( @@ -149,8 +151,9 @@ class TestTAI64NA(TestCase): self.assertSequenceEqual(tail, junk) def test_throws_when_not_enough_data(self) -> None: - with self.assertRaises(NotEnoughData): - loads(b"\x1a" + b"\x00" * 15) + with self.assertRaises(NotEnoughData) as err: + loads(b"\x1a" + b"\x00" * (16-3)) + self.assertEqual(err.exception.n, 3) def test_throws_when_too_many_attosecs(self) -> None: with self.assertRaises(DecodeError) as err: diff --git a/pyac/tests/test_uuid.py b/pyac/tests/test_uuid.py index 7722d95..e95002f 100644 --- a/pyac/tests/test_uuid.py +++ b/pyac/tests/test_uuid.py @@ -33,7 +33,7 @@ class TestUUID(TestCase): encoded: bytes = b"\x04\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34\x56\x78" with self.assertRaises(NotEnoughData) as err: loads(encoded[:-4]) - self.assertEqual(err.exception.n, 1+16) + self.assertEqual(err.exception.n, 4) @given(uuids()) def test_symmetric(self, u: UUID) -> None: -- 2.48.1