]> Cypherpunks repositories - keks.git/commitdiff
No need in Raw's tag separation from the body
authorSergey Matveev <stargrave@stargrave.org>
Wed, 18 Dec 2024 08:02:55 +0000 (11:02 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Wed, 18 Dec 2024 08:02:55 +0000 (11:02 +0300)
py3/keks.py
py3/test-vector.py
py3/tests/test_float.py
py3/tests/test_fuzz_inputs.py
py3/tests/test_generic.py
py3/tests/test_tai.py

index 58d2593df8b1607c7c4b1b783b07d21a4ab8c3fcaf05e3d5a2927d971e10e114..72e3163a6ed290789d49087c1f72f4871b1dfa4b6036354c9a16bc363a100b50 100755 (executable)
@@ -96,7 +96,24 @@ class NotEnoughData(DecodeError):
         return "%s(%s)" % (self.__class__.__name__, self)
 
 
-Raw = namedtuple("Raw", ("t", "v"))
+class Raw:
+    __slots__ = ("v",)
+
+    def __init__(self, v: bytes):
+        self.v = v
+
+    def __bytes__(self) -> bytes:
+        return self.v
+
+    def __eq__(self, other) -> bool:
+        if not isinstance(other, self.__class__):
+            return False
+        return self.v == other.v
+
+    def __repr__(self) -> str:
+        return "Raw(%s)" % self.v.hex()
+
+
 Blob = namedtuple("Blob", ("l", "v"))
 
 
@@ -198,7 +215,7 @@ def dumps(v):
             (ms * 1000).to_bytes(4, "big")
         )
     if isinstance(v, Raw):
-        return _byte(v.t) + v.v
+        return v.v
     if isinstance(v, bytes):
         return _str(v, utf8=False)
     if isinstance(v, str):
@@ -300,7 +317,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
     if l is not None:
         if len(v) < 1+l:
             raise NotEnoughData(1+l-len(v))
-        return Raw(v[0], v[1:1+l]), v[1+l:]
+        return Raw(v[:1+l]), v[1+l:]
     l = _tais.get(b)
     if l is not None:
         if len(v) < 1+l:
@@ -324,14 +341,14 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
                 raise DecodeError("too many attoseconds")
         secs = tai2utc(secs - TAI64Base, leapsecUTCAllow)
         if secs is None:
-            return Raw(v[0], v[1:1+l]), v[1+l:]
+            return Raw(v[:1+l]), v[1+l:]
         if (asecs > 0) or ((nsecs % 1000) > 0):
             # Python can represent neither attoseconds, nor nanoseconds
-            return Raw(v[0], v[1:1+l]), v[1+l:]
+            return Raw(v[:1+l]), v[1+l:]
         try:
             dt = datetime(1970, 1, 1) + timedelta(seconds=secs)
         except OverflowError:
-            return Raw(v[0], v[1:1+l]), v[1+l:]
+            return Raw(v[:1+l]), v[1+l:]
         dt += timedelta(microseconds=nsecs // 1000)
         return dt, v[1+l:]
     if (b & TagStr) > 0:
index e36061e94d62c288e7ca8c5a4070419a8bd1f2a7095805d588fee282e6bf4164..60156f98c0ea4540702a6767165db730b752a8cec989200edc23bd5909c0b495 100644 (file)
@@ -16,7 +16,7 @@ data = {
             -123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789,
         ],
     },
-    "floats": [keks.Raw(keks.TagFloat32, bytes.fromhex("01020304"))],
+    "floats": [keks.Raw(keks._byte(keks.TagFloat32) + bytes.fromhex("01020304"))],
     "nil": None,
     "bool": [True, False],
     "str": {
@@ -43,15 +43,15 @@ data = {
         {},
         keks.Blob(123, b""),
         UUID("00000000-0000-0000-0000-000000000000"),
-        keks.Raw(keks.TagTAI64, bytes.fromhex("0000000000000000")),
+        keks.Raw(keks._byte(keks.TagTAI64) + bytes.fromhex("0000000000000000")),
     ],
     "uuid": UUID("0e875e3f-d385-49eb-87b4-be42d641c367"),
 }
 data["dates"] = [
     (datetime(1970, 1, 1) + timedelta(seconds=1234567890)),
     (datetime(1970, 1, 1) + timedelta(seconds=1234567890)).replace(microsecond=456),
-    keks.Raw(keks.TagTAI64N, bytes.fromhex("40000000499602F40006F855")),
-    keks.Raw(keks.TagTAI64NA, bytes.fromhex("40000000499602F40006F855075BCD15")),
+    keks.Raw(keks._byte(keks.TagTAI64N) + bytes.fromhex("40000000499602F40006F855")),
+    keks.Raw(keks._byte(keks.TagTAI64NA) + bytes.fromhex("40000000499602F40006F855075BCD15")),
 ]
 raw = keks.dumps(data)
 dec, tail = keks.loads(raw)
