From: Sergey Matveev Date: Sun, 15 Dec 2024 13:30:24 +0000 (+0300) Subject: Streaming encoding/decoding X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=c4cc2cca398b8d90b0d856cd15881b16afc3f9054b6234b3afb46caaa3ed0826;p=keks.git Streaming encoding/decoding --- diff --git a/gyac/atom/dec.go b/gyac/atom/dec.go index 8f36ad2..59dca13 100644 --- a/gyac/atom/dec.go +++ b/gyac/atom/dec.go @@ -17,6 +17,7 @@ package atom import ( "errors" + "io" "math/big" "strings" "unicode/utf8" @@ -29,67 +30,77 @@ import ( ) var ( - ErrNotEnough = errors.New("not enough data") - ErrLenTooBig = errors.New("string len >1<<60") + ErrLenTooBig = errors.New("too big string") ErrIntNonMinimal = errors.New("int non minimal") ErrUnknownType = errors.New("unknown type") ErrBadUTF8 = errors.New("invalid UTF-8") ErrBadInt = errors.New("bad int value") ) +func strDecode(r io.Reader, tag byte) (read int64, v []byte, err error) { + l := int64(tag & 63) + var ll int64 + switch l { + case 61: + ll = 1 + case 62: + ll = 2 + l += ((1 << 8) - 1) + case 63: + ll = 8 + l += ((1 << 8) - 1) + ((1 << 16) - 1) + } + if ll != 0 { + read += ll + v = make([]byte, ll) + _, err = io.ReadFull(r, v) + if err != nil { + return + } + ul := be.Get(v) + if ul > (1<<63)-(63+((1<<8)-1)+((1<<16)-1)) { + err = ErrLenTooBig + return + } + l += int64(ul) + } + read += l + if read < 0 { // overflowed + err = ErrLenTooBig + return + } + // TODO: check if it is too large for memory + v = make([]byte, l) + _, err = io.ReadFull(r, v) + return +} + // Decode a single YAC-encoded atom. Atom means that it does not decode // full lists, maps, blobs and may return types.EOC. -func Decode(buf []byte) (t types.Type, v any, off int, err error) { - off = 1 - if len(buf) < 1 { - err = ErrNotEnough +func Decode(r io.Reader) (t types.Type, v any, read int64, err error) { + buf := make([]byte, 1) + _, err = io.ReadFull(r, buf) + if err != nil { return } + read = 1 tag := buf[0] if (tag & Strings) > 0 { - l := int(tag & 63) if (tag & IsUTF8) == 0 { t = types.Bin } else { t = types.Str } - ll := 0 - switch l { - case 61: - ll = 1 - case 62: - ll = 2 - l += ((1 << 8) - 1) - case 63: - ll = 8 - l += ((1 << 8) - 1) + ((1 << 16) - 1) - } - if ll != 0 { - off += ll - if len(buf) < off { - err = ErrNotEnough - return - } - ul := be.Get(buf[1 : 1+ll]) - if ul > (1<<63)-(63+((1<<8)-1)+((1<<16)-1)) { - err = ErrLenTooBig - return - } - l += int(ul) - } - off += l - if off <= 0 { - err = ErrLenTooBig - return - } - if len(buf) < off { - err = ErrNotEnough + var strRead int64 + strRead, buf, err = strDecode(r, tag) + read += strRead + if err != nil { return } if t == types.Bin { - v = buf[1+ll : 1+ll+l] + v = buf } else { - s := unsafe.String(unsafe.SliceData(buf[1+ll:]), l) + s := unsafe.String(unsafe.SliceData(buf), len(buf)) v = s if !utf8.ValidString(s) { err = ErrBadUTF8 @@ -112,26 +123,27 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { t = types.Bool v = true case UUID: - off += 16 t = types.UUID - if len(buf) < off { - err = ErrNotEnough + read += 16 + buf = make([]byte, 16) + _, err = io.ReadFull(r, buf) + if err != nil { return } - v, err = uuid.FromBytes(buf[1 : 1+16]) - + v, err = uuid.FromBytes(buf) case List: t = types.List case Map: t = types.Map case Blob: t = types.Blob - off += 8 - if len(buf) < off { - err = ErrNotEnough + read += 8 + buf = make([]byte, 8) + _, err = io.ReadFull(r, buf) + if err != nil { return } - chunkLen := be.Get(buf[1 : 1+8]) + chunkLen := be.Get(buf) if chunkLen >= (1<<63)-1 { err = ErrLenTooBig return @@ -145,28 +157,22 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { } else { t = types.Int } - var binOff int - if len(buf) < 2 { - err = ErrNotEnough + read += 1 + _, err = io.ReadFull(r, buf) + if err != nil { return } - if buf[1]&Strings == 0 { + if buf[0]&Strings == 0 || buf[0]&IsUTF8 != 0 { err = ErrBadInt return } - var binT types.Type - var binV any - binT, binV, binOff, err = Decode(buf[1:]) - off += binOff + var binRead int64 + binRead, buf, err = strDecode(r, buf[0]) + read += binRead if err != nil { return } - if binT != types.Bin { - err = ErrBadInt - return - } - raw := binV.([]byte) - if len(raw) == 0 { + if len(buf) == 0 { if t == types.UInt { v = uint64(0) } else { @@ -174,13 +180,13 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { } return } - if raw[0] == 0 { + if buf[0] == 0 { err = ErrIntNonMinimal return } - if len(raw) > 8 { + if len(buf) > 8 { bi := big.NewInt(0) - bi = bi.SetBytes(raw) + bi = bi.SetBytes(buf) if t == types.Int { n1 := big.NewInt(-1) bi = bi.Sub(n1, bi) @@ -189,13 +195,13 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { v = bi return } - i := be.Get(raw) + i := be.Get(buf) if t == types.UInt { v = i } else { if i >= (1 << 63) { bi := big.NewInt(0) - bi = bi.SetBytes(raw) + bi = bi.SetBytes(buf) n1 := big.NewInt(-1) bi = bi.Sub(n1, bi) t = types.BigInt @@ -220,14 +226,15 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { case Float256: l = 32 } - off += l - if len(buf) < off { + read += int64(l) + buf = make([]byte, l) + _, err = io.ReadFull(r, buf) + if err != nil { t = types.Float - err = ErrNotEnough return } t = types.Raw - v = Raw{T: Type(tag), V: buf[1 : 1+l]} + v = Raw{T: Type(tag), V: buf} case TAI64, TAI64N, TAI64NA: var l int @@ -239,19 +246,20 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { case TAI64NA: l = 16 } - off += l - if len(buf) < off { - err = ErrNotEnough + t = types.TAI64 + read += int64(l) + buf = make([]byte, l) + _, err = io.ReadFull(r, buf) + if err != nil { return } - t = types.TAI64 - v = buf[1 : 1+l] - if be.Get(buf[1:1+8]) > (1 << 63) { + v = buf + if be.Get(buf[:8]) > (1 << 63) { err = errors.New("reserved TAI64 values in use") return } if l > 8 { - nsecs := be.Get(buf[1+8 : 1+8+4]) + nsecs := be.Get(buf[8 : 8+4]) if l == 12 && nsecs == 0 { err = errors.New("non-minimal TAI64N") return @@ -262,7 +270,7 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { } } if l > 12 { - asecs := be.Get(buf[1+8+4 : 1+8+4+4]) + asecs := be.Get(buf[8+4 : 8+4+4]) if asecs == 0 { err = errors.New("non-minimal TAI64NA") return @@ -275,7 +283,6 @@ func Decode(buf []byte) (t types.Type, v any, off int, err error) { default: err = ErrUnknownType - return } return } diff --git a/gyac/atom/enc.go b/gyac/atom/enc.go index a6b757b..d9979f4 100644 --- a/gyac/atom/enc.go +++ b/gyac/atom/enc.go @@ -16,6 +16,8 @@ package atom import ( + "bytes" + "io" "math/big" "github.com/google/uuid" @@ -52,32 +54,33 @@ const ( IsUTF8 = 0x40 ) -// Append an encoded EOC atom to the buf. -func EOCEncode(buf []byte) []byte { - return append(buf, byte(EOC)) +// Write an encoded EOC atom. +func EOCEncode(w io.Writer) (written int64, err error) { + return io.Copy(w, bytes.NewReader([]byte{byte(EOC)})) } -// Append an encoded NIL atom to the buf. -func NILEncode(buf []byte) []byte { - return append(buf, byte(NIL)) +// Write an encoded NIL atom. +func NILEncode(w io.Writer) (written int64, err error) { + return io.Copy(w, bytes.NewReader([]byte{byte(NIL)})) } -// Append an encoded TRUE/FALSE atom to the buf. -func BoolEncode(buf []byte, v bool) []byte { +// Write an encoded TRUE/FALSE atom. +func BoolEncode(w io.Writer, v bool) (written int64, err error) { + data := []byte{byte(False)} if v { - return append(buf, byte(True)) + data[0] = byte(True) } - return append(buf, byte(False)) + return io.Copy(w, bytes.NewReader(data)) } -// Append an encoded UUID atom to the buf. -func UUIDEncode(buf []byte, v uuid.UUID) []byte { - return append(append(buf, byte(UUID)), v[:]...) +// Write an encoded UUID atom. +func UUIDEncode(w io.Writer, v uuid.UUID) (written int64, err error) { + return io.Copy(w, bytes.NewReader(append([]byte{byte(UUID)}, v[:]...))) } -func atomUintEncode(v uint64) (buf []byte) { +func atomUintEncode(w io.Writer, v uint64) (written int64, err error) { if v == 0 { - return BinEncode(nil, []byte{}) + return BinEncode(w, []byte{}) } l := 0 for ; l < 7; l++ { @@ -85,114 +88,163 @@ func atomUintEncode(v uint64) (buf []byte) { break } } - buf = make([]byte, l+1) + buf := make([]byte, l+1) be.Put(buf, v) - return BinEncode(nil, buf) + return BinEncode(w, buf) } -// Append an encoded +INT atom to the buf. -func UIntEncode(buf []byte, v uint64) []byte { - return append(buf, append([]byte{byte(PInt)}, atomUintEncode(v)...)...) +// Write an encoded +INT atom. +func UIntEncode(w io.Writer, v uint64) (written int64, err error) { + written, err = io.Copy(w, bytes.NewReader([]byte{byte(PInt)})) + if err != nil { + return + } + written, err = atomUintEncode(w, v) + written++ + return } -// Append an encoded -INT atom to the buf if v is negative. +// Write an encoded -INT atom if v is negative. // +INT atom otherwise, same as UIntEncode. -func IntEncode(buf []byte, v int64) []byte { +func IntEncode(w io.Writer, v int64) (written int64, err error) { if v >= 0 { - return UIntEncode(buf, uint64(v)) + return UIntEncode(w, uint64(v)) + } + written, err = io.Copy(w, bytes.NewReader([]byte{byte(NInt)})) + if err != nil { + return } - return append(buf, append([]byte{byte(NInt)}, - atomUintEncode(uint64(-(v+1)))...)...) + written, err = atomUintEncode(w, uint64(-(v + 1))) + written++ + return } -// Append an encoded ±INT atom to the buf. -func BigIntEncode(buf []byte, v *big.Int) []byte { +// Write an encoded ±INT atom. +func BigIntEncode(w io.Writer, v *big.Int) (written int64, err error) { if v.Cmp(bigIntZero) >= 0 { - return append(buf, BinEncode([]byte{byte(PInt)}, v.Bytes())...) + written, err = io.Copy(w, bytes.NewReader([]byte{byte(PInt)})) + if err != nil { + return + } + written, err = BinEncode(w, v.Bytes()) + written++ + return } n1 := big.NewInt(-1) v = v.Abs(v) v = v.Add(v, n1) - return append(buf, BinEncode([]byte{byte(NInt)}, v.Bytes())...) + written, err = io.Copy(w, bytes.NewReader([]byte{byte(NInt)})) + if err != nil { + return + } + written, err = BinEncode(w, v.Bytes()) + written++ + return } -// Append an encoded LIST atom to the buf. +// Write an encoded LIST atom. // You have to manually terminate it with EOCEncode. -func ListEncode(buf []byte) []byte { - return append(buf, byte(List)) +func ListEncode(w io.Writer) (written int64, err error) { + return io.Copy(w, bytes.NewReader([]byte{byte(List)})) } -// Append an encoded MAP atom to the buf. +// Write an encoded MAP atom. // You have to manually terminate it with EOCEncode. -func MapEncode(buf []byte) []byte { - return append(buf, byte(Map)) +func MapEncode(w io.Writer) (written int64, err error) { + return io.Copy(w, bytes.NewReader([]byte{byte(Map)})) } -// Append an encoded BLOB atom to the buf. +// Write an encoded BLOB atom. // You have to manually provide necessary chunks and // properly terminate it with BinEncode. -func BlobEncode(buf []byte, chunkLen int) []byte { +func BlobEncode(w io.Writer, chunkLen int) (written int64, err error) { l := make([]byte, 9) l[0] = byte(Blob) be.Put(l[1:], uint64(chunkLen-1)) - return append(buf, l...) + return io.Copy(w, bytes.NewReader(l)) } -func atomStrEncode(buf, data []byte, utf8 bool) []byte { - var lv int - var l []byte +func atomStrEncode(w io.Writer, data []byte, utf8 bool) (written int64, err error) { + tag := byte(Strings) + if utf8 { + tag |= IsUTF8 + } + var hdr []byte if len(data) >= 63+((1<<8)-1)+((1<<16)-1) { - lv = 63 - l = make([]byte, 8) - be.Put(l, uint64(len(data)-(lv+((1<<8)-1)+((1<<16)-1)))) + hdr = make([]byte, 8+1) + hdr[0] = tag | 63 + be.Put(hdr[1:], uint64(len(data)-(63+((1<<8)-1)+((1<<16)-1)))) } else if len(data) >= 62+255 { - lv = 62 - l = make([]byte, 2) - be.Put(l, uint64(len(data)-(lv+((1<<8)-1)))) + hdr = make([]byte, 2+1) + hdr[0] = tag | 62 + be.Put(hdr[1:], uint64(len(data)-(62+((1<<8)-1)))) } else if len(data) >= 61 { - lv = 61 - l = []byte{byte(len(data) - lv)} + hdr = []byte{tag | 61, byte(len(data) - 61)} } else { - lv = len(data) + hdr = []byte{tag | byte(len(data))} } - b := byte(Strings | lv) - if utf8 { - b |= IsUTF8 + var hdrLen int64 + hdrLen, err = io.Copy(w, bytes.NewReader(hdr)) + if err != nil { + written = hdrLen + return } - return append(append(append(buf, b), l...), data...) + written, err = io.Copy(w, bytes.NewReader(data)) + written += hdrLen + return } -// Append an encoded STR atom to the buf. -func StrEncode(buf []byte, str string) []byte { - return atomStrEncode(buf, []byte(str), true) +// Write an encoded STR atom. +func StrEncode(w io.Writer, str string) (written int64, err error) { + return atomStrEncode(w, []byte(str), true) } -// Append an encoded BIN atom to the buf. -func BinEncode(buf, bin []byte) []byte { - return atomStrEncode(buf, bin, false) +// Write an encoded BIN atom. +func BinEncode(w io.Writer, bin []byte) (written int64, err error) { + return atomStrEncode(w, bin, false) } -// Append an encoded CHUNK atom to the buf. -// That is basically an appended NIL with the chunk value. -func ChunkEncode(buf, chunk []byte) []byte { - return append(append(buf, byte(NIL)), chunk...) +// Write an encoded CHUNK atom. +// That is basically NIL with the chunk value. +func ChunkEncode(w io.Writer, chunk []byte) (written int64, err error) { + written, err = NILEncode(w) + if err != nil { + return + } + written, err = io.Copy(w, bytes.NewReader(chunk)) + written++ + return } -// Append an encoded TAI64* atom to the buf. -func TAI64Encode(buf, tai []byte) []byte { +// Write an encoded TAI64* atom. +func TAI64Encode(w io.Writer, tai []byte) (written int64, err error) { + var tag []byte switch len(tai) { case 8: - return append(append(buf, byte(TAI64)), tai...) + tag = []byte{byte(TAI64)} case 12: - return append(append(buf, byte(TAI64N)), tai...) + tag = []byte{byte(TAI64N)} case 16: - return append(append(buf, byte(TAI64NA)), tai...) + tag = []byte{byte(TAI64NA)} default: panic("wrong TAI64 value") } + written, err = io.Copy(w, bytes.NewReader(tag)) + if err != nil { + return + } + written, err = io.Copy(w, bytes.NewReader(tai)) + written++ + return } -// Append an encoded raw atom's value to the buf. -func RawEncode(buf []byte, raw Raw) []byte { - return append(append(buf, byte(raw.T)), raw.V...) +// Write an encoded raw atom's value. +func RawEncode(w io.Writer, raw Raw) (written int64, err error) { + written, err = io.Copy(w, bytes.NewReader([]byte{byte(raw.T)})) + if err != nil { + return + } + written, err = io.Copy(w, bytes.NewReader(raw.V)) + written++ + return } diff --git a/gyac/cmd/print/main.go b/gyac/cmd/print/main.go index c4bfee5..2d33fb4 100644 --- a/gyac/cmd/print/main.go +++ b/gyac/cmd/print/main.go @@ -3,7 +3,6 @@ package main import ( "bufio" "fmt" - "io" "log" "os" @@ -11,20 +10,13 @@ import ( ) func main() { - data, err := io.ReadAll(bufio.NewReader(os.Stdin)) + item, read, err := gyac.Decode(bufio.NewReader(os.Stdin)) if err != nil { log.Fatal(err) } - item, tail, err := gyac.Decode(data) - if err != nil { - log.Fatal(err) - } - if len(tail) > 0 { - log.Fatalln("trailing data:", tail) - } e, err := item.ToGo() if err != nil { log.Fatal(err) } - fmt.Printf("%v\n", e) + fmt.Printf("%v\n%d bytes\n", e, read) } diff --git a/gyac/cmd/test-vector-anys/main.go b/gyac/cmd/test-vector-anys/main.go index a5fd7f5..ccb4e9d 100644 --- a/gyac/cmd/test-vector-anys/main.go +++ b/gyac/cmd/test-vector-anys/main.go @@ -112,5 +112,10 @@ func main() { if err != nil { log.Fatal(err) } - fmt.Println(hex.EncodeToString(item.Encode(nil))) + var buf bytes.Buffer + _, err = item.Encode(&buf) + if err != nil { + log.Fatal(err) + } + fmt.Println(hex.EncodeToString(buf.Bytes())) } diff --git a/gyac/cmd/test-vector-manual/main.go b/gyac/cmd/test-vector-manual/main.go index c070c26..151a7e4 100644 --- a/gyac/cmd/test-vector-manual/main.go +++ b/gyac/cmd/test-vector-manual/main.go @@ -21,183 +21,192 @@ func mustHexDec(s string) []byte { return b } +var Size int64 + +func mustEncode(n int64, err error) { + Size += n + if err != nil { + panic(err) + } +} + func main() { - buf := make([]byte, 0, 68*1024) + var buf bytes.Buffer { - buf = atom.MapEncode(buf) + mustEncode(atom.MapEncode(&buf)) { - buf = atom.StrEncode(buf, "nil") - buf = atom.NILEncode(buf) + mustEncode(atom.StrEncode(&buf, "nil")) + mustEncode(atom.NILEncode(&buf)) } { - buf = atom.StrEncode(buf, "str") - buf = atom.MapEncode(buf) + mustEncode(atom.StrEncode(&buf, "str")) + mustEncode(atom.MapEncode(&buf)) { - buf = atom.StrEncode(buf, "bin") - buf = atom.ListEncode(buf) + mustEncode(atom.StrEncode(&buf, "bin")) + mustEncode(atom.ListEncode(&buf)) { - buf = atom.BinEncode(buf, []byte("")) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'0'}, 60)) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'1'}, 61)) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'2'}, 255)) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'A'}, 61+255)) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'B'}, 62+255)) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'3'}, 1024)) - buf = atom.BinEncode(buf, bytes.Repeat([]byte{'4'}, 63+255+65535+1)) + mustEncode(atom.BinEncode(&buf, []byte(""))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'0'}, 60))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'1'}, 61))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'2'}, 255))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'A'}, 61+255))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'B'}, 62+255))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'3'}, 1024))) + mustEncode(atom.BinEncode(&buf, bytes.Repeat([]byte{'4'}, 63+255+65535+1))) } - buf = atom.EOCEncode(buf) + mustEncode(atom.EOCEncode(&buf)) { - buf = atom.StrEncode(buf, "utf8") - buf = atom.StrEncode(buf, "привет мир") + mustEncode(atom.StrEncode(&buf, "utf8")) + mustEncode(atom.StrEncode(&buf, "привет мир")) } } - buf = atom.EOCEncode(buf) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "blob") - buf = atom.ListEncode(buf) + mustEncode(atom.StrEncode(&buf, "blob")) + mustEncode(atom.ListEncode(&buf)) { - buf = atom.BlobEncode(buf, 12) - buf = atom.BinEncode(buf, []byte{'5'}) + mustEncode(atom.BlobEncode(&buf, 12)) + mustEncode(atom.BinEncode(&buf, []byte{'5'})) } { - buf = atom.BlobEncode(buf, 12) - buf = atom.ChunkEncode(buf, bytes.Repeat([]byte{'6'}, 12)) - buf = atom.BinEncode(buf, []byte{}) + mustEncode(atom.BlobEncode(&buf, 12)) + mustEncode(atom.ChunkEncode(&buf, bytes.Repeat([]byte{'6'}, 12))) + mustEncode(atom.BinEncode(&buf, []byte{})) } { - buf = atom.BlobEncode(buf, 12) - buf = atom.ChunkEncode(buf, bytes.Repeat([]byte{'7'}, 12)) - buf = atom.BinEncode(buf, []byte{'7'}) + mustEncode(atom.BlobEncode(&buf, 12)) + mustEncode(atom.ChunkEncode(&buf, bytes.Repeat([]byte{'7'}, 12))) + mustEncode(atom.BinEncode(&buf, []byte{'7'})) } { - buf = atom.BlobEncode(buf, 5) - buf = atom.ChunkEncode(buf, []byte("12345")) - buf = atom.ChunkEncode(buf, []byte("67890")) - buf = atom.BinEncode(buf, []byte{'-'}) + mustEncode(atom.BlobEncode(&buf, 5)) + mustEncode(atom.ChunkEncode(&buf, []byte("12345"))) + mustEncode(atom.ChunkEncode(&buf, []byte("67890"))) + mustEncode(atom.BinEncode(&buf, []byte{'-'})) } - buf = atom.EOCEncode(buf) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "bool") - buf = atom.ListEncode(buf) - buf = atom.BoolEncode(buf, true) - buf = atom.BoolEncode(buf, false) - buf = atom.EOCEncode(buf) + mustEncode(atom.StrEncode(&buf, "bool")) + mustEncode(atom.ListEncode(&buf)) + mustEncode(atom.BoolEncode(&buf, true)) + mustEncode(atom.BoolEncode(&buf, false)) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "ints") - buf = atom.MapEncode(buf) + mustEncode(atom.StrEncode(&buf, "ints")) + mustEncode(atom.MapEncode(&buf)) { - buf = atom.StrEncode(buf, "neg") - buf = atom.ListEncode(buf) - buf = atom.IntEncode(buf, -1) - buf = atom.IntEncode(buf, -2) - buf = atom.IntEncode(buf, -32) - buf = atom.IntEncode(buf, -33) - buf = atom.IntEncode(buf, -123) - buf = atom.IntEncode(buf, -1234) - buf = atom.IntEncode(buf, -12345678) + mustEncode(atom.StrEncode(&buf, "neg")) + mustEncode(atom.ListEncode(&buf)) + mustEncode(atom.IntEncode(&buf, -1)) + mustEncode(atom.IntEncode(&buf, -2)) + mustEncode(atom.IntEncode(&buf, -32)) + mustEncode(atom.IntEncode(&buf, -33)) + mustEncode(atom.IntEncode(&buf, -123)) + mustEncode(atom.IntEncode(&buf, -1234)) + mustEncode(atom.IntEncode(&buf, -12345678)) b := big.NewInt(0) b.SetBytes(mustHexDec("0100000000000000000000")) b = b.Neg(b) - buf = atom.BigIntEncode(buf, b) + mustEncode(atom.BigIntEncode(&buf, b)) b.SetBytes(mustHexDec("0100000000000000000000000000000001")) b = b.Neg(b) - buf = atom.BigIntEncode(buf, b) + mustEncode(atom.BigIntEncode(&buf, b)) b.SetBytes(mustHexDec("e5a461280341856d4ad908a69ea5f3ccc10c7882142bb7d801cc380f26b6b4d69632024ee521f8cfafb443d49a2a3d0cc73bb4757e882f5396ed302b418210d0d49d71be86ca699cf5ee3bd6d57ed658e69316229644ba650c92d7f0d4db29c3ad1dfa9979166f4c6e79561a58f8e2c63d08df4e2246ed1f64d2d613a19d8c9a6870e6188e2f3ad40c038fda30452f8ddfcd212a6a974bc25ec6a0564c66a7d28750ff9db458b74441e49ee5e82dbf4974d645678e0ad031f97aaba855451eef17a89b42821e530816dd5793a83b7a82e8ede81e7f3395691f761784f8bc627961cd40845ee908a40b9d1f01927b38eb1a7d4efd60db0944f7ec1b832b7e6eb1833f9a351576ad5de571fae8865da7514f06b0fbf38c1f2a8538f5d38b4e18001ccbb9ddcb488530f6086d14744d8b5672166e48e9ef93772575db66b6f257c6ffad6e2c291510c5ed02e1a8b24b44ec1e2a91686238e8defd18c01998634a5076a6b7f85fc81a1d61a15b2c528dfa082ce3e3e2ca649ac04817ec5c123e0b761ab103f780c014f021bbeb7ea3b86e0ca1c833e38ef5c897a6d7e1f4a2398c490b3d65e2f45c7fae402d1df1698b6fddb185481664871c2664bfd1686b2b3372783f1856f6247a3f8437a2818f68b7c4ea13a5f57b73c72870b684045f15")) b = b.Neg(b) - buf = atom.BigIntEncode(buf, b) - buf = atom.EOCEncode(buf) + mustEncode(atom.BigIntEncode(&buf, b)) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "pos") - buf = atom.ListEncode(buf) - buf = atom.UIntEncode(buf, 0) - buf = atom.UIntEncode(buf, 1) - buf = atom.UIntEncode(buf, 31) - buf = atom.UIntEncode(buf, 32) - buf = atom.UIntEncode(buf, 123) - buf = atom.UIntEncode(buf, 1234) - buf = atom.UIntEncode(buf, 12345678) + mustEncode(atom.StrEncode(&buf, "pos")) + mustEncode(atom.ListEncode(&buf)) + mustEncode(atom.UIntEncode(&buf, 0)) + mustEncode(atom.UIntEncode(&buf, 1)) + mustEncode(atom.UIntEncode(&buf, 31)) + mustEncode(atom.UIntEncode(&buf, 32)) + mustEncode(atom.UIntEncode(&buf, 123)) + mustEncode(atom.UIntEncode(&buf, 1234)) + mustEncode(atom.UIntEncode(&buf, 12345678)) b := big.NewInt(0) b.SetBytes(mustHexDec("0100000000000000000000")) - buf = atom.BigIntEncode(buf, b) + mustEncode(atom.BigIntEncode(&buf, b)) b.SetBytes(mustHexDec("0100000000000000000000000000000000")) - buf = atom.BigIntEncode(buf, b) - buf = atom.EOCEncode(buf) + mustEncode(atom.BigIntEncode(&buf, b)) + mustEncode(atom.EOCEncode(&buf)) } - buf = atom.EOCEncode(buf) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "uuid") - buf = atom.UUIDEncode(buf, - uuid.MustParse("0e875e3f-d385-49eb-87b4-be42d641c367")) + mustEncode(atom.StrEncode(&buf, "uuid")) + mustEncode(atom.UUIDEncode(&buf, + uuid.MustParse("0e875e3f-d385-49eb-87b4-be42d641c367"))) } { - buf = atom.StrEncode(buf, "dates") - buf = atom.ListEncode(buf) + mustEncode(atom.StrEncode(&buf, "dates")) + mustEncode(atom.ListEncode(&buf)) { var tai tai64n.TAI64 t := time.Unix(1234567890, 0) t = tai64n.Leapsecs.Add(t) tai.FromTime(t) - buf = atom.TAI64Encode(buf, tai[:]) + mustEncode(atom.TAI64Encode(&buf, tai[:])) } { var tai tai64n.TAI64N t := time.Unix(1234567890, 456*1000) t = tai64n.Leapsecs.Add(t) tai.FromTime(t) - buf = atom.TAI64Encode(buf, tai[:]) + mustEncode(atom.TAI64Encode(&buf, tai[:])) } { var tai tai64n.TAI64N t := time.Unix(1234567890, 456789) t = tai64n.Leapsecs.Add(t) tai.FromTime(t) - buf = atom.TAI64Encode(buf, tai[:]) + mustEncode(atom.TAI64Encode(&buf, tai[:])) } - buf = atom.RawEncode(buf, atom.Raw{ + mustEncode(atom.RawEncode(&buf, atom.Raw{ T: atom.TAI64NA, V: []byte("\x40\x00\x00\x00\x49\x96\x02\xF4\x00\x06\xF8\x55\x07\x5B\xCD\x15"), - }) - buf = atom.EOCEncode(buf) + })) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "floats") - buf = atom.ListEncode(buf) - buf = atom.RawEncode(buf, atom.Raw{ + mustEncode(atom.StrEncode(&buf, "floats")) + mustEncode(atom.ListEncode(&buf)) + mustEncode(atom.RawEncode(&buf, atom.Raw{ T: atom.Float32, V: []byte("\x01\x02\x03\x04"), - }) - buf = atom.EOCEncode(buf) + })) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.StrEncode(buf, "empties") - buf = atom.ListEncode(buf) + mustEncode(atom.StrEncode(&buf, "empties")) + mustEncode(atom.ListEncode(&buf)) { - buf = atom.ListEncode(buf) - buf = atom.EOCEncode(buf) + mustEncode(atom.ListEncode(&buf)) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.MapEncode(buf) - buf = atom.EOCEncode(buf) + mustEncode(atom.MapEncode(&buf)) + mustEncode(atom.EOCEncode(&buf)) } { - buf = atom.BlobEncode(buf, 123) - buf = atom.BinEncode(buf, []byte{}) + mustEncode(atom.BlobEncode(&buf, 123)) + mustEncode(atom.BinEncode(&buf, []byte{})) } - buf = atom.UUIDEncode(buf, uuid.Nil) - buf = atom.RawEncode(buf, atom.Raw{ + mustEncode(atom.UUIDEncode(&buf, uuid.Nil)) + mustEncode(atom.RawEncode(&buf, atom.Raw{ T: atom.TAI64, V: []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), - }) - buf = atom.EOCEncode(buf) + })) + mustEncode(atom.EOCEncode(&buf)) } - buf = atom.EOCEncode(buf) + mustEncode(atom.EOCEncode(&buf)) } - fmt.Println(hex.EncodeToString(buf)) + fmt.Println(hex.EncodeToString(buf.Bytes())) } diff --git a/gyac/dec.go b/gyac/dec.go index d8f12cb..1ddafb5 100644 --- a/gyac/dec.go +++ b/gyac/dec.go @@ -17,6 +17,7 @@ package gyac import ( "errors" + "io" "go.cypherpunks.su/yac/gyac/atom" "go.cypherpunks.su/yac/gyac/types" @@ -55,21 +56,18 @@ type Item struct { } func decode( - buf []byte, + r io.Reader, allowContainers bool, recursionDepth int, -) (item Item, tail []byte, err error) { +) (item Item, read int64, err error) { if recursionDepth > parseMaxRecursionDepth { err = errors.New("deep recursion") return } - var off int - item.T, item.V, off, err = atom.Decode(buf) + item.T, item.V, read, err = atom.Decode(r) if err != nil { return } - buf = buf[off:] - tail = buf switch item.T { case types.List: if !allowContainers { @@ -77,12 +75,12 @@ func decode( return } var sub Item + var subRead int64 var v []Item for { - sub, buf, err = decode(buf, true, recursionDepth+1) - tail = buf + sub, subRead, err = decode(r, true, recursionDepth+1) + read += subRead if err != nil { - tail = buf return } if sub.T == types.EOC { @@ -99,10 +97,11 @@ func decode( } v := make(map[string]Item) var sub Item + var subRead int64 var keyPrev string for { - sub, buf, err = decode(buf, false, recursionDepth+1) - tail = buf + sub, subRead, err = decode(r, false, recursionDepth+1) + read += subRead if err != nil { return } @@ -128,8 +127,8 @@ func decode( } keyPrev = s } - sub, buf, err = decode(buf, true, recursionDepth+1) - tail = buf + sub, subRead, err = decode(r, true, recursionDepth+1) + read += subRead if err != nil { return } @@ -146,25 +145,27 @@ func decode( err = atom.ErrUnknownType return } + // TODO: check if it is too large for memory chunkLen := int(item.V.(uint64)) v := Blob{ChunkLen: chunkLen} var sub Item + var subRead int64 BlobCycle: for { - sub, buf, err = decode(buf, false, recursionDepth+1) - tail = buf + sub, subRead, err = decode(r, false, recursionDepth+1) + read += subRead if err != nil { return } switch sub.T { case types.NIL: - if len(buf) <= chunkLen { - err = atom.ErrNotEnough + buf := make([]byte, chunkLen) + read += int64(chunkLen) + _, err = io.ReadFull(r, buf) + if err != nil { return } - v.Chunks = append(v.Chunks, buf[:chunkLen]) - buf = buf[chunkLen:] - tail = buf + v.Chunks = append(v.Chunks, buf) case types.Bin: b := sub.V.([]byte) if len(b) >= chunkLen { @@ -186,11 +187,11 @@ func decode( return } -// Decode single YAC-encoded data item. Remaining data will be kept in tail. -func Decode(buf []byte) (item Item, tail []byte, err error) { - item, tail, err = decode(buf, true, 0) +// Decode single YAC-encoded data item. +func Decode(r io.Reader) (item Item, read int64, err error) { + item, read, err = decode(r, true, 0) if item.T == types.EOC { err = ErrUnexpectedEOC } - return item, tail, err + return } diff --git a/gyac/enc.go b/gyac/enc.go index a1f06d4..6565077 100644 --- a/gyac/enc.go +++ b/gyac/enc.go @@ -16,7 +16,9 @@ package gyac import ( + "bytes" "fmt" + "io" "math/big" "sort" @@ -26,29 +28,38 @@ import ( "go.cypherpunks.su/yac/gyac/types" ) -// Encode an item appending to the buf. -func (item Item) Encode(buf []byte) []byte { +// Encode an item. +func (item Item) Encode(w io.Writer) (written int64, err error) { switch item.T { case types.Invalid: panic("invalid item's type met") case types.NIL: - return atom.NILEncode(buf) + return atom.NILEncode(w) case types.Bool: - return atom.BoolEncode(buf, item.V.(bool)) + return atom.BoolEncode(w, item.V.(bool)) case types.UUID: - return atom.UUIDEncode(buf, item.V.(uuid.UUID)) + return atom.UUIDEncode(w, item.V.(uuid.UUID)) case types.UInt: - return atom.UIntEncode(buf, item.V.(uint64)) + return atom.UIntEncode(w, item.V.(uint64)) case types.Int: - return atom.IntEncode(buf, item.V.(int64)) + return atom.IntEncode(w, item.V.(int64)) case types.BigInt: - return atom.BigIntEncode(buf, item.V.(*big.Int)) + return atom.BigIntEncode(w, item.V.(*big.Int)) case types.List: - buf = atom.ListEncode(buf) + written, err = atom.ListEncode(w) + if err != nil { + return + } + var n int64 for _, v := range item.V.([]Item) { - buf = v.Encode(buf) + n, err = v.Encode(w) + written += n + if err != nil { + return + } } - buf = atom.EOCEncode(buf) + n, err = atom.EOCEncode(w) + written += n case types.Map: m := item.V.(map[string]Item) keys := make([]string, 0, len(m)) @@ -56,40 +67,72 @@ func (item Item) Encode(buf []byte) []byte { keys = append(keys, k) } sort.Sort(ByLenFirst(keys)) - buf = atom.MapEncode(buf) + written, err = atom.MapEncode(w) + if err != nil { + return + } + var n int64 for _, k := range keys { - buf = atom.StrEncode(buf, k) - buf = m[k].Encode(buf) + n, err = atom.StrEncode(w, k) + written += n + if err != nil { + return + } + n, err = m[k].Encode(w) + written += n + if err != nil { + return + } } - buf = atom.EOCEncode(buf) + n, err = atom.EOCEncode(w) + written += n case types.Blob: blob := item.V.(Blob) - buf = atom.BlobEncode(buf, blob.ChunkLen) + written, err = atom.BlobEncode(w, blob.ChunkLen) + if err != nil { + return + } + var n int64 for _, chunk := range blob.Chunks { if len(chunk) == blob.ChunkLen { - buf = atom.ChunkEncode(buf, chunk) + n, err = atom.ChunkEncode(w, chunk) + written += n + if err != nil { + return + } } } if len(blob.Chunks) == 0 { - buf = atom.BinEncode(buf, []byte{}) + n, err = atom.BinEncode(w, []byte{}) } else { last := blob.Chunks[len(blob.Chunks)-1] if len(last) == blob.ChunkLen { - buf = atom.BinEncode(buf, []byte{}) + n, err = atom.BinEncode(w, []byte{}) } else { - buf = atom.BinEncode(buf, last) + n, err = atom.BinEncode(w, last) } } + written += n case types.TAI64: - return atom.TAI64Encode(buf, item.V.([]byte)) + return atom.TAI64Encode(w, item.V.([]byte)) case types.Bin: - return atom.BinEncode(buf, item.V.([]byte)) + return atom.BinEncode(w, item.V.([]byte)) case types.Str: - return atom.StrEncode(buf, item.V.(string)) + return atom.StrEncode(w, item.V.(string)) case types.Raw: - return atom.RawEncode(buf, item.V.(atom.Raw)) + return atom.RawEncode(w, item.V.(atom.Raw)) default: panic(fmt.Errorf("unhandled type: %v", item.T)) } - return buf + return +} + +// Append an encoded item to the provided buf. +func (item Item) EncodeBuf(buf []byte) ([]byte, error) { + var b bytes.Buffer + _, err := item.Encode(&b) + if err != nil { + return nil, err + } + return append(buf, b.Bytes()...), nil } diff --git a/gyac/fuzz_test.go b/gyac/fuzz_test.go index f009115..b46304b 100644 --- a/gyac/fuzz_test.go +++ b/gyac/fuzz_test.go @@ -8,10 +8,10 @@ import ( func FuzzItemDecode(f *testing.F) { var item Item var err error - var tail []byte var e any + var buf bytes.Buffer f.Fuzz(func(t *testing.T, b []byte) { - item, tail, err = Decode(b) + item, _, err = Decode(bytes.NewReader(b)) if err == nil { e, err = item.ToGo() if err != nil { @@ -21,7 +21,12 @@ func FuzzItemDecode(f *testing.F) { if err != nil { t.Fail() } - if !bytes.Equal(append(item.Encode(nil), tail...), b) { + buf.Reset() + _, err = item.Encode(&buf) + if err != nil { + t.Fail() + } + if !bytes.Equal(buf.Bytes(), b[:buf.Len()]) { t.Fail() } } diff --git a/gyac/mapstruct/dec.go b/gyac/mapstruct/dec.go index 53eaca2..08415e8 100644 --- a/gyac/mapstruct/dec.go +++ b/gyac/mapstruct/dec.go @@ -17,6 +17,7 @@ package mapstruct import ( "errors" + "io" "go.cypherpunks.su/yac/gyac" "go.cypherpunks.su/yac/gyac/types" @@ -24,9 +25,9 @@ import ( // Decode YAC-encoded data to the dst structure. // It will return an error if decoded data is not map. -func Decode(dst any, raw []byte) (tail []byte, err error) { +func Decode(dst any, src io.Reader) (err error) { var item gyac.Item - item, tail, err = gyac.Decode(raw) + item, _, err = gyac.Decode(src) if err != nil { return } diff --git a/gyac/pki/av.go b/gyac/pki/av.go index 8995419..83c9e48 100644 --- a/gyac/pki/av.go +++ b/gyac/pki/av.go @@ -8,7 +8,6 @@ import ( "go.cypherpunks.su/yac/gyac" pkihash "go.cypherpunks.su/yac/gyac/pki/hash" - "go.cypherpunks.su/yac/gyac/pki/utils" ) // Algorithm-value often used structure. @@ -35,7 +34,7 @@ func (av *AV) Id() (id uuid.UUID) { if err != nil { panic(err) } - utils.MustWrite(hasher, item.Encode(nil)) + item.Encode(hasher) id, err = uuid.NewRandomFromReader(bytes.NewReader(hasher.Sum(nil))) if err != nil { panic(err) diff --git a/gyac/pki/cer.go b/gyac/pki/cer.go index 2093f57..bd48070 100644 --- a/gyac/pki/cer.go +++ b/gyac/pki/cer.go @@ -97,8 +97,8 @@ func (sd *SignedData) CerParse() error { } // Parse YAC-encoded data as SignedData with the CerLoad (certificate) contents. -func CerParse(data []byte) (sd *SignedData, tail []byte, err error) { - sd, tail, err = SignedDataParse(data) +func CerParse(data []byte) (sd *SignedData, err error) { + sd, err = SignedDataParse(data) if err != nil { return } @@ -181,7 +181,11 @@ func (sd *SignedData) CerCheckSignatureFrom(parent *CerLoad) (err error) { if err != nil { return } - return parent.CheckSignature(item.Encode(nil), sig.Sign.V) + buf, err := item.EncodeBuf(nil) + if err != nil { + return + } + return parent.CheckSignature(buf, sig.Sign.V) } // Get CerLoad from SignedData. diff --git a/gyac/pki/cmd/yacertool/main.go b/gyac/pki/cmd/yacertool/main.go index 39c5f98..9b9ee48 100644 --- a/gyac/pki/cmd/yacertool/main.go +++ b/gyac/pki/cmd/yacertool/main.go @@ -88,7 +88,7 @@ func main() { var caCers []*pki.SignedData for _, issuingCer := range issuingCers { var sd *pki.SignedData - sd, _, err = pki.CerParse(utils.MustReadFile(issuingCer)) + sd, err = pki.CerParse(utils.MustReadFile(issuingCer)) if err != nil { log.Fatal(err) } @@ -106,7 +106,7 @@ func main() { if *verify { var sd *pki.SignedData - sd, _, err = pki.CerParse(utils.MustReadFile(*cerPath)) + sd, err = pki.CerParse(utils.MustReadFile(*cerPath)) if err != nil { log.Fatal(err) } @@ -146,7 +146,12 @@ func main() { if err != nil { log.Fatal(err) } - err = os.WriteFile(*prvPath, item.Encode(nil), 0o600) + var data []byte + data, err = item.EncodeBuf(nil) + if err != nil { + log.Fatal(err) + } + err = os.WriteFile(*prvPath, data, 0o600) if err != nil { log.Fatal(err) } @@ -178,7 +183,12 @@ func main() { if err != nil { log.Fatal(err) } - err = os.WriteFile(*cerPath, item.Encode(nil), 0o666) + var data []byte + data, err = item.EncodeBuf(nil) + if err != nil { + log.Fatal(err) + } + err = os.WriteFile(*cerPath, data, 0o666) if err != nil { log.Fatal(err) } diff --git a/gyac/pki/cmd/yacsdtool/main.go b/gyac/pki/cmd/yacsdtool/main.go index 929356b..38830c3 100644 --- a/gyac/pki/cmd/yacsdtool/main.go +++ b/gyac/pki/cmd/yacsdtool/main.go @@ -30,7 +30,7 @@ func main() { if *cerPath == "" { log.Fatal("no -cer is set") } - cer, _, err := pki.CerParse(utils.MustReadFile(*cerPath)) + cer, err := pki.CerParse(utils.MustReadFile(*cerPath)) if err != nil { log.Fatal(err) } @@ -56,7 +56,7 @@ func main() { } if *verify { var sd *pki.SignedData - sd, _, err = pki.SignedDataParse(utils.MustReadFile(*sdPath)) + sd, err = pki.SignedDataParse(utils.MustReadFile(*sdPath)) if err != nil { log.Fatal(err) } @@ -103,7 +103,12 @@ func main() { if err != nil { log.Fatal(err) } - err = os.WriteFile(*sdPath, item.Encode(nil), 0o666) + var data []byte + data, err = item.EncodeBuf(nil) + if err != nil { + log.Fatal(err) + } + err = os.WriteFile(*sdPath, data, 0o666) if err != nil { log.Fatal(err) } diff --git a/gyac/pki/prv.go b/gyac/pki/prv.go index e6faa42..b8aca7e 100644 --- a/gyac/pki/prv.go +++ b/gyac/pki/prv.go @@ -1,6 +1,7 @@ package pki import ( + "bytes" "crypto" "errors" "fmt" @@ -14,7 +15,7 @@ import ( func PrvParse(data []byte) (prv crypto.Signer, pub []byte, err error) { var av AV var tail []byte - tail, err = mapstruct.Decode(&av, data) + err = mapstruct.Decode(&av, bytes.NewReader(data)) if err != nil { return } diff --git a/gyac/pki/signed-data.go b/gyac/pki/signed-data.go index e35cffc..bf346ab 100644 --- a/gyac/pki/signed-data.go +++ b/gyac/pki/signed-data.go @@ -1,6 +1,7 @@ package pki import ( + "bytes" "crypto" "crypto/rand" "errors" @@ -123,9 +124,9 @@ func SignedDataParseItem(item gyac.Item) (sd *SignedData, err error) { // Parse signed-data from YAC-encoded data. This is just a wrapper over // SignedDataParseItem. -func SignedDataParse(data []byte) (sd *SignedData, tail []byte, err error) { +func SignedDataParse(data []byte) (sd *SignedData, err error) { var item gyac.Item - item, tail, err = gyac.Decode(data) + item, _, err = gyac.Decode(bytes.NewReader(data)) if err != nil { return } @@ -153,11 +154,11 @@ func (sd *SignedData) SignWith( if err != nil { return } - sig.Sign.V, err = prv.Sign( - rand.Reader, - item.Encode(nil), - crypto.Hash(0), - ) + buf, err := item.EncodeBuf(nil) + if err != nil { + return + } + sig.Sign.V, err = prv.Sign(rand.Reader, buf, crypto.Hash(0)) if err != nil { return err }