]> Cypherpunks repositories - keks.git/commitdiff
Split long strings tests
authorSergey Matveev <stargrave@stargrave.org>
Sat, 30 Nov 2024 18:15:37 +0000 (21:15 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sat, 30 Nov 2024 19:38:56 +0000 (22:38 +0300)
pyac/tests/test_str.py

index 2b808554b5142b6f5b13beab6f6799f5c8e30f171b343dbf92b0c3bfaedcc4e8..37504d71493b3dd4c56686f18effd587ebf8dde7e7cbdb16e081ee9373180952 100644 (file)
@@ -2,12 +2,12 @@ from unittest import TestCase
 
 from hypothesis import given
 from hypothesis.strategies import integers
-from hypothesis.strategies import text
 
 from pyac import DecodeError
 from pyac import dumps
 from pyac import loads
 from pyac import NotEnoughData
+from tests.strategies import junk_st
 
 
 TagStr: int = 0x80
@@ -26,21 +26,63 @@ invalid_utf8_3byte_s = integers(min_value=0, max_value=(1 << 10) - 1).map(
 
 
 class TestStr(TestCase):
-    @given(text(max_size=60))
-    def test_encode(self, test_str: str) -> None:
-        if(len(test_str.encode("utf-8")) > 60):
-            return
-        encoded = dumps(test_str)
-        tag = (TagStr | TagUTF8 | len(test_str.encode("utf-8"))).to_bytes(1, "big")
-        self.assertSequenceEqual(encoded, tag + test_str.encode("utf-8"))
-
-    def test_long(self) -> None:
-        long_strings = ["a" * 62, "a" * 318, "a" * 65853]
-        for long_string in long_strings:
-            encoded = dumps(long_string)
-            decoded, tail = loads(encoded)
-            self.assertSequenceEqual(decoded, long_string)
-            self.assertSequenceEqual(tail, b"")
+    @given(junk_st)
+    def test_empty(self, junk: bytes) -> None:
+        encoded: bytes = dumps("")
+        self.assertSequenceEqual(encoded, b"\xc0")
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, "")
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_hello_world(self, junk: bytes) -> None:
+        s: str = "hello world"
+        encoded: bytes = dumps(s)
+        self.assertSequenceEqual(
+            encoded,
+            b"\xcb\x68\x65\x6c\x6c\x6f\x20\x77\x6f\x72\x6c\x64",
+        )
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_привет_мир(self, junk: bytes) -> None:
+        s: str = "привет мир"
+        encoded: bytes = dumps(s)
+        self.assertSequenceEqual(
+            encoded,
+            b"\xd3\xd0\xbf\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82\x20\xd0\xbc\xd0\xb8\xd1\x80",
+        )
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_len62(self, junk: bytes) -> None:
+        s: str = "a" * 62
+        encoded = dumps(s)
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_len318(self, junk: bytes) -> None:
+        assert 318 == 62 + 255 + 1
+        s: str = "a" * 318
+        encoded = dumps(s)
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_len65853(self, junk: bytes) -> None:
+        assert 65853 == 62 + 255 + 65535 + 1
+        s: str = "a" * 65853
+        encoded = dumps(s)
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
 
     def test_throws_when_null_byte_in_utf(self) -> None:
         with self.assertRaises(DecodeError) as err:
@@ -70,21 +112,48 @@ class TestStr(TestCase):
 
 
 class TestBin(TestCase):
-    def test_encode(self) -> None:
-        bs = b"\x00\x01\x02"
-        encoded = dumps(bs)
-        self.assertSequenceEqual(encoded, b"\x83\x00\x01\x02")
-        decoded, tail = loads(encoded)
-        self.assertSequenceEqual(decoded, b"\x00\x01\x02")
-        self.assertSequenceEqual(tail, b"")
-
-    def test_long(self) -> None:
-        long_bss = [b"\x01" * 62, b"\x01" * 318, b"\x01" * 65853]
-        for long_bs in long_bss:
-            encoded = dumps(long_bs)
-            decoded, tail = loads(encoded)
-            self.assertSequenceEqual(decoded, long_bs)
-            self.assertSequenceEqual(tail, b"")
+    @given(junk_st)
+    def test_empty(self, junk: bytes) -> None:
+        encoded: bytes = dumps(b"")
+        self.assertSequenceEqual(encoded, b"\x80")
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, b"")
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_1234(self, junk: bytes) -> None:
+        s: bytes = b"\x01\x02\x03\x04"
+        encoded = dumps(s)
+        self.assertSequenceEqual(encoded, b"\x84\x01\x02\x03\x04")
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_len62(self, junk: bytes) -> None:
+        s: bytes = b"a" * 62
+        encoded = dumps(s)
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_len318(self, junk: bytes) -> None:
+        assert 318 == 62 + 255 + 1
+        s: bytes = b"a" * 318
+        encoded = dumps(s)
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
+
+    @given(junk_st)
+    def test_len65853(self, junk: bytes) -> None:
+        assert 65853 == 62 + 255 + 65535 + 1
+        s: bytes = b"a" * 65853
+        encoded = dumps(s)
+        decoded, tail = loads(encoded + junk)
+        self.assertSequenceEqual(decoded, s)
+        self.assertSequenceEqual(tail, junk)
 
     def test_throws_when_not_enough_data(self) -> None:
         encoded = b"\x85he"