index af90c552cf6c07bf2eab5df798d6d496fc7f06f5d2f8645b11c456a837c17b28..5ff56857fcd91d6d94947835633458933c2e294f4f1b15fb6a3a808e0bfa01d6 100644 (file)
@@ -18,6 +18,7 @@ from unittest import TestCase
 
 from hypothesis import given
 
+from keks import _byte
 from keks import dumps
 from keks import loads
 from keks import NotEnoughData
@@ -35,31 +36,31 @@ class TestFloat(TestCase):
     @given(junk_st)
     def test_loads_16(self, junk: bytes) -> None:
         decoded, tail = loads((b"\x10" + b"\x11" * 2) + junk)
-        self.assertEqual(decoded, Raw(0x10, b"\x11" * 2))
+        self.assertEqual(decoded, Raw(_byte(0x10) + b"\x11" * 2))
         self.assertSequenceEqual(tail, junk)
 
     @given(junk_st)
     def test_loads_32(self, junk: bytes) -> None:
         decoded, tail = loads((b"\x11" + b"\x11" * 4) + junk)
-        self.assertEqual(decoded, Raw(0x11, b"\x11" * 4))
+        self.assertEqual(decoded, Raw(_byte(0x11) + b"\x11" * 4))
         self.assertSequenceEqual(tail, junk)
 
     @given(junk_st)
     def test_loads_64(self, junk: bytes) -> None:
         decoded, tail = loads((b"\x12" + b"\x11" * 8) + junk)
-        self.assertEqual(decoded, Raw(0x12, b"\x11" * 8))
+        self.assertEqual(decoded, Raw(_byte(0x12) + b"\x11" * 8))
         self.assertSequenceEqual(tail, junk)
 
     @given(junk_st)
     def test_loads_128(self, junk: bytes) -> None:
         decoded, tail = loads((b"\x13" + b"\x11" * 16) + junk)
-        self.assertEqual(decoded, Raw(0x13, b"\x11" * 16))
+        self.assertEqual(decoded, Raw(_byte(0x13) + b"\x11" * 16))
         self.assertSequenceEqual(tail, junk)
 
     @given(junk_st)
     def test_loads_256(self, junk: bytes) -> None:
         decoded, tail = loads((b"\x14" + b"\x11" * 32) + junk)
-        self.assertEqual(decoded, Raw(0x14, b"\x11" * 32))
+        self.assertEqual(decoded, Raw(_byte(0x14) + b"\x11" * 32))
         self.assertSequenceEqual(tail, junk)
 
     def test_not_enough_data_16(self) -> None:
index b95b8557f58ebb22db07d4b9f01e01500c64a46a419f4377027dd583ad67f161..0b78db7d65f0b873eabd92a722147a51d1e3f2d61d465f7098eb6ce91249da00 100644 (file)
@@ -6,6 +6,7 @@ from os.path import join as path_join
 from unittest import skipIf
 from unittest import TestCase
 
+from keks import _byte
 from keks import Blob
 from keks import Leapsecs1972
 from keks import loads
@@ -67,13 +68,13 @@ class TestFuzzInputs(TestCase):
         self.assertEqual(readInput("tai-utc0"), datetime(1970, 1, 1, 0, 0))
         self.assertEqual(readInput("tai-before"), datetime(1969, 12, 31, 23, 59, 49))
         self.assertEqual(readInput("tai-leap"), Raw(
-            t=TagTAI64, v=bytes.fromhex("40000000586846A4"),
+            _byte(TagTAI64) + bytes.fromhex("40000000586846A4"),
         ))
