]> Cypherpunks repositories - keks.git/commitdiff
Unify returned NotEnoughData sizes
authorSergey Matveev <stargrave@stargrave.org>
Sat, 30 Nov 2024 20:02:05 +0000 (23:02 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sat, 30 Nov 2024 20:02:06 +0000 (23:02 +0300)
NotEnoughData.n is decided to show how many *more* bytes we require,
not how many bytes at least at all is needed.

pyac/pyac.py
pyac/tests/test_blob.py
pyac/tests/test_float.py
pyac/tests/test_int.py
pyac/tests/test_str.py
pyac/tests/test_tai.py
pyac/tests/test_uuid.py

index be55f31d928c0dba8acf81e37a193e4c8dfac0d188a2bd7ef932e47824bb27b6..59c038ff976b4ce8939775cb866d8deb25e1c40a2dec303d11744141f851d217 100755 (executable)
@@ -286,17 +286,17 @@ def loads(v, sets=False, leapsecUTCAllow=False):
         return True, v[1:]
     if v[0] == TagUUID:
         if len(v) < 1+16:
-            raise NotEnoughData(1+16)
+            raise NotEnoughData(1+16-len(v))
         return UUID(bytes=v[1:1+16]), v[1+16:]
     if v[0] in _floats:
         l = _floats[v[0]]
         if len(v) < 1+l:
-            raise NotEnoughData(1+l)
+            raise NotEnoughData(1+l-len(v))
         return Raw(v[0], v[1:1+l]), v[1+l:]
     if v[0] in _tais:
         l = _tais[v[0]]
         if len(v) < 1+l:
-            raise NotEnoughData(1+l)
+            raise NotEnoughData(1+l - len(v))
         secs = int.from_bytes(v[1:1+8], "big")
         if secs >= (1 << 63):
             raise DecodeError("reserved TAI64 value is in use")
@@ -337,10 +337,10 @@ def loads(v, sets=False, leapsecUTCAllow=False):
             l += ((1 << 8)-1) + ((1 << 16)-1)
         if llen > 0:
             if len(v) < 1+llen:
-                raise NotEnoughData(1+llen)
+                raise NotEnoughData(1+llen-len(v))
             l += int.from_bytes(v[1:1+llen], "big")
         if len(v) < 1+llen+l:
-            raise NotEnoughData(1+llen+l)
+            raise NotEnoughData(1+llen+l-len(v))
         s = v[1+llen:1+llen+l]
         if (v[0] & TagUTF8) > 0:
             try:
@@ -391,7 +391,7 @@ def loads(v, sets=False, leapsecUTCAllow=False):
         return ret, v
     if v[0] == TagBlob:
         if len(v) < 1+8:
-            raise NotEnoughData(1+8)
+            raise NotEnoughData(1+8-len(v))
         l = 1 + int.from_bytes(v[1:1+8], "big")
         v = v[1+8:]
         raws = []
@@ -399,7 +399,7 @@ def loads(v, sets=False, leapsecUTCAllow=False):
             i, v = loads(v)
             if i is None:
                 if len(v) < l:
-                    raise NotEnoughData(l)
+                    raise NotEnoughData(l-len(v)+1)
                 raws.append(v[:l])
                 v = v[l:]
             elif isinstance(i, bytes):
index e50c91af9ef6a1d107f4e327895a4f382f8128b7147e571c9ddd76831b5a2b00..5c6cc344a9f84b9f79d8c49bbf6bcc82506185aa35ea168801d6123b067bb585 100644 (file)
@@ -75,13 +75,13 @@ class TestBlob(TestCase):
         encoded = b"\x0b\x00\x00\x00\x00\x00\x00\x00\x03\x01test\x01da"
         with self.assertRaises(NotEnoughData) as err:
             loads(encoded)
-        self.assertEqual(err.exception.n, 4)
+        self.assertEqual(err.exception.n, 3)
 
     def test_throws_when_not_enough_data_for_length(self) -> None:
         encoded = b"\x0b\x00\x00\x00\x00"
         with self.assertRaises(NotEnoughData) as err:
             loads(encoded)
-        self.assertEqual(err.exception.n, 9)
+        self.assertEqual(err.exception.n, 8-4)
 
     def test_throws_when_wrong_terminator_length(self) -> None:
         encoded = b"\x0b\x00\x00\x00\x00\x00\x00\x00\x03\x01test\x01data\x8Aterminator"
index 96f440ce2f3cce840751d4d34d447b27989dc547347dc13a41fbc51634847ee2..1a279bbe9e98c1c8d9a6c491a78e4a4409e5b08c3435a247b45dab61381e5e14 100644 (file)
@@ -49,24 +49,24 @@ class TestFloat(TestCase):
     def test_not_enough_data_16(self) -> None:
         with self.assertRaises(NotEnoughData) as err:
             loads(b"\x10" + b"\x11" * (2-1))
-        self.assertEqual(err.exception.n, 1+2)
+        self.assertEqual(err.exception.n, 1)
 
     def test_not_enough_data_32(self) -> None:
         with self.assertRaises(NotEnoughData) as err:
             loads(b"\x11" + b"\x11" * (4-1))
-        self.assertEqual(err.exception.n, 1+4)
+        self.assertEqual(err.exception.n, 1)
 
     def test_not_enough_data_64(self) -> None:
         with self.assertRaises(NotEnoughData) as err:
             loads(b"\x12" + b"\x11" * (8-1))
-        self.assertEqual(err.exception.n, 1+8)
+        self.assertEqual(err.exception.n, 1)
 
     def test_not_enough_data_128(self) -> None:
         with self.assertRaises(NotEnoughData) as err:
             loads(b"\x13" + b"\x11" * (16-1))
-        self.assertEqual(err.exception.n, 1+16)
+        self.assertEqual(err.exception.n, 1)
 
     def test_not_enough_data_256(self) -> None:
         with self.assertRaises(NotEnoughData) as err:
             loads(b"\x14" + b"\x11" * (32-1))
-        self.assertEqual(err.exception.n, 1+32)
+        self.assertEqual(err.exception.n, 1)
index 8aa6583adbf3992ff80fb130eb5d8476dedb4f6ec603637cdf5239b6b2b94408..16d3629520aab43fd04f1d249247923418122b9e1cc4d4375df4bd2d8fcc4346 100644 (file)
@@ -84,9 +84,11 @@ class TestInt(TestCase):
         self.assertSequenceEqual(tail, junk)
 
     def test_decode_not_enough_data(self) -> None:
-        encoded: bytes = b"\x0c\x81"
         with self.assertRaises(NotEnoughData) as err:
-            loads(encoded)
+            loads(b"\x0c\x83\x01\x02")
+        self.assertEqual(err.exception.n, 1)
+        with self.assertRaises(NotEnoughData) as err:
+            loads(b"\x0c\x83\x01")
         self.assertEqual(err.exception.n, 2)
 
     def test_throws_when_unminimal_int(self) -> None:
index 25d6287f24d200af4154681cd9c6fdc948ecfd55b139a196cebbf80291f05628..f4cf6228a8a3e915360269bbd235e8802fa6fbf5eb95a3cf42d54d47256da6e5 100644 (file)
@@ -105,12 +105,32 @@ class TestStr(TestCase):
             result, tail = loads(encoded)
         self.assertEqual(str(err.exception), "invalid UTF-8")
 
-    def test_throws_when_not_enough_data_for_length(self) -> None:
+    def test_throws_when_not_enough_data(self) -> None:
+        encoded = b"\xc5he"
+        with self.assertRaises(NotEnoughData) as err:
+            loads(encoded)
+        self.assertEqual(err.exception.n, 5-2)
+
+    def test_throws_when_not_enough_data_for_length_8(self) -> None:
+        long_string = "a" * 100
+        encoded = dumps(long_string)[:1]
+        with self.assertRaises(NotEnoughData) as err:
+            loads(encoded)
+        self.assertEqual(err.exception.n, 1)
+
+    def test_throws_when_not_enough_data_for_length_16(self) -> None:
         long_string = "a" * 318
         encoded = dumps(long_string)[:2]
         with self.assertRaises(NotEnoughData) as err:
             loads(encoded)
-        self.assertEqual(err.exception.n, 3)
+        self.assertEqual(err.exception.n, 1)
+
+    def test_throws_when_not_enough_data_for_length_64(self) -> None:
+        long_string = "a" * 65853
+        encoded = dumps(long_string)[:2]
+        with self.assertRaises(NotEnoughData) as err:
+            loads(encoded)
+        self.assertEqual(err.exception.n, 7)
 
     @given(unicode_allowed)
     def test_symmetric(self, s: str):
@@ -166,14 +186,28 @@ class TestBin(TestCase):
         encoded = b"\x85he"
         with self.assertRaises(NotEnoughData) as err:
             loads(encoded)
-        self.assertEqual(err.exception.n, 6)
+        self.assertEqual(err.exception.n, 5-2)
+
+    def test_throws_when_not_enough_data_for_length_8(self) -> None:
+        long_string = b"a" * 100
+        encoded = dumps(long_string)[:1]
+        with self.assertRaises(NotEnoughData) as err:
+            loads(encoded)
+        self.assertEqual(err.exception.n, 1)
 
-    def test_throws_when_not_enough_data_for_length(self) -> None:
+    def test_throws_when_not_enough_data_for_length_16(self) -> None:
         long_string = b"a" * 318
         encoded = dumps(long_string)[:2]
         with self.assertRaises(NotEnoughData) as err:
             loads(encoded)
-        self.assertEqual(err.exception.n, 3)
+        self.assertEqual(err.exception.n, 1)
+
+    def test_throws_when_not_enough_data_for_length_64(self) -> None:
+        long_string = b"a" * 65853
+        encoded = dumps(long_string)[:2]
+        with self.assertRaises(NotEnoughData) as err:
+            loads(encoded)
+        self.assertEqual(err.exception.n, 7)
 
     @given(binary())
     def test_symmetric(self, s: bytes):
index cc30cf8f71efa2defb0a78d47c04cc64c4311f92950fab29d95620ac5d1e4e2b..d40e2d0c03d4d6c87ab70cbcb0f991b00276d1b6d593806579b88d230781c39c 100644 (file)
@@ -60,8 +60,9 @@ class TestTAI64(TestCase):
         self.assertSequenceEqual(tail, junk)
 
     def test_throws_when_not_enough_data(self) -> None:
-        with self.assertRaises(NotEnoughData):
-            loads(b"\x18" + b"\x00" * 7)
+        with self.assertRaises(NotEnoughData) as err:
+            loads(b"\x18" + b"\x00" * (8-1))
+        self.assertEqual(err.exception.n, 1)
 
     def test_large_number_of_secs(self) -> None:
         decoded, tail = loads(b"\x18\x70\x00\x00\x00\x65\x19\x5f\x65")
@@ -85,8 +86,9 @@ class TestTAI64N(TestCase):
         self.assertSequenceEqual(tail, junk)
 
     def test_throws_when_not_enough_data(self) -> None:
-        with self.assertRaises(NotEnoughData):
-            loads(b"\x19" + b"\x00" * 11)
+        with self.assertRaises(NotEnoughData) as err:
+            loads(b"\x19" + b"\x00" * (12-2))
+        self.assertEqual(err.exception.n, 2)
 
     def test_nanoseconds_not_convertible_to_microseconds(self) -> None:
         decoded, tail = loads(
@@ -149,8 +151,9 @@ class TestTAI64NA(TestCase):
         self.assertSequenceEqual(tail, junk)
 
     def test_throws_when_not_enough_data(self) -> None:
-        with self.assertRaises(NotEnoughData):
-            loads(b"\x1a" + b"\x00" * 15)
+        with self.assertRaises(NotEnoughData) as err:
+            loads(b"\x1a" + b"\x00" * (16-3))
+        self.assertEqual(err.exception.n, 3)
 
     def test_throws_when_too_many_attosecs(self) -> None:
         with self.assertRaises(DecodeError) as err:
index 7722d95975dbf54933d49d8809c260d82114d235b6756e1e669619771b5c7f94..e95002ff905fb0e5093e67b6b0ef2034bee5c99fafe82fca152deaa16739af32 100644 (file)
@@ -33,7 +33,7 @@ class TestUUID(TestCase):
         encoded: bytes = b"\x04\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34\x56\x78"
         with self.assertRaises(NotEnoughData) as err:
             loads(encoded[:-4])
-        self.assertEqual(err.exception.n, 1+16)
+        self.assertEqual(err.exception.n, 4)
 
     @given(uuids())
     def test_symmetric(self, u: UUID) -> None: