]> Cypherpunks repositories - keks.git/commitdiff
Remove excess expectEOC state
authorSergey Matveev <stargrave@stargrave.org>
Sun, 15 Dec 2024 08:58:26 +0000 (11:58 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sun, 15 Dec 2024 09:25:45 +0000 (12:25 +0300)
cyac/lib/items.c
gyac/dec.go
pyac/pyac.py

index 1f4ee1144f9fc454beb2f537b489fd3cc5d9aef74e291bfedfcecf8f52e9b7b7..278c681c1187676c053e6e4da4e31afad6e87ee08eba0ada427858ab48333106 100644 (file)
@@ -121,7 +121,6 @@ yacItemsParse( // NOLINT(misc-no-recursion)
     const unsigned char *buf,
     const size_t len,
     const bool allowContainers,
-    const bool expectEOC,
     const size_t recursionDepth)
 {
     if (recursionDepth >= parseMaxRecursionDepth) {
@@ -136,11 +135,6 @@ yacItemsParse( // NOLINT(misc-no-recursion)
 #pragma clang diagnostic ignored "-Wswitch-enum"
     switch (items->list[item].atom.typ) {
 #pragma clang diagnostic pop
-    case YACItemEOC:
-        if (!expectEOC) {
-            return YACErrUnexpectedEOC;
-        }
-        break;
     case YACItemList: {
         if (!allowContainers) {
             return YACErrUnknownType;
@@ -151,7 +145,7 @@ yacItemsParse( // NOLINT(misc-no-recursion)
         size_t cur = 0;
         size_t idx = item;
         for (;;) {
-            err = yacItemsParse(items, off, buf, len, true, true, recursionDepth + 1);
+            err = yacItemsParse(items, off, buf, len, true, recursionDepth + 1);
             if (err != YACErrNo) {
                 return err;
             }
@@ -182,7 +176,7 @@ yacItemsParse( // NOLINT(misc-no-recursion)
         size_t prevKeyLen = 0;
         const unsigned char *prevKey = NULL;
         for (;;) {
-            err = yacItemsParse(items, off, buf, len, false, true, recursionDepth + 1);
+            err = yacItemsParse(items, off, buf, len, false, recursionDepth + 1);
             if (err != YACErrNo) {
                 return err;
             }
@@ -218,11 +212,14 @@ yacItemsParse( // NOLINT(misc-no-recursion)
             }
             prev = cur;
             idx = (items->len) - 1;
-            err = yacItemsParse(items, off, buf, len, true, false, recursionDepth + 1);
+            err = yacItemsParse(items, off, buf, len, true, recursionDepth + 1);
             if (err != YACErrNo) {
                 return err;
             }
             cur = idx + 1;
+            if (items->list[cur].atom.typ == YACItemEOC) {
+                return YACErrUnexpectedEOC;
+            }
             items->list[prev].next = cur;
             prev = cur;
             idx = (items->len) - 1;
@@ -240,7 +237,7 @@ yacItemsParse( // NOLINT(misc-no-recursion)
         size_t cur = 0;
         bool eoc = false;
         while (!eoc) {
-            err = yacItemsParse(items, off, buf, len, false, true, recursionDepth + 1);
+            err = yacItemsParse(items, off, buf, len, false, recursionDepth + 1);
             if (err != YACErrNo) {
                 return err;
             }
@@ -294,7 +291,11 @@ YACItemsParse( // NOLINT(misc-no-recursion)
     const unsigned char *buf,
     const size_t len)
 {
-    return yacItemsParse(items, off, buf, len, true, false, 0);
+    enum YACErr err = yacItemsParse(items, off, buf, len, true, 0);
+    if (items->list[0].atom.typ == YACItemEOC) {
+        err = YACErrUnexpectedEOC;
+    }
+    return err;
 }
 
 bool
index 6bf959123ba6d36c465779aa8168b2e171dcf5b31a52bfc0667934b6005aafc0..d8f12cb9e78dd1c7e30e505482bd5a0218e13ea4627f134f867d7e72949822eb 100644 (file)
@@ -56,7 +56,7 @@ type Item struct {
 
 func decode(
        buf []byte,
-       allowContainers, expectEOC bool,
+       allowContainers bool,
        recursionDepth int,
 ) (item Item, tail []byte, err error) {
        if recursionDepth > parseMaxRecursionDepth {
@@ -71,11 +71,6 @@ func decode(
        buf = buf[off:]
        tail = buf
        switch item.T {
-       case types.EOC:
-               if !expectEOC {
-                       err = ErrUnexpectedEOC
-                       return
-               }
        case types.List:
                if !allowContainers {
                        err = atom.ErrUnknownType
@@ -84,7 +79,7 @@ func decode(
                var sub Item
                var v []Item
                for {
-                       sub, buf, err = decode(buf, true, true, recursionDepth+1)
+                       sub, buf, err = decode(buf, true, recursionDepth+1)
                        tail = buf
                        if err != nil {
                                tail = buf
@@ -106,7 +101,7 @@ func decode(
                var sub Item
                var keyPrev string
                for {
-                       sub, buf, err = decode(buf, false, true, recursionDepth+1)
+                       sub, buf, err = decode(buf, false, recursionDepth+1)
                        tail = buf
                        if err != nil {
                                return
@@ -133,11 +128,15 @@ func decode(
                                }
                                keyPrev = s
                        }
-                       sub, buf, err = decode(buf, true, false, recursionDepth+1)
+                       sub, buf, err = decode(buf, true, recursionDepth+1)
                        tail = buf
                        if err != nil {
                                return
                        }
+                       if sub.T == types.EOC {
+                               err = ErrUnexpectedEOC
+                               return
+                       }
                        v[keyPrev] = sub
                }
                item.V = v
@@ -152,7 +151,7 @@ func decode(
                var sub Item
        BlobCycle:
                for {
-                       sub, buf, err = decode(buf, false, true, recursionDepth+1)
+                       sub, buf, err = decode(buf, false, recursionDepth+1)
                        tail = buf
                        if err != nil {
                                return
@@ -189,5 +188,9 @@ func decode(
 
 // Decode single YAC-encoded data item. Remaining data will be kept in tail.
 func Decode(buf []byte) (item Item, tail []byte, err error) {
-       return decode(buf, true, false, 0)
+       item, tail, err = decode(buf, true, 0)
+       if item.T == types.EOC {
+               err = ErrUnexpectedEOC
+       }
+       return item, tail, err
 }
index f9a64aa8d16d829e9400fd4fee0b1f35280a9a6347cdf27aa0fe0ddf60a0886b..39c49bb344fffa7683d87ef566f1eef08c9d45af11ea300a2ccf2b1be6e15ee8 100755 (executable)
@@ -280,13 +280,11 @@ def tai2utc(secs, leapsecUTCAllow=False):
     return secs - diff
 
 
-def _loads(v, sets=False, leapsecUTCAllow=False, _expectEOC=False, _allowContainers=True):
+def _loads(v, sets=False, leapsecUTCAllow=False, _allowContainers=True):
     if len(v) == 0:
         raise NotEnoughData(1)
     b = v[0]
     if b == TagEOC:
-        if not _expectEOC:
-            raise DecodeError("unexpected EOC")
         return _EOC, v[1:]
     if b == TagNIL:
         return None, v[1:]
@@ -374,7 +372,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _expectEOC=False, _allowContain
         ret = []
         v = v[1:]
         while True:
-            i, v = loads(v, sets=sets, leapsecUTCAllow=leapsecUTCAllow, _expectEOC=True)
+            i, v = _loads(v, sets=sets, leapsecUTCAllow=leapsecUTCAllow)
             if i == _EOC:
                 break
             ret.append(i)
@@ -385,7 +383,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _expectEOC=False, _allowContain
         kPrev = ""
         allNILs = True
         while True:
-            k, v = _loads(v, _expectEOC=True, _allowContainers=False)
+            k, v = _loads(v, _allowContainers=False)
             if k == _EOC:
                 break
             if not isinstance(k, str):
@@ -394,7 +392,9 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _expectEOC=False, _allowContain
                 if len(k) == 0:
                     raise DecodeError("empty key")
                 raise DecodeError("unsorted keys")
-            i, v = loads(v, sets=sets, leapsecUTCAllow=leapsecUTCAllow, _expectEOC=False)
+            i, v = _loads(v, sets=sets, leapsecUTCAllow=leapsecUTCAllow)
+            if i == _EOC:
+                raise DecodeError("unexpected EOC")
             ret[k] = i
             kPrev = k
             if i is not None:
@@ -409,7 +409,7 @@ def _loads(v, sets=False, leapsecUTCAllow=False, _expectEOC=False, _allowContain
         v = v[1+8:]
         raws = []
         while True:
-            i, v = _loads(v, _expectEOC=True, _allowContainers=False)
+            i, v = _loads(v, _allowContainers=False)
             if i is None:
                 if len(v) < l:
                     raise NotEnoughData(l-len(v)+1)
@@ -437,9 +437,12 @@ def loads(v, **kwargs):
     :rtype: (any, bytes)
     """
     try:
-        return _loads(v, **kwargs)
+        ret, tail = _loads(v, **kwargs)
     except RecursionError as err:
         raise DecodeError("deep recursion") from err
+    if ret == _EOC:
+        raise DecodeError("unexpected EOC")
+    return ret, tail
 
 
 if __name__ == "__main__":