]> Cypherpunks repositories - keks.git/commitdiff
Decode context options
authorSergey Matveev <stargrave@stargrave.org>
Mon, 16 Dec 2024 13:39:26 +0000 (16:39 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Mon, 16 Dec 2024 13:39:26 +0000 (16:39 +0300)
go/atom/ctx.go [new file with mode: 0644]
go/atom/dec.go
go/cmd/print/main.go
go/dec.go
go/fuzz_test.go
go/mapstruct/dec.go
go/pki/prv.go
go/pki/signed-data.go

diff --git a/go/atom/ctx.go b/go/atom/ctx.go
new file mode 100644 (file)
index 0000000..138d442
--- /dev/null
@@ -0,0 +1,50 @@
+// GoKEKS -- Go KEKS codec implementation
+// Copyright (C) 2024-2025 Sergey Matveev <stargrave@stargrave.org>
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as
+// published by the Free Software Foundation, version 3 of the License.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+package atom
+
+import "io"
+
+type Decoder struct {
+       // You have to set one of R or B as a data source. Decoding from the
+       // B buffer takes less allocations, it is faster.
+       R io.Reader
+       B []byte
+
+       // Maximal allowable string length. 0 means no limits, but pay
+       // attention that if there is no sufficient memory available,
+       // then Go may panic.
+       MaxStrLen int64
+
+       // Disable UTF-8 codepoints validation check.
+       DisableUTF8Check bool
+}
+
+// Read n bytes from the data source. If data source is R, then buf is
+// allocated for each new read. If data source is B, then buf is a slice
+// of the original B buffer.
+func (ctx *Decoder) Want(n int) (buf []byte, err error) {
+       if ctx.B == nil {
+               buf = make([]byte, n)
+               _, err = io.ReadFull(ctx.R, buf)
+               return
+       }
+       if len(ctx.B) < n {
+               err = io.ErrUnexpectedEOF
+               return
+       }
+       buf, ctx.B = ctx.B[:n], ctx.B[n:]
+       return
+}
index 48f287e35c05af6e4643013ffe45e3ffc6e152cca994d5b82ab55b5573625288..971c96c5d9dc312ca1c30e410f9070725225cef711ffd89b5cd60bc4eb826171 100644 (file)
@@ -17,7 +17,6 @@ package atom
 
 import (
        "errors"
-       "io"
        "math/big"
        "strings"
        "unicode/utf8"
@@ -37,9 +36,9 @@ var (
        ErrBadInt        = errors.New("bad int value")
 )
 
-func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) {
+func (ctx *Decoder) strDecode(tag byte) (read int64, v []byte, err error) {
        l := int64(tag & 63)
-       var ll int64
+       var ll int
        switch l {
        case 61:
                ll = 1
@@ -51,9 +50,8 @@ func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) {
                l += ((1 << 8) - 1) + ((1 << 16) - 1)
        }
        if ll != 0 {
-               read += ll
-               v = make([]byte, ll)
-               _, err = io.ReadFull(r, v)
+               read += int64(ll)
+               v, err = ctx.Want(ll)
                if err != nil {
                        return
                }
@@ -69,17 +67,19 @@ func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) {
                err = ErrLenTooBig
                return
        }
-       // TODO: check if it is too large for memory
-       v = make([]byte, l)
-       _, err = io.ReadFull(r, v)
+       if ctx.MaxStrLen > 0 && l > ctx.MaxStrLen {
+               err = ErrLenTooBig
+               return
+       }
+       v, err = ctx.Want(int(l))
        return
 }
 
 // Decode a single KEKS-encoded atom. Atom means that it does not decode
 // full lists, maps, blobs and may return types.EOC.
-func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
-       buf := make([]byte, 1)
-       _, err = io.ReadFull(r, buf)
+func (ctx *Decoder) Decode() (t types.Type, v any, read int64, err error) {
+       var buf []byte
+       buf, err = ctx.Want(1)
        if err != nil {
                return
        }
@@ -92,7 +92,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
                        t = types.Str
                }
                var strRead int64
-               strRead, buf, err = strDecode(r, tag)
+               strRead, buf, err = ctx.strDecode(tag)
                read += strRead
                if err != nil {
                        return
@@ -102,11 +102,13 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
                } else {
                        s := unsafe.String(unsafe.SliceData(buf), len(buf))
                        v = s
-                       if !utf8.ValidString(s) {
-                               err = ErrBadUTF8
-                       }
-                       if strings.Contains(s, "\x00") {
-                               err = ErrBadUTF8
+                       if !ctx.DisableUTF8Check {
+                               if !utf8.ValidString(s) {
+                                       err = ErrBadUTF8
+                               }
+                               if strings.Contains(s, "\x00") {
+                                       err = ErrBadUTF8
+                               }
                        }
                }
                return
@@ -125,8 +127,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
        case UUID:
                t = types.UUID
                read += 16
-               buf = make([]byte, 16)
-               _, err = io.ReadFull(r, buf)
+               buf, err = ctx.Want(16)
                if err != nil {
                        return
                }
@@ -138,8 +139,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
        case Blob:
                t = types.Blob
                read += 8
-               buf = make([]byte, 8)
-               _, err = io.ReadFull(r, buf)
+               buf, err = ctx.Want(8)
                if err != nil {
                        return
                }
@@ -158,7 +158,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
                        t = types.Int
                }
                read += 1
-               _, err = io.ReadFull(r, buf)
+               buf, err = ctx.Want(1)
                if err != nil {
                        return
                }
@@ -167,7 +167,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
                        return
                }
                var binRead int64
-               binRead, buf, err = strDecode(r, buf[0])
+               binRead, buf, err = ctx.strDecode(buf[0])
                read += binRead
                if err != nil {
                        return
@@ -227,8 +227,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
                        l = 32
                }
                read += int64(l)
-               buf = make([]byte, l)
-               _, err = io.ReadFull(r, buf)
+               buf, err = ctx.Want(l)
                if err != nil {
                        t = types.Float
                        return
@@ -248,8 +247,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) {
                }
                t = types.TAI64
                read += int64(l)
-               buf = make([]byte, l)
-               _, err = io.ReadFull(r, buf)
+               buf, err = ctx.Want(l)
                if err != nil {
                        return
                }
index 57a1cce851168c045c810a3ff9c65f7ed55d2848cc996f44ee85cce120b57212..36258d815b84e820b0cfbaef0edb3de6dc66c5c2f92a4825e22df16a5f4db219 100644 (file)
@@ -7,10 +7,11 @@ import (
        "os"
 
        "go.cypherpunks.su/keks"
+       "go.cypherpunks.su/keks/atom"
 )
 
 func main() {
-       item, read, err := keks.Decode(bufio.NewReader(os.Stdin))
+       item, read, err := keks.Decode(&atom.Decoder{R: bufio.NewReader(os.Stdin)})
        if err != nil {
                log.Fatal(err)
        }
index 0d28a4abdda141f91b84e03428dd3392bfb0ae3cafc794416d3e4f9138c71178..0d0fd2391c2f87dafbc1f50769ed6bbac6b18ee4237f7bbdcd8eb22abff92702 100644 (file)
--- a/go/dec.go
+++ b/go/dec.go
@@ -57,7 +57,7 @@ type Item struct {
 }
 
 func decode(
-       r io.Reader,
+       ctx *atom.Decoder,
        allowContainers bool,
        recursionDepth int,
 ) (item Item, read int64, err error) {
@@ -65,7 +65,7 @@ func decode(
                err = errors.New("deep recursion")
                return
        }
-       item.T, item.V, read, err = atom.Decode(r)
+       item.T, item.V, read, err = ctx.Decode()
        if err != nil {
                return
        }
@@ -79,7 +79,7 @@ func decode(
                var subRead int64
                var v []Item
                for {
-                       sub, subRead, err = decode(r, true, recursionDepth+1)
+                       sub, subRead, err = decode(ctx, true, recursionDepth+1)
                        read += subRead
                        if err != nil {
                                return
@@ -101,7 +101,7 @@ func decode(
                var subRead int64
                var keyPrev string
                for {
-                       sub, subRead, err = decode(r, false, recursionDepth+1)
+                       sub, subRead, err = decode(ctx, false, recursionDepth+1)
                        read += subRead
                        if err != nil {
                                return
@@ -128,7 +128,7 @@ func decode(
                                }
                                keyPrev = s
                        }
-                       sub, subRead, err = decode(r, true, recursionDepth+1)
+                       sub, subRead, err = decode(ctx, true, recursionDepth+1)
                        read += subRead
                        if err != nil {
                                return
@@ -146,27 +146,30 @@ func decode(
                        err = atom.ErrUnknownType
                        return
                }
-               // TODO: check if it is too large for memory
                chunkLen := int(item.V.(uint64))
+               if ctx.MaxStrLen != 0 && int64(chunkLen) > ctx.MaxStrLen {
+                       err = atom.ErrLenTooBig
+                       return
+               }
                v := Blob{ChunkLen: chunkLen}
                var sub Item
                var subRead int64
                var chunks []io.Reader
        BlobCycle:
                for {
-                       sub, subRead, err = decode(r, false, recursionDepth+1)
+                       sub, subRead, err = decode(ctx, false, recursionDepth+1)
                        read += subRead
                        if err != nil {
                                return
                        }
                        switch sub.T {
                        case types.NIL:
-                               buf := make([]byte, chunkLen)
-                               read += int64(chunkLen)
-                               _, err = io.ReadFull(r, buf)
+                               var buf []byte
+                               buf, err = ctx.Want(chunkLen)
                                if err != nil {
                                        return
                                }
+                               read += int64(chunkLen)
                                chunks = append(chunks, bytes.NewReader(buf))
                                v.DecodedLen += int64(chunkLen)
                        case types.Bin:
@@ -193,8 +196,8 @@ func decode(
 }
 
 // Decode single KEKS-encoded data item.
-func Decode(r io.Reader) (item Item, read int64, err error) {
-       item, read, err = decode(r, true, 0)
+func Decode(ctx *atom.Decoder) (item Item, read int64, err error) {
+       item, read, err = decode(ctx, true, 0)
        if item.T == types.EOC {
                err = ErrUnexpectedEOC
        }
index c9bc4c68ceece2724ad61ba2d0fa7b118524a9af8c7711ad0802e47759cfc7fa..1640a657878319a3dc084868457b73cd7a1a29e5daccb78d626f66a7b66b19f4 100644 (file)
@@ -3,6 +3,8 @@ package keks
 import (
        "bytes"
        "testing"
+
+       "go.cypherpunks.su/keks/atom"
 )
 
 func FuzzItemDecode(f *testing.F) {
@@ -11,7 +13,7 @@ func FuzzItemDecode(f *testing.F) {
        var e any
        var buf bytes.Buffer
        f.Fuzz(func(t *testing.T, b []byte) {
-               item, _, err = Decode(bytes.NewReader(b))
+               item, _, err = Decode(&atom.Decoder{B: b, MaxStrLen: 1 << 20})
                if err == nil {
                        e, err = item.ToGo()
                        if err != nil {
index 81202b592bd739cfe485acdc327b16a2922274459be1cbc0b8d804d257d2f819..42ac5d5c9c066c0ea3d84234c01695b5b8bdce0dee0141fb3ba9f665570a8b09 100644 (file)
@@ -17,17 +17,17 @@ package mapstruct
 
 import (
        "errors"
-       "io"
 
        "go.cypherpunks.su/keks"
+       "go.cypherpunks.su/keks/atom"
        "go.cypherpunks.su/keks/types"
 )
 
 // Decode KEKS-encoded data to the dst structure.
 // It will return an error if decoded data is not map.
-func Decode(dst any, src io.Reader) (err error) {
+func Decode(dst any, ctx *atom.Decoder) (err error) {
        var item keks.Item
-       item, _, err = keks.Decode(src)
+       item, _, err = keks.Decode(ctx)
        if err != nil {
                return
        }
index 0f9a0a370b831c16cce43454e52bb41495788a38302c3b7d382174fadccf5385..527e33f89cd62dd4eec6f08998ab0d544b02809de5bbca06ec317032e1664b7f 100644 (file)
 package pki
 
 import (
-       "bytes"
        "crypto"
        "errors"
        "fmt"
 
+       "go.cypherpunks.su/keks/atom"
        "go.cypherpunks.su/keks/mapstruct"
        ed25519blake2b "go.cypherpunks.su/keks/pki/ed25519-blake2b"
        "go.cypherpunks.su/keks/pki/gost"
@@ -30,7 +30,7 @@ import (
 func PrvParse(data []byte) (prv crypto.Signer, pub []byte, err error) {
        var av AV
        var tail []byte
-       err = mapstruct.Decode(&av, bytes.NewReader(data))
+       err = mapstruct.Decode(&av, &atom.Decoder{B: data, MaxStrLen: 1<<16})
        if err != nil {
                return
        }
index fd2ea441ecd3f6c9095d5f5b6327baf6d38da9c9b4636e1fa8e54dbcfee1a501..3e81ff2e46e924bfcacf6c46a5182ef3566060c76b04d67709199634e72f9889 100644 (file)
@@ -16,7 +16,6 @@
 package pki
 
 import (
-       "bytes"
        "crypto"
        "crypto/rand"
        "errors"
@@ -25,6 +24,7 @@ import (
        "github.com/google/uuid"
 
        "go.cypherpunks.su/keks"
+       "go.cypherpunks.su/keks/atom"
        "go.cypherpunks.su/keks/mapstruct"
        "go.cypherpunks.su/keks/types"
 )
@@ -141,7 +141,7 @@ func SignedDataParseItem(item keks.Item) (sd *SignedData, err error) {
 // SignedDataParseItem.
 func SignedDataParse(data []byte) (sd *SignedData, err error) {
        var item keks.Item
-       item, _, err = keks.Decode(bytes.NewReader(data))
+       item, _, err = keks.Decode(&atom.Decoder{B: data})
        if err != nil {
                return
        }