]> Cypherpunks repositories - gostls13.git/commitdiff
compress/flate: implement Flush
authorRuss Cox <rsc@golang.org>
Tue, 14 Dec 2010 19:57:12 +0000 (14:57 -0500)
committerRuss Cox <rsc@golang.org>
Tue, 14 Dec 2010 19:57:12 +0000 (14:57 -0500)
This Flush is equivalent to zlib's Z_SYNC_FLUSH.
The addition of the explicit Writer type opens the
door to adding a PartialFlush if needed for SSH
and maybe even FullFlush.  It also opens the door
for a SetDictionary method to be added.

http://www.bolet.org/~pornin/deflate-flush.html
documents the various intricacies of flushing a
DEFLATE stream.

R=agl1, r
CC=golang-dev
https://golang.org/cl/3637041

src/pkg/compress/flate/deflate.go
src/pkg/compress/flate/deflate_test.go
src/pkg/compress/flate/inflate.go

index 509c8debd1e9605702f670bfd9bfb1dd15c672e4..591b35c4463e44421d547a7e1fdb2859e4c88790 100644 (file)
@@ -89,6 +89,10 @@ type compressor struct {
        // (1 << logWindowSize) - 1.
        windowMask int
 
+       eof      bool // has eof been reached on input?
+       sync     bool // writer wants to flush
+       syncChan chan os.Error
+
        // hashHead[hashValue] contains the largest inputIndex with the specified hash value
        hashHead []int
 
@@ -124,6 +128,9 @@ func (d *compressor) flush() os.Error {
 }
 
 func (d *compressor) fillWindow(index int) (int, os.Error) {
+       if d.sync {
+               return index, nil
+       }
        wSize := d.windowMask + 1
        if index >= wSize+wSize-(minMatchLength+maxMatchLength) {
                // shift the window by wSize
@@ -142,12 +149,14 @@ func (d *compressor) fillWindow(index int) (int, os.Error) {
                        d.hashPrev[i] = max(h-wSize, -1)
                }
        }
-       var count int
-       var err os.Error
-       count, err = io.ReadAtLeast(d.r, d.window[d.windowEnd:], 1)
+       count, err := d.r.Read(d.window[d.windowEnd:])
        d.windowEnd += count
+       if count == 0 && err == nil {
+               d.sync = true
+       }
        if err == os.EOF {
-               return index, nil
+               d.eof = true
+               err = nil
        }
        return index, err
 }
@@ -227,10 +236,17 @@ func (d *compressor) storedDeflate() os.Error {
        buf := make([]byte, maxStoreBlockSize)
        for {
                n, err := d.r.Read(buf)
-               if n > 0 {
+               if n == 0 && err == nil {
+                       d.sync = true
+               }
+               if n > 0 || d.sync {
                        if err := d.writeStoredBlock(buf[0:n]); err != nil {
                                return err
                        }
+                       if d.sync {
+                               d.syncChan <- nil
+                               d.sync = false
+                       }
                }
                if err != nil {
                        if err == os.EOF {
@@ -275,6 +291,7 @@ func (d *compressor) doDeflate() (err os.Error) {
                hash = int(d.window[index])<<hashShift + int(d.window[index+1])
        }
        chainHead := -1
+Loop:
        for {
                if index > windowEnd {
                        panic("index > windowEnd")
@@ -291,7 +308,31 @@ func (d *compressor) doDeflate() (err os.Error) {
                        maxInsertIndex = windowEnd - (minMatchLength - 1)
                        lookahead = windowEnd - index
                        if lookahead == 0 {
-                               break
+                               // Flush current output block if any.
+                               if byteAvailable {
+                                       // There is still one pending token that needs to be flushed
+                                       tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF)
+                                       ti++
+                                       byteAvailable = false
+                               }
+                               if ti > 0 {
+                                       if err = d.writeBlock(tokens[0:ti], index, false); err != nil {
+                                               return
+                                       }
+                                       ti = 0
+                               }
+                               if d.sync {
+                                       d.w.writeStoredHeader(0, false)
+                                       d.w.flush()
+                                       d.syncChan <- d.w.err
+                                       d.sync = false
+                               }
+
+                               // If this was only a sync (not at EOF) keep going.
+                               if !d.eof {
+                                       continue
+                               }
+                               break Loop
                        }
                }
                if index < maxInsertIndex {
@@ -383,23 +424,11 @@ func (d *compressor) doDeflate() (err os.Error) {
                                byteAvailable = true
                        }
                }
-
-       }
-       if byteAvailable {
-               // There is still one pending token that needs to be flushed
-               tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF)
-               ti++
-       }
-
-       if ti > 0 {
-               if err = d.writeBlock(tokens[0:ti], index, false); err != nil {
-                       return
-               }
        }
        return
 }
 
-func (d *compressor) compressor(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) {
+func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) {
        d.r = r
        d.w = newHuffmanBitWriter(w)
        d.level = level
@@ -417,6 +446,10 @@ func (d *compressor) compressor(r io.Reader, w io.Writer, level int, logWindowSi
                return WrongValueError{"level", 0, 9, int32(level)}
        }
 
+       if d.sync {
+               d.syncChan <- err
+               d.sync = false
+       }
        if err != nil {
                return err
        }
@@ -426,16 +459,63 @@ func (d *compressor) compressor(r io.Reader, w io.Writer, level int, logWindowSi
        return d.flush()
 }
 
-func newCompressor(w io.Writer, level int, logWindowSize uint) io.WriteCloser {
+// NewWriter returns a new Writer compressing
+// data at the given level.  Following zlib, levels
+// range from 1 (BestSpeed) to 9 (BestCompression);
+// higher levels typically run slower but compress more.
+// Level 0 (NoCompression) does not attempt any
+// compression; it only adds the necessary DEFLATE framing.
+func NewWriter(w io.Writer, level int) *Writer {
+       const logWindowSize = logMaxOffsetSize
        var d compressor
+       d.syncChan = make(chan os.Error, 1)
        pr, pw := syncPipe()
        go func() {
-               err := d.compressor(pr, w, level, logWindowSize)
+               err := d.compress(pr, w, level, logWindowSize)
                pr.CloseWithError(err)
        }()
-       return pw
+       return &Writer{pw, &d}
+}
+
+// A Writer takes data written to it and writes the compressed
+// form of that data to an underlying writer (see NewWriter).
+type Writer struct {
+       w *syncPipeWriter
+       d *compressor
+}
+
+// Write writes data to w, which will eventually write the
+// compressed form of data to its underlying writer.
+func (w *Writer) Write(data []byte) (n int, err os.Error) {
+       if len(data) == 0 {
+               // no point, and nil interferes with sync
+               return
+       }
+       return w.w.Write(data)
+}
+
+// Flush flushes any pending compressed data to the underlying writer.
+// It is useful mainly in compressed network protocols, to ensure that
+// a remote reader has enough data to reconstruct a packet.
+// Flush does not return until the data has been written.
+// If the underlying writer returns an error, Flush returns that error.
+//
+// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
+func (w *Writer) Flush() os.Error {
+       // For more about flushing:
+       // http://www.bolet.org/~pornin/deflate-flush.html
+       if w.d.sync {
+               panic("compress/flate: double Flush")
+       }
+       _, err := w.w.Write(nil)
+       err1 := <-w.d.syncChan
+       if err == nil {
+               err = err1
+       }
+       return err
 }
 
-func NewWriter(w io.Writer, level int) io.WriteCloser {
-       return newCompressor(w, level, logMaxOffsetSize)
+// Close flushes and closes the writer.
+func (w *Writer) Close() os.Error {
+       return w.w.Close()
 }
index 9718d2f5abdb78b2150de96d4fabbec3e7e4642d..3db955609d7eb85c1b9ea2904b74ba0a92f5407e 100644 (file)
@@ -7,8 +7,10 @@ package flate
 import (
        "bytes"
        "fmt"
+       "io"
        "io/ioutil"
        "os"
+       "sync"
        "testing"
 )
 
@@ -79,7 +81,7 @@ func getLargeDataChunk() []byte {
 
 func TestDeflate(t *testing.T) {
        for _, h := range deflateTests {
-               buffer := bytes.NewBuffer([]byte{})
+               buffer := bytes.NewBuffer(nil)
                w := NewWriter(buffer, h.level)
                w.Write(h.in)
                w.Close()
@@ -90,21 +92,144 @@ func TestDeflate(t *testing.T) {
        }
 }
 
+type syncBuffer struct {
+       buf    bytes.Buffer
+       mu     sync.RWMutex
+       closed bool
+       ready  chan bool
+}
+
+func newSyncBuffer() *syncBuffer {
+       return &syncBuffer{ready: make(chan bool, 1)}
+}
+
+func (b *syncBuffer) Read(p []byte) (n int, err os.Error) {
+       for {
+               b.mu.RLock()
+               n, err = b.buf.Read(p)
+               b.mu.RUnlock()
+               if n > 0 || b.closed {
+                       return
+               }
+               <-b.ready
+       }
+       panic("unreachable")
+}
+
+func (b *syncBuffer) Write(p []byte) (n int, err os.Error) {
+       n, err = b.buf.Write(p)
+       _ = b.ready <- true
+       return
+}
+
+func (b *syncBuffer) WriteMode() {
+       b.mu.Lock()
+}
+
+func (b *syncBuffer) ReadMode() {
+       b.mu.Unlock()
+       _ = b.ready <- true
+}
+
+func (b *syncBuffer) Close() os.Error {
+       b.closed = true
+       _ = b.ready <- true
+       return nil
+}
+
+func testSync(t *testing.T, level int, input []byte, name string) {
+       if len(input) == 0 {
+               return
+       }
+
+       t.Logf("--testSync %d, %d, %s", level, len(input), name)
+       buf := newSyncBuffer()
+       buf1 := new(bytes.Buffer)
+       buf.WriteMode()
+       w := NewWriter(io.MultiWriter(buf, buf1), level)
+       r := NewReader(buf)
+
+       // Write half the input and read back.
+       for i := 0; i < 2; i++ {
+               var lo, hi int
+               if i == 0 {
+                       lo, hi = 0, (len(input)+1)/2
+               } else {
+                       lo, hi = (len(input)+1)/2, len(input)
+               }
+               t.Logf("#%d: write %d-%d", i, lo, hi)
+               if _, err := w.Write(input[lo:hi]); err != nil {
+                       t.Errorf("testSync: write: %v", err)
+                       return
+               }
+               if i == 0 {
+                       if err := w.Flush(); err != nil {
+                               t.Errorf("testSync: flush: %v", err)
+                               return
+                       }
+               } else {
+                       if err := w.Close(); err != nil {
+                               t.Errorf("testSync: close: %v", err)
+                       }
+               }
+               buf.ReadMode()
+               out := make([]byte, hi-lo+1)
+               m, err := io.ReadAtLeast(r, out, hi-lo)
+               t.Logf("#%d: read %d", i, m)
+               if m != hi-lo || err != nil {
+                       t.Errorf("testSync/%d (%d, %d, %s): read %d: %d, %v (%d left)", i, level, len(input), name, hi-lo, m, err, buf.buf.Len())
+                       return
+               }
+               if !bytes.Equal(input[lo:hi], out[:hi-lo]) {
+                       t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
+                       return
+               }
+               if i == 0 && buf.buf.Len() != 0 {
+                       t.Errorf("testSync/%d (%d, %d, %s): extra data after %d", i, level, len(input), name, hi-lo)
+               }
+               buf.WriteMode()
+       }
+       buf.ReadMode()
+       out := make([]byte, 10)
+       if n, err := r.Read(out); n > 0 || err != os.EOF {
+               t.Errorf("testSync (%d, %d, %s): final Read: %d, %v (hex: %x)", level, len(input), name, n, err, out[0:n])
+       }
+       if buf.buf.Len() != 0 {
+               t.Errorf("testSync (%d, %d, %s): extra data at end", level, len(input), name)
+       }
+       r.Close()
+
+       // stream should work for ordinary reader too
+       r = NewReader(buf1)
+       out, err := ioutil.ReadAll(r)
+       if err != nil {
+               t.Errorf("testSync: read: %s", err)
+               return
+       }
+       r.Close()
+       if !bytes.Equal(input, out) {
+               t.Errorf("testSync: decompress(compress(data)) != data: level=%d input=%s", level, name)
+       }
+}
+
+
 func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error {
-       buffer := bytes.NewBuffer([]byte{})
+       buffer := bytes.NewBuffer(nil)
        w := NewWriter(buffer, level)
        w.Write(input)
        w.Close()
-       decompressor := NewReader(buffer)
-       decompressed, err := ioutil.ReadAll(decompressor)
+       r := NewReader(buffer)
+       out, err := ioutil.ReadAll(r)
        if err != nil {
-               t.Errorf("reading decompressor: %s", err)
+               t.Errorf("read: %s", err)
                return err
        }
-       decompressor.Close()
-       if bytes.Compare(input, decompressed) != 0 {
+       r.Close()
+       if !bytes.Equal(input, out) {
                t.Errorf("decompress(compress(data)) != data: level=%d input=%s", level, name)
        }
+
+       testSync(t, level, input, name)
        return nil
 }
 
index 5e2146320e887147ff2c3c5703bede3bb5f64fbb..7dc8cf93bd9f9dc9310c78cfde6d809d87fc660f 100644 (file)
@@ -217,6 +217,7 @@ type decompressor struct {
        // Output history, buffer.
        hist  [maxHist]byte
        hp    int  // current output position in buffer
+       hw    int  // have written hist[0:hw] already
        hfull bool // buffer has filled at least once
 
        // Temporary buffer (avoids repeated allocation).
@@ -497,6 +498,11 @@ func (f *decompressor) dataBlock() os.Error {
                return CorruptInputError(f.roffset)
        }
 
+       if n == 0 {
+               // 0-length block means sync
+               return f.flush()
+       }
+
        // Read len bytes into history,
        // writing as history fills.
        for n > 0 {
@@ -560,19 +566,23 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
 
 // Flush any buffered output to the underlying writer.
 func (f *decompressor) flush() os.Error {
-       if f.hp == 0 {
+       if f.hw == f.hp {
                return nil
        }
-       n, err := f.w.Write(f.hist[0:f.hp])
-       if n != f.hp && err == nil {
+       n, err := f.w.Write(f.hist[f.hw:f.hp])
+       if n != f.hp-f.hw && err == nil {
                err = io.ErrShortWrite
        }
        if err != nil {
                return &WriteError{f.woffset, err}
        }
-       f.woffset += int64(f.hp)
-       f.hp = 0
-       f.hfull = true
+       f.woffset += int64(f.hp - f.hw)
+       f.hw = f.hp
+       if f.hp == len(f.hist) {
+               f.hp = 0
+               f.hw = 0
+               f.hfull = true
+       }
        return nil
 }
 
@@ -583,9 +593,9 @@ func makeReader(r io.Reader) Reader {
        return bufio.NewReader(r)
 }
 
-// Inflate reads DEFLATE-compressed data from r and writes
+// decompress reads DEFLATE-compressed data from r and writes
 // the uncompressed data to w.
-func (f *decompressor) decompressor(r io.Reader, w io.Writer) os.Error {
+func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
        f.r = makeReader(r)
        f.w = w
        f.woffset = 0
@@ -605,6 +615,6 @@ func (f *decompressor) decompressor(r io.Reader, w io.Writer) os.Error {
 func NewReader(r io.Reader) io.ReadCloser {
        var f decompressor
        pr, pw := io.Pipe()
-       go func() { pw.CloseWithError(f.decompressor(r, pw)) }()
+       go func() { pw.CloseWithError(f.decompress(r, pw)) }()
        return pr
 }