From eda68a8c24422877b3b9a1e54a021c44256af6fcaee0de3950ed1851445d8a87 Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Tue, 14 Jan 2025 11:22:38 +0300 Subject: [PATCH] Honest bytewise map's key ordering --- py3/keks.py | 15 ++++++++++----- py3/tests/test_map.py | 14 +++++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/py3/keks.py b/py3/keks.py index c8dc7eb..973d3a6 100755 --- a/py3/keks.py +++ b/py3/keks.py @@ -117,7 +117,8 @@ class Raw: Blob = namedtuple("Blob", ("l", "v")) -def LenFirstSort(v): +def LenFirstUTF8Sort(v): + v = v.encode("utf-8") return (len(v), v) @@ -252,7 +253,7 @@ def dumps(v): keys = v.keys() if not all(isinstance(k, str) for k in keys): raise ValueError("map keys can be only strings") - keys = sorted(keys, key=LenFirstSort) + keys = sorted(keys, key=LenFirstUTF8Sort) if (len(keys) > 0) and len(keys[0]) == 0: raise ValueError("map keys can not be empty") for k in keys: @@ -396,7 +397,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True): if b == TagMap: ret = {} v = v[1:] - kPrev = "" + kPrev = b"" allNILs = True while True: k, v = _loads(v, _allowContainers=False) @@ -404,7 +405,11 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True): break if not isinstance(k, str): raise DecodeError("non-string key") - if (len(k) < len(kPrev)) or ((len(k) == len(kPrev)) and (k <= kPrev)): + kUTF8 = k.encode("utf-8") + if ( + (len(kUTF8) < len(kPrev)) or + ((len(kUTF8) == len(kPrev)) and (kUTF8 <= kPrev)) + ): if len(k) == 0: raise DecodeError("empty key") raise DecodeError("unsorted keys") @@ -412,7 +417,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True): if i == _EOC: raise DecodeError("unexpected EOC") ret[k] = i - kPrev = k + kPrev = kUTF8 if i is not None: allNILs = False if sets and allNILs: diff --git a/py3/tests/test_map.py b/py3/tests/test_map.py index b43768d..6324b1a 100644 --- a/py3/tests/test_map.py +++ b/py3/tests/test_map.py @@ -40,7 +40,10 @@ class TestMap(TestCase): b"".join( [ b"".join([dumps(key), dumps(test_map[key])]) - for key in sorted(test_map.keys(), key=lambda x: [len(x), x]) + for key in sorted(test_map.keys(), key=lambda x: ( + len(x.encode("utf-8")), + x.encode("utf-8"), + )) ] ) + b"\x00" @@ -151,3 +154,12 @@ class TestSet(TestCase): encoded = dumps(test) decoded, tail = loads(encoded, sets=True) self.assertEqual(decoded, test) + + def test_equal_len_codepoints(self): + s1 = bytes.fromhex("f1ad9bb3c2997c6dc391c2a0c2845a").decode("utf-8") + s2 = bytes.fromhex("f1aaab9ec3adc2bcc3b4c3bec38a0cc3ac").decode("utf-8") + encoded = dumps(set((s1, s2))) + self.assertSequenceEqual( + encoded.hex(), + (b"".join((b"\x09", dumps(s1), b"\x01", dumps(s2), b"\x01", b"\x00"))).hex(), + ) -- 2.48.1