]> Cypherpunks repositories - gostls13.git/commitdiff
compress/zlib: add FDICT flag in Reader/Writer
authorRoss Light <rlight2@gmail.com>
Fri, 15 Apr 2011 22:32:03 +0000 (15:32 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2011 22:32:03 +0000 (15:32 -0700)
R=bradfitzgo, rsc, bradfitzwork
CC=golang-dev
https://golang.org/cl/4406046

src/pkg/compress/zlib/reader.go
src/pkg/compress/zlib/reader_test.go
src/pkg/compress/zlib/writer.go
src/pkg/compress/zlib/writer_test.go

index 721f6ec55958e417e8e367a2e5e023121d5070b3..b191c27330616b114b39f6c2ed4c92bf60ea5c6d 100644 (file)
@@ -36,7 +36,7 @@ const zlibDeflate = 8
 
 var ChecksumError os.Error = os.ErrorString("zlib checksum error")
 var HeaderError os.Error = os.ErrorString("invalid zlib header")
-var UnsupportedError os.Error = os.ErrorString("unsupported zlib format")
+var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary")
 
 type reader struct {
        r            flate.Reader
@@ -50,6 +50,12 @@ type reader struct {
 // The implementation buffers input and may read more data than necessary from r.
 // It is the caller's responsibility to call Close on the ReadCloser when done.
 func NewReader(r io.Reader) (io.ReadCloser, os.Error) {
+       return NewReaderDict(r, nil)
+}
+
+// NewReaderDict is like NewReader but uses a preset dictionary.
+// NewReaderDict ignores the dictionary if the compressed data does not refer to it.
+func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, os.Error) {
        z := new(reader)
        if fr, ok := r.(flate.Reader); ok {
                z.r = fr
@@ -65,11 +71,19 @@ func NewReader(r io.Reader) (io.ReadCloser, os.Error) {
                return nil, HeaderError
        }
        if z.scratch[1]&0x20 != 0 {
-               // BUG(nigeltao): The zlib package does not implement the FDICT flag.
-               return nil, UnsupportedError
+               _, err = io.ReadFull(z.r, z.scratch[0:4])
+               if err != nil {
+                       return nil, err
+               }
+               checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3])
+               if checksum != adler32.Checksum(dict) {
+                       return nil, DictionaryError
+               }
+               z.decompressor = flate.NewReaderDict(z.r, dict)
+       } else {
+               z.decompressor = flate.NewReader(z.r)
        }
        z.digest = adler32.New()
-       z.decompressor = flate.NewReader(z.r)
        return z, nil
 }
 
index eaefc3a361a36bc544358406824b775400bd56c6..195db446c9fff0ae708cc55989ba1136c249eb79 100644 (file)
@@ -15,6 +15,7 @@ type zlibTest struct {
        desc       string
        raw        string
        compressed []byte
+       dict       []byte
        err        os.Error
 }
 
@@ -27,6 +28,7 @@ var zlibTests = []zlibTest{
                "",
                []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01},
                nil,
+               nil,
        },
        {
                "goodbye",
@@ -37,23 +39,27 @@ var zlibTests = []zlibTest{
                        0x01, 0x00, 0x28, 0xa5, 0x05, 0x5e,
                },
                nil,
+               nil,
        },
        {
                "bad header",
                "",
                []byte{0x78, 0x9f, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01},
+               nil,
                HeaderError,
        },
        {
                "bad checksum",
                "",
                []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0xff},
+               nil,
                ChecksumError,
        },
        {
                "not enough data",
                "",
                []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00},
+               nil,
                io.ErrUnexpectedEOF,
        },
        {
@@ -64,6 +70,33 @@ var zlibTests = []zlibTest{
                        0x78, 0x9c, 0xff,
                },
                nil,
+               nil,
+       },
+       {
+               "dictionary",
+               "Hello, World!\n",
+               []byte{
+                       0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00,
+                       0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24,
+                       0x12, 0x04, 0x74,
+               },
+               []byte{
+                       0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x0a,
+               },
+               nil,
+       },
+       {
+               "wrong dictionary",
+               "",
+               []byte{
+                       0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00,
+                       0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24,
+                       0x12, 0x04, 0x74,
+               },
+               []byte{
+                       0x48, 0x65, 0x6c, 0x6c,
+               },
+               DictionaryError,
        },
 }
 
