]> Cypherpunks repositories - gostls13.git/commitdiff
compress/flate: implement Reset method on Writer.
authorRémy Oudompheng <oudomphe@phare.normalesup.org>
Thu, 29 Aug 2013 19:09:23 +0000 (21:09 +0200)
committerRémy Oudompheng <oudomphe@phare.normalesup.org>
Thu, 29 Aug 2013 19:09:23 +0000 (21:09 +0200)
Fixes #6138.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12953048

src/pkg/compress/flate/deflate.go
src/pkg/compress/flate/deflate_test.go
src/pkg/compress/flate/huffman_bit_writer.go

index d357fe361a55ed0bad4f1637e7c014ffc32e2146..b3e079150a1521c3893ebcb161d1d07a346f061b 100644 (file)
@@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
        return nil
 }
 
+var zeroes [32]int
+var bzeroes [256]byte
+
+func (d *compressor) reset(w io.Writer) {
+       d.w.reset(w)
+       d.sync = false
+       d.err = nil
+       switch d.compressionLevel.chain {
+       case 0:
+               // level was NoCompression.
+               for i := range d.window {
+                       d.window[i] = 0
+               }
+               d.windowEnd = 0
+       default:
+               d.chainHead = -1
+               for s := d.hashHead; len(s) > 0; {
+                       n := copy(s, zeroes[:])
+                       s = s[n:]
+               }
+               for s := d.hashPrev; len(s) > 0; s = s[len(zeroes):] {
+                       copy(s, zeroes[:])
+               }
+               d.hashOffset = 1
+
+               d.index, d.windowEnd = 0, 0
+               for s := d.window; len(s) > 0; {
+                       n := copy(s, bzeroes[:])
+                       s = s[n:]
+               }
+               d.blockStart, d.byteAvailable = 0, false
+
+               d.tokens = d.tokens[:maxFlateBlockTokens+1]
+               for i := 0; i <= maxFlateBlockTokens; i++ {
+                       d.tokens[i] = 0
+               }
+               d.tokens = d.tokens[:0]
+               d.length = minMatchLength - 1
+               d.offset = 0
+               d.hash = 0
+               d.maxInsertIndex = 0
+       }
+}
+
 func (d *compressor) close() error {
        d.sync = true
        d.step(d)
@@ -439,7 +483,6 @@ func (d *compressor) close() error {
 // If level is in the range [-1, 9] then the error returned will be nil.
 // Otherwise the error returned will be non-nil.
 func NewWriter(w io.Writer, level int) (*Writer, error) {
-       const logWindowSize = logMaxOffsetSize
        var dw Writer
        if err := dw.d.init(w, level); err != nil {
                return nil, err
@@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
        zw.Write(dict)
        zw.Flush()
        dw.enabled = true
+       zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
        return zw, err
 }
 
@@ -480,7 +524,8 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
 // A Writer takes data written to it and writes the compressed
 // form of that data to an underlying writer (see NewWriter).
 type Writer struct {
-       d compressor
+       d    compressor
+       dict []byte
 }
 
 // Write writes data to w, which will eventually write the
@@ -506,3 +551,21 @@ func (w *Writer) Flush() error {
 func (w *Writer) Close() error {
        return w.d.close()
 }
+
+// Reset discards the writer's state and makes it equivalent to
+// the result of NewWriter or NewWriterDict called with w
+// and w's level and dictionary.
+func (w *Writer) Reset(dst io.Writer) {
+       if dw, ok := w.d.w.w.(*dictWriter); ok {
+               // w was created with NewWriterDict
+               dw.w = dst
+               w.d.reset(dw)
+               dw.enabled = false
+               w.Write(w.dict)
+               w.Flush()
+               dw.enabled = true
+       } else {
+               // w was created with NewWriter
+               w.d.reset(dst)
+       }
+}
index 8c4a6d6b36fab8a6526d335d15836c4f17ae00b0..730234c385039510249bd0aa482a793c2871fb5f 100644 (file)
@@ -9,6 +9,7 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "reflect"
        "sync"
        "testing"
 )
@@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) {
        }
        w.Close()
 }
+
+func TestWriterReset(t *testing.T) {
+       for level := 0; level <= 9; level++ {
+               if testing.Short() && level > 1 {
+                       break
+               }
+               w, err := NewWriter(ioutil.Discard, level)
+               if err != nil {
+                       t.Fatalf("NewWriter: %v", err)
+               }
+               buf := []byte("hello world")
+               for i := 0; i < 1024; i++ {
+                       w.Write(buf)
+               }
+               w.Reset(ioutil.Discard)
+
+               wref, err := NewWriter(ioutil.Discard, level)
+               if err != nil {
+                       t.Fatalf("NewWriter: %v", err)
+               }
+
+               // DeepEqual doesn't compare functions.
+               w.d.fill, wref.d.fill = nil, nil
+               w.d.step, wref.d.step = nil, nil
+               if !reflect.DeepEqual(w, wref) {
+                       t.Errorf("level %d Writer not reset after Reset", level)
+               }
+       }
+       testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
+       testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
+       testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
+       dict := []byte("we are the world")
+       testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
+       testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
+       testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
+}
+
+func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
+       buf := new(bytes.Buffer)
+       w, err := newWriter(buf)
+       if err != nil {
+               t.Fatalf("NewWriter: %v", err)
+       }
+       b := []byte("hello world")
+       for i := 0; i < 1024; i++ {
+               w.Write(b)
+       }
+       w.Close()
+       out1 := buf.String()
+
+       buf2 := new(bytes.Buffer)
+       w.Reset(buf2)
+       for i := 0; i < 1024; i++ {
+               w.Write(b)
+       }
+       w.Close()
+       out2 := buf2.String()
+
+       if out1 != out2 {
+               t.Errorf("got %q, expected %q", out2, out1)
+       }
+       t.Logf("got %d bytes", len(out1))
+}
index 25e1da336aa469bc5560e57b5476488f40471021..b182a710b9af9ab3767c8684d68c9ae2a7fab48c 100644 (file)
@@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
        }
 }
 
+func (w *huffmanBitWriter) reset(writer io.Writer) {
+       w.w = writer
+       w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
+       w.bytes = [64]byte{}
+       for i := range w.codegen {
+               w.codegen[i] = 0
+       }
+       for _, s := range [...][]int32{w.literalFreq, w.offsetFreq, w.codegenFreq} {
+               for i := range s {
+                       s[i] = 0
+               }
+       }
+       for _, enc := range [...]*huffmanEncoder{
+               w.literalEncoding,
+               w.offsetEncoding,
+               w.codegenEncoding} {
+               for i := range enc.code {
+                       enc.code[i] = 0
+               }
+               for i := range enc.codeBits {
+                       enc.codeBits[i] = 0
+               }
+       }
+}
+
 func (w *huffmanBitWriter) flushBits() {
        if w.err != nil {
                w.nbits = 0