]> Cypherpunks repositories - keks.git/commitdiff
Honest bytewise map's key ordering
authorSergey Matveev <stargrave@stargrave.org>
Tue, 14 Jan 2025 08:22:38 +0000 (11:22 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Tue, 14 Jan 2025 08:22:38 +0000 (11:22 +0300)
py3/keks.py
py3/tests/test_map.py

index c8dc7eb8ecfaa87cf0f5b1e1d81b218679906791f37b16c9ee19c6ac782f6eba..973d3a6d65ea88aa72c624e790325e00d971988a17aba82d76910bc36eba55fe 100755 (executable)
@@ -117,7 +117,8 @@ class Raw:
 Blob = namedtuple("Blob", ("l", "v"))
 
 
-def LenFirstSort(v):
+def LenFirstUTF8Sort(v):
+    v = v.encode("utf-8")
     return (len(v), v)
 
 
@@ -252,7 +253,7 @@ def dumps(v):
         keys = v.keys()
         if not all(isinstance(k, str) for k in keys):
             raise ValueError("map keys can be only strings")
-        keys = sorted(keys, key=LenFirstSort)
+        keys = sorted(keys, key=LenFirstUTF8Sort)
         if (len(keys) > 0) and len(keys[0]) == 0:
             raise ValueError("map keys can not be empty")
         for k in keys:
@@ -396,7 +397,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
     if b == TagMap:
         ret = {}
         v = v[1:]
-        kPrev = ""
+        kPrev = b""
         allNILs = True
         while True:
             k, v = _loads(v, _allowContainers=False)
@@ -404,7 +405,11 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
                 break
             if not isinstance(k, str):
                 raise DecodeError("non-string key")
-            if (len(k) < len(kPrev)) or ((len(k) == len(kPrev)) and (k <= kPrev)):
+            kUTF8 = k.encode("utf-8")
+            if (
+                (len(kUTF8) < len(kPrev)) or
+                ((len(kUTF8) == len(kPrev)) and (kUTF8 <= kPrev))
+            ):
                 if len(k) == 0:
                     raise DecodeError("empty key")
                 raise DecodeError("unsorted keys")
@@ -412,7 +417,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
             if i == _EOC:
                 raise DecodeError("unexpected EOC")
             ret[k] = i
-            kPrev = k
+            kPrev = kUTF8
             if i is not None:
                 allNILs = False
         if sets and allNILs:
index b43768d1d436b0b5bd2a2f4f5c3435d9341205240c67f99bb8138ca28b3ffdd3..6324b1a77ba9f600f31f0bf7ee35087a00c6113b5d403a5107c89c3e4a831591 100644 (file)
@@ -40,7 +40,10 @@ class TestMap(TestCase):
             b"".join(
                 [
                     b"".join([dumps(key), dumps(test_map[key])])
-                    for key in sorted(test_map.keys(), key=lambda x: [len(x), x])
+                    for key in sorted(test_map.keys(), key=lambda x: (
+                        len(x.encode("utf-8")),
+                        x.encode("utf-8"),
+                    ))
                 ]
             ) +
             b"\x00"
@@ -151,3 +154,12 @@ class TestSet(TestCase):
         encoded = dumps(test)
         decoded, tail = loads(encoded, sets=True)
         self.assertEqual(decoded, test)
+
+    def test_equal_len_codepoints(self):
+        s1 = bytes.fromhex("f1ad9bb3c2997c6dc391c2a0c2845a").decode("utf-8")
+        s2 = bytes.fromhex("f1aaab9ec3adc2bcc3b4c3bec38a0cc3ac").decode("utf-8")
+        encoded = dumps(set((s1, s2)))
+        self.assertSequenceEqual(
+            encoded.hex(),
+            (b"".join((b"\x09", dumps(s1), b"\x01", dumps(s2), b"\x01", b"\x00"))).hex(),
+        )