]> Cypherpunks repositories - keks.git/commitdiff
Trivial tiny performance optimisations
authorSergey Matveev <stargrave@stargrave.org>
Thu, 5 Dec 2024 09:56:37 +0000 (12:56 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Thu, 5 Dec 2024 09:56:37 +0000 (12:56 +0300)
pyac/pyac.py

index 25c8da64c8ef50e1fb632b3acdfc159f019751b4967bfe24607b72f184e7aba3..300bc6754fa36cee288f77882509bd31570a1b2bd3c3b8d535dd64490685aa9b 100755 (executable)
@@ -62,6 +62,24 @@ TagStr = 0x80
 TagUTF8 = 0x40
 
 
+def _byte(v):
+    return v.to_bytes(1, "big")
+
+
+TagEOCb = _byte(TagEOC)
+TagNILb = _byte(TagNIL)
+TagFalseb = _byte(TagFalse)
+TagTrueb = _byte(TagTrue)
+TagUUIDb = _byte(TagUUID)
+TagListb = _byte(TagList)
+TagMapb = _byte(TagMap)
+TagBlobb = _byte(TagBlob)
+TagPIntb = _byte(TagPInt)
+TagNIntb = _byte(TagNInt)
+TagTAI64b = _byte(TagTAI64)
+TagTAI64Nb = _byte(TagTAI64N)
+
+
 class DecodeError(ValueError):
     pass
 
@@ -124,10 +142,6 @@ Leapsecs = tuple(
 )
 
 
-def _byte(v):
-    return v.to_bytes(1, "big")
-
-
 def _str(v, utf8):
     l = len(v)
     if l >= (63 + ((1 << 8)-1) + ((1 << 16)-1)):
@@ -163,13 +177,13 @@ def utc2tai(secs):
 
 def dumps(v):
     if v is None:
-        return _byte(TagNIL)
+        return TagNILb
     if v is False:
-        return _byte(TagFalse)
+        return TagFalseb
     if v is True:
-        return _byte(TagTrue)
+        return TagTrueb
     if isinstance(v, UUID):
-        return _byte(TagUUID) + v.bytes
+        return TagUUIDb + v.bytes
     if isinstance(v, float):
         raise NotImplementedError("no FLOAT* support")
     if isinstance(v, datetime):
@@ -178,9 +192,9 @@ def dumps(v):
         secs = int(v.replace(tzinfo=timezone.utc).timestamp())
         secs = utc2tai(secs) + TAI64Base
         if ms == 0:
-            return _byte(TagTAI64) + secs.to_bytes(8, "big")
+            return TagTAI64b + secs.to_bytes(8, "big")
         return (
-            _byte(TagTAI64N) + secs.to_bytes(8, "big") +
+            TagTAI64Nb + secs.to_bytes(8, "big") +
             (ms * 1000).to_bytes(4, "big")
         )
     if isinstance(v, Raw):
@@ -190,33 +204,35 @@ def dumps(v):
     if isinstance(v, str):
         return _str(v.encode("utf-8"), utf8=True)
     if isinstance(v, int):
-        t = TagPInt
+        t = TagPIntb
         if v < 0:
-            t = TagNInt
+            t = TagNIntb
             v = (-v) - 1
         if v == 0:
-            return _byte(t) + dumps(b"")
-        return _byte(t) + dumps(v.to_bytes(_ceil(v.bit_length() / 8), "big"))
+            return t + dumps(b"")
+        return t + dumps(v.to_bytes(_ceil(v.bit_length() / 8), "big"))
     if isinstance(v, Blob):
         assert (v.l > 0) and (v.l <= (1 << 64))
         l, v = v.l, v.v
-        raws = [_byte(TagBlob), (l-1).to_bytes(8, "big")]
+        raws = [TagBlobb, (l-1).to_bytes(8, "big")]
+        append = raws.append
         chunks = len(v) // l
         for i in range(chunks):
-            raws.append(dumps(None))
-            raws.append(v[i*l:(i+1)*l])
+            append(dumps(None))
+            append(v[i*l:(i+1)*l])
         left = len(v) - chunks*l
         assert left < l
-        raws.append(dumps(b"") if (left == 0) else dumps(v[-left:]))
+        append(dumps(b"") if (left == 0) else dumps(v[-left:]))
         return b"".join(raws)
     if isinstance(v, (list, tuple)):
-        return b"".join([_byte(TagList)] + [dumps(i) for i in v] + [_byte(TagEOC)])
+        return b"".join([TagListb] + [dumps(i) for i in v] + [TagEOCb])
     if isinstance(v, set):
         if not all(isinstance(i, str) for i in v):
             raise ValueError("set can contain only strings")
         return dumps({i: None for i in v})
     if isinstance(v, dict):
-        raws = [_byte(TagMap)]
+        raws = [TagMapb]
+        append = raws.append
         keys = v.keys()
         if not all(isinstance(k, str) for k in keys):
             raise ValueError("map keys can be only strings")
@@ -224,9 +240,9 @@ def dumps(v):
         if (len(keys) > 0) and len(keys[0]) == 0:
             raise ValueError("map keys can not be empty")
         for k in keys:
-            raws.append(dumps(k))
-            raws.append(dumps(v[k]))
-        raws.append(_byte(TagEOC))
+            append(dumps(k))
+            append(dumps(v[k]))
+        append(TagEOCb)
         return b"".join(raws)
     raise NotImplementedError("unsupported type")
 
@@ -267,25 +283,26 @@ def tai2utc(secs, leapsecUTCAllow=False):
 def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
     if len(v) == 0:
         raise NotEnoughData(1)
-    if v[0] == TagEOC:
+    b = v[0]
+    if b == TagEOC:
         return _EOC, v[1:]
-    if v[0] == TagNIL:
+    if b == TagNIL:
         return None, v[1:]
-    if v[0] == TagFalse:
+    if b == TagFalse:
         return False, v[1:]
-    if v[0] == TagTrue:
+    if b == TagTrue:
         return True, v[1:]
-    if v[0] == TagUUID:
+    if b == TagUUID:
         if len(v) < 1+16:
             raise NotEnoughData(1+16-len(v))
         return UUID(bytes=v[1:1+16]), v[1+16:]
-    if v[0] in _floats:
-        l = _floats[v[0]]
+    l = _floats.get(b)
+    if l is not None:
         if len(v) < 1+l:
             raise NotEnoughData(1+l-len(v))
         return Raw(v[0], v[1:1+l]), v[1+l:]
-    if v[0] in _tais:
-        l = _tais[v[0]]
+    l = _tais.get(b)
+    if l is not None:
         if len(v) < 1+l:
             raise NotEnoughData(1+l - len(v))
         secs = int.from_bytes(v[1:1+8], "big")
@@ -317,8 +334,8 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
             return Raw(v[0], v[1:1+l]), v[1+l:]
         dt += timedelta(microseconds=nsecs // 1000)
         return dt, v[1+l:]
-    if (v[0] & TagStr) > 0:
-        l = v[0] & 0b00111111
+    if (b & TagStr) > 0:
+        l = b & 0b00111111
         if l < 61:
             llen = 0
         elif l == 61:
@@ -336,7 +353,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
         if len(v) < 1+llen+l:
             raise NotEnoughData(1+llen+l-len(v))
         s = v[1+llen:1+llen+l]
-        if (v[0] & TagUTF8) > 0:
+        if (b & TagUTF8) > 0:
             try:
                 s = s.decode("utf-8")
             except UnicodeDecodeError as err:
@@ -344,12 +361,12 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
             if "\x00" in s:
                 raise DecodeError("null byte in UTF-8")
         return s, v[1+llen+l:]
-    if v[0] == TagPInt:
+    if b == TagPInt:
         return _int(v[1:])
-    if v[0] == TagNInt:
+    if b == TagNInt:
         i, v = _int(v[1:])
         return (-1 - i), v
-    if (v[0] == TagList) and _allowContainers:
+    if (b == TagList) and _allowContainers:
         ret = []
         v = v[1:]
         while True:
@@ -358,7 +375,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
                 break
             ret.append(i)
         return ret, v
-    if (v[0] == TagMap) and _allowContainers:
+    if (b == TagMap) and _allowContainers:
         ret = {}
         v = v[1:]
         kPrev = ""
@@ -383,7 +400,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
         if sets and allNILs:
             ret = set(ret.keys())
         return ret, v
-    if (v[0] == TagBlob) and _allowContainers:
+    if (b == TagBlob) and _allowContainers:
         if len(v) < 1+8:
             raise NotEnoughData(1+8-len(v))
         l = 1 + int.from_bytes(v[1:1+8], "big")