-        self.assertEqual(readInput("tai-ns"), Raw(t=TagTAI64N, v=(
+        self.assertEqual(readInput("tai-ns"), Raw(_byte(TagTAI64N) + (
             (TAI64Base + Leapsecs1972 + 1234).to_bytes(8, "big") +
             (1234).to_bytes(4, "big")
         )))
-        self.assertEqual(readInput("tai-as"), Raw(t=TagTAI64NA, v=(
+        self.assertEqual(readInput("tai-as"), Raw(_byte(TagTAI64NA) + (
             (TAI64Base + Leapsecs1972 + 1234).to_bytes(8, "big") +
             2 * (1234).to_bytes(4, "big")
         )))
index 2cde6f5810cc7ef547c9f5132488994cc061047b987f9413a9ae6180d113b483..386882300361a5c2c95bb965c0cee21af57fe0cbdda836261ed711adab987989 100644 (file)
@@ -50,9 +50,9 @@ class TestEmptyData(TestCase):
 
 
 class TestRaw(TestCase):
-    @given(binary(min_size=1, max_size=1), binary(max_size=8))
-    def runTest(self, hdr: bytes, body: bytes) -> None:
-        self.assertSequenceEqual(dumps(Raw(hdr[0], body)), hdr + body)
+    @given(binary(max_size=8))
+    def runTest(self, body: bytes) -> None:
+        self.assertSequenceEqual(dumps(Raw(body)), body)
 
 
 class TestLonelyEOC(TestCase):
index 868d34245d669796bdde8e005eb43e3fc0a781b24c88cc245217a05ffdd2fab9..b77208914be45f99706fa5b85a48edfdd6d0236bd8b18455e91a593fd569446e 100644 (file)
@@ -83,7 +83,9 @@ class TestTAI64(TestCase):
 
     def test_large_number_of_secs(self) -> None:
         decoded, tail = loads(bytes.fromhex("187000000065195F65"))
-        self.assertEqual(decoded, Raw(t=0x18, v=bytes.fromhex("7000000065195F65")))
+        self.assertEqual(decoded, Raw(
+            _byte(0x18) + bytes.fromhex("7000000065195F65")),
+        )
         self.assertSequenceEqual(tail, b"")
 
     def test_throws_when_msb_is_set(self) -> None:
@@ -120,7 +122,7 @@ class TestTAI64N(TestCase):
         decoded, tail = loads(bytes.fromhex("194000000065195F65075BCA01"))
         self.assertEqual(
             decoded,
-            Raw(t=0x19, v=bytes.fromhex("4000000065195F65075BCA01")),
+            Raw(_byte(0x19) + bytes.fromhex("4000000065195F65075BCA01")),
         )
         self.assertSequenceEqual(tail, b"")
 
@@ -145,10 +147,11 @@ class TestTAI64N(TestCase):
             bytes.fromhex("197000000065195F65") +
             bytes.fromhex("00010000")
         )
-        self.assertEqual(
-            decoded,
-            Raw(t=0x19, v=bytes.fromhex("7000000065195F65") + bytes.fromhex("00010000")),
-        )
+        self.assertEqual(decoded, Raw(
+            _byte(0x19) +
+            bytes.fromhex("7000000065195F65") +
+            bytes.fromhex("00010000"),
+        ))
         self.assertSequenceEqual(tail, b"")
 
     def test_throws_when_msb_is_set(self) -> None:
@@ -164,7 +167,7 @@ class TestTAI64NA(TestCase):
     @given(junk_st)
     def test_decode(self, junk: bytes) -> None:
         encoded = bytes.fromhex("1A4000000065195F65075BCA00075BCA00") + junk
-        expected = Raw(t=0x1A, v=bytes.fromhex("4000000065195F65075BCA00075BCA00"))
+        expected = Raw(_byte(0x1A) + bytes.fromhex("4000000065195F65075BCA00075BCA00"))
         decoded, tail = loads(encoded)
         self.assertEqual(decoded, expected)
         self.assertSequenceEqual(tail, junk)
@@ -202,8 +205,9 @@ class TestTAI64NA(TestCase):
             bytes.fromhex("00010000")
         )
         self.assertEqual(decoded, Raw(
-            t=0x1A,
-            v=bytes.fromhex("7000000065195F65") + 2 * bytes.fromhex("00010000"),
+            _byte(0x1A) +
+            bytes.fromhex("7000000065195F65") +
+            2 * bytes.fromhex("00010000"),
         ))
         self.assertSequenceEqual(tail, b"")
 
@@ -265,4 +269,4 @@ class TestLeapsecs(TestCase):
                 leapsecUTCAllow=False,
             )
             self.assertIsInstance(decoded, Raw)
-            self.assertSequenceEqual(decoded.v, bytes.fromhex(leapsec))
+            self.assertSequenceEqual(decoded.v[1:], bytes.fromhex(leapsec))