@@ -71,7 +104,7 @@ func TestDecompressor(t *testing.T) {
        b := new(bytes.Buffer)
        for _, tt := range zlibTests {
                in := bytes.NewBuffer(tt.compressed)
-               zlib, err := NewReader(in)
+               zlib, err := NewReaderDict(in, tt.dict)
                if err != nil {
                        if err != tt.err {
                                t.Errorf("%s: NewReader: %s", tt.desc, err)
index 031586cd2bbea56e604980ae8b79e701a9e8b09b..f1f9b285375685bbd01382624ae9571e2ad43d6a 100644 (file)
@@ -21,56 +21,80 @@ const (
        DefaultCompression = flate.DefaultCompression
 )
 
-type writer struct {
+// A Writer takes data written to it and writes the compressed
+// form of that data to an underlying writer (see NewWriter).
+type Writer struct {
        w          io.Writer
-       compressor io.WriteCloser
+       compressor *flate.Writer
        digest     hash.Hash32
        err        os.Error
        scratch    [4]byte
 }
 
 // NewWriter calls NewWriterLevel with the default compression level.
-func NewWriter(w io.Writer) (io.WriteCloser, os.Error) {
+func NewWriter(w io.Writer) (*Writer, os.Error) {
        return NewWriterLevel(w, DefaultCompression)
 }
 
-// NewWriterLevel creates a new io.WriteCloser that satisfies writes by compressing data written to w.
+// NewWriterLevel calls NewWriterDict with no dictionary.
+func NewWriterLevel(w io.Writer, level int) (*Writer, os.Error) {
+       return NewWriterDict(w, level, nil)
+}
+
+// NewWriterDict creates a new io.WriteCloser that satisfies writes by compressing data written to w.
 // It is the caller's responsibility to call Close on the WriteCloser when done.
 // level is the compression level, which can be DefaultCompression, NoCompression,
 // or any integer value between BestSpeed and BestCompression (inclusive).
-func NewWriterLevel(w io.Writer, level int) (io.WriteCloser, os.Error) {
-       z := new(writer)
+// dict is the preset dictionary to compress with, or nil to use no dictionary.
+func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) {
+       z := new(Writer)
        // ZLIB has a two-byte header (as documented in RFC 1950).
        // The first four bits is the CINFO (compression info), which is 7 for the default deflate window size.
        // The next four bits is the CM (compression method), which is 8 for deflate.
        z.scratch[0] = 0x78
        // The next two bits is the FLEVEL (compression level). The four values are:
        // 0=fastest, 1=fast, 2=default, 3=best.
-       // The next bit, FDICT, is unused, in this implementation.
+       // The next bit, FDICT, is set if a dictionary is given.
        // The final five FCHECK bits form a mod-31 checksum.
        switch level {
        case 0, 1:
-               z.scratch[1] = 0x01
+               z.scratch[1] = 0 << 6
        case 2, 3, 4, 5:
-               z.scratch[1] = 0x5e
+               z.scratch[1] = 1 << 6
        case 6, -1:
-               z.scratch[1] = 0x9c
+               z.scratch[1] = 2 << 6
        case 7, 8, 9:
-               z.scratch[1] = 0xda
+               z.scratch[1] = 3 << 6
        default:
                return nil, os.NewError("level out of range")
        }
+       if dict != nil {
+               z.scratch[1] |= 1 << 5
+       }
+       z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31)
        _, err := w.Write(z.scratch[0:2])
        if err != nil {
                return nil, err
        }
+       if dict != nil {
+               // The next four bytes are the Adler-32 checksum of the dictionary.
+               checksum := adler32.Checksum(dict)
+               z.scratch[0] = uint8(checksum >> 24)
+               z.scratch[1] = uint8(checksum >> 16)
+               z.scratch[2] = uint8(checksum >> 8)
+               z.scratch[3] = uint8(checksum >> 0)
+               _, err = w.Write(z.scratch[0:4])
+               if err != nil {
+                       return nil, err
+               }
+       }
        z.w = w
        z.compressor = flate.NewWriter(w, level)
        z.digest = adler32.New()
        return z, nil
 }
 
-func (z *writer) Write(p []byte) (n int, err os.Error) {
+func (z *Writer) Write(p []byte) (n int, err os.Error) {
        if z.err != nil {
                return 0, z.err
        }
@@ -86,8 +110,17 @@ func (z *writer) Write(p []byte) (n int, err os.Error) {
        return
 }
 
+// Flush flushes the underlying compressor.
+func (z *Writer) Flush() os.Error {
+       if z.err != nil {
+               return z.err
+       }
+       z.err = z.compressor.Flush()
+       return z.err
+}
+
 // Calling Close does not close the wrapped io.Writer originally passed to NewWriter.
-func (z *writer) Close() os.Error {
+func (z *Writer) Close() os.Error {
        if z.err != nil {
                return z.err
        }
index 7eb1cd494978d5e41cc997993980c114e2163aa2..f94f28470065effa9821e66e586f77e0a4736c03 100644 (file)
@@ -16,13 +16,19 @@ var filenames = []string{
        "../testdata/pi.txt",
 }
 
-// Tests that compressing and then decompressing the given file at the given compression level
+// Tests that compressing and then decompressing the given file at the given compression level and dictionary
 // yields equivalent bytes to the original file.
-func testFileLevel(t *testing.T, fn string, level int) {
+func testFileLevelDict(t *testing.T, fn string, level int, d string) {
+       // Read dictionary, if given.
+       var dict []byte
+       if d != "" {
+               dict = []byte(d)
+       }
+
        // Read the file, as golden output.
        golden, err := os.Open(fn)
        if err != nil {
-               t.Errorf("%s (level=%d): %v", fn, level, err)
+               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
                return
        }
        defer golden.Close()
@@ -30,7 +36,7 @@ func testFileLevel(t *testing.T, fn string, level int) {
        // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
        raw, err := os.Open(fn)
        if err != nil {
-               t.Errorf("%s (level=%d): %v", fn, level, err)
+               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
                return
        }
        piper, pipew := io.Pipe()
@@ -38,9 +44,9 @@ func testFileLevel(t *testing.T, fn string, level int) {
        go func() {
                defer raw.Close()
                defer pipew.Close()
-               zlibw, err := NewWriterLevel(pipew, level)
+               zlibw, err := NewWriterDict(pipew, level, dict)
                if err != nil {
-                       t.Errorf("%s (level=%d): %v", fn, level, err)
+                       t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
                        return
                }
                defer zlibw.Close()
@@ -48,7 +54,7 @@ func testFileLevel(t *testing.T, fn string, level int) {
                for {
                        n, err0 := raw.Read(b[0:])
                        if err0 != nil && err0 != os.EOF {
-                               t.Errorf("%s (level=%d): %v", fn, level, err0)
+                               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
                                return
                        }
                        _, err1 := zlibw.Write(b[0:n])
@@ -57,7 +63,7 @@ func testFileLevel(t *testing.T, fn string, level int) {
                                return
                        }
                        if err1 != nil {
-                               t.Errorf("%s (level=%d): %v", fn, level, err1)
+                               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
                                return
                        }
                        if err0 == os.EOF {
@@ -65,9 +71,9 @@ func testFileLevel(t *testing.T, fn string, level int) {
                        }
                }
        }()
-       zlibr, err := NewReader(piper)
+       zlibr, err := NewReaderDict(piper, dict)
        if err != nil {
-               t.Errorf("%s (level=%d): %v", fn, level, err)
+               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
                return
        }
        defer zlibr.Close()
@@ -76,20 +82,20 @@ func testFileLevel(t *testing.T, fn string, level int) {
        b0, err0 := ioutil.ReadAll(golden)
        b1, err1 := ioutil.ReadAll(zlibr)
        if err0 != nil {
-               t.Errorf("%s (level=%d): %v", fn, level, err0)
+               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
                return
        }
        if err1 != nil {
-               t.Errorf("%s (level=%d): %v", fn, level, err1)
+               t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
                return
        }
        if len(b0) != len(b1) {
-               t.Errorf("%s (level=%d): length mismatch %d versus %d", fn, level, len(b0), len(b1))
+               t.Errorf("%s (level=%d, dict=%q): length mismatch %d versus %d", fn, level, d, len(b0), len(b1))
                return
        }
        for i := 0; i < len(b0); i++ {
                if b0[i] != b1[i] {
-                       t.Errorf("%s (level=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, i, b0[i], b1[i])
+                       t.Errorf("%s (level=%d, dict=%q): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, d, i, b0[i], b1[i])
                        return
                }
        }
@@ -97,10 +103,21 @@ func testFileLevel(t *testing.T, fn string, level int) {
 
 func TestWriter(t *testing.T) {
        for _, fn := range filenames {
-               testFileLevel(t, fn, DefaultCompression)
-               testFileLevel(t, fn, NoCompression)
+               testFileLevelDict(t, fn, DefaultCompression, "")
+               testFileLevelDict(t, fn, NoCompression, "")
+               for level := BestSpeed; level <= BestCompression; level++ {
+                       testFileLevelDict(t, fn, level, "")
+               }
+       }
+}
+
+func TestWriterDict(t *testing.T) {
+       const dictionary = "0123456789."
+       for _, fn := range filenames {
+               testFileLevelDict(t, fn, DefaultCompression, dictionary)
+               testFileLevelDict(t, fn, NoCompression, dictionary)
                for level := BestSpeed; level <= BestCompression; level++ {
-                       testFileLevel(t, fn, level)
+                       testFileLevelDict(t, fn, level, dictionary)
                }
        }
 }