From: Sergey Matveev Date: Mon, 16 Dec 2024 13:39:26 +0000 (+0300) Subject: Decode context options X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=7fa69120116a449c0c52137f10c4d6eed392b17515f29a2e8d31b35c8ae4edae;p=keks.git Decode context options --- diff --git a/go/atom/ctx.go b/go/atom/ctx.go new file mode 100644 index 0000000..138d442 --- /dev/null +++ b/go/atom/ctx.go @@ -0,0 +1,50 @@ +// GoKEKS -- Go KEKS codec implementation +// Copyright (C) 2024-2025 Sergey Matveev +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as +// published by the Free Software Foundation, version 3 of the License. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this program. If not, see . + +package atom + +import "io" + +type Decoder struct { + // You have to set one of R or B as a data source. Decoding from the + // B buffer takes less allocations, it is faster. + R io.Reader + B []byte + + // Maximal allowable string length. 0 means no limits, but pay + // attention that if there is no sufficient memory available, + // then Go may panic. + MaxStrLen int64 + + // Disable UTF-8 codepoints validation check. + DisableUTF8Check bool +} + +// Read n bytes from the data source. If data source is R, then buf is +// allocated for each new read. If data source is B, then buf is a slice +// of the original B buffer. +func (ctx *Decoder) Want(n int) (buf []byte, err error) { + if ctx.B == nil { + buf = make([]byte, n) + _, err = io.ReadFull(ctx.R, buf) + return + } + if len(ctx.B) < n { + err = io.ErrUnexpectedEOF + return + } + buf, ctx.B = ctx.B[:n], ctx.B[n:] + return +} diff --git a/go/atom/dec.go b/go/atom/dec.go index 48f287e..971c96c 100644 --- a/go/atom/dec.go +++ b/go/atom/dec.go @@ -17,7 +17,6 @@ package atom import ( "errors" - "io" "math/big" "strings" "unicode/utf8" @@ -37,9 +36,9 @@ var ( ErrBadInt = errors.New("bad int value") ) -func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) { +func (ctx *Decoder) strDecode(tag byte) (read int64, v []byte, err error) { l := int64(tag & 63) - var ll int64 + var ll int switch l { case 61: ll = 1 @@ -51,9 +50,8 @@ func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) { l += ((1 << 8) - 1) + ((1 << 16) - 1) } if ll != 0 { - read += ll - v = make([]byte, ll) - _, err = io.ReadFull(r, v) + read += int64(ll) + v, err = ctx.Want(ll) if err != nil { return } @@ -69,17 +67,19 @@ func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) { err = ErrLenTooBig return } - // TODO: check if it is too large for memory - v = make([]byte, l) - _, err = io.ReadFull(r, v) + if ctx.MaxStrLen > 0 && l > ctx.MaxStrLen { + err = ErrLenTooBig + return + } + v, err = ctx.Want(int(l)) return } // Decode a single KEKS-encoded atom. Atom means that it does not decode // full lists, maps, blobs and may return types.EOC. -func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { - buf := make([]byte, 1) - _, err = io.ReadFull(r, buf) +func (ctx *Decoder) Decode() (t types.Type, v any, read int64, err error) { + var buf []byte + buf, err = ctx.Want(1) if err != nil { return } @@ -92,7 +92,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { t = types.Str } var strRead int64 - strRead, buf, err = strDecode(r, tag) + strRead, buf, err = ctx.strDecode(tag) read += strRead if err != nil { return @@ -102,11 +102,13 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { } else { s := unsafe.String(unsafe.SliceData(buf), len(buf)) v = s - if !utf8.ValidString(s) { - err = ErrBadUTF8 - } - if strings.Contains(s, "\x00") { - err = ErrBadUTF8 + if !ctx.DisableUTF8Check { + if !utf8.ValidString(s) { + err = ErrBadUTF8 + } + if strings.Contains(s, "\x00") { + err = ErrBadUTF8 + } } } return @@ -125,8 +127,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { case UUID: t = types.UUID read += 16 - buf = make([]byte, 16) - _, err = io.ReadFull(r, buf) + buf, err = ctx.Want(16) if err != nil { return } @@ -138,8 +139,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { case Blob: t = types.Blob read += 8 - buf = make([]byte, 8) - _, err = io.ReadFull(r, buf) + buf, err = ctx.Want(8) if err != nil { return } @@ -158,7 +158,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { t = types.Int } read += 1 - _, err = io.ReadFull(r, buf) + buf, err = ctx.Want(1) if err != nil { return } @@ -167,7 +167,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { return } var binRead int64 - binRead, buf, err = strDecode(r, buf[0]) + binRead, buf, err = ctx.strDecode(buf[0]) read += binRead if err != nil { return @@ -227,8 +227,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { l = 32 } read += int64(l) - buf = make([]byte, l) - _, err = io.ReadFull(r, buf) + buf, err = ctx.Want(l) if err != nil { t = types.Float return @@ -248,8 +247,7 @@ func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { } t = types.TAI64 read += int64(l) - buf = make([]byte, l) - _, err = io.ReadFull(r, buf) + buf, err = ctx.Want(l) if err != nil { return } diff --git a/go/cmd/print/main.go b/go/cmd/print/main.go index 57a1cce..36258d8 100644 --- a/go/cmd/print/main.go +++ b/go/cmd/print/main.go @@ -7,10 +7,11 @@ import ( "os" "go.cypherpunks.su/keks" + "go.cypherpunks.su/keks/atom" ) func main() { - item, read, err := keks.Decode(bufio.NewReader(os.Stdin)) + item, read, err := keks.Decode(&atom.Decoder{R: bufio.NewReader(os.Stdin)}) if err != nil { log.Fatal(err) } diff --git a/go/dec.go b/go/dec.go index 0d28a4a..0d0fd23 100644 --- a/go/dec.go +++ b/go/dec.go @@ -57,7 +57,7 @@ type Item struct { } func decode( - r io.Reader, + ctx *atom.Decoder, allowContainers bool, recursionDepth int, ) (item Item, read int64, err error) { @@ -65,7 +65,7 @@ func decode( err = errors.New("deep recursion") return } - item.T, item.V, read, err = atom.Decode(r) + item.T, item.V, read, err = ctx.Decode() if err != nil { return } @@ -79,7 +79,7 @@ func decode( var subRead int64 var v []Item for { - sub, subRead, err = decode(r, true, recursionDepth+1) + sub, subRead, err = decode(ctx, true, recursionDepth+1) read += subRead if err != nil { return @@ -101,7 +101,7 @@ func decode( var subRead int64 var keyPrev string for { - sub, subRead, err = decode(r, false, recursionDepth+1) + sub, subRead, err = decode(ctx, false, recursionDepth+1) read += subRead if err != nil { return @@ -128,7 +128,7 @@ func decode( } keyPrev = s } - sub, subRead, err = decode(r, true, recursionDepth+1) + sub, subRead, err = decode(ctx, true, recursionDepth+1) read += subRead if err != nil { return @@ -146,27 +146,30 @@ func decode( err = atom.ErrUnknownType return } - // TODO: check if it is too large for memory chunkLen := int(item.V.(uint64)) + if ctx.MaxStrLen != 0 && int64(chunkLen) > ctx.MaxStrLen { + err = atom.ErrLenTooBig + return + } v := Blob{ChunkLen: chunkLen} var sub Item var subRead int64 var chunks []io.Reader BlobCycle: for { - sub, subRead, err = decode(r, false, recursionDepth+1) + sub, subRead, err = decode(ctx, false, recursionDepth+1) read += subRead if err != nil { return } switch sub.T { case types.NIL: - buf := make([]byte, chunkLen) - read += int64(chunkLen) - _, err = io.ReadFull(r, buf) + var buf []byte + buf, err = ctx.Want(chunkLen) if err != nil { return } + read += int64(chunkLen) chunks = append(chunks, bytes.NewReader(buf)) v.DecodedLen += int64(chunkLen) case types.Bin: @@ -193,8 +196,8 @@ func decode( } // Decode single KEKS-encoded data item. -func Decode(r io.Reader) (item Item, read int64, err error) { - item, read, err = decode(r, true, 0) +func Decode(ctx *atom.Decoder) (item Item, read int64, err error) { + item, read, err = decode(ctx, true, 0) if item.T == types.EOC { err = ErrUnexpectedEOC } diff --git a/go/fuzz_test.go b/go/fuzz_test.go index c9bc4c6..1640a65 100644 --- a/go/fuzz_test.go +++ b/go/fuzz_test.go @@ -3,6 +3,8 @@ package keks import ( "bytes" "testing" + + "go.cypherpunks.su/keks/atom" ) func FuzzItemDecode(f *testing.F) { @@ -11,7 +13,7 @@ func FuzzItemDecode(f *testing.F) { var e any var buf bytes.Buffer f.Fuzz(func(t *testing.T, b []byte) { - item, _, err = Decode(bytes.NewReader(b)) + item, _, err = Decode(&atom.Decoder{B: b, MaxStrLen: 1 << 20}) if err == nil { e, err = item.ToGo() if err != nil { diff --git a/go/mapstruct/dec.go b/go/mapstruct/dec.go index 81202b5..42ac5d5 100644 --- a/go/mapstruct/dec.go +++ b/go/mapstruct/dec.go @@ -17,17 +17,17 @@ package mapstruct import ( "errors" - "io" "go.cypherpunks.su/keks" + "go.cypherpunks.su/keks/atom" "go.cypherpunks.su/keks/types" ) // Decode KEKS-encoded data to the dst structure. // It will return an error if decoded data is not map. -func Decode(dst any, src io.Reader) (err error) { +func Decode(dst any, ctx *atom.Decoder) (err error) { var item keks.Item - item, _, err = keks.Decode(src) + item, _, err = keks.Decode(ctx) if err != nil { return } diff --git a/go/pki/prv.go b/go/pki/prv.go index 0f9a0a3..527e33f 100644 --- a/go/pki/prv.go +++ b/go/pki/prv.go @@ -16,11 +16,11 @@ package pki import ( - "bytes" "crypto" "errors" "fmt" + "go.cypherpunks.su/keks/atom" "go.cypherpunks.su/keks/mapstruct" ed25519blake2b "go.cypherpunks.su/keks/pki/ed25519-blake2b" "go.cypherpunks.su/keks/pki/gost" @@ -30,7 +30,7 @@ import ( func PrvParse(data []byte) (prv crypto.Signer, pub []byte, err error) { var av AV var tail []byte - err = mapstruct.Decode(&av, bytes.NewReader(data)) + err = mapstruct.Decode(&av, &atom.Decoder{B: data, MaxStrLen: 1<<16}) if err != nil { return } diff --git a/go/pki/signed-data.go b/go/pki/signed-data.go index fd2ea44..3e81ff2 100644 --- a/go/pki/signed-data.go +++ b/go/pki/signed-data.go @@ -16,7 +16,6 @@ package pki import ( - "bytes" "crypto" "crypto/rand" "errors" @@ -25,6 +24,7 @@ import ( "github.com/google/uuid" "go.cypherpunks.su/keks" + "go.cypherpunks.su/keks/atom" "go.cypherpunks.su/keks/mapstruct" "go.cypherpunks.su/keks/types" ) @@ -141,7 +141,7 @@ func SignedDataParseItem(item keks.Item) (sd *SignedData, err error) { // SignedDataParseItem. func SignedDataParse(data []byte) (sd *SignedData, err error) { var item keks.Item - item, _, err = keks.Decode(bytes.NewReader(data)) + item, _, err = keks.Decode(&atom.Decoder{B: data}) if err != nil { return }