]> Cypherpunks repositories - gostls13.git/commitdiff
internal/zstd: new internal package for zstd decompression
authorIan Lance Taylor <iant@golang.org>
Fri, 3 Mar 2023 19:42:07 +0000 (11:42 -0800)
committerGopher Robot <gobot@golang.org>
Tue, 18 Apr 2023 20:34:13 +0000 (20:34 +0000)
This package only does zstd decompression, which is starting to
be used for ELF debug sections. If we need zstd compression we
should use github.com/klauspost/compress/zstd. But for now that
is a very large package to vendor into the standard library.

For #55107

Change-Id: I60ede735357d491be653477ed419cf5f2f0d3f71
Reviewed-on: https://go-review.googlesource.com/c/go/+/473356
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>

12 files changed:
src/go/build/deps_test.go
src/internal/zstd/bits.go [new file with mode: 0644]
src/internal/zstd/block.go [new file with mode: 0644]
src/internal/zstd/fse.go [new file with mode: 0644]
src/internal/zstd/fse_test.go [new file with mode: 0644]
src/internal/zstd/fuzz_test.go [new file with mode: 0644]
src/internal/zstd/huff.go [new file with mode: 0644]
src/internal/zstd/literals.go [new file with mode: 0644]
src/internal/zstd/xxhash.go [new file with mode: 0644]
src/internal/zstd/xxhash_test.go [new file with mode: 0644]
src/internal/zstd/zstd.go [new file with mode: 0644]
src/internal/zstd/zstd_test.go [new file with mode: 0644]

