]> Cypherpunks repositories - keks.git/commitdiff
Simplify pyac
authorSergey Matveev <stargrave@stargrave.org>
Wed, 13 Nov 2024 12:54:36 +0000 (15:54 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Wed, 13 Nov 2024 13:30:12 +0000 (16:30 +0300)
Current classes are just useless. Initially it was probably expected to
use pyac as a generic tool for examining YAC data, like PyDERASN, that
keeps various offsets and lengths. But I doubt it would be used that way.

pyac/pyac.py
pyac/test-vector.py

index 0e909faa2775d1422ead679e5816deb53286692e257cbb51f8746f0ac02792ba..85a30a59028ee1572ec66165ddbc1b3d1dfe9a65aa35764a66d04114871bd4c5 100644 (file)
 #
 # You should have received a copy of the GNU Lesser General Public
 # License along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""Python YAC encoder/decoder implementation
+
+`YAC <http://www.yac.cypherpunks.su>`__ is yet another binary
+serialisation encoding format. It is aimed to be lightweight in
+terms of CPU, memory, storage and codec implementation size usage.
+YAC is deterministic and streamable. It supports wide range of data
+types, making it able to transparently replace JSON.
+
+It has :py:func:`loads` and :py:func:`dumps` functions, similar to
+native :py:module:`json` library's. YAC supports dictionaries, lists,
+None, booleans, UUID, floats (currently not implemented!), integers
+(including big ones), datetime, Unicode and binary strings.
+
+There is special :py:func:`pyac.Raw` namedtuple, that holds arbitrary
+YAC encoded data, that can not be represented in native Python types.
+Also there is :py:func:`pyac.Blob` namedtuple, that holds the data, that
+is encoded streamingly in chunks.
+"""
+
+from collections import namedtuple
+from datetime import datetime
+from datetime import timedelta
+from datetime import timezone
+from math import ceil as _ceil
+from uuid import UUID
+
+
+TagEOC = 0x00
+TagNIL = 0x01
+TagFalse = 0x02
+TagTrue = 0x03
+TagUUID = 0x04
+TagList = 0x08
+TagMap = 0x09
+TagBlob = 0x0B
+TagPInt = 0x0C
+TagNInt = 0x0D
+TagFloat16 = 0x10
+TagFloat32 = 0x11
+TagFloat64 = 0x12
+TagFloat128 = 0x13
+TagFloat256 = 0x14
+TagTAI64 = 0x18
+TagTAI64N = 0x19
+TagTAI64NA = 0x1A
+TagStr = 0x80
+TagUTF8 = 0x40
 
 
 class DecodeError(ValueError):
     pass
 
 
-WrongTag = DecodeError("wrong tag")
-
-
 class NotEnoughData(DecodeError):
     def __init__(self, n):
         self.n = n
@@ -33,339 +77,12 @@ class NotEnoughData(DecodeError):
         return "%s(%s)" % (self.__class__.__name__, self)
 
 
-class EOC:
-    tags = (0x00,)
-
-    def encode(self):
-        return self.tags[0].to_bytes(1, "big")
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] != klass.tags[0]:
-            raise WrongTag
-        return EOC(), data[1:]
-
-    def __repr__(self):
-        return "EOC"
-
-
-class Nil:
-    tags = (0x01,)
-
-    def encode(self):
-        return self.tags[0].to_bytes(1, "big")
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] != klass.tags[0]:
-            raise WrongTag
-        return Nil(), data[1:]
-
-    def __repr__(self):
-        return "NIL"
-
-    def py(self):
-        return None
-
-
-class Bool:
-    tags = (0x02, 0x03)
-
-    def __init__(self, v):
-        if isinstance(v, Bool):
-            v = v.v
-        self.v = v
-
-    def encode(self):
-        if self.v is True:
-            return self.tags[1].to_bytes(1, "big")
-        return self.tags[0].to_bytes(1, "big")
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] == klass.tags[0]:
-            return klass(False), data[1:]
-        if data[0] == klass.tags[1]:
-            return klass(True), data[1:]
-        raise WrongTag
-
-    def __repr__(self):
-        return "TRUE" if self.v is True else "FALSE"
-
-    def py(self):
-        return self.v
-
-
-from uuid import UUID as pyUUID
-
-
-class UUID:
-    tags = (0x04,)
-
-    def __init__(self, v):
-        if isinstance(v, UUID):
-            v = v.v
-        if isinstance(v, pyUUID):
-            self.v = v
-        else:
-            self.v = pyUUID(v)
-
-    def encode(self):
-        return self.tags[0].to_bytes(1, "big") + self.v.bytes
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] != klass.tags[0]:
-            raise WrongTag
-        if len(data) < 1+16:
-            raise NotEnoughData(1+16)
-        return klass(pyUUID(bytes=data[1:1+16])), data[1+16:]
-
-    def __repr__(self):
-        return "UUID[%s]" % str(self.v)
-
-    def __eq__(self, their):
-        if isinstance(their, pyUUID):
-            return self.v == their
-        return self.v == their.v
-
-    def __bytes__(self):
-        return self.v.bytes
-
-    def py(self):
-        return self.v
-
-
-from math import ceil
-
-
-class Int:
-    tagPositive = 0x0C
-    tagNegative = 0x0D
-    tags = (tagPositive, tagNegative)
-
-    def __init__(self, v=0):
-        if isinstance(v, Int):
-            v = v.v
-        self.v = v
-
-    def encode(self):
-        tag = self.tagPositive
-        v = self.v
-        if v < 0:
-            tag = self.tagNegative
-            v = (-v) - 1
-        if v == 0:
-            return tag.to_bytes(1, "big") + Bin(b"").encode()
-        return tag.to_bytes(1, "big") + Bin(
-            v.to_bytes(ceil(v.bit_length() / 8), "big")
-        ).encode()
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] not in klass.tags:
-            raise WrongTag
-        neg = data[0] == klass.tagNegative
-        raw, data = Bin.decode(data[1:])
-        raw = bytes(raw)
-        if raw == b"":
-            return (Int(-1) if neg else Int(0)), data
-        if raw[0] == 0:
-            raise DecodeError("non-miminal encoding")
-        v = int.from_bytes(raw, "big")
-        if neg:
-            v = -1 - v
-        return klass(v), data
-
-    def __repr__(self):
-        return "INT(%d)" % self.v
-
-    def __int__(self):
-        return self.v
-
-    def py(self):
-        return self.v
-
-
-class List:
-    tags = (0x08,)
-
-    def __init__(self, v=()):
-        if isinstance(v, List):
-            v = v.v
-        self.v = v
-
-    def encode(self):
-        raws = [self.tags[0].to_bytes(1, "big")]
-        for v in self.v:
-            raws.append(Encode(v))
-        raws.append(EOC.tags[0].to_bytes(1, "big"))
-        return b"".join(raws)
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] != klass.tags[0]:
-            raise WrongTag
-        data = data[1:]
-        vs = []
-        while True:
-            v, data = Decode(data)
-            if isinstance(v, EOC):
-                break
-            vs.append(v)
-        return klass(vs), data
-
-    def __repr__(self):
-        return "LIST[" + ", ".join(repr(v) for v in self.v) + "]"
-
-    def py(self):
-        return [v.py() for v in self.v]
-
-
-LenFirstSort = lambda x: (len(x), x)
-
-
-class Map:
-    tags = (0x09,)
-
-    def __init__(self, v=()):
-        if isinstance(v, Map):
-            v = v.v
-        self.v = v
-
-    def encode(self):
-        raws = [self.tags[0].to_bytes(1, "big")]
-        for k in sorted(self.v.keys(), key=LenFirstSort):
-            assert isinstance(k, (str, Str))
-            raws.append(Str(k).encode())
-            raws.append(Encode(self.v[k]))
-        raws.append(EOC.tags[0].to_bytes(1, "big"))
-        return b"".join(raws)
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] != klass.tags[0]:
-            raise WrongTag
-        data = data[1:]
-        vs = {}
-        kPrev = ""
-        while True:
-            k, data = Decode(data)
-            if isinstance(k, EOC):
-                break
-            if not isinstance(k, Str):
-                raise DecodeError("non-string key")
-            k = str(k)
-            if len(k) == 0:
-                raise DecodeError("empty key")
-            if (len(k) < len(kPrev)) or ((len(k) == len(kPrev)) and (k <= kPrev)):
-                raise DecodeError("unsorted keys")
-            v, data = Decode(data)
-            if isinstance(v, EOC):
-                raise DecodeError("unexpected EOC")
-            vs[k] = v
-            kPrev = k
-        return klass(vs), data
-
-    def __repr__(self):
-        return "MAP[" + "; ".join("%s: %r" % (k, self.v[k])
-            for k in sorted(self.v.keys(), key=LenFirstSort)
-        ) + "]"
-
-    def py(self):
-        return {str(k): v.py() for k, v in self.v.items()}
+Raw = namedtuple("Raw", ("t", "v"))
+Blob = namedtuple("Blob", ("l", "v"))
 
 
-class Blob:
-    tags = (0x0B,)
-
-    def __init__(self, l, v=b""):
-        if isinstance(v, Blob):
-            v = v.v
-            l = v.l
-        assert (l > 0) and (l <= (1<<64))
-        self.v = v
-        self.l = l
-
-    def encode(self):
-        raws = [self.tags[0].to_bytes(1, "big"), (self.l - 1).to_bytes(8, "big")]
-        chunks = len(self.v) // (self.l)
-        for i in range(chunks):
-            raws.append(Nil().encode())
-            raws.append(self.v[i*(self.l):(i+1)*(self.l)])
-        left = len(self.v) - chunks*(self.l)
-        assert left < (self.l)
-        if left == 0:
-            raws.append(Bin(b"").encode())
-        else:
-            raws.append(Bin(self.v[-left:]).encode())
-        return b"".join(raws)
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] != klass.tags[0]:
-            raise WrongTag
-        data = data[1:]
-        if len(data) < 8:
-            raise NotEnoughData(8)
-        l = 1 + int.from_bytes(data[:8], "big")
-        data = data[8:]
-        vs = []
-        while True:
-            v, data = Decode(data)
-            if isinstance(v, Nil):
-                if len(data) < l:
-                    raise NotEnoughData(l)
-                vs.append(data[:l])
-                data = data[l:]
-            elif isinstance(v, Bin):
-                v = bytes(v)
-                if len(v) >= l:
-                    raise DecodeError("wrong terminator len")
-                vs.append(v)
-                break
-            else:
-                raise DecodeError("unexpected tag")
-        return klass(l, b"".join(vs)), data
-
-    def __repr__(self):
-        return "BLOB(%d, %d)" % (self.l, len(self.v))
-
-    def __bytes__(self):
-        return self.v
-
-    def __eq__(self, their):
-        return (self.l == their.l) and (self.v == their.v)
-
-    def py(self):
-        return self
-
-
-class Float:
-    tags = (0x10, 0x11, 0x12, 0x13, 0x14)
-
-    @classmethod
-    def decode(klass, data):
-        t = data[0]
-        data = data[1:]
-        if t == klass.tags[0]:
-            l = 2
-        elif t == klass.tags[1]:
-            l = 4
-        elif t == klass.tags[2]:
-            l = 8
-        elif t == klass.tags[3]:
-            l = 16
-        elif t == klass.tags[4]:
-            l = 32
-        if len(data) < l:
-            raise NotEnoughData(l)
-        return Raw((t, data[:l])), data[l:]
-
-
-from datetime import datetime
-from datetime import timedelta
-from datetime import timezone
+def LenFirstSort(v):
+    return (len(v), v)
 
 
 TAI64Base = 0x4000000000000000
@@ -404,48 +121,140 @@ Leapsecs = tuple(
 )
 
 
-class TAI64:
-    tags = (0x18, 0x19, 0x1A)
-
-    def __init__(self, v):
-        if isinstance(v, TAI64):
-            v = v.v
-        self.v = v
-
-    def encode(self):
-        v = int(self.v.replace(tzinfo=timezone.utc).timestamp())
+def _byte(v):
+    return v.to_bytes(1, "big")
+
+
+def _str(v, utf8):
+    l = len(v)
+    if l >= (63 + ((1 << 8)-1) + ((1 << 16)-1)):
+        lv = 63
+        l -= (lv + ((1 << 8)-1) + ((1 << 16)-1))
+        lb = l.to_bytes(8, "big")
+    elif l >= (62 + ((1 << 8)-1)):
+        lv = 62
+        l -= (lv + ((1 << 8)-1))
+        lb = l.to_bytes(2, "big")
+    elif l >= 61:
+        lv = 61
+        l -= lv
+        lb = l.to_bytes(1, "big")
+    else:
+        lv = l
+        lb = b""
+    t = TagStr
+    if utf8:
+        t |= TagUTF8
+    return _byte(t | lv) + lb + v
+
+
+def dumps(v):
+    if v is None:
+        return _byte(TagNIL)
+    if v is False:
+        return _byte(TagFalse)
+    if v is True:
+        return _byte(TagTrue)
+    if isinstance(v, UUID):
+        return _byte(TagUUID) + v.bytes
+    if isinstance(v, float):
+        raise NotImplementedError("no FLOAT* support")
+    if isinstance(v, datetime):
+        secs = int(v.replace(tzinfo=timezone.utc).timestamp())
         diff = Leapsecs1972
         for n, leapsec in enumerate(Leapsecs):
-            if v > leapsec:
+            if secs > leapsec:
                 diff += len(Leapsecs) - n
                 break
-        v += TAI64Base + diff
-        if self.v.microsecond == 0:
-            return self.tags[0].to_bytes(1, "big") + v.to_bytes(8, "big")
+        secs += TAI64Base + diff
+        if v.microsecond == 0:
+            return _byte(TagTAI64) + secs.to_bytes(8, "big")
         return (
-            self.tags[1].to_bytes(1, "big") +
-            v.to_bytes(8, "big") +
-            (self.v.microsecond * 1000).to_bytes(4, "big")
+            _byte(TagTAI64N) + secs.to_bytes(8, "big") +
+            (v.microsecond * 1000).to_bytes(4, "big")
         )
-
-    @classmethod
-    def decode(klass, data):
-        if data[0] not in (klass.tags[0], klass.tags[1], klass.tags[2]):
-            raise WrongTag
-        hdr = data[0]
-        data = data[1:]
-        if hdr == klass.tags[0]:
-            l = 8
-        elif hdr == klass.tags[1]:
-            l = 12
-        else:
-            l = 16
-        if len(data) < l:
-            raise NotEnoughData(l)
-
-        secs = int.from_bytes(data[:8], "big")
-        if secs > (1<<63):
-            raise DecodeError("reserved TAI64 values in use")
+    if isinstance(v, Raw):
+        return _byte(v.t) + v.v
+    if isinstance(v, bytes):
+        return _str(v, utf8=False)
+    if isinstance(v, str):
+        return _str(v.encode("utf-8"), utf8=True)
+    if isinstance(v, int):
+        t = TagPInt
+        if v < 0:
+            t = TagNInt
+            v = (-v) - 1
+        if v == 0:
+            return _byte(t) + dumps(b"")
+        return _byte(t) + dumps(v.to_bytes(_ceil(v.bit_length() / 8), "big"))
+    if isinstance(v, Blob):
+        assert (v.l > 0) and (v.l <= (1 << 64))
+        l, v = v.l, v.v
+        raws = [_byte(TagBlob), (l-1).to_bytes(8, "big")]
+        chunks = len(v) // l
+        for i in range(chunks):
+            raws.append(dumps(None))
+            raws.append(v[i*l:(i+1)*l])
+        left = len(v) - chunks*l
+        assert left < l
+        raws.append(dumps(b"") if (left == 0) else dumps(v[-left:]))
+        return b"".join(raws)
+    if isinstance(v, (list, tuple)):
+        return b"".join([_byte(TagList)] + [dumps(i) for i in v] + [_byte(TagEOC)])
+    if isinstance(v, dict):
+        raws = [_byte(TagMap)]
+        for k in sorted(v.keys(), key=LenFirstSort):
+            assert isinstance(k, str)
+            raws.append(dumps(k))
+            raws.append(dumps(v[k]))
+        raws.append(_byte(TagEOC))
+        return b"".join(raws)
+    raise NotImplementedError("unsupported type")
+
+
+def _int(v):
+    s, tail = loads(v)
+    if not isinstance(s, bytes):
+        raise DecodeError("non-BIN in INT")
+    if s == b"":
+        return 0, tail
+    if s[0] == 0:
+        raise DecodeError("non-minimal encoding")
+    return int.from_bytes(s, "big"), tail
+
+
+_EOC = object()
+
+
+def loads(v):
+    if len(v) == 0:
+        raise NotEnoughData(1)
+    if v[0] == TagEOC:
+        return _EOC, v[1:]
+    if v[0] == TagNIL:
+        return None, v[1:]
+    if v[0] == TagFalse:
+        return False, v[1:]
+    if v[0] == TagTrue:
+        return True, v[1:]
+    if v[0] == TagUUID:
+        if len(v) < 1+16:
+            raise NotEnoughData(1+16)
+        return UUID(bytes=v[1:1+16]), v[1+16:]
+    floats = {TagFloat16: 2, TagFloat32: 4, TagFloat64: 8, TagFloat128: 16, TagFloat256: 32}
+    if v[0] in floats:
+        l = floats[v[0]]
+        if len(v) < 1+l:
+            raise NotEnoughData(1+l)
+        return Raw(v[0], v[1:1+l]), v[1+l:]
+    tais = {TagTAI64: 8, TagTAI64N: 12, TagTAI64NA: 16}
+    if v[0] in tais:
+        l = tais[v[0]]
+        if len(v) < 1+l:
+            raise NotEnoughData(1+l)
+        secs = int.from_bytes(v[1:1+8], "big")
+        if secs > (1 << 63):
+            raise DecodeError("reserved TAI64 value is in use")
         secs -= TAI64Base
         diff = 0
         for n, leapsec in enumerate(Leapsecs):
@@ -453,220 +262,100 @@ class TAI64:
                 diff = 10 + len(Leapsecs) - n
                 break
         secs -= diff
-
         nsecs = 0
         if l > 8:
-            nsecs = int.from_bytes(data[8:8+4], "big")
+            nsecs = int.from_bytes(v[1+8:1+8+4], "big")
             if nsecs > 999999999:
                 raise DecodeError("too many nanoseconds")
         asecs = 0
         if l > 12:
-            asecs = int.from_bytes(data[8+4:8+4+4], "big")
+            asecs = int.from_bytes(v[1+8+4:1+8+4+4], "big")
             if asecs > 999999999:
                 raise DecodeError("too many attoseconds")
-
-        if (abs(secs) > (1<<60)) or (asecs > 0) or ((nsecs % 1000) > 0):
+        if (abs(secs) > (1 << 60)) or (asecs > 0) or ((nsecs % 1000) > 0):
             # Python can represent neither big values, nor nanoseconds
-            return Raw((hdr, data[:l])), data[l:]
-
+            return Raw(v[0], v[1:1+l]), v[1+l:]
         dt = datetime(1970, 1, 1) + timedelta(seconds=secs)
-        if nsecs > 0:
-            dt += timedelta(microseconds=nsecs // 1000)
-        return klass(dt), data[l:]
-
-    def __repr__(self):
-        if self.v.microsecond > 0:
-            return "TAI64N(%s)" % str(self.v)
-        return "TAI64(%s)" % str(self.v)
-
-    def py(self):
-        return self.v
-
-
-class BaseString:
-    def __init__(self, v, utf8):
-        if isinstance(v, BaseString):
-            v = v.v
-        self.v = v
-        self.utf8 = utf8
-
-    def __lt__(self, their):
-        return self.v < their.v
-
-    def __gt__(self, their):
-        return self.v > their.v
-
-    def __eq__(self, their):
-        return self.v == their.v
-
-    def __hash__(self):
-        return hash(self.v)
-
-    def encode(self):
-        l = len(self.v)
-        if l >= (63 + ((1<<8)-1) + ((1<<16)-1)):
-            lv = 63
-            l -= (lv + ((1<<8)-1) + ((1<<16)-1))
-            lb = l.to_bytes(8, "big")
-        elif l >= (62 + ((1<<8)-1)):
-            lv = 62
-            l -= (lv + ((1<<8)-1))
-            lb = l.to_bytes(2, "big")
-        elif l >= 61:
-            lv = 61
-            l -= lv
-            lb = l.to_bytes(1, "big")
-        else:
-            lv = l
-            lb = b""
-        v = 0x80
-        if self.utf8:
-            v |= 0x40
-        return b"".join(((v | lv).to_bytes(1, "big"), lb, self.v))
-
-    @classmethod
-    def decode(klass, data):
-        if (data[0] & 0x80) == 0:
-            raise ValueError("wrong tag")
-        utf8 = (data[0] & 0x40) > 0
-        l = data[0] & 0b00111111
-        orig = data
+        dt += timedelta(microseconds=nsecs // 1000)
+        return dt, v[1+l:]
+    if (v[0] & TagStr) > 0:
+        l = v[0] & 0b00111111
         if l < 61:
             llen = 0
         elif l == 61:
             llen = 1
         elif l == 62:
             llen = 2
-            l += ((1<<8)-1)
+            l += ((1 << 8)-1)
         elif l == 63:
             llen = 8
-            l += ((1<<8)-1) + ((1<<16)-1)
-        data = data[1:]
+            l += ((1 << 8)-1) + ((1 << 16)-1)
         if llen > 0:
-            if len(data) < llen:
-                raise NotEnoughData(llen)
-            l += int.from_bytes(data[:llen], "big")
-            data = data[llen:]
-        if len(data) < l:
-            raise NotEnoughData(l)
-        return klass(data[:l], utf8=utf8), data[l:]
-
-
-class Str(BaseString):
-    def __init__(self, v=""):
-        if isinstance(v, Str):
-            super().__init__(v.v, utf8=True)
-        else:
-            super().__init__(v.encode("utf-8"), utf8=True)
-
-    @classmethod
-    def decode(klass, data):
-        obj, tail = BaseString.decode(data)
-        assert obj.utf8 is True
-        try:
-            v = obj.v.decode("utf-8")
-        except UnicodeDecodeError as err:
-            raise DecodeError("invalid UTF-8") from err
-        if "\x00" in v:
-            raise DecodeError("null byte in UTF-8")
-        return klass(v), tail
-
-    def __repr__(self):
-        return "STR(" + self.v.decode("utf-8") + ")"
-
-    def __str__(self):
-        return self.v.decode("utf-8")
-
-    def py(self):
-        return self.v.decode("utf-8")
-
-
-class Bin(BaseString):
-    def __init__(self, v=b""):
-        if isinstance(v, Bin):
-            v = v.v
-        super().__init__(v, utf8=False)
-
-    @classmethod
-    def decode(klass, data):
-        obj, tail = BaseString.decode(data)
-        assert obj.utf8 is False
-        return klass(obj.v), tail
-
-    def __repr__(self):
-        return "BIN(" + self.v.hex() + ")"
-
-    def __bytes__(self):
-        return self.v
-
-    def py(self):
-        return self.v
-
-
-class Raw:
-    def __init__(self, v=(0, b"")):
-        if isinstance(v, Raw):
-            v = v.v
-        self.v = v
-
-    def encode(self):
-        t, v = self.v
-        return t.to_bytes(1, "big") + v
-
-    def __eq__(self, their):
-        return self.v == their.v
-
-    def __repr__(self):
-        return "RAW(%d, %s)" % self.v
-
-    def __bytes__(self):
-        return self.v[1]
-
-    def py(self):
-        return self
-
-
-_tags = {}
-for klass in (EOC, Nil, Bool, UUID, Int, List, Map, Blob, Float, TAI64):
-    for tag in klass.tags:
-        _tags[tag] = klass
-
-
-def Decode(data):
-    hdr = data[0]
-    if hdr >= 0x80: # strings
-        klass = Str if ((hdr & 0x40) == 0x40) else Bin
-    else:
-        klass = _tags[hdr]
-    if klass is None:
-        raise ValueError("unknown tag")
-    v, data = klass.decode(data)
-    return v, data
-
-
-def Encode(v):
-    if (v is None) or isinstance(v, Nil):
-        return Nil().encode()
-    if isinstance(v, (bool, Bool)):
-        return Bool(v).encode()
-    if isinstance(v, (pyUUID, UUID)):
-        return UUID(v).encode()
-    if isinstance(v, (int, Int)):
-        return Int(v).encode()
-    if isinstance(v, (list, tuple, List)):
-        return List(v).encode()
-    if isinstance(v, (dict, Map)):
-        return Map(v).encode()
-    if isinstance(v, Blob):
-        return v.encode()
-    if isinstance(v, float):
-        return Float(v).encode()
-    if isinstance(v, (datetime, TAI64)):
-        return TAI64(v).encode()
-    if isinstance(v, (bytes, Bin)):
-        return Bin(v).encode()
-    if isinstance(v, (str, Str)):
-        return Str(v).encode()
-    if isinstance(v, Raw):
-        return v.encode()
-    raise ValueError("unknown type", type(v))
+            if len(v) < 1+llen:
+                raise NotEnoughData(1+llen)
+            l += int.from_bytes(v[1:1+llen], "big")
+        if len(v) < 1+llen+l:
+            raise NotEnoughData(1+llen+l)
+        s = v[1+llen:1+llen+l]
+        if (v[0] & TagUTF8) > 0:
+            try:
+                s = s.decode("utf-8")
+            except UnicodeDecodeError as err:
+                raise DecodeError("invalid UTF-8") from err
+            if "\x00" in s:
+                raise DecodeError("null byte in UTF-8")
+        return s, v[1+llen+l:]
+    if v[0] == TagPInt:
+        return _int(v[1:])
+    if v[0] == TagNInt:
+        i, v = _int(v[1:])
+        return (-1 - i), v
+    if v[0] == TagList:
+        ret = []
+        v = v[1:]
+        while True:
+            i, v = loads(v)
+            if i == _EOC:
+                break
+            ret.append(i)
+        return ret, v
+    if v[0] == TagMap:
+        ret = {}
+        v = v[1:]
+        kPrev = ""
+        while True:
+            k, v = loads(v)
+            if k == _EOC:
+                break
+            if not isinstance(k, str):
+                raise DecodeError("non-string key")
+            if (len(k) < len(kPrev)) or ((len(k) == len(kPrev)) and (k <= kPrev)):
+                raise DecodeError("unsorted keys")
+            i, v = loads(v)
+            if k == _EOC:
+                raise DecodeError("unexpected EOC")
+            ret[k] = i
+            kPrev = k
+        return ret, v
+    if v[0] == TagBlob:
+        if len(v) < 1+8:
+            raise NotEnoughData(1+8)
+        l = 1 + int.from_bytes(v[1:1+8], "big")
+        v = v[1+8:]
+        raws = []
+        while True:
+            i, v = loads(v)
+            if i is None:
+                if len(v) < l:
+                    raise NotEnoughData(l)
+                raws.append(v[:l])
+                v = v[l:]
+            elif isinstance(i, bytes):
+                if len(i) >= l:
+                    raise DecodeError("wrong terminator len")
+                raws.append(i)
+                break
+            else:
+                raise DecodeError("unexpected tag")
+        return Blob(l, b"".join(raws)), v
+    raise DecodeError("unknown tag")
index 7fd79b0926d7535c8145eba21d27bebc2f15eff4b2796a85fefe9f9e85284146..97eeaee610a7165e3020aae704f22931aba1fa31508050dc455e995716ad1b25 100644 (file)
@@ -1,19 +1,22 @@
-from pyac import *
+from datetime import datetime
+from datetime import timedelta
+from uuid import UUID
+import pyac
 
 
 data = {
     "ints": {
         "pos": [
-            0, 1, 31, 32, 123, 1234, 12345678, 1<<80,
+            0, 1, 31, 32, 123, 1234, 12345678, 1 << 80,
             0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + 1,
         ],
         "neg": [
-            -1, -2, -32, -33, -123, -1234, -12345678, -1<<80,
+            -1, -2, -32, -33, -123, -1234, -12345678, -1 << 80,
             -(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + 2),
             -123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789,
         ],
     },
-    "floats": [Raw((Float.tags[1], b"\x01\x02\x03\x04"))],
+    "floats": [pyac.Raw(pyac.TagFloat32, b"\x01\x02\x03\x04")],
     "nil": None,
     "bool": [True, False],
     "str": {
@@ -30,30 +33,29 @@ data = {
         "utf8": "привет мир",
     },
     "blob": [
-        Blob(12, 1 * b"5"),
-        Blob(12, 12 * b"6"),
-        Blob(12, 13 * b"7"),
-        Blob(5, b"1234567890-"),
+        pyac.Blob(12, 1 * b"5"),
+        pyac.Blob(12, 12 * b"6"),
+        pyac.Blob(12, 13 * b"7"),
+        pyac.Blob(5, b"1234567890-"),
     ],
     "empties": [
         [],
         {},
-        Blob(123, b""),
+        pyac.Blob(123, b""),
         UUID("00000000-0000-0000-0000-000000000000"),
-        Raw((TAI64.tags[0], bytes.fromhex("0000000000000000"))),
+        pyac.Raw(pyac.TagTAI64, bytes.fromhex("0000000000000000")),
     ],
     "uuid": UUID("0e875e3f-d385-49eb-87b4-be42d641c367"),
 }
-from datetime import datetime
 data["dates"] = [
     (datetime(1970, 1, 1) + timedelta(seconds=1234567890)),
     (datetime(1970, 1, 1) + timedelta(seconds=1234567890)).replace(microsecond=456),
-    Raw((TAI64.tags[1], bytes.fromhex("40000000499602F40006F855"))),
-    Raw((TAI64.tags[2], bytes.fromhex("40000000499602F40006F855075BCD15"))),
+    pyac.Raw(pyac.TagTAI64N, bytes.fromhex("40000000499602F40006F855")),
+    pyac.Raw(pyac.TagTAI64NA, bytes.fromhex("40000000499602F40006F855075BCD15")),
 ]
-raw = Encode(data)
-dec, tail = Decode(raw)
+raw = pyac.dumps(data)
+dec, tail = pyac.loads(raw)
 assert tail == b""
-assert Encode(dec) == raw
-assert dec.py() == data
+assert pyac.dumps(dec) == raw
+assert dec == data
 print(raw.hex())