From: Rémy Oudompheng Date: Thu, 29 Aug 2013 19:09:23 +0000 (+0200) Subject: compress/flate: implement Reset method on Writer. X-Git-Tag: go1.2rc2~389 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=f5f0e40e803125e47c64372fc7d808cbd8b9577a;p=gostls13.git compress/flate: implement Reset method on Writer. Fixes #6138. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/12953048 --- diff --git a/src/pkg/compress/flate/deflate.go b/src/pkg/compress/flate/deflate.go index d357fe361a..b3e079150a 100644 --- a/src/pkg/compress/flate/deflate.go +++ b/src/pkg/compress/flate/deflate.go @@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) { return nil } +var zeroes [32]int +var bzeroes [256]byte + +func (d *compressor) reset(w io.Writer) { + d.w.reset(w) + d.sync = false + d.err = nil + switch d.compressionLevel.chain { + case 0: + // level was NoCompression. + for i := range d.window { + d.window[i] = 0 + } + d.windowEnd = 0 + default: + d.chainHead = -1 + for s := d.hashHead; len(s) > 0; { + n := copy(s, zeroes[:]) + s = s[n:] + } + for s := d.hashPrev; len(s) > 0; s = s[len(zeroes):] { + copy(s, zeroes[:]) + } + d.hashOffset = 1 + + d.index, d.windowEnd = 0, 0 + for s := d.window; len(s) > 0; { + n := copy(s, bzeroes[:]) + s = s[n:] + } + d.blockStart, d.byteAvailable = 0, false + + d.tokens = d.tokens[:maxFlateBlockTokens+1] + for i := 0; i <= maxFlateBlockTokens; i++ { + d.tokens[i] = 0 + } + d.tokens = d.tokens[:0] + d.length = minMatchLength - 1 + d.offset = 0 + d.hash = 0 + d.maxInsertIndex = 0 + } +} + func (d *compressor) close() error { d.sync = true d.step(d) @@ -439,7 +483,6 @@ func (d *compressor) close() error { // If level is in the range [-1, 9] then the error returned will be nil. // Otherwise the error returned will be non-nil. func NewWriter(w io.Writer, level int) (*Writer, error) { - const logWindowSize = logMaxOffsetSize var dw Writer if err := dw.d.init(w, level); err != nil { return nil, err @@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { zw.Write(dict) zw.Flush() dw.enabled = true + zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method. return zw, err } @@ -480,7 +524,8 @@ func (w *dictWriter) Write(b []byte) (n int, err error) { // A Writer takes data written to it and writes the compressed // form of that data to an underlying writer (see NewWriter). type Writer struct { - d compressor + d compressor + dict []byte } // Write writes data to w, which will eventually write the @@ -506,3 +551,21 @@ func (w *Writer) Flush() error { func (w *Writer) Close() error { return w.d.close() } + +// Reset discards the writer's state and makes it equivalent to +// the result of NewWriter or NewWriterDict called with w +// and w's level and dictionary. +func (w *Writer) Reset(dst io.Writer) { + if dw, ok := w.d.w.w.(*dictWriter); ok { + // w was created with NewWriterDict + dw.w = dst + w.d.reset(dw) + dw.enabled = false + w.Write(w.dict) + w.Flush() + dw.enabled = true + } else { + // w was created with NewWriter + w.d.reset(dst) + } +} diff --git a/src/pkg/compress/flate/deflate_test.go b/src/pkg/compress/flate/deflate_test.go index 8c4a6d6b36..730234c385 100644 --- a/src/pkg/compress/flate/deflate_test.go +++ b/src/pkg/compress/flate/deflate_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/ioutil" + "reflect" "sync" "testing" ) @@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) { } w.Close() } + +func TestWriterReset(t *testing.T) { + for level := 0; level <= 9; level++ { + if testing.Short() && level > 1 { + break + } + w, err := NewWriter(ioutil.Discard, level) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + buf := []byte("hello world") + for i := 0; i < 1024; i++ { + w.Write(buf) + } + w.Reset(ioutil.Discard) + + wref, err := NewWriter(ioutil.Discard, level) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + + // DeepEqual doesn't compare functions. + w.d.fill, wref.d.fill = nil, nil + w.d.step, wref.d.step = nil, nil + if !reflect.DeepEqual(w, wref) { + t.Errorf("level %d Writer not reset after Reset", level) + } + } + testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) }) + testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) }) + testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) }) + dict := []byte("we are the world") + testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) }) + testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) }) + testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) }) +} + +func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) { + buf := new(bytes.Buffer) + w, err := newWriter(buf) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + b := []byte("hello world") + for i := 0; i < 1024; i++ { + w.Write(b) + } + w.Close() + out1 := buf.String() + + buf2 := new(bytes.Buffer) + w.Reset(buf2) + for i := 0; i < 1024; i++ { + w.Write(b) + } + w.Close() + out2 := buf2.String() + + if out1 != out2 { + t.Errorf("got %q, expected %q", out2, out1) + } + t.Logf("got %d bytes", len(out1)) +} diff --git a/src/pkg/compress/flate/huffman_bit_writer.go b/src/pkg/compress/flate/huffman_bit_writer.go index 25e1da336a..b182a710b9 100644 --- a/src/pkg/compress/flate/huffman_bit_writer.go +++ b/src/pkg/compress/flate/huffman_bit_writer.go @@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { } } +func (w *huffmanBitWriter) reset(writer io.Writer) { + w.w = writer + w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil + w.bytes = [64]byte{} + for i := range w.codegen { + w.codegen[i] = 0 + } + for _, s := range [...][]int32{w.literalFreq, w.offsetFreq, w.codegenFreq} { + for i := range s { + s[i] = 0 + } + } + for _, enc := range [...]*huffmanEncoder{ + w.literalEncoding, + w.offsetEncoding, + w.codegenEncoding} { + for i := range enc.code { + enc.code[i] = 0 + } + for i := range enc.codeBits { + enc.codeBits[i] = 0 + } + } +} + func (w *huffmanBitWriter) flushBits() { if w.err != nil { w.nbits = 0