index 3238d96b9da42afd9bfe93d48d84bd2ada52fd85..24ad2def85c46e22f39553af9c5e6b8272267d0c 100644 (file)
@@ -226,7 +226,7 @@ var depsRules = `
 
        # compression
        FMT, encoding/binary, hash/adler32, hash/crc32
-       < compress/bzip2, compress/flate, compress/lzw
+       < compress/bzip2, compress/flate, compress/lzw, internal/zstd
        < archive/zip, compress/gzip, compress/zlib;
 
        # templates
diff --git a/src/internal/zstd/bits.go b/src/internal/zstd/bits.go
new file mode 100644 (file)
index 0000000..c9a2f70
--- /dev/null
@@ -0,0 +1,130 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "math/bits"
+)
+
+// block is the data for a single compressed block.
+// The data starts immediately after the 3 byte block header,
+// and is Block_Size bytes long.
+type block []byte
+
+// bitReader reads a bit stream going forward.
+type bitReader struct {
+       r    *Reader // for error reporting
+       data block   // the bits to read
+       off  uint32  // current offset into data
+       bits uint32  // bits ready to be returned
+       cnt  uint32  // number of valid bits in the bits field
+}
+
+// makeBitReader makes a bit reader starting at off.
+func (r *Reader) makeBitReader(data block, off int) bitReader {
+       return bitReader{
+               r:    r,
+               data: data,
+               off:  uint32(off),
+       }
+}
+
+// moreBits is called to read more bits.
+// This ensures that at least 16 bits are available.
+func (br *bitReader) moreBits() error {
+       for br.cnt < 16 {
+               if br.off >= uint32(len(br.data)) {
+                       return br.r.makeEOFError(int(br.off))
+               }
+               c := br.data[br.off]
+               br.off++
+               br.bits |= uint32(c) << br.cnt
+               br.cnt += 8
+       }
+       return nil
+}
+
+// val is called to fetch a value of b bits.
+func (br *bitReader) val(b uint8) uint32 {
+       r := br.bits & ((1 << b) - 1)
+       br.bits >>= b
+       br.cnt -= uint32(b)
+       return r
+}
+
+// backup steps back to the last byte we used.
+func (br *bitReader) backup() {
+       for br.cnt >= 8 {
+               br.off--
+               br.cnt -= 8
+       }
+}
+
+// makeError returns an error at the current offset wrapping a string.
+func (br *bitReader) makeError(msg string) error {
+       return br.r.makeError(int(br.off), msg)
+}
+
+// reverseBitReader reads a bit stream in reverse.
+type reverseBitReader struct {
+       r     *Reader // for error reporting
+       data  block   // the bits to read
+       off   uint32  // current offset into data
+       start uint32  // start in data; we read backward to start
+       bits  uint32  // bits ready to be returned
+       cnt   uint32  // number of valid bits in bits field
+}
+
+// makeReverseBitReader makes a reverseBitReader reading backward
+// from off to start. The bitstream starts with a 1 bit in the last
+// byte, at off.
+func (r *Reader) makeReverseBitReader(data block, off, start int) (reverseBitReader, error) {
+       streamStart := data[off]
+       if streamStart == 0 {
+               return reverseBitReader{}, r.makeError(off, "zero byte at reverse bit stream start")
+       }
+       rbr := reverseBitReader{
+               r:     r,
+               data:  data,
+               off:   uint32(off),
+               start: uint32(start),
+               bits:  uint32(streamStart),
+               cnt:   uint32(7 - bits.LeadingZeros8(streamStart)),
+       }
+       return rbr, nil
+}
+
+// val is called to fetch a value of b bits.
+func (rbr *reverseBitReader) val(b uint8) (uint32, error) {
+       if !rbr.fetch(b) {
+               return 0, rbr.r.makeEOFError(int(rbr.off))
+       }
+
+       rbr.cnt -= uint32(b)
+       v := (rbr.bits >> rbr.cnt) & ((1 << b) - 1)
+       return v, nil
+}
+
+// fetch is called to ensure that at least b bits are available.
+// It reports false if this can't be done,
+// in which case only rbr.cnt bits are available.
+func (rbr *reverseBitReader) fetch(b uint8) bool {
+       for rbr.cnt < uint32(b) {
+               if rbr.off <= rbr.start {
+                       return false
+               }
+               rbr.off--
+               c := rbr.data[rbr.off]
+               rbr.bits <<= 8
+               rbr.bits |= uint32(c)
+               rbr.cnt += 8
+       }
+       return true
+}
+
+// makeError returns an error at the current offset wrapping a string.
+func (rbr *reverseBitReader) makeError(msg string) error {
+       return rbr.r.makeError(int(rbr.off), msg)
+}
diff --git a/src/internal/zstd/block.go b/src/internal/zstd/block.go
new file mode 100644 (file)
index 0000000..bd3040c
--- /dev/null
@@ -0,0 +1,436 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "io"
+)
+
+// debug can be set in the source to print debug info using println.
+const debug = false
+
+// compressedBlock decompresses a compressed block, storing the decompressed
+// data in r.buffer. The blockSize argument is the compressed size.
+// RFC 3.1.1.3.
+func (r *Reader) compressedBlock(blockSize int) error {
+       if len(r.compressedBuf) >= blockSize {
+               r.compressedBuf = r.compressedBuf[:blockSize]
+       } else {
+               // We know that blockSize <= 128K,
+               // so this won't allocate an enormous amount.
+               need := blockSize - len(r.compressedBuf)
+               r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
+       }
+
+       if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
+               return r.wrapNonEOFError(0, err)
+       }
+
+       data := block(r.compressedBuf)
+       off := 0
+       r.buffer = r.buffer[:0]
+
+       litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
+       if err != nil {
+               return err
+       }
+       r.literals = litbuf
+
+       off = litoff
+
+       seqCount, off, err := r.initSeqs(data, off)
+       if err != nil {
+               return err
+       }
+
+       if seqCount == 0 {
+               // No sequences, just literals.
+               if off < len(data) {
+                       return r.makeError(off, "extraneous data after no sequences")
+               }
+               if len(litbuf) == 0 {
+                       return r.makeError(off, "no sequences and no literals")
+               }
+               r.buffer = append(r.buffer, litbuf...)
+               return nil
+       }
+
+       return r.execSeqs(data, off, litbuf, seqCount)
+}
+
+// seqCode is the kind of sequence codes we have to handle.
+type seqCode int
+
+const (
+       seqLiteral seqCode = iota
+       seqOffset
+       seqMatch
+)
+
+// seqCodeInfoData is the information needed to set up seqTables and
+// seqTableBits for a particular kind of sequence code.
+type seqCodeInfoData struct {
+       predefTable     []fseBaselineEntry // predefined FSE
+       predefTableBits int                // number of bits in predefTable
+       maxSym          int                // max symbol value in FSE
+       maxBits         int                // max bits for FSE
+
+       // toBaseline converts from an FSE table to an FSE baseline table.
+       toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
+}
+
+// seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
+var seqCodeInfo = [3]seqCodeInfoData{
+       seqLiteral: {
+               predefTable:     predefinedLiteralTable[:],
+               predefTableBits: 6,
+               maxSym:          35,
+               maxBits:         9,
+               toBaseline:      (*Reader).makeLiteralBaselineFSE,
+       },
+       seqOffset: {
+               predefTable:     predefinedOffsetTable[:],
+               predefTableBits: 5,
+               maxSym:          31,
+               maxBits:         8,
+               toBaseline:      (*Reader).makeOffsetBaselineFSE,
+       },
+       seqMatch: {
+               predefTable:     predefinedMatchTable[:],
+               predefTableBits: 6,
+               maxSym:          52,
+               maxBits:         9,
+               toBaseline:      (*Reader).makeMatchBaselineFSE,
+       },
+}
+
+// initSeqs reads the Sequences_Section_Header and sets up the FSE
+// tables used to read the sequence codes. It returns the number of
+// sequences and the new offset. RFC 3.1.1.3.2.1.
+func (r *Reader) initSeqs(data block, off int) (int, int, error) {
+       if off >= len(data) {
+               return 0, 0, r.makeEOFError(off)
+       }
+
+       seqHdr := data[off]
+       off++
+       if seqHdr == 0 {
+               return 0, off, nil
+       }
+
+       var seqCount int
+       if seqHdr < 128 {
+               seqCount = int(seqHdr)
+       } else if seqHdr < 255 {
+               if off >= len(data) {
+                       return 0, 0, r.makeEOFError(off)
+               }
+               seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
+               off++
+       } else {
+               if off+1 >= len(data) {
+                       return 0, 0, r.makeEOFError(off)
+               }
+               seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
+               off += 2
+       }
+
+       // Read the Symbol_Compression_Modes byte.
+
+       if off >= len(data) {
+               return 0, 0, r.makeEOFError(off)
+       }
+       symMode := data[off]
+       if symMode&3 != 0 {
+               return 0, 0, r.makeError(off, "invalid symbol compression mode")
+       }
+       off++
+
+       // Set up the FSE tables used to decode the sequence codes.
+
+       var err error
+       off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
+       if err != nil {
+               return 0, 0, err
+       }
+
+       off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
+       if err != nil {
+               return 0, 0, err
+       }
+
+       off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
+       if err != nil {
+               return 0, 0, err
+       }
+
+       return seqCount, off, nil
+}
+
+// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
+// r.seqTableBits for kind. We store these in the Reader because one of
+// the modes simply reuses the value from the last block in the frame.
+func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
+       info := &seqCodeInfo[kind]
+       switch mode {
+       case 0:
+               // Predefined_Mode
+               r.seqTables[kind] = info.predefTable
+               r.seqTableBits[kind] = uint8(info.predefTableBits)
+               return off, nil
+
+       case 1:
+               // RLE_Mode
+               if off >= len(data) {
+                       return 0, r.makeEOFError(off)
+               }
+               rle := data[off]
+               off++
+
+               // Build a simple baseline table that always returns rle.
+
+               entry := []fseEntry{
+                       {
+                               sym:  rle,
+                               bits: 0,
+                               base: 0,
+                       },
+               }
+               if cap(r.seqTableBuffers[kind]) == 0 {
+                       r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
+               }
+               r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
+               if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
+                       return 0, err
+               }
+
+               r.seqTables[kind] = r.seqTableBuffers[kind]
+               r.seqTableBits[kind] = 0
+               return off, nil
+
+       case 2:
+               // FSE_Compressed_Mode
+               if cap(r.fseScratch) < 1<<info.maxBits {
+                       r.fseScratch = make([]fseEntry, 1<<info.maxBits)
+               }
+               r.fseScratch = r.fseScratch[:1<<info.maxBits]
+
+               tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
+               if err != nil {
+                       return 0, err
+               }
+               r.fseScratch = r.fseScratch[:1<<tableBits]
+
+               if cap(r.seqTableBuffers[kind]) == 0 {
+                       r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
+               }
+               r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
+
+               if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
+                       return 0, err
+               }
+
+               r.seqTables[kind] = r.seqTableBuffers[kind]
+               r.seqTableBits[kind] = uint8(tableBits)
+               return roff, nil
+
+       case 3:
+               // Repeat_Mode
+               if len(r.seqTables[kind]) == 0 {
+                       return 0, r.makeError(off, "missing repeat sequence FSE table")
+               }
+               return off, nil
+       }
+       panic("unreachable")
+}
+
+// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
+func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
+       // Set up the initial states for the sequence code readers.
+
+       rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
+       if err != nil {
+               return err
+       }
+
+       literalState, err := rbr.val(r.seqTableBits[seqLiteral])
+       if err != nil {
+               return err
+       }
+
+       offsetState, err := rbr.val(r.seqTableBits[seqOffset])
+       if err != nil {
+               return err
+       }
+
+       matchState, err := rbr.val(r.seqTableBits[seqMatch])
+       if err != nil {
+               return err
+       }
+
+       // Read and perform all the sequences. RFC 3.1.1.4.
+
+       seq := 0
+       for seq < seqCount {
+               if len(r.buffer)+len(litbuf) > 128<<10 {
+                       return rbr.makeError("uncompressed size too big")
+               }
+
+               ptoffset := &r.seqTables[seqOffset][offsetState]
+               ptmatch := &r.seqTables[seqMatch][matchState]
+               ptliteral := &r.seqTables[seqLiteral][literalState]
+
+               add, err := rbr.val(ptoffset.basebits)
+               if err != nil {
+                       return err
+               }
+               offset := ptoffset.baseline + add
+
+               add, err = rbr.val(ptmatch.basebits)
+               if err != nil {
+                       return err
+               }
+               match := ptmatch.baseline + add
+
+               add, err = rbr.val(ptliteral.basebits)
+               if err != nil {
+                       return err
+               }
+               literal := ptliteral.baseline + add
+
+               // Handle repeat offsets. RFC 3.1.1.5.
+               // See the comment in makeOffsetBaselineFSE.
+               if ptoffset.basebits > 1 {
+                       r.repeatedOffset3 = r.repeatedOffset2
+                       r.repeatedOffset2 = r.repeatedOffset1
+                       r.repeatedOffset1 = offset
+               } else {
+                       if literal == 0 {
+                               offset++
+                       }
+                       switch offset {
+                       case 1:
+                               offset = r.repeatedOffset1
+                       case 2:
+                               offset = r.repeatedOffset2
+                               r.repeatedOffset2 = r.repeatedOffset1
+                               r.repeatedOffset1 = offset
+                       case 3:
+                               offset = r.repeatedOffset3
+                               r.repeatedOffset3 = r.repeatedOffset2
+                               r.repeatedOffset2 = r.repeatedOffset1
+                               r.repeatedOffset1 = offset
+                       case 4:
+                               offset = r.repeatedOffset1 - 1
+                               r.repeatedOffset3 = r.repeatedOffset2
+                               r.repeatedOffset2 = r.repeatedOffset1
+                               r.repeatedOffset1 = offset
+                       }
+               }
+
+               seq++
+               if seq < seqCount {
+                       // Update the states.
+                       add, err = rbr.val(ptliteral.bits)
+                       if err != nil {
+                               return err
+                       }
+                       literalState = uint32(ptliteral.base) + add
+
+                       add, err = rbr.val(ptmatch.bits)
+                       if err != nil {
+                               return err
+                       }
+                       matchState = uint32(ptmatch.base) + add
+
+                       add, err = rbr.val(ptoffset.bits)
+                       if err != nil {
+                               return err
+                       }
+                       offsetState = uint32(ptoffset.base) + add
+               }
+
+               // The next sequence is now in literal, offset, match.
+
+               if debug {
+                       println("literal", literal, "offset", offset, "match", match)
+               }
+
+               // Copy literal bytes from litbuf.
+               if literal > uint32(len(litbuf)) {
+                       return rbr.makeError("literal byte overflow")
+               }
+               if literal > 0 {
+                       r.buffer = append(r.buffer, litbuf[:literal]...)
+                       litbuf = litbuf[literal:]
+               }
+
+               if match > 0 {
+                       if err := r.copyFromWindow(&rbr, offset, match); err != nil {
+                               return err
+                       }
+               }
+       }
+
+       if len(litbuf) > 0 {
+               r.buffer = append(r.buffer, litbuf...)
+       }
+
+       if rbr.cnt != 0 {
+               return r.makeError(off, "extraneous data after sequences")
+       }
+
+       return nil
+}
+
+// Copy match bytes from the decoded output, or the window, at offset.
+func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
+       if offset == 0 {
+               return rbr.makeError("invalid zero offset")
+       }
+
+       lenBlock := uint32(len(r.buffer))
+       if lenBlock < offset {
+               lenWindow := uint32(len(r.window))
+               windowOffset := offset - lenBlock
+               if windowOffset > lenWindow {
+                       return rbr.makeError("offset past window")
+               }
+               from := lenWindow - windowOffset
+               if from+match <= lenWindow {
+                       r.buffer = append(r.buffer, r.window[from:from+match]...)
+                       return nil
+               }
+               r.buffer = append(r.buffer, r.window[from:]...)
+               copied := lenWindow - from
+               offset -= copied
+               match -= copied
+
+               if offset == 0 && match > 0 {
+                       return rbr.makeError("invalid offset")
+               }
+       }
+
+       from := lenBlock - offset
+       if offset >= match {
+               r.buffer = append(r.buffer, r.buffer[from:from+match]...)
+               return nil
+       }
+
+       // We are being asked to copy data that we are adding to the
+       // buffer in the same copy.
+       for match > 0 {
+               var copy uint32
+               if offset >= match {
+                       copy = match
+               } else {
+                       copy = offset
+               }
+               r.buffer = append(r.buffer, r.buffer[from:from+copy]...)
+               match -= copy
+               from += copy
+       }
+       return nil
+}
diff --git a/src/internal/zstd/fse.go b/src/internal/zstd/fse.go
new file mode 100644 (file)
index 0000000..ea661d4
--- /dev/null
@@ -0,0 +1,437 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "math/bits"
+)
+
+// fseEntry is one entry in an FSE table.
+type fseEntry struct {
+       sym  uint8  // value that this entry records
+       bits uint8  // number of bits to read to determine next state
+       base uint16 // add those bits to this state to get the next state
+}
+
+// readFSE reads an FSE table from data starting at off.
+// maxSym is the maximum symbol value.
+// maxBits is the maximum number of bits permitted for symbols in the table.
+// The FSE is written into table, which must be at least 1<<maxBits in size.
+// This returns the number of bits in the FSE table and the new offset.
+// RFC 4.1.1.
+func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
+       br := r.makeBitReader(data, off)
+       if err := br.moreBits(); err != nil {
+               return 0, 0, err
+       }
+
+       accuracyLog := int(br.val(4)) + 5
+       if accuracyLog > maxBits {
+               return 0, 0, br.makeError("FSE accuracy log too large")
+       }
+
+       // The number of remaining probabilities, plus 1.
+       // This determines the number of bits to be read for the next value.
+       remaining := (1 << accuracyLog) + 1
+
+       // The current difference between small and large values,
+       // which depends on the number of remaining values.
+       // Small values use 1 less bit.
+       threshold := 1 << accuracyLog
+
+       // The number of bits needed to compute threshold.
+       bitsNeeded := accuracyLog + 1
+
+       // The next character value.
+       sym := 0
+
+       // Whether the last count was 0.
+       prev0 := false
+
+       var norm [256]int16
+
+       for remaining > 1 && sym <= maxSym {
+               if err := br.moreBits(); err != nil {
+                       return 0, 0, err
+               }
+
+               if prev0 {
+                       // Previous count was 0, so there is a 2-bit
+                       // repeat flag. If the 2-bit flag is 0b11,
+                       // it adds 3 and then there is another repeat flag.
+                       zsym := sym
+                       for (br.bits & 0xfff) == 0xfff {
+                               zsym += 3 * 6
+                               br.bits >>= 12
+                               br.cnt -= 12
+                               if err := br.moreBits(); err != nil {
+                                       return 0, 0, err
+                               }
+                       }
+                       for (br.bits & 3) == 3 {
+                               zsym += 3
+                               br.bits >>= 2
+                               br.cnt -= 2
+                               if err := br.moreBits(); err != nil {
+                                       return 0, 0, err
+                               }
+                       }
+
+                       // We have at least 14 bits here,
+                       // no need to call moreBits
+
+                       zsym += int(br.val(2))
+
+                       if zsym > maxSym {
+                               return 0, 0, br.makeError("FSE symbol index overflow")
+                       }
+
+                       for ; sym < zsym; sym++ {
+                               norm[uint8(sym)] = 0
+                       }
+
+                       prev0 = false
+                       continue
+               }
+
+               max := (2*threshold - 1) - remaining
+               var count int
+               if int(br.bits&uint32(threshold-1)) < max {
+                       // A small value.
+                       count = int(br.bits & uint32((threshold - 1)))
+                       br.bits >>= bitsNeeded - 1
+                       br.cnt -= uint32(bitsNeeded - 1)
+               } else {
+                       // A large value.
+                       count = int(br.bits & uint32((2*threshold - 1)))
+                       if count >= threshold {
+                               count -= max
+                       }
+                       br.bits >>= bitsNeeded
+                       br.cnt -= uint32(bitsNeeded)
+               }
+
+               count--
+               if count >= 0 {
+                       remaining -= count
+               } else {
+                       remaining--
+               }
+               if sym >= 256 {
+                       return 0, 0, br.makeError("FSE sym overflow")
+               }
+               norm[uint8(sym)] = int16(count)
+               sym++
+
+               prev0 = count == 0
+
+               for remaining < threshold {
+                       bitsNeeded--
+                       threshold >>= 1
+               }
+       }
+
+       if remaining != 1 {
+               return 0, 0, br.makeError("too many symbols in FSE table")
+       }
+
+       for ; sym <= maxSym; sym++ {
+               norm[uint8(sym)] = 0
+       }
+
+       br.backup()
+
+       if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
+               return 0, 0, err
+       }
+
+       return accuracyLog, int(br.off), nil
+}
+
+// buildFSE builds an FSE decoding table from a list of probabilities.
+// The probabilities are in norm. next is scratch space. The number of bits
+// in the table is tableBits.
+func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
+       tableSize := 1 << tableBits
+       highThreshold := tableSize - 1
+
+       var next [256]uint16
+
+       for i, n := range norm {
+               if n >= 0 {
+                       next[uint8(i)] = uint16(n)
+               } else {
+                       table[highThreshold].sym = uint8(i)
+                       highThreshold--
+                       next[uint8(i)] = 1
+               }
+       }
+
+       pos := 0
+       step := (tableSize >> 1) + (tableSize >> 3) + 3
+       mask := tableSize - 1
+       for i, n := range norm {
+               for j := 0; j < int(n); j++ {
+                       table[pos].sym = uint8(i)
+                       pos = (pos + step) & mask
+                       for pos > highThreshold {
+                               pos = (pos + step) & mask
+                       }
+               }
+       }
+       if pos != 0 {
+               return r.makeError(off, "FSE count error")
+       }
+
+       for i := 0; i < tableSize; i++ {
+               sym := table[i].sym
+               nextState := next[sym]
+               next[sym]++
+
+               if nextState == 0 {
+                       return r.makeError(off, "FSE state error")
+               }
+
+               highBit := 15 - bits.LeadingZeros16(nextState)
+
+               bits := tableBits - highBit
+               table[i].bits = uint8(bits)
+               table[i].base = (nextState << bits) - uint16(tableSize)
+       }
+
+       return nil
+}
+
+// fseBaselineEntry is an entry in an FSE baseline table.
+// We use these for literal/match/length values.
+// Those require mapping the symbol to a baseline value,
+// and then reading zero or more bits and adding the value to the baseline.
+// Rather than looking thees up in separate tables,
+// we convert the FSE table to an FSE baseline table.
+type fseBaselineEntry struct {
+       baseline uint32 // baseline for value that this entry represents
+       basebits uint8  // number of bits to read to add to baseline
+       bits     uint8  // number of bits to read to determine next state
+       base     uint16 // add the bits to this base to get the next state
+}
+
+// Given a literal length code, we need to read a number of bits and
+// add that to a baseline. For states 0 to 15 the baseline is the
+// state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
+
+const literalLengthOffset = 16
+
+var literalLengthBase = []uint32{
+       16 | (1 << 24),
+       18 | (1 << 24),
+       20 | (1 << 24),
+       22 | (1 << 24),
+       24 | (2 << 24),
+       28 | (2 << 24),
+       32 | (3 << 24),
+       40 | (3 << 24),
+       48 | (4 << 24),
+       64 | (6 << 24),
+       128 | (7 << 24),
+       256 | (8 << 24),
+       512 | (9 << 24),
+       1024 | (10 << 24),
+       2048 | (11 << 24),
+       4096 | (12 << 24),
+       8192 | (13 << 24),
+       16384 | (14 << 24),
+       32768 | (15 << 24),
+       65536 | (16 << 24),
+}
+
+// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
+func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
+       for i, e := range fseTable {
+               be := fseBaselineEntry{
+                       bits: e.bits,
+                       base: e.base,
+               }
+               if e.sym < literalLengthOffset {
+                       be.baseline = uint32(e.sym)
+                       be.basebits = 0
+               } else {
+                       if e.sym > 35 {
+                               return r.makeError(off, "FSE baseline symbol overflow")
+                       }
+                       idx := e.sym - literalLengthOffset
+                       basebits := literalLengthBase[idx]
+                       be.baseline = basebits & 0xffffff
+                       be.basebits = uint8(basebits >> 24)
+               }
+               baselineTable[i] = be
+       }
+       return nil
+}
+
+// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
+func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
+       for i, e := range fseTable {
+               be := fseBaselineEntry{
+                       bits: e.bits,
+                       base: e.base,
+               }
+               if e.sym > 31 {
+                       return r.makeError(off, "FSE offset symbol overflow")
+               }
+
+               // The simple way to write this is
+               //     be.baseline = 1 << e.sym
+               //     be.basebits = e.sym
+               // That would give us an offset value that corresponds to
+               // the one described in the RFC. However, for offsets > 3
+               // we have to subtract 3. And for offset values 1, 2, 3
+               // we use a repeated offset.
+               //
+               // The baseline is always a power of 2, and is never 0,
+               // so for those low values we will see one entry that is
+               // baseline 1, basebits 0, and one entry that is baseline 2,
+               // basebits 1. All other entries will have baseline >= 4
+               // basebits >= 2.
+               //
+               // So we can check for RFC offset <= 3 by checking for
+               // basebits <= 1. That means that we can subtract 3 here
+               // and not worry about doing it in the hot loop.
+
+               be.baseline = 1 << e.sym
+               if e.sym >= 2 {
+                       be.baseline -= 3
+               }
+               be.basebits = e.sym
+               baselineTable[i] = be
+       }
+       return nil
+}
+
+// Given a match length code, we need to read a number of bits and add
+// that to a baseline. For states 0 to 31 the baseline is state+3 and
+// the number of bits is zero. RFC 3.1.1.3.2.1.1.
+
+const matchLengthOffset = 32
+
+var matchLengthBase = []uint32{
+       35 | (1 << 24),
+       37 | (1 << 24),
+       39 | (1 << 24),
+       41 | (1 << 24),
+       43 | (2 << 24),
+       47 | (2 << 24),
+       51 | (3 << 24),
+       59 | (3 << 24),
+       67 | (4 << 24),
+       83 | (4 << 24),
+       99 | (5 << 24),
+       131 | (7 << 24),
+       259 | (8 << 24),
+       515 | (9 << 24),
+       1027 | (10 << 24),
+       2051 | (11 << 24),
+       4099 | (12 << 24),
+       8195 | (13 << 24),
+       16387 | (14 << 24),
+       32771 | (15 << 24),
+       65539 | (16 << 24),
+}
+
+// makeMatchBaselineFSE converts the match length fseTable to baselineTable.
+func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
+       for i, e := range fseTable {
+               be := fseBaselineEntry{
+                       bits: e.bits,
+                       base: e.base,
+               }
+               if e.sym < matchLengthOffset {
+                       be.baseline = uint32(e.sym) + 3
+                       be.basebits = 0
+               } else {
+                       if e.sym > 52 {
+                               return r.makeError(off, "FSE baseline symbol overflow")
+                       }
+                       idx := e.sym - matchLengthOffset
+                       basebits := matchLengthBase[idx]
+                       be.baseline = basebits & 0xffffff
+                       be.basebits = uint8(basebits >> 24)
+               }
+               baselineTable[i] = be
+       }
+       return nil
+}
+
+// predefinedLiteralTable is the predefined table to use for literal lengths.
+// Generated from table in RFC 3.1.1.3.2.2.1.
+// Checked by TestPredefinedTables.
+var predefinedLiteralTable = [...]fseBaselineEntry{
+       {0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
+       {3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
+       {7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
+       {12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
+       {20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
+       {32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
+       {128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
+       {4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
+       {2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
+       {7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
+       {11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
+       {18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
+       {32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
+       {64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
+       {2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
+       {2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
+       {6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
+       {11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
+       {18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
+       {28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
+       {65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
+       {8192, 13, 6, 0},
+}
+
+// predefinedOffsetTable is the predefined table to use for offsets.
+// Generated from table in RFC 3.1.1.3.2.2.3.
+// Checked by TestPredefinedTables.
+var predefinedOffsetTable = [...]fseBaselineEntry{
+       {1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
+       {32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
+       {125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
+       {8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
+       {16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
+       {125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
+       {4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
+       {8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
+       {61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
+       {268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
+       {33554429, 25, 5, 0}, {16777213, 24, 5, 0},
+}
+
+// predefinedMatchTable is the predefined table to use for match lengths.
+// Generated from table in RFC 3.1.1.3.2.2.2.
+// Checked by TestPredefinedTables.
+var predefinedMatchTable = [...]fseBaselineEntry{
+       {3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
+       {6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
+       {11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
+       {19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
+       {28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
+       {37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
+       {59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
+       {515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
+       {6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
+       {10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
+       {18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
+       {27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
+       {35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
+       {51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
+       {259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
+       {5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
+       {10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
+       {17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
+       {26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
+       {65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
+       {8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
+       {1027, 10, 6, 0},
+}
diff --git a/src/internal/zstd/fse_test.go b/src/internal/zstd/fse_test.go
new file mode 100644 (file)
index 0000000..6f106b6
--- /dev/null
@@ -0,0 +1,89 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "slices"
+       "testing"
+)
+
+// literalPredefinedDistribution is the predefined distribution table
+// for literal lengths. RFC 3.1.1.3.2.2.1.
+var literalPredefinedDistribution = []int16{
+       4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1,
+       2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
+       -1, -1, -1, -1,
+}
+
+// offsetPredefinedDistribution is the predefined distribution table
+// for offsets. RFC 3.1.1.3.2.2.3.
+var offsetPredefinedDistribution = []int16{
+       1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+       1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
+}
+
+// matchPredefinedDistribution is the predefined distribution table
+// for match lengths. RFC 3.1.1.3.2.2.2.
+var matchPredefinedDistribution = []int16{
+       1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1,
+       -1, -1, -1, -1, -1,
+}
+
+// TestPredefinedTables verifies that we can generate the predefined
+// literal/offset/match tables from the input data in RFC 8878.
+// This serves as a test of the predefined tables, and also of buildFSE
+// and the functions that make baseline FSE tables.
+func TestPredefinedTables(t *testing.T) {
+       tests := []struct {
+               name         string
+               distribution []int16
+               tableBits    int
+               toBaseline   func(*Reader, int, []fseEntry, []fseBaselineEntry) error
+               predef       []fseBaselineEntry
+       }{
+               {
+                       name:         "literal",
+                       distribution: literalPredefinedDistribution,
+                       tableBits:    6,
+                       toBaseline:   (*Reader).makeLiteralBaselineFSE,
+                       predef:       predefinedLiteralTable[:],
+               },
+               {
+                       name:         "offset",
+                       distribution: offsetPredefinedDistribution,
+                       tableBits:    5,
+                       toBaseline:   (*Reader).makeOffsetBaselineFSE,
+                       predef:       predefinedOffsetTable[:],
+               },
+               {
+                       name:         "match",
+                       distribution: matchPredefinedDistribution,
+                       tableBits:    6,
+                       toBaseline:   (*Reader).makeMatchBaselineFSE,
+                       predef:       predefinedMatchTable[:],
+               },
+       }
+       for _, test := range tests {
+               test := test
+               t.Run(test.name, func(t *testing.T) {
+                       var r Reader
+                       table := make([]fseEntry, 1<<test.tableBits)
+                       if err := r.buildFSE(0, test.distribution, table, test.tableBits); err != nil {
+                               t.Fatal(err)
+                       }
+
+                       baselineTable := make([]fseBaselineEntry, len(table))
+                       if err := test.toBaseline(&r, 0, table, baselineTable); err != nil {
+                               t.Fatal(err)
+                       }
+
+                       if !slices.Equal(baselineTable, test.predef) {
+                               t.Errorf("got %v, want %v", baselineTable, test.predef)
+                       }
+               })
+       }
+}
diff --git a/src/internal/zstd/fuzz_test.go b/src/internal/zstd/fuzz_test.go
new file mode 100644 (file)
index 0000000..bb6f0a9
--- /dev/null
@@ -0,0 +1,140 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "bytes"
+       "io"
+       "os"
+       "os/exec"
+       "testing"
+)
+
+// badStrings is some inputs that FuzzReader failed on earlier.
+var badStrings = []string{
+       "(\xb5/\xfdd00,\x05\x00\xc4\x0400000000000000000000000000000000000000000000000000000000000000000000000000000 \xa07100000000000000000000000000000000000000000000000000000000000000000000000000aM\x8a2y0B\b",
+       "(\xb5/\xfd00$\x05\x0020 00X70000a70000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+       "(\xb5/\xfd00$\x05\x0020 00B00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+       "(\xb5/\xfd00}\x00\x0020\x00\x9000000000000",
+       "(\xb5/\xfd00}\x00\x00&0\x02\x830!000000000",
+       "(\xb5/\xfd\x1002000$\x05\x0010\xcc0\xa8100000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+       "(\xb5/\xfd\x1002000$\x05\x0000\xcc0\xa8100d\x0000001000000000000000000000000000000000000000000000000000000000000000000000000\x000000000000000000000000000000000000000000000000000000000000000000000000000000",
+       "(\xb5/\xfd001\x00\x0000000000000000000",
+}
+
+// This is a simple fuzzer to see if the decompressor panics.
+func FuzzReader(f *testing.F) {
+       for _, test := range tests {
+               f.Add([]byte(test.compressed))
+       }
+       for _, s := range badStrings {
+               f.Add([]byte(s))
+       }
+       f.Fuzz(func(t *testing.T, b []byte) {
+               r := NewReader(bytes.NewReader(b))
+               io.Copy(io.Discard, r)
+       })
+}
+
+// Fuzz test to verify that what we decompress is what we compress.
+// This isn't a great fuzz test because the fuzzer can't efficiently
+// explore the space of decompressor behavior, since it can't see
+// what the compressor is doing. But it's better than nothing.
+func FuzzDecompressor(f *testing.F) {
+       if _, err := os.Stat("/usr/bin/zstd"); err != nil {
+               f.Skip("skipping because /usr/bin/zstd does not exist")
+       }
+
+       for _, test := range tests {
+               f.Add([]byte(test.uncompressed))
+       }
+
+       // Add some larger data, as that has more interesting compression.
+       f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
+       var buf bytes.Buffer
+       for i := 0; i < 256; i++ {
+               buf.WriteByte(byte(i))
+       }
+       f.Add(bytes.Repeat(buf.Bytes(), 64))
+       f.Add(bigData(f))
+
+       f.Fuzz(func(t *testing.T, b []byte) {
+               cmd := exec.Command("/usr/bin/zstd", "-z")
+               cmd.Stdin = bytes.NewReader(b)
+               var compressed bytes.Buffer
+               cmd.Stdout = &compressed
+               cmd.Stderr = os.Stderr
+               if err := cmd.Run(); err != nil {
+                       t.Errorf("running zstd failed: %v", err)
+               }
+
+               r := NewReader(bytes.NewReader(compressed.Bytes()))
+               got, err := io.ReadAll(r)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if !bytes.Equal(got, b) {
+                       showDiffs(t, got, b)
+               }
+       })
+}
+
+// Fuzz test to check that if we can decompress some data,
+// so can zstd, and that we get the same result.
+func FuzzReverse(f *testing.F) {
+       if _, err := os.Stat("/usr/bin/zstd"); err != nil {
+               f.Skip("skipping because /usr/bin/zstd does not exist")
+       }
+
+       for _, test := range tests {
+               f.Add([]byte(test.compressed))
+       }
+
+       // Set a hook to reject some cases where we don't match zstd.
+       fuzzing = true
+       defer func() { fuzzing = false }()
+
+       f.Fuzz(func(t *testing.T, b []byte) {
+               r := NewReader(bytes.NewReader(b))
+               goExp, goErr := io.ReadAll(r)
+
+               cmd := exec.Command("/usr/bin/zstd", "-d")
+               cmd.Stdin = bytes.NewReader(b)
+               var uncompressed bytes.Buffer
+               cmd.Stdout = &uncompressed
+               cmd.Stderr = os.Stderr
+               zstdErr := cmd.Run()
+               zstdExp := uncompressed.Bytes()
+
+               if goErr == nil && zstdErr == nil {
+                       if !bytes.Equal(zstdExp, goExp) {
+                               showDiffs(t, zstdExp, goExp)
+                       }
+               } else {
+                       // Ideally we should check that this package and
+                       // the zstd program both fail or both succeed,
+                       // and that if they both fail one byte sequence
+                       // is an exact prefix of the other.
+                       // Actually trying this proved to be frustrating,
+                       // as the zstd program appears to accept invalid
+                       // byte sequences using rules that are difficult
+                       // to determine.
+                       // So we just check the prefix.
+
+                       c := len(goExp)
+                       if c > len(zstdExp) {
+                               c = len(zstdExp)
+                       }
+                       goExp = goExp[:c]
+                       zstdExp = zstdExp[:c]
+                       if !bytes.Equal(goExp, zstdExp) {
+                               t.Error("byte mismatch after error")
+                               t.Logf("Go error: %v\n", goErr)
+                               t.Logf("zstd error: %v\n", zstdErr)
+                               showDiffs(t, zstdExp, goExp)
+                       }
+               }
+       })
+}
diff --git a/src/internal/zstd/huff.go b/src/internal/zstd/huff.go
new file mode 100644 (file)
index 0000000..452e24b
--- /dev/null
@@ -0,0 +1,204 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "io"
+       "math/bits"
+)
+
+// maxHuffmanBits is the largest possible Huffman table bits.
+const maxHuffmanBits = 11
+
+// readHuff reads Huffman table from data starting at off into table.
+// Each entry in a Huffman table is a pair of bytes.
+// The high byte is the encoded value. The low byte is the number
+// of bits used to encode that value. We index into the table
+// with a value of size tableBits. A value that requires fewer bits
+// appear in the table multiple times.
+// This returns the number of bits in the Huffman table and the new offset.
+// RFC 4.2.1.
+func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
+       if off >= len(data) {
+               return 0, 0, r.makeEOFError(off)
+       }
+
+       hdr := data[off]
+       off++
+
+       var weights [256]uint8
+       var count int
+       if hdr < 128 {
+               // The table is compressed using an FSE. RFC 4.2.1.2.
+               if len(r.fseScratch) < 1<<6 {
+                       r.fseScratch = make([]fseEntry, 1<<6)
+               }
+               fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
+               if err != nil {
+                       return 0, 0, err
+               }
+               fseTable := r.fseScratch
+
+               if off+int(hdr) > len(data) {
+                       return 0, 0, r.makeEOFError(off)
+               }
+
+               rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
+               if err != nil {
+                       return 0, 0, err
+               }
+
+               state1, err := rbr.val(uint8(fseBits))
+               if err != nil {
+                       return 0, 0, err
+               }
+
+               state2, err := rbr.val(uint8(fseBits))
+               if err != nil {
+                       return 0, 0, err
+               }
+
+               // There are two independent FSE streams, tracked by
+               // state1 and state2. We decode them alternately.
+
+               for {
+                       pt := &fseTable[state1]
+                       if !rbr.fetch(pt.bits) {
+                               if count >= 254 {
+                                       return 0, 0, rbr.makeError("Huffman count overflow")
+                               }
+                               weights[count] = pt.sym
+                               weights[count+1] = fseTable[state2].sym
+                               count += 2
+                               break
+                       }
+
+                       v, err := rbr.val(pt.bits)
+                       if err != nil {
+                               return 0, 0, err
+                       }
+                       state1 = uint32(pt.base) + v
+
+                       if count >= 255 {
+                               return 0, 0, rbr.makeError("Huffman count overflow")
+                       }
+
+                       weights[count] = pt.sym
+                       count++
+
+                       pt = &fseTable[state2]
+
+                       if !rbr.fetch(pt.bits) {
+                               if count >= 254 {
+                                       return 0, 0, rbr.makeError("Huffman count overflow")
+                               }
+                               weights[count] = pt.sym
+                               weights[count+1] = fseTable[state1].sym
+                               count += 2
+                               break
+                       }
+
+                       v, err = rbr.val(pt.bits)
+                       if err != nil {
+                               return 0, 0, err
+                       }
+                       state2 = uint32(pt.base) + v
+
+                       if count >= 255 {
+                               return 0, 0, rbr.makeError("Huffman count overflow")
+                       }
+
+                       weights[count] = pt.sym
+                       count++
+               }
+
+               off += int(hdr)
+       } else {
+               // The table is not compressed. Each weight is 4 bits.
+
+               count = int(hdr) - 127
+               if off+((count+1)/2) >= len(data) {
+                       return 0, 0, io.ErrUnexpectedEOF
+               }
+               for i := 0; i < count; i += 2 {
+                       b := data[off]
+                       off++
+                       weights[i] = b >> 4
+                       weights[i+1] = b & 0xf
+               }
+       }
+
+       // RFC 4.2.1.3.
+
+       var weightMark [13]uint32
+       weightMask := uint32(0)
+       for _, w := range weights[:count] {
+               if w > 12 {
+                       return 0, 0, r.makeError(off, "Huffman weight overflow")
+               }
+               weightMark[w]++
+               if w > 0 {
+                       weightMask += 1 << (w - 1)
+               }
+       }
+       if weightMask == 0 {
+               return 0, 0, r.makeError(off, "bad Huffman weights")
+       }
+
+       tableBits = 32 - bits.LeadingZeros32(weightMask)
+       if tableBits > maxHuffmanBits {
+               return 0, 0, r.makeError(off, "bad Huffman weights")
+       }
+
+       if len(table) < 1<<tableBits {
+               return 0, 0, r.makeError(off, "Huffman table too small")
+       }
+
+       // Work out the last weight value, which is omitted because
+       // the weights must sum to a power of two.
+       left := (uint32(1) << tableBits) - weightMask
+       if left == 0 {
+               return 0, 0, r.makeError(off, "bad Huffman weights")
+       }
+       highBit := 31 - bits.LeadingZeros32(left)
+       if uint32(1)<<highBit != left {
+               return 0, 0, r.makeError(off, "bad Huffman weights")
+       }
+       if count >= 256 {
+               return 0, 0, r.makeError(off, "Huffman weight overflow")
+       }
+       weights[count] = uint8(highBit + 1)
+       count++
+       weightMark[highBit+1]++
+
+       if weightMark[1] < 2 || weightMark[1]&1 != 0 {
+               return 0, 0, r.makeError(off, "bad Huffman weights")
+       }
+
+       // Change weightMark from a count of weights to the index of
+       // the first symbol for that weight. We shift the indexes to
+       // also store how many we have seen so far,
+       next := uint32(0)
+       for i := 0; i < tableBits; i++ {
+               cur := next
+               next += weightMark[i+1] << i
+               weightMark[i+1] = cur
+       }
+
+       for i, w := range weights[:count] {
+               if w == 0 {
+                       continue
+               }
+               length := uint32(1) << (w - 1)
+               tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
+               start := weightMark[w]
+               for j := uint32(0); j < length; j++ {
+                       table[start+j] = tval
+               }
+               weightMark[w] += length
+       }
+
+       return tableBits, off, nil
+}
diff --git a/src/internal/zstd/literals.go b/src/internal/zstd/literals.go
new file mode 100644 (file)
index 0000000..b46d668
--- /dev/null
@@ -0,0 +1,330 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "encoding/binary"
+)
+
+// readLiterals reads and decompresses the literals from data at off.
+// The literals are appended to outbuf, which is returned.
+// Also returns the new input offset. RFC 3.1.1.3.1.
+func (r *Reader) readLiterals(data block, off int, outbuf []byte) (int, []byte, error) {
+       if off >= len(data) {
+               return 0, nil, r.makeEOFError(off)
+       }
+
+       // Literals section header. RFC 3.1.1.3.1.1.
+       hdr := data[off]
+       off++
+
+       if (hdr&3) == 0 || (hdr&3) == 1 {
+               return r.readRawRLELiterals(data, off, hdr, outbuf)
+       } else {
+               return r.readHuffLiterals(data, off, hdr, outbuf)
+       }
+}
+
+// readRawRLELiterals reads and decompresses a Raw_Literals_Block or
+// a RLE_Literals_Block. RFC 3.1.1.3.1.1.
+func (r *Reader) readRawRLELiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
+       raw := (hdr & 3) == 0
+
+       var regeneratedSize int
+       switch (hdr >> 2) & 3 {
+       case 0, 2:
+               regeneratedSize = int(hdr >> 3)
+       case 1:
+               if off >= len(data) {
+                       return 0, nil, r.makeEOFError(off)
+               }
+               regeneratedSize = int(hdr>>4) + (int(data[off]) << 4)
+               off++
+       case 3:
+               if off+1 >= len(data) {
+                       return 0, nil, r.makeEOFError(off)
+               }
+               regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + (int(data[off+1]) << 12)
+               off += 2
+       }
+
+       // We are going to use the entire literal block in the output.
+       // The maximum size of one decompressed block is 128K,
+       // so we can't have more literals than that.
+       if regeneratedSize > 128<<10 {
+               return 0, nil, r.makeError(off, "literal size too large")
+       }
+
+       if raw {
+               // RFC 3.1.1.3.1.2.
+               if off+regeneratedSize > len(data) {
+                       return 0, nil, r.makeError(off, "raw literal size too large")
+               }
+               outbuf = append(outbuf, data[off:off+regeneratedSize]...)
+               off += regeneratedSize
+       } else {
+               // RFC 3.1.1.3.1.3.
+               if off >= len(data) {
+                       return 0, nil, r.makeError(off, "RLE literal missing")
+               }
+               rle := data[off]
+               off++
+               for i := 0; i < regeneratedSize; i++ {
+                       outbuf = append(outbuf, rle)
+               }
+       }
+
+       return off, outbuf, nil
+}
+
+// readHuffLiterals reads and decompresses a Compressed_Literals_Block or
+// a Treeless_Literals_Block. RFC 3.1.1.3.1.4.
+func (r *Reader) readHuffLiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
+       var (
+               regeneratedSize int
+               compressedSize  int
+               streams         int
+       )
+       switch (hdr >> 2) & 3 {
+       case 0, 1:
+               if off+1 >= len(data) {
+                       return 0, nil, r.makeEOFError(off)
+               }
+               regeneratedSize = (int(hdr) >> 4) | ((int(data[off]) & 0x3f) << 4)
+               compressedSize = (int(data[off]) >> 6) | (int(data[off+1]) << 2)
+               off += 2
+               if ((hdr >> 2) & 3) == 0 {
+                       streams = 1
+               } else {
+                       streams = 4
+               }
+       case 2:
+               if off+2 >= len(data) {
+                       return 0, nil, r.makeEOFError(off)
+               }
+               regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 3) << 12)
+               compressedSize = (int(data[off+1]) >> 2) | (int(data[off+2]) << 6)
+               off += 3
+               streams = 4
+       case 3:
+               if off+3 >= len(data) {
+                       return 0, nil, r.makeEOFError(off)
+               }
+               regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 0x3f) << 12)
+               compressedSize = (int(data[off+1]) >> 6) | (int(data[off+2]) << 2) | (int(data[off+3]) << 10)
+               off += 4
+               streams = 4
+       }
+
+       // We are going to use the entire literal block in the output.
+       // The maximum size of one decompressed block is 128K,
+       // so we can't have more literals than that.
+       if regeneratedSize > 128<<10 {
+               return 0, nil, r.makeError(off, "literal size too large")
+       }
+
+       roff := off + compressedSize
+       if roff > len(data) || roff < 0 {
+               return 0, nil, r.makeEOFError(off)
+       }
+
+       totalStreamsSize := compressedSize
+       if (hdr & 3) == 2 {
+               // Compressed_Literals_Block.
+               // Read new huffman tree.
+
+               if len(r.huffmanTable) < 1<<maxHuffmanBits {
+                       r.huffmanTable = make([]uint16, 1<<maxHuffmanBits)
+               }
+
+               huffmanTableBits, hoff, err := r.readHuff(data, off, r.huffmanTable)
+               if err != nil {
+                       return 0, nil, err
+               }
+               r.huffmanTableBits = huffmanTableBits
+
+               if totalStreamsSize < hoff-off {
+                       return 0, nil, r.makeError(off, "Huffman table too big")
+               }
+               totalStreamsSize -= hoff - off
+               off = hoff
+       } else {
+               // Treeless_Literals_Block
+               // Reuse previous Huffman tree.
+               if r.huffmanTableBits == 0 {
+                       return 0, nil, r.makeError(off, "missing literals Huffman tree")
+               }
+       }
+
+       // Decompress compressedSize bytes of data at off using the
+       // Huffman tree.
+
+       var err error
+       if streams == 1 {
+               outbuf, err = r.readLiteralsOneStream(data, off, totalStreamsSize, regeneratedSize, outbuf)
+       } else {
+               outbuf, err = r.readLiteralsFourStreams(data, off, totalStreamsSize, regeneratedSize, outbuf)
+       }
+
+       if err != nil {
+               return 0, nil, err
+       }
+
+       return roff, outbuf, nil
+}
+
+// readLiteralsOneStream reads a single stream of compressed literals.
+func (r *Reader) readLiteralsOneStream(data block, off, compressedSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
+       // We let the reverse bit reader read earlier bytes,
+       // because the Huffman table ignores bits that it doesn't need.
+       rbr, err := r.makeReverseBitReader(data, off+compressedSize-1, off-2)
+       if err != nil {
+               return nil, err
+       }
+
+       huffTable := r.huffmanTable
+       huffBits := uint32(r.huffmanTableBits)
+       huffMask := (uint32(1) << huffBits) - 1
+
+       for i := 0; i < regeneratedSize; i++ {
+               if !rbr.fetch(uint8(huffBits)) {
+                       return nil, rbr.makeError("literals Huffman stream out of bits")
+               }
+
+               var t uint16
+               idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
+               t = huffTable[idx]
+               outbuf = append(outbuf, byte(t>>8))
+               rbr.cnt -= uint32(t & 0xff)
+       }
+
+       return outbuf, nil
+}
+
+// readLiteralsFourStreams reads four interleaved streams of
+// compressed literals.
+func (r *Reader) readLiteralsFourStreams(data block, off, totalStreamsSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
+       // Read the jump table to find out where the streams are.
+       // RFC 3.1.1.3.1.6.
+       if off+5 >= len(data) {
+               return nil, r.makeEOFError(off)
+       }
+       if totalStreamsSize < 6 {
+               return nil, r.makeError(off, "total streams size too small for jump table")
+       }
+
+       streamSize1 := binary.LittleEndian.Uint16(data[off:])
+       streamSize2 := binary.LittleEndian.Uint16(data[off+2:])
+       streamSize3 := binary.LittleEndian.Uint16(data[off+4:])
+       off += 6
+
+       tot := uint64(streamSize1) + uint64(streamSize2) + uint64(streamSize3)
+       if tot > uint64(totalStreamsSize)-6 {
+               return nil, r.makeEOFError(off)
+       }
+       streamSize4 := uint32(totalStreamsSize) - 6 - uint32(tot)
+
+       off--
+       off1 := off + int(streamSize1)
+       start1 := off + 1
+
+       off2 := off1 + int(streamSize2)
+       start2 := off1 + 1
+
+       off3 := off2 + int(streamSize3)
+       start3 := off2 + 1
+
+       off4 := off3 + int(streamSize4)
+       start4 := off3 + 1
+
+       // We let the reverse bit readers read earlier bytes,
+       // because the Huffman tables ignore bits that they don't need.
+
+       rbr1, err := r.makeReverseBitReader(data, off1, start1-2)
+       if err != nil {
+               return nil, err
+       }
+
+       rbr2, err := r.makeReverseBitReader(data, off2, start2-2)
+       if err != nil {
+               return nil, err
+       }
+
+       rbr3, err := r.makeReverseBitReader(data, off3, start3-2)
+       if err != nil {
+               return nil, err
+       }
+
+       rbr4, err := r.makeReverseBitReader(data, off4, start4-2)
+       if err != nil {
+               return nil, err
+       }
+
+       regeneratedStreamSize := (regeneratedSize + 3) / 4
+
+       out1 := len(outbuf)
+       out2 := out1 + regeneratedStreamSize
+       out3 := out2 + regeneratedStreamSize
+       out4 := out3 + regeneratedStreamSize
+
+       regeneratedStreamSize4 := regeneratedSize - regeneratedStreamSize*3
+
+       outbuf = append(outbuf, make([]byte, regeneratedSize)...)
+
+       huffTable := r.huffmanTable
+       huffBits := uint32(r.huffmanTableBits)
+       huffMask := (uint32(1) << huffBits) - 1
+
+       for i := 0; i < regeneratedStreamSize; i++ {
+               use4 := i < regeneratedStreamSize4
+
+               fetchHuff := func(rbr *reverseBitReader) (uint16, error) {
+                       if !rbr.fetch(uint8(huffBits)) {
+                               return 0, rbr.makeError("literals Huffman stream out of bits")
+                       }
+                       idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
+                       return huffTable[idx], nil
+               }
+
+               t1, err := fetchHuff(&rbr1)
+               if err != nil {
+                       return nil, err
+               }
+
+               t2, err := fetchHuff(&rbr2)
+               if err != nil {
+                       return nil, err
+               }
+
+               t3, err := fetchHuff(&rbr3)
+               if err != nil {
+                       return nil, err
+               }
+
+               if use4 {
+                       t4, err := fetchHuff(&rbr4)
+                       if err != nil {
+                               return nil, err
+                       }
+                       outbuf[out4] = byte(t4 >> 8)
+                       out4++
+                       rbr4.cnt -= uint32(t4 & 0xff)
+               }
+
+               outbuf[out1] = byte(t1 >> 8)
+               out1++
+               rbr1.cnt -= uint32(t1 & 0xff)
+
+               outbuf[out2] = byte(t2 >> 8)
+               out2++
+               rbr2.cnt -= uint32(t2 & 0xff)
+
+               outbuf[out3] = byte(t3 >> 8)
+               out3++
+               rbr3.cnt -= uint32(t3 & 0xff)
+       }
+
+       return outbuf, nil
+}
diff --git a/src/internal/zstd/xxhash.go b/src/internal/zstd/xxhash.go
new file mode 100644 (file)
index 0000000..4d579ee
--- /dev/null
@@ -0,0 +1,148 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "encoding/binary"
+       "math/bits"
+)
+
+const (
+       xxhPrime64c1 = 0x9e3779b185ebca87
+       xxhPrime64c2 = 0xc2b2ae3d27d4eb4f
+       xxhPrime64c3 = 0x165667b19e3779f9
+       xxhPrime64c4 = 0x85ebca77c2b2ae63
+       xxhPrime64c5 = 0x27d4eb2f165667c5
+)
+
+// xxhash64 is the state of a xxHash-64 checksum.
+type xxhash64 struct {
+       len uint64    // total length hashed
+       v   [4]uint64 // accumulators
+       buf [32]byte  // buffer
+       cnt int       // number of bytes in buffer
+}
+
+// reset discards the current state and prepares to compute a new hash.
+// We assume a seed of 0 since that is what zstd uses.
+func (xh *xxhash64) reset() {
+       xh.len = 0
+
+       // Separate addition for awkward constant overflow.
+       xh.v[0] = xxhPrime64c1
+       xh.v[0] += xxhPrime64c2
+
+       xh.v[1] = xxhPrime64c2
+       xh.v[2] = 0
+
+       // Separate negation for awkward constant overflow.
+       xh.v[3] = xxhPrime64c1
+       xh.v[3] = -xh.v[3]
+
+       for i := range xh.buf {
+               xh.buf[i] = 0
+       }
+       xh.cnt = 0
+}
+
+// update adds a buffer to the has.
+func (xh *xxhash64) update(b []byte) {
+       xh.len += uint64(len(b))
+
+       if xh.cnt+len(b) < len(xh.buf) {
+               copy(xh.buf[xh.cnt:], b)
+               xh.cnt += len(b)
+               return
+       }
+
+       if xh.cnt > 0 {
+               n := copy(xh.buf[xh.cnt:], b)
+               b = b[n:]
+               xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(xh.buf[:]))
+               xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(xh.buf[8:]))
+               xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(xh.buf[16:]))
+               xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(xh.buf[24:]))
+               xh.cnt = 0
+       }
+
+       for len(b) >= 32 {
+               xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(b))
+               xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(b[8:]))
+               xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(b[16:]))
+               xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(b[24:]))
+               b = b[32:]
+       }
+
+       if len(b) > 0 {
+               copy(xh.buf[:], b)
+               xh.cnt = len(b)
+       }
+}
+
+// digest returns the final hash value.
+func (xh *xxhash64) digest() uint64 {
+       var h64 uint64
+       if xh.len < 32 {
+               h64 = xh.v[2] + xxhPrime64c5
+       } else {
+               h64 = bits.RotateLeft64(xh.v[0], 1) +
+                       bits.RotateLeft64(xh.v[1], 7) +
+                       bits.RotateLeft64(xh.v[2], 12) +
+                       bits.RotateLeft64(xh.v[3], 18)
+               h64 = xh.mergeRound(h64, xh.v[0])
+               h64 = xh.mergeRound(h64, xh.v[1])
+               h64 = xh.mergeRound(h64, xh.v[2])
+               h64 = xh.mergeRound(h64, xh.v[3])
+       }
+
+       h64 += xh.len
+
+       len := xh.len
+       len &= 31
+       buf := xh.buf[:]
+       for len >= 8 {
+               k1 := xh.round(0, binary.LittleEndian.Uint64(buf))
+               buf = buf[8:]
+               h64 ^= k1
+               h64 = bits.RotateLeft64(h64, 27)*xxhPrime64c1 + xxhPrime64c4
+               len -= 8
+       }
+       if len >= 4 {
+               h64 ^= uint64(binary.LittleEndian.Uint32(buf)) * xxhPrime64c1
+               buf = buf[4:]
+               h64 = bits.RotateLeft64(h64, 23)*xxhPrime64c2 + xxhPrime64c3
+               len -= 4
+       }
+       for len > 0 {
+               h64 ^= uint64(buf[0]) * xxhPrime64c5
+               buf = buf[1:]
+               h64 = bits.RotateLeft64(h64, 11) * xxhPrime64c1
+               len--
+       }
+
+       h64 ^= h64 >> 33
+       h64 *= xxhPrime64c2
+       h64 ^= h64 >> 29
+       h64 *= xxhPrime64c3
+       h64 ^= h64 >> 32
+
+       return h64
+}
+
+// round updates a value.
+func (xh *xxhash64) round(v, n uint64) uint64 {
+       v += n * xxhPrime64c2
+       v = bits.RotateLeft64(v, 31)
+       v *= xxhPrime64c1
+       return v
+}
+
+// mergeRound updates a value in the final round.
+func (xh *xxhash64) mergeRound(v, n uint64) uint64 {
+       n = xh.round(0, n)
+       v ^= n
+       v = v*xxhPrime64c1 + xxhPrime64c4
+       return v
+}
diff --git a/src/internal/zstd/xxhash_test.go b/src/internal/zstd/xxhash_test.go
new file mode 100644 (file)
index 0000000..646cee8
--- /dev/null
@@ -0,0 +1,105 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "bytes"
+       "os"
+       "os/exec"
+       "strconv"
+       "testing"
+)
+
+var xxHashTests = []struct {
+       data string
+       hash uint64
+}{
+       {
+               "hello, world",
+               0xb33a384e6d1b1242,
+       },
+       {
+               "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$",
+               0x1032d841e824f998,
+       },
+}
+
+func TestXXHash(t *testing.T) {
+       var xh xxhash64
+       for i, test := range xxHashTests {
+               xh.reset()
+               xh.update([]byte(test.data))
+               if got := xh.digest(); got != test.hash {
+                       t.Errorf("#%d: got %#x want %#x", i, got, test.hash)
+               }
+       }
+}
+
+func TestLargeXXHash(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping expensive test in short mode")
+       }
+
+       data := bigData(t)
+       var xh xxhash64
+       xh.reset()
+       i := 0
+       for i < len(data) {
+               // Write varying amounts to test buffering.
+               c := i%4094 + 1
+               if i+c > len(data) {
+                       c = len(data) - i
+               }
+               xh.update(data[i : i+c])
+               i += c
+       }
+
+       got := xh.digest()
+       want := uint64(0xf0dd39fd7e063f82)
+       if got != want {
+               t.Errorf("got %#x want %#x", got, want)
+       }
+}
+
+func FuzzXXHash(f *testing.F) {
+       if _, err := os.Stat("/usr/bin/xxhsum"); err != nil {
+               f.Skip("skipping because /usr/bin/xxhsum does not exist")
+       }
+
+       for _, test := range xxHashTests {
+               f.Add([]byte(test.data))
+       }
+       f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
+       var buf bytes.Buffer
+       for i := 0; i < 256; i++ {
+               buf.WriteByte(byte(i))
+       }
+       f.Add(bytes.Repeat(buf.Bytes(), 64))
+       f.Add(bigData(f))
+
+       f.Fuzz(func(t *testing.T, b []byte) {
+               cmd := exec.Command("/usr/bin/xxhsum", "-H64")
+               cmd.Stdin = bytes.NewReader(b)
+               var hhsumHash bytes.Buffer
+               cmd.Stdout = &hhsumHash
+               if err := cmd.Run(); err != nil {
+                       t.Fatalf("running hhsum failed: %v", err)
+               }
+               hhHashBytes := bytes.Fields(bytes.TrimSpace(hhsumHash.Bytes()))[0]
+               hhHash, err := strconv.ParseUint(string(hhHashBytes), 16, 64)
+               if err != nil {
+                       t.Fatalf("could not parse hash %q: %v", hhHashBytes, err)
+               }
+
+               var xh xxhash64
+               xh.reset()
+               xh.update(b)
+               goHash := xh.digest()
+
+               if goHash != hhHash {
+                       t.Errorf("Go hash %#x != xxhsum hash %#x", goHash, hhHash)
+               }
+       })
+}
diff --git a/src/internal/zstd/zstd.go b/src/internal/zstd/zstd.go
new file mode 100644 (file)
index 0000000..a860789
--- /dev/null
@@ -0,0 +1,508 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package zstd provides a decompressor for zstd streams,
+// described in RFC 8878. It does not support dictionaries.
+package zstd
+
+import (
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "io"
+)
+
+// fuzzing is a fuzzer hook set to true when fuzzing.
+// This is used to reject cases where we don't match zstd.
+var fuzzing = false
+
+// Reader implements [io.Reader] to read a zstd compressed stream.
+type Reader struct {
+       // The underlying Reader.
+       r io.Reader
+
+       // Whether we have read the frame header.
+       // This is of interest when buffer is empty.
+       // If true we expect to see a new block.
+       sawFrameHeader bool
+
+       // Whether the current frame expects a checksum.
+       hasChecksum bool
+
+       // Whether we have read at least one frame.
+       readOneFrame bool
+
+       // True if the frame size is not known.
+       frameSizeUnknown bool
+
+       // The number of uncompressed bytes remaining in the current frame.
+       // If frameSizeUnknown is true, this is not valid.
+       remainingFrameSize uint64
+
+       // The number of bytes read from r up to the start of the current
+       // block, for error reporting.
+       blockOffset int64
+
+       // Buffered decompressed data.
+       buffer []byte
+       // Current read offset in buffer.
+       off int
+
+       // The current repeated offsets.
+       repeatedOffset1 uint32
+       repeatedOffset2 uint32
+       repeatedOffset3 uint32
+
+       // The current Huffman tree used for compressing literals.
+       huffmanTable     []uint16
+       huffmanTableBits int
+
+       // The window for back references.
+       windowSize int    // maximum required window size
+       window     []byte // window data
+
+       // A buffer available to hold a compressed block.
+       compressedBuf []byte
+
+       // A buffer for literals.
+       literals []byte
+
+       // Sequence decode FSE tables.
+       seqTables    [3][]fseBaselineEntry
+       seqTableBits [3]uint8
+
+       // Buffers for sequence decode FSE tables.
+       seqTableBuffers [3][]fseBaselineEntry
+
+       // Scratch space used for small reads, to avoid allocation.
+       scratch [16]byte
+
+       // A scratch table for reading an FSE. Only temporarily valid.
+       fseScratch []fseEntry
+
+       // For checksum computation.
+       checksum xxhash64
+}
+
+// NewReader creates a new Reader that decompresses data from the given reader.
+func NewReader(input io.Reader) *Reader {
+       r := new(Reader)
+       r.Reset(input)
+       return r
+}
+
+// Reset discards the current state and starts reading a new stream from r.
+// This permits reusing a Reader rather than allocating a new one.
+func (r *Reader) Reset(input io.Reader) {
+       r.r = input
+
+       // Several fields are preserved to avoid allocation.
+       // Others are always set before they are used.
+       r.sawFrameHeader = false
+       r.hasChecksum = false
+       r.readOneFrame = false
+       r.frameSizeUnknown = false
+       r.remainingFrameSize = 0
+       r.blockOffset = 0
+       // buffer
+       r.off = 0
+       // repeatedOffset1
+       // repeatedOffset2
+       // repeatedOffset3
+       // huffmanTable
+       // huffmanTableBits
+       // windowSize
+       // window
+       // compressedBuf
+       // literals
+       // seqTables
+       // seqTableBits
+       // seqTableBuffers
+       // scratch
+       // fseScratch
+}
+
+// Read implements [io.Reader].
+func (r *Reader) Read(p []byte) (int, error) {
+       if err := r.refillIfNeeded(); err != nil {
+               return 0, err
+       }
+       n := copy(p, r.buffer[r.off:])
+       r.off += n
+       return n, nil
+}
+
+// ReadByte implements [io.ByteReader].
+func (r *Reader) ReadByte() (byte, error) {
+       if err := r.refillIfNeeded(); err != nil {
+               return 0, err
+       }
+       ret := r.buffer[r.off]
+       r.off++
+       return ret, nil
+}
+
+// refillIfNeeded reads the next block if necessary.
+func (r *Reader) refillIfNeeded() error {
+       for r.off >= len(r.buffer) {
+               if err := r.refill(); err != nil {
+                       return err
+               }
+               r.off = 0
+       }
+       return nil
+}
+
+// refill reads and decompresses the next block.
+func (r *Reader) refill() error {
+       if !r.sawFrameHeader {
+               if err := r.readFrameHeader(); err != nil {
+                       return err
+               }
+       }
+       return r.readBlock()
+}
+
+// readFrameHeader reads the frame header and prepares to read a block.
+func (r *Reader) readFrameHeader() error {
+retry:
+       relativeOffset := 0
+
+       // Read magic number. RFC 3.1.1.
+       if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
+               // We require that the stream contain at least one frame.
+               if err == io.EOF && !r.readOneFrame {
+                       err = io.ErrUnexpectedEOF
+               }
+               return r.wrapError(relativeOffset, err)
+       }
+
+       if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
+               if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
+                       // This is a skippable frame.
+                       r.blockOffset += int64(relativeOffset) + 4
+                       if err := r.skipFrame(); err != nil {
+                               return err
+                       }
+                       goto retry
+               }
+
+               return r.makeError(relativeOffset, "invalid magic number")
+       }
+
+       relativeOffset += 4
+
+       // Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
+       if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
+               return r.wrapNonEOFError(relativeOffset, err)
+       }
+       descriptor := r.scratch[0]
+
+       singleSegment := descriptor&(1<<5) != 0
+
+       fcsFieldSize := 1 << (descriptor >> 6)
+       if fcsFieldSize == 1 && !singleSegment {
+               fcsFieldSize = 0
+       }
+
+       var windowDescriptorSize int
+       if singleSegment {
+               windowDescriptorSize = 0
+       } else {
+               windowDescriptorSize = 1
+       }
+
+       if descriptor&(1<<3) != 0 {
+               return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
+       }
+
+       r.hasChecksum = descriptor&(1<<2) != 0
+       if r.hasChecksum {
+               r.checksum.reset()
+       }
+
+       if descriptor&3 != 0 {
+               return r.makeError(relativeOffset, "dictionaries are not supported")
+       }
+
+       relativeOffset++
+
+       headerSize := windowDescriptorSize + fcsFieldSize
+
+       if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
+               return r.wrapNonEOFError(relativeOffset, err)
+       }
+
+       // Figure out the maximum amount of data we need to retain
+       // for backreferences.
+
+       if singleSegment {
+               // No window required, as all the data is in a single buffer.
+               r.windowSize = 0
+       } else {
+               // Window descriptor. RFC 3.1.1.1.2.
+               windowDescriptor := r.scratch[0]
+               exponent := uint64(windowDescriptor >> 3)
+               mantissa := uint64(windowDescriptor & 7)
+               windowLog := exponent + 10
+               windowBase := uint64(1) << windowLog
+               windowAdd := (windowBase / 8) * mantissa
+               windowSize := windowBase + windowAdd
+
+               // Default zstd sets limits on the window size.
+               if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
+                       return r.makeError(relativeOffset, "windowSize too large")
+               }
+
+               // RFC 8878 permits us to set an 8M max on window size.
+               if windowSize > 8<<20 {
+                       windowSize = 8 << 20
+               }
+
+               r.windowSize = int(windowSize)
+       }
+
+       // Frame_Content_Size. RFC 3.1.1.4.
+       r.frameSizeUnknown = false
+       r.remainingFrameSize = 0
+       fb := r.scratch[windowDescriptorSize:]
+       switch fcsFieldSize {
+       case 0:
+               r.frameSizeUnknown = true
+       case 1:
+               r.remainingFrameSize = uint64(fb[0])
+       case 2:
+               r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
+       case 4:
+               r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
+       case 8:
+               r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
+       default:
+               panic("unreachable")
+       }
+
+       relativeOffset += headerSize
+
+       r.sawFrameHeader = true
+       r.readOneFrame = true
+       r.blockOffset += int64(relativeOffset)
+
+       // Prepare to read blocks from the frame.
+       r.repeatedOffset1 = 1
+       r.repeatedOffset2 = 4
+       r.repeatedOffset3 = 8
+       r.huffmanTableBits = 0
+       r.window = r.window[:0]
+       r.seqTables[0] = nil
+       r.seqTables[1] = nil
+       r.seqTables[2] = nil
+
+       return nil
+}
+
+// skipFrame skips a skippable frame. RFC 3.1.2.
+func (r *Reader) skipFrame() error {
+       relativeOffset := 0
+
+       if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
+               return r.wrapNonEOFError(relativeOffset, err)
+       }
+
+       relativeOffset += 4
+
+       size := binary.LittleEndian.Uint32(r.scratch[:4])
+
+       if seeker, ok := r.r.(io.Seeker); ok {
+               if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil {
+                       return err
+               }
+               r.blockOffset += int64(relativeOffset) + int64(size)
+               return nil
+       }
+
+       var skip []byte
+       const chunk = 1 << 20 // 1M
+       for size >= chunk {
+               if len(skip) == 0 {
+                       skip = make([]byte, chunk)
+               }
+               if _, err := io.ReadFull(r.r, skip); err != nil {
+                       return r.wrapNonEOFError(relativeOffset, err)
+               }
+               relativeOffset += chunk
+               size -= chunk
+       }
+       if size > 0 {
+               if len(skip) == 0 {
+                       skip = make([]byte, size)
+               }
+               if _, err := io.ReadFull(r.r, skip); err != nil {
+                       return r.wrapNonEOFError(relativeOffset, err)
+               }
+               relativeOffset += int(size)
+       }
+
+       r.blockOffset += int64(relativeOffset)
+
+       return nil
+}
+
+// readBlock reads the next block from a frame.
+func (r *Reader) readBlock() error {
+       relativeOffset := 0
+
+       // Read Block_Header. RFC 3.1.1.2.
+       if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
+               return r.wrapNonEOFError(relativeOffset, err)
+       }
+
+       relativeOffset += 3
+
+       header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
+
+       lastBlock := header&1 != 0
+       blockType := (header >> 1) & 3
+       blockSize := int(header >> 3)
+
+       // Maximum block size is smaller of window size and 128K.
+       // We don't record the window size for a single segment frame,
+       // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
+       if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) {
+               return r.makeError(relativeOffset, "block size too large")
+       }
+
+       // Handle different block types. RFC 3.1.1.2.2.
+       switch blockType {
+       case 0:
+               r.setBufferSize(blockSize)
+               if _, err := io.ReadFull(r.r, r.buffer); err != nil {
+                       return r.wrapNonEOFError(relativeOffset, err)
+               }
+               relativeOffset += blockSize
+               r.blockOffset += int64(relativeOffset)
+       case 1:
+               r.setBufferSize(blockSize)
+               if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
+                       return r.wrapNonEOFError(relativeOffset, err)
+               }
+               relativeOffset++
+               v := r.scratch[0]
+               for i := range r.buffer {
+                       r.buffer[i] = v
+               }
+               r.blockOffset += int64(relativeOffset)
+       case 2:
+               r.blockOffset += int64(relativeOffset)
+               if err := r.compressedBlock(blockSize); err != nil {
+                       return err
+               }
+               r.blockOffset += int64(blockSize)
+       case 3:
+               return r.makeError(relativeOffset, "invalid block type")
+       }
+
+       if !r.frameSizeUnknown {
+               if uint64(len(r.buffer)) > r.remainingFrameSize {
+                       return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
+               }
+               r.remainingFrameSize -= uint64(len(r.buffer))
+       }
+
+       if r.hasChecksum {
+               r.checksum.update(r.buffer)
+       }
+
+       if !lastBlock {
+               r.saveWindow(r.buffer)
+       } else {
+               if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
+                       return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
+               }
+               // Check for checksum at end of frame. RFC 3.1.1.
+               if r.hasChecksum {
+                       if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
+                               return r.wrapNonEOFError(0, err)
+                       }
+
+                       inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
+                       dataChecksum := uint32(r.checksum.digest())
+                       if inputChecksum != dataChecksum {
+                               return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
+                       }
+
+                       r.blockOffset += 4
+               }
+               r.sawFrameHeader = false
+       }
+
+       return nil
+}
+
+// setBufferSize sets the decompressed buffer size.
+// When this is called the buffer is empty.
+func (r *Reader) setBufferSize(size int) {
+       if cap(r.buffer) < size {
+               need := size - cap(r.buffer)
+               r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
+       }
+       r.buffer = r.buffer[:size]
+}
+
+// saveWindow saves bytes in the backreference window.
+// TODO: use a circular buffer for less data movement.
+func (r *Reader) saveWindow(buf []byte) {
+       if r.windowSize == 0 {
+               return
+       }
+
+       if len(buf) >= r.windowSize {
+               from := len(buf) - r.windowSize
+               r.window = append(r.window[:0], buf[from:]...)
+               return
+       }
+
+       keep := r.windowSize - len(buf) // must be positive
+       if keep < len(r.window) {
+               remove := len(r.window) - keep
+               copy(r.window[:], r.window[remove:])
+       }
+
+       r.window = append(r.window, buf...)
+}
+
+// zstdError is an error while decompressing.
+type zstdError struct {
+       offset int64
+       err    error
+}
+
+func (ze *zstdError) Error() string {
+       return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
+}
+
+func (ze *zstdError) Unwrap() error {
+       return ze.err
+}
+
+func (r *Reader) makeEOFError(off int) error {
+       return r.wrapError(off, io.ErrUnexpectedEOF)
+}
+
+func (r *Reader) wrapNonEOFError(off int, err error) error {
+       if err == io.EOF {
+               err = io.ErrUnexpectedEOF
+       }
+       return r.wrapError(off, err)
+}
+
+func (r *Reader) makeError(off int, msg string) error {
+       return r.wrapError(off, errors.New(msg))
+}
+
+func (r *Reader) wrapError(off int, err error) error {
+       if err == io.EOF {
+               return err
+       }
+       return &zstdError{r.blockOffset + int64(off), err}
+}
diff --git a/src/internal/zstd/zstd_test.go b/src/internal/zstd/zstd_test.go
new file mode 100644 (file)
index 0000000..bc75e0f
--- /dev/null
@@ -0,0 +1,249 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+       "bytes"
+       "fmt"
+       "internal/race"
+       "internal/testenv"
+       "io"
+       "os"
+       "os/exec"
+       "strings"
+       "sync"
+       "testing"
+)
+
+// tests holds some simple test cases, including some found by fuzzing.
+var tests = []struct {
+       name, uncompressed, compressed string
+}{
+       {
+               "hello",
+               "hello, world\n",
+               "\x28\xb5\x2f\xfd\x24\x0d\x69\x00\x00\x68\x65\x6c\x6c\x6f\x2c\x20\x77\x6f\x72\x6c\x64\x0a\x4c\x1f\xf9\xf1",
+       },
+       {
+               // a small compressed .debug_ranges section.
+               "ranges",
+               "\xcc\x11\x00\x00\x00\x00\x00\x00\xd5\x13\x00\x00\x00\x00\x00\x00" +
+                       "\x1c\x14\x00\x00\x00\x00\x00\x00\x72\x14\x00\x00\x00\x00\x00\x00" +
+                       "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
+                       "\x0c\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
+                       "\x29\x14\x00\x00\x00\x00\x00\x00\x4e\x14\x00\x00\x00\x00\x00\x00" +
+                       "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
+                       "\x67\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
+                       "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x5f\x0b\x00\x00\x00\x00\x00\x00\x6c\x0b\x00\x00\x00\x00\x00\x00" +
+                       "\x7d\x0b\x00\x00\x00\x00\x00\x00\x7e\x0c\x00\x00\x00\x00\x00\x00" +
+                       "\x38\x0f\x00\x00\x00\x00\x00\x00\x5c\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x83\x0c\x00\x00\x00\x00\x00\x00\xfa\x0c\x00\x00\x00\x00\x00\x00" +
+                       "\xfd\x0d\x00\x00\x00\x00\x00\x00\xef\x0e\x00\x00\x00\x00\x00\x00" +
+                       "\x14\x0f\x00\x00\x00\x00\x00\x00\x38\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\xfd\x0d\x00\x00\x00\x00\x00\x00\xd8\x0e\x00\x00\x00\x00\x00\x00" +
+                       "\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\xfa\x0c\x00\x00\x00\x00\x00\x00\xea\x0d\x00\x00\x00\x00\x00\x00" +
+                       "\xef\x0e\x00\x00\x00\x00\x00\x00\x14\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\x5c\x0f\x00\x00\x00\x00\x00\x00\x9f\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\xac\x0f\x00\x00\x00\x00\x00\x00\xdb\x0f\x00\x00\x00\x00\x00\x00" +
+                       "\xff\x0f\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x60\x11\x00\x00\x00\x00\x00\x00\xd1\x16\x00\x00\x00\x00\x00\x00" +
+                       "\x40\x0b\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x7a\x00\x00\x00\x00\x00\x00\x00\xb6\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x7a\x00\x00\x00\x00\x00\x00\x00\xa9\x00\x00\x00\x00\x00\x00\x00" +
+                       "\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
+                       "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+
+               "\x28\xb5\x2f\xfd\x64\xa0\x01\x2d\x05\x00\xc4\x04\xcc\x11\x00\xd5" +
+                       "\x13\x00\x1c\x14\x00\x72\x9d\xd5\xfb\x12\x00\x09\x0c\x13\xcb\x13" +
+                       "\x29\x4e\x67\x5f\x0b\x6c\x0b\x7d\x0b\x7e\x0c\x38\x0f\x5c\x0f\x83" +
+                       "\x0c\xfa\x0c\xfd\x0d\xef\x0e\x14\x38\x9f\x0f\xac\x0f\xdb\x0f\xff" +
+                       "\x0f\xd8\x9f\xac\xdb\xff\xea\x5c\x2c\x10\x60\xd1\x16\x40\x0b\x7a" +
+                       "\x00\xb6\x00\x9f\x01\xa7\x01\xa9\x36\x20\xa0\x83\x14\x34\x63\x4a" +
+                       "\x21\x70\x8c\x07\x46\x03\x4e\x10\x62\x3c\x06\x4e\xc8\x8c\xb0\x32" +
+                       "\x2a\x59\xad\xb2\xf1\x02\x82\x7c\x33\xcb\x92\x6f\x32\x4f\x9b\xb0" +
+                       "\xa2\x30\xf0\xc0\x06\x1e\x98\x99\x2c\x06\x1e\xd8\xc0\x03\x56\xd8" +
+                       "\xc0\x03\x0f\x6c\xe0\x01\xf1\xf0\xee\x9a\xc6\xc8\x97\x99\xd1\x6c" +
+                       "\xb4\x21\x45\x3b\x10\xe4\x7b\x99\x4d\x8a\x36\x64\x5c\x77\x08\x02" +
+                       "\xcb\xe0\xce",
+       },
+       {
+               "fuzz1",
+               "0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000",
+               "(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G",
+       },
+}
+
+func TestSamples(t *testing.T) {
+       for _, test := range tests {
+               test := test
+               t.Run(test.name, func(t *testing.T) {
+                       r := NewReader(strings.NewReader(test.compressed))
+                       got, err := io.ReadAll(r)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       gotstr := string(got)
+                       if gotstr != test.uncompressed {
+                               t.Errorf("got %q want %q", gotstr, test.uncompressed)
+                       }
+               })
+       }
+}
+
+var (
+       bigDataOnce  sync.Once
+       bigDataBytes []byte
+       bigDataErr   error
+)
+
+// bigData returns the contents of our large test file.
+func bigData(t testing.TB) []byte {
+       bigDataOnce.Do(func() {
+               bigDataBytes, bigDataErr = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
+       })
+       if bigDataErr != nil {
+               t.Fatal(bigDataErr)
+       }
+       return bigDataBytes
+}
+
+var (
+       zstdBigOnce  sync.Once
+       zstdBigBytes []byte
+       zstdBigSkip  bool
+       zstdBigErr   error
+)
+
+// zstdBigData returns the compressed contents of our large test file.
+// This will only run on Unix systems with zstd installed.
+// That's OK as the package is GOOS-independent.
+func zstdBigData(t testing.TB) []byte {
+       input := bigData(t)
+
+       zstdBigOnce.Do(func() {
+               if _, err := os.Stat("/usr/bin/zstd"); err != nil {
+                       zstdBigSkip = true
+                       return
+               }
+
+               cmd := exec.Command("/usr/bin/zstd", "-z")
+               cmd.Stdin = bytes.NewReader(input)
+               var compressed bytes.Buffer
+               cmd.Stdout = &compressed
+               cmd.Stderr = os.Stderr
+               if err := cmd.Run(); err != nil {
+                       zstdBigErr = fmt.Errorf("running zstd failed: %v", err)
+                       return
+               }
+
+               zstdBigBytes = compressed.Bytes()
+       })
+       if zstdBigSkip {
+               t.Skip("skipping because /usr/bin/zstd does not exist")
+       }
+       if zstdBigErr != nil {
+               t.Fatal(zstdBigErr)
+       }
+       return zstdBigBytes
+}
+
+// Test decompressing a large file. We don't have a compressor,
+// so this test only runs on systems with zstd installed.
+func TestLarge(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping expensive test in short mode")
+       }
+
+       data := bigData(t)
+       compressed := zstdBigData(t)
+
+       t.Logf("/usr/bin/zstd compressed %d bytes to %d", len(data), len(compressed))
+
+       r := NewReader(bytes.NewReader(compressed))
+       got, err := io.ReadAll(r)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if !bytes.Equal(got, data) {
+               showDiffs(t, got, data)
+       }
+}
+
+// showDiffs reports the first few differences in two []byte.
+func showDiffs(t *testing.T, got, want []byte) {
+       t.Error("data mismatch")
+       if len(got) != len(want) {
+               t.Errorf("got data length %d, want %d", len(got), len(want))
+       }
+       diffs := 0
+       for i, b := range got {
+               if i >= len(want) {
+                       break
+               }
+               if b != want[i] {
+                       diffs++
+                       if diffs > 20 {
+                               break
+                       }
+                       t.Logf("%d: %#x != %#x", i, b, want[i])
+               }
+       }
+}
+
+func TestAlloc(t *testing.T) {
+       testenv.SkipIfOptimizationOff(t)
+       if race.Enabled {
+               t.Skip("skipping allocation test under race detector")
+       }
+
+       compressed := zstdBigData(t)
+       input := bytes.NewReader(compressed)
+       r := NewReader(input)
+       c := testing.AllocsPerRun(10, func() {
+               input.Reset(compressed)
+               r.Reset(input)
+               io.Copy(io.Discard, r)
+       })
+       if c != 0 {
+               t.Errorf("got %v allocs, want 0", c)
+       }
+}
+
+func BenchmarkLarge(b *testing.B) {
+       b.StopTimer()
+       b.ReportAllocs()
+
+       compressed := zstdBigData(b)
+
+       b.SetBytes(int64(len(compressed)))
+
+       input := bytes.NewReader(compressed)
+       r := NewReader(input)
+
+       b.StartTimer()
+       for i := 0; i < b.N; i++ {
+               input.Reset(compressed)
+               r.Reset(input)
+               io.Copy(io.Discard, r)
+       }
+}