]> Cypherpunks repositories - gostls13.git/commitdiff
compress/flate: do not use background goroutines
authorRuss Cox <rsc@golang.org>
Thu, 2 Jun 2011 13:32:38 +0000 (09:32 -0400)
committerRuss Cox <rsc@golang.org>
Thu, 2 Jun 2011 13:32:38 +0000 (09:32 -0400)
Programs expect that Read and Write are synchronous.
The background goroutines make the implementation
a little easier, but they introduce asynchrony that
trips up calling code.  Remove them.

R=golang-dev, krasin
CC=golang-dev
https://golang.org/cl/4548079

src/pkg/compress/flate/deflate.go
src/pkg/compress/flate/deflate_test.go
src/pkg/compress/flate/inflate.go

index a02a5e8d94bcac794234dbf0b18782440284a8c0..b1cee0b2f0f919aad91690c43bc2b50e844f8816 100644 (file)
@@ -11,16 +11,18 @@ import (
 )
 
 const (
-       NoCompression        = 0
-       BestSpeed            = 1
-       fastCompression      = 3
-       BestCompression      = 9
-       DefaultCompression   = -1
-       logMaxOffsetSize     = 15  // Standard DEFLATE
-       wideLogMaxOffsetSize = 22  // Wide DEFLATE
-       minMatchLength       = 3   // The smallest match that the compressor looks for
-       maxMatchLength       = 258 // The longest match for the compressor
-       minOffsetSize        = 1   // The shortest offset that makes any sence
+       NoCompression      = 0
+       BestSpeed          = 1
+       fastCompression    = 3
+       BestCompression    = 9
+       DefaultCompression = -1
+       logWindowSize      = 15
+       windowSize         = 1 << logWindowSize
+       windowMask         = windowSize - 1
+       logMaxOffsetSize   = 15  // Standard DEFLATE
+       minMatchLength     = 3   // The smallest match that the compressor looks for
+       maxMatchLength     = 258 // The longest match for the compressor
+       minOffsetSize      = 1   // The shortest offset that makes any sence
 
        // The maximum number of tokens we put into a single flat block, just too
        // stop things from getting too large.
@@ -32,22 +34,6 @@ const (
        hashShift           = (hashBits + minMatchLength - 1) / minMatchLength
 )
 
-type syncPipeReader struct {
-       *io.PipeReader
-       closeChan chan bool
-}
-
-func (sr *syncPipeReader) CloseWithError(err os.Error) os.Error {
-       retErr := sr.PipeReader.CloseWithError(err)
-       sr.closeChan <- true // finish writer close
-       return retErr
-}
-
-type syncPipeWriter struct {
-       *io.PipeWriter
-       closeChan chan bool
-}
-
 type compressionLevel struct {
        good, lazy, nice, chain, fastSkipHashing int
 }
@@ -68,105 +54,73 @@ var levels = []compressionLevel{
        {32, 258, 258, 4096, math.MaxInt32},
 }
 
-func (sw *syncPipeWriter) Close() os.Error {
-       err := sw.PipeWriter.Close()
-       <-sw.closeChan // wait for reader close
-       return err
-}
-
-func syncPipe() (*syncPipeReader, *syncPipeWriter) {
-       r, w := io.Pipe()
-       sr := &syncPipeReader{r, make(chan bool, 1)}
-       sw := &syncPipeWriter{w, sr.closeChan}
-       return sr, sw
-}
-
 type compressor struct {
-       level         int
-       logWindowSize uint
-       w             *huffmanBitWriter
-       r             io.Reader
-       // (1 << logWindowSize) - 1.
-       windowMask int
+       compressionLevel
 
-       eof      bool // has eof been reached on input?
-       sync     bool // writer wants to flush
-       syncChan chan os.Error
+       w *huffmanBitWriter
 
-       // hashHead[hashValue] contains the largest inputIndex with the specified hash value
-       hashHead []int
+       // compression algorithm
+       fill func(*compressor, []byte) int // copy data to window
+       step func(*compressor)             // process window
+       sync bool                          // requesting flush
 
+       // Input hash chains
+       // hashHead[hashValue] contains the largest inputIndex with the specified hash value
        // If hashHead[hashValue] is within the current window, then
        // hashPrev[hashHead[hashValue] & windowMask] contains the previous index
        // with the same hash value.
-       hashPrev []int
-
-       // If we find a match of length >= niceMatch, then we don't bother searching
-       // any further.
-       niceMatch int
-
-       // If we find a match of length >= goodMatch, we only do a half-hearted
-       // effort at doing lazy matching starting at the next character
-       goodMatch int
-
-       // The maximum number of chains we look at when finding a match
-       maxChainLength int
-
-       // The sliding window we use for matching
-       window []byte
-
-       // The index just past the last valid character
-       windowEnd int
-
-       // index in "window" at which current block starts
-       blockStart int
-}
-
-func (d *compressor) flush() os.Error {
-       d.w.flush()
-       return d.w.err
+       chainHead int
+       hashHead  []int
+       hashPrev  []int
+
+       // input window: unprocessed data is window[index:windowEnd]
+       index         int
+       window        []byte
+       windowEnd     int
+       blockStart    int  // window index where current tokens start
+       byteAvailable bool // if true, still need to process window[index-1].
+
+       // queued output tokens: tokens[:ti]
+       tokens []token
+       ti     int
+
+       // deflate state
+       length         int
+       offset         int
+       hash           int
+       maxInsertIndex int
+       err            os.Error
 }
 
-func (d *compressor) fillWindow(index int) (int, os.Error) {
-       if d.sync {
-               return index, nil
-       }
-       wSize := d.windowMask + 1
-       if index >= wSize+wSize-(minMatchLength+maxMatchLength) {
-               // shift the window by wSize
-               copy(d.window, d.window[wSize:2*wSize])
-               index -= wSize
-               d.windowEnd -= wSize
-               if d.blockStart >= wSize {
-                       d.blockStart -= wSize
+func (d *compressor) fillDeflate(b []byte) int {
+       if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
+               // shift the window by windowSize
+               copy(d.window, d.window[windowSize:2*windowSize])
+               d.index -= windowSize
+               d.windowEnd -= windowSize
+               if d.blockStart >= windowSize {
+                       d.blockStart -= windowSize
                } else {
                        d.blockStart = math.MaxInt32
                }
                for i, h := range d.hashHead {
-                       v := h - wSize
+                       v := h - windowSize
                        if v < -1 {
                                v = -1
                        }
                        d.hashHead[i] = v
                }
                for i, h := range d.hashPrev {
-                       v := -h - wSize
+                       v := -h - windowSize
                        if v < -1 {
                                v = -1
                        }
                        d.hashPrev[i] = v
                }
        }
-       count, err := d.r.Read(d.window[d.windowEnd:])
-       d.windowEnd += count
-       if count == 0 && err == nil {
-               d.sync = true
-       }
-       if err == os.EOF {
-               d.eof = true
-               err = nil
-       }
-       return index, err
+       n := copy(d.window[d.windowEnd:], b)
+       d.windowEnd += n
+       return n
 }
 
 func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error {
@@ -194,21 +148,21 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
 
        // We quit when we get a match that's at least nice long
        nice := len(win) - pos
-       if d.niceMatch < nice {
-               nice = d.niceMatch
+       if d.nice < nice {
+               nice = d.nice
        }
 
        // If we've got a match that's good enough, only look in 1/4 the chain.
-       tries := d.maxChainLength
+       tries := d.chain
        length = prevLength
-       if length >= d.goodMatch {
+       if length >= d.good {
                tries >>= 2
        }
 
        w0 := win[pos]
        w1 := win[pos+1]
        wEnd := win[pos+length]
-       minIndex := pos - (d.windowMask + 1)
+       minIndex := pos - windowSize
 
        for i := prevHead; tries > 0; tries-- {
                if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] {
@@ -233,7 +187,7 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
                        // hashPrev[i & windowMask] has already been overwritten, so stop now.
                        break
                }
-               if i = d.hashPrev[i&d.windowMask]; i < minIndex || i < 0 {
+               if i = d.hashPrev[i&windowMask]; i < minIndex || i < 0 {
                        break
                }
        }
@@ -248,234 +202,224 @@ func (d *compressor) writeStoredBlock(buf []byte) os.Error {
        return d.w.err
 }
 
-func (d *compressor) storedDeflate() os.Error {
-       buf := make([]byte, maxStoreBlockSize)
-       for {
-               n, err := d.r.Read(buf)
-               if n == 0 && err == nil {
-                       d.sync = true
-               }
-               if n > 0 || d.sync {
-                       if err := d.writeStoredBlock(buf[0:n]); err != nil {
-                               return err
-                       }
-                       if d.sync {
-                               d.syncChan <- nil
-                               d.sync = false
-                       }
-               }
-               if err != nil {
-                       if err == os.EOF {
-                               break
-                       }
-                       return err
-               }
-       }
-       return nil
-}
-
-func (d *compressor) doDeflate() (err os.Error) {
-       // init
-       d.windowMask = 1<<d.logWindowSize - 1
+func (d *compressor) initDeflate() {
        d.hashHead = make([]int, hashSize)
-       d.hashPrev = make([]int, 1<<d.logWindowSize)
-       d.window = make([]byte, 2<<d.logWindowSize)
+       d.hashPrev = make([]int, windowSize)
+       d.window = make([]byte, 2*windowSize)
        fillInts(d.hashHead, -1)
-       tokens := make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
-       l := levels[d.level]
-       d.goodMatch = l.good
-       d.niceMatch = l.nice
-       d.maxChainLength = l.chain
-       lazyMatch := l.lazy
-       length := minMatchLength - 1
-       offset := 0
-       byteAvailable := false
-       isFastDeflate := l.fastSkipHashing != 0
-       index := 0
-       // run
-       if index, err = d.fillWindow(index); err != nil {
+       d.tokens = make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
+       d.length = minMatchLength - 1
+       d.offset = 0
+       d.byteAvailable = false
+       d.index = 0
+       d.ti = 0
+       d.hash = 0
+       d.chainHead = -1
+}
+
+func (d *compressor) deflate() {
+       if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
                return
        }
-       maxOffset := d.windowMask + 1 // (1 << logWindowSize);
-       // only need to change when you refill the window
-       windowEnd := d.windowEnd
-       maxInsertIndex := windowEnd - (minMatchLength - 1)
-       ti := 0
-
-       hash := int(0)
-       if index < maxInsertIndex {
-               hash = int(d.window[index])<<hashShift + int(d.window[index+1])
+
+       d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
+       if d.index < d.maxInsertIndex {
+               d.hash = int(d.window[d.index])<<hashShift + int(d.window[d.index+1])
        }
-       chainHead := -1
+
 Loop:
        for {
-               if index > windowEnd {
+               if d.index > d.windowEnd {
                        panic("index > windowEnd")
                }
-               lookahead := windowEnd - index
+               lookahead := d.windowEnd - d.index
                if lookahead < minMatchLength+maxMatchLength {
-                       if index, err = d.fillWindow(index); err != nil {
-                               return
+                       if !d.sync {
+                               break Loop
                        }
-                       windowEnd = d.windowEnd
-                       if index > windowEnd {
+                       if d.index > d.windowEnd {
                                panic("index > windowEnd")
                        }
-                       maxInsertIndex = windowEnd - (minMatchLength - 1)
-                       lookahead = windowEnd - index
                        if lookahead == 0 {
                                // Flush current output block if any.
-                               if byteAvailable {
+                               if d.byteAvailable {
                                        // There is still one pending token that needs to be flushed
-                                       tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF)
-                                       ti++
-                                       byteAvailable = false
+                                       d.tokens[d.ti] = literalToken(uint32(d.window[d.index-1]))
+                                       d.ti++
+                                       d.byteAvailable = false
                                }
-                               if ti > 0 {
-                                       if err = d.writeBlock(tokens[0:ti], index, false); err != nil {
+                               if d.ti > 0 {
+                                       if d.err = d.writeBlock(d.tokens[0:d.ti], d.index, false); d.err != nil {
                                                return
                                        }
-                                       ti = 0
-                               }
-                               if d.sync {
-                                       d.w.writeStoredHeader(0, false)
-                                       d.w.flush()
-                                       d.syncChan <- d.w.err
-                                       d.sync = false
-                               }
-
-                               // If this was only a sync (not at EOF) keep going.
-                               if !d.eof {
-                                       continue
+                                       d.ti = 0
                                }
                                break Loop
                        }
                }
-               if index < maxInsertIndex {
+               if d.index < d.maxInsertIndex {
                        // Update the hash
-                       hash = (hash<<hashShift + int(d.window[index+2])) & hashMask
-                       chainHead = d.hashHead[hash]
-                       d.hashPrev[index&d.windowMask] = chainHead
-                       d.hashHead[hash] = index
+                       d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
+                       d.chainHead = d.hashHead[d.hash]
+                       d.hashPrev[d.index&windowMask] = d.chainHead
+                       d.hashHead[d.hash] = d.index
                }
-               prevLength := length
-               prevOffset := offset
-               length = minMatchLength - 1
-               offset = 0
-               minIndex := index - maxOffset
+               prevLength := d.length
+               prevOffset := d.offset
+               d.length = minMatchLength - 1
+               d.offset = 0
+               minIndex := d.index - windowSize
                if minIndex < 0 {
                        minIndex = 0
                }
 
-               if chainHead >= minIndex &&
-                       (isFastDeflate && lookahead > minMatchLength-1 ||
-                               !isFastDeflate && lookahead > prevLength && prevLength < lazyMatch) {
-                       if newLength, newOffset, ok := d.findMatch(index, chainHead, minMatchLength-1, lookahead); ok {
-                               length = newLength
-                               offset = newOffset
+               if d.chainHead >= minIndex &&
+                       (d.fastSkipHashing != 0 && lookahead > minMatchLength-1 ||
+                               d.fastSkipHashing == 0 && lookahead > prevLength && prevLength < d.lazy) {
+                       if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead, minMatchLength-1, lookahead); ok {
+                               d.length = newLength
+                               d.offset = newOffset
                        }
                }
-               if isFastDeflate && length >= minMatchLength ||
-                       !isFastDeflate && prevLength >= minMatchLength && length <= prevLength {
+               if d.fastSkipHashing != 0 && d.length >= minMatchLength ||
+                       d.fastSkipHashing == 0 && prevLength >= minMatchLength && d.length <= prevLength {
                        // There was a match at the previous step, and the current match is
                        // not better. Output the previous match.
-                       if isFastDeflate {
-                               tokens[ti] = matchToken(uint32(length-minMatchLength), uint32(offset-minOffsetSize))
+                       if d.fastSkipHashing != 0 {
+                               d.tokens[d.ti] = matchToken(uint32(d.length-minMatchLength), uint32(d.offset-minOffsetSize))
                        } else {
-                               tokens[ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
+                               d.tokens[d.ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
                        }
-                       ti++
+                       d.ti++
                        // Insert in the hash table all strings up to the end of the match.
                        // index and index-1 are already inserted. If there is not enough
                        // lookahead, the last two strings are not inserted into the hash
                        // table.
-                       if length <= l.fastSkipHashing {
+                       if d.length <= d.fastSkipHashing {
                                var newIndex int
-                               if isFastDeflate {
-                                       newIndex = index + length
+                               if d.fastSkipHashing != 0 {
+                                       newIndex = d.index + d.length
                                } else {
                                        newIndex = prevLength - 1
                                }
-                               for index++; index < newIndex; index++ {
-                                       if index < maxInsertIndex {
-                                               hash = (hash<<hashShift + int(d.window[index+2])) & hashMask
+                               for d.index++; d.index < newIndex; d.index++ {
+                                       if d.index < d.maxInsertIndex {
+                                               d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
                                                // Get previous value with the same hash.
                                                // Our chain should point to the previous value.
-                                               d.hashPrev[index&d.windowMask] = d.hashHead[hash]
+                                               d.hashPrev[d.index&windowMask] = d.hashHead[d.hash]
                                                // Set the head of the hash chain to us.
-                                               d.hashHead[hash] = index
+                                               d.hashHead[d.hash] = d.index
                                        }
                                }
-                               if !isFastDeflate {
-                                       byteAvailable = false
-                                       length = minMatchLength - 1
+                               if d.fastSkipHashing == 0 {
+                                       d.byteAvailable = false
+                                       d.length = minMatchLength - 1
                                }
                        } else {
                                // For matches this long, we don't bother inserting each individual
                                // item into the table.
-                               index += length
-                               hash = (int(d.window[index])<<hashShift + int(d.window[index+1]))
+                               d.index += d.length
+                               d.hash = (int(d.window[d.index])<<hashShift + int(d.window[d.index+1]))
                        }
-                       if ti == maxFlateBlockTokens {
+                       if d.ti == maxFlateBlockTokens {
                                // The block includes the current character
-                               if err = d.writeBlock(tokens, index, false); err != nil {
+                               if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
                                        return
                                }
-                               ti = 0
+                               d.ti = 0
                        }
                } else {
-                       if isFastDeflate || byteAvailable {
-                               i := index - 1
-                               if isFastDeflate {
-                                       i = index
+                       if d.fastSkipHashing != 0 || d.byteAvailable {
+                               i := d.index - 1
+                               if d.fastSkipHashing != 0 {
+                                       i = d.index
                                }
-                               tokens[ti] = literalToken(uint32(d.window[i]) & 0xFF)
-                               ti++
-                               if ti == maxFlateBlockTokens {
-                                       if err = d.writeBlock(tokens, i+1, false); err != nil {
+                               d.tokens[d.ti] = literalToken(uint32(d.window[i]))
+                               d.ti++
+                               if d.ti == maxFlateBlockTokens {
+                                       if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
                                                return
                                        }
-                                       ti = 0
+                                       d.ti = 0
                                }
                        }
-                       index++
-                       if !isFastDeflate {
-                               byteAvailable = true
+                       d.index++
+                       if d.fastSkipHashing == 0 {
+                               d.byteAvailable = true
                        }
                }
        }
-       return
 }
 
-func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) {
-       d.r = r
+func (d *compressor) fillStore(b []byte) int {
+       n := copy(d.window[d.windowEnd:], b)
+       d.windowEnd += n
+       return n
+}
+
+func (d *compressor) store() {
+       if d.windowEnd > 0 {
+               d.err = d.writeStoredBlock(d.window[:d.windowEnd])
+       }
+       d.windowEnd = 0
+}
+
+func (d *compressor) write(b []byte) (n int, err os.Error) {
+       n = len(b)
+       b = b[d.fill(d, b):]
+       for len(b) > 0 {
+               d.step(d)
+               b = b[d.fill(d, b):]
+       }
+       return n, d.err
+}
+
+func (d *compressor) syncFlush() os.Error {
+       d.sync = true
+       d.step(d)
+       if d.err == nil {
+               d.w.writeStoredHeader(0, false)
+               d.w.flush()
+               d.err = d.w.err
+       }
+       d.sync = false
+       return d.err
+}
+
+func (d *compressor) init(w io.Writer, level int) (err os.Error) {
        d.w = newHuffmanBitWriter(w)
-       d.level = level
-       d.logWindowSize = logWindowSize
 
        switch {
        case level == NoCompression:
-               err = d.storedDeflate()
+               d.window = make([]byte, maxStoreBlockSize)
+               d.fill = (*compressor).fillStore
+               d.step = (*compressor).store
        case level == DefaultCompression:
-               d.level = 6
+               level = 6
                fallthrough
        case 1 <= level && level <= 9:
-               err = d.doDeflate()
+               d.compressionLevel = levels[level]
+               d.initDeflate()
+               d.fill = (*compressor).fillDeflate
+               d.step = (*compressor).deflate
        default:
                return WrongValueError{"level", 0, 9, int32(level)}
        }
+       return nil
+}
 
-       if d.sync {
-               d.syncChan <- err
-               d.sync = false
-       }
-       if err != nil {
-               return err
+func (d *compressor) close() os.Error {
+       d.sync = true
+       d.step(d)
+       if d.err != nil {
+               return d.err
        }
        if d.w.writeStoredHeader(0, true); d.w.err != nil {
                return d.w.err
        }
-       return d.flush()
+       d.w.flush()
+       return d.w.err
 }
 
 // NewWriter returns a new Writer compressing
@@ -486,14 +430,9 @@ func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize
 // compression; it only adds the necessary DEFLATE framing.
 func NewWriter(w io.Writer, level int) *Writer {
        const logWindowSize = logMaxOffsetSize
-       var d compressor
-       d.syncChan = make(chan os.Error, 1)
-       pr, pw := syncPipe()
-       go func() {
-               err := d.compress(pr, w, level, logWindowSize)
-               pr.CloseWithError(err)
-       }()
-       return &Writer{pw, &d}
+       var dw Writer
+       dw.d.init(w, level)
+       return &dw
 }
 
 // NewWriterDict is like NewWriter but initializes the new
@@ -526,18 +465,13 @@ func (w *dictWriter) Write(b []byte) (n int, err os.Error) {
 // A Writer takes data written to it and writes the compressed
 // form of that data to an underlying writer (see NewWriter).
 type Writer struct {
-       w *syncPipeWriter
-       d *compressor
+       d compressor
 }
 
 // Write writes data to w, which will eventually write the
 // compressed form of data to its underlying writer.
 func (w *Writer) Write(data []byte) (n int, err os.Error) {
-       if len(data) == 0 {
-               // no point, and nil interferes with sync
-               return
-       }
-       return w.w.Write(data)
+       return w.d.write(data)
 }
 
 // Flush flushes any pending compressed data to the underlying writer.
@@ -550,18 +484,10 @@ func (w *Writer) Write(data []byte) (n int, err os.Error) {
 func (w *Writer) Flush() os.Error {
        // For more about flushing:
        // http://www.bolet.org/~pornin/deflate-flush.html
-       if w.d.sync {
-               panic("compress/flate: double Flush")
-       }
-       _, err := w.w.Write(nil)
-       err1 := <-w.d.syncChan
-       if err == nil {
-               err = err1
-       }
-       return err
+       return w.d.syncFlush()
 }
 
 // Close flushes and closes the writer.
 func (w *Writer) Close() os.Error {
-       return w.w.Close()
+       return w.d.close()
 }
index 650a8059ace91c8e95e6a59e68f4ff1de79fb9d8..2ac811c389e1654e0aa70829c248384ce943759c 100644 (file)
@@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{
        &deflateInflateTest{[]byte{0x11, 0x12}},
        &deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
        &deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
-       &deflateInflateTest{getLargeDataChunk()},
+       &deflateInflateTest{largeDataChunk()},
 }
 
 var reverseBitsTests = []*reverseBitsTest{
@@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{
        &reverseBitsTest{29, 5, 23},
 }
 
-func getLargeDataChunk() []byte {
+func largeDataChunk() []byte {
        result := make([]byte, 100000)
        for i := range result {
-               result[i] = byte(int64(i) * int64(i) & 0xFF)
+               result[i] = byte(i * i & 0xFF)
        }
        return result
 }
 
 func TestDeflate(t *testing.T) {
        for _, h := range deflateTests {
-               buffer := bytes.NewBuffer(nil)
-               w := NewWriter(buffer, h.level)
+               var buf bytes.Buffer
+               w := NewWriter(&buf, h.level)
                w.Write(h.in)
                w.Close()
-               if bytes.Compare(buffer.Bytes(), h.out) != 0 {
-                       t.Errorf("buffer is wrong; level = %v, buffer.Bytes() = %v, expected output = %v",
-                               h.level, buffer.Bytes(), h.out)
+               if !bytes.Equal(buf.Bytes(), h.out) {
+                       t.Errorf("Deflate(%d, %x) = %x, want %x", h.level, h.in, buf.Bytes(), h.out)
                }
        }
 }
index 64bbf24ff834a1ba3d2b502c90ddd1b63441abe9..3845f12041045f886f05bd9a17004512379b0a2a 100644 (file)
@@ -195,9 +195,8 @@ type Reader interface {
 
 // Decompress state.
 type decompressor struct {
-       // Input/output sources.
+       // Input source.
        r       Reader
-       w       io.Writer
        roffset int64
        woffset int64
 
@@ -220,38 +219,79 @@ type decompressor struct {
 
        // Temporary buffer (avoids repeated allocation).
        buf [4]byte
+
+       // Next step in the decompression,
+       // and decompression state.
+       step     func(*decompressor)
+       final    bool
+       err      os.Error
+       toRead   []byte
+       hl, hd   *huffmanDecoder
+       copyLen  int
+       copyDist int
 }
 
-func (f *decompressor) inflate() (err os.Error) {
-       final := false
-       for err == nil && !final {
-               for f.nb < 1+2 {
-                       if err = f.moreBits(); err != nil {
-                               return
-                       }
+func (f *decompressor) nextBlock() {
+       if f.final {
+               if f.hw != f.hp {
+                       f.flush((*decompressor).nextBlock)
+                       return
                }
-               final = f.b&1 == 1
-               f.b >>= 1
-               typ := f.b & 3
-               f.b >>= 2
-               f.nb -= 1 + 2
-               switch typ {
-               case 0:
-                       err = f.dataBlock()
-               case 1:
-                       // compressed, fixed Huffman tables
-                       err = f.decodeBlock(&fixedHuffmanDecoder, nil)
-               case 2:
-                       // compressed, dynamic Huffman tables
-                       if err = f.readHuffman(); err == nil {
-                               err = f.decodeBlock(&f.h1, &f.h2)
-                       }
-               default:
-                       // 3 is reserved.
-                       err = CorruptInputError(f.roffset)
+               f.err = os.EOF
+               return
+       }
+       for f.nb < 1+2 {
+               if f.err = f.moreBits(); f.err != nil {
+                       return
+               }
+       }
+       f.final = f.b&1 == 1
+       f.b >>= 1
+       typ := f.b & 3
+       f.b >>= 2
+       f.nb -= 1 + 2
+       switch typ {
+       case 0:
+               f.dataBlock()
+       case 1:
+               // compressed, fixed Huffman tables
+               f.hl = &fixedHuffmanDecoder
+               f.hd = nil
+               f.huffmanBlock()
+       case 2:
+               // compressed, dynamic Huffman tables
+               if f.err = f.readHuffman(); f.err != nil {
+                       break
                }
+               f.hl = &f.h1
+               f.hd = &f.h2
+               f.huffmanBlock()
+       default:
+               // 3 is reserved.
+               f.err = CorruptInputError(f.roffset)
        }
-       return
+}
+
+func (f *decompressor) Read(b []byte) (int, os.Error) {
+       for {
+               if len(f.toRead) > 0 {
+                       n := copy(b, f.toRead)
+                       f.toRead = f.toRead[n:]
+                       return n, nil
+               }
+               if f.err != nil {
+                       return 0, f.err
+               }
+               f.step(f)
+       }
+       panic("unreachable")
+}
+
+func (f *decompressor) Close() os.Error {
+       if f.err == os.EOF {
+               return nil
+       }
+       return f.err
 }
 
 // RFC 1951 section 3.2.7.
@@ -356,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error {
 // hl and hd are the Huffman states for the lit/length values
 // and the distance values, respectively.  If hd == nil, using the
 // fixed distance encoding associated with fixed Huffman blocks.
-func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
+func (f *decompressor) huffmanBlock() {
        for {
-               v, err := f.huffSym(hl)
+               v, err := f.huffSym(f.hl)
                if err != nil {
-                       return err
+                       f.err = err
+                       return
                }
                var n uint // number of bits extra
                var length int
@@ -369,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
                        f.hist[f.hp] = byte(v)
                        f.hp++
                        if f.hp == len(f.hist) {
-                               if err = f.flush(); err != nil {
-                                       return err
-                               }
+                               // After the flush, continue this loop.
+                               f.flush((*decompressor).huffmanBlock)
+                               return
                        }
                        continue
                case v == 256:
-                       return nil
+                       // Done with huffman block; read next block.
+                       f.step = (*decompressor).nextBlock
+                       return
                // otherwise, reference to older data
                case v < 265:
                        length = v - (257 - 3)
@@ -402,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
                if n > 0 {
                        for f.nb < n {
                                if err = f.moreBits(); err != nil {
-                                       return err
+                                       f.err = err
+                                       return
                                }
                        }
                        length += int(f.b & uint32(1<<n-1))
@@ -411,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
                }
 
                var dist int
-               if hd == nil {
+               if f.hd == nil {
                        for f.nb < 5 {
                                if err = f.moreBits(); err != nil {
-                                       return err
+                                       f.err = err
+                                       return
                                }
                        }
                        dist = int(reverseByte[(f.b&0x1F)<<3])
                        f.b >>= 5
                        f.nb -= 5
                } else {
-                       if dist, err = f.huffSym(hd); err != nil {
-                               return err
+                       if dist, err = f.huffSym(f.hd); err != nil {
+                               f.err = err
+                               return
                        }
                }
 
@@ -430,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
                case dist < 4:
                        dist++
                case dist >= 30:
-                       return CorruptInputError(f.roffset)
+                       f.err = CorruptInputError(f.roffset)
+                       return
                default:
                        nb := uint(dist-2) >> 1
                        // have 1 bit in bottom of dist, need nb more.
                        extra := (dist & 1) << nb
                        for f.nb < nb {
                                if err = f.moreBits(); err != nil {
-                                       return err
+                                       f.err = err
+                                       return
                                }
                        }
                        extra |= int(f.b & uint32(1<<nb-1))
@@ -448,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
 
                // Copy history[-dist:-dist+length] into output.
                if dist > len(f.hist) {
-                       return InternalError("bad history distance")
+                       f.err = InternalError("bad history distance")
+                       return
                }
 
                // No check on length; encoding can be prescient.
                if !f.hfull && dist > f.hp {
-                       return CorruptInputError(f.roffset)
+                       f.err = CorruptInputError(f.roffset)
+                       return
                }
 
                p := f.hp - dist
@@ -465,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
                        f.hp++
                        p++
                        if f.hp == len(f.hist) {
-                               if err = f.flush(); err != nil {
-                                       return err
-                               }
+                               // After flush continue copying out of history.
+                               f.copyLen = length - (i + 1)
+                               f.copyDist = dist
+                               f.flush((*decompressor).copyHuff)
+                               return
                        }
                        if p == len(f.hist) {
                                p = 0
@@ -477,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
        panic("unreached")
 }
 
+func (f *decompressor) copyHuff() {
+       length := f.copyLen
+       dist := f.copyDist
+       p := f.hp - dist
+       if p < 0 {
+               p += len(f.hist)
+       }
+       for i := 0; i < length; i++ {
+               f.hist[f.hp] = f.hist[p]
+               f.hp++
+               p++
+               if f.hp == len(f.hist) {
+                       f.copyLen = length - (i + 1)
+                       f.flush((*decompressor).copyHuff)
+                       return
+               }
+               if p == len(f.hist) {
+                       p = 0
+               }
+       }
+
+       // Continue processing Huffman block.
+       f.huffmanBlock()
+}
+
 // Copy a single uncompressed data block from input to output.
-func (f *decompressor) dataBlock() os.Error {
+func (f *decompressor) dataBlock() {
        // Uncompressed.
        // Discard current half-byte.
        f.nb = 0
@@ -488,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error {
        nr, err := io.ReadFull(f.r, f.buf[0:4])
        f.roffset += int64(nr)
        if err != nil {
-               return &ReadError{f.roffset, err}
+               f.err = &ReadError{f.roffset, err}
+               return
        }
        n := int(f.buf[0]) | int(f.buf[1])<<8
        nn := int(f.buf[2]) | int(f.buf[3])<<8
        if uint16(nn) != uint16(^n) {
-               return CorruptInputError(f.roffset)
+               f.err = CorruptInputError(f.roffset)
+               return
        }
 
        if n == 0 {
                // 0-length block means sync
-               return f.flush()
+               f.flush((*decompressor).nextBlock)
+               return
        }
 
-       // Read len bytes into history,
-       // writing as history fills.
+       f.copyLen = n
+       f.copyData()
+}
+
+func (f *decompressor) copyData() {
+       // Read f.dataLen bytes into history,
+       // pausing for reads as history fills.
+       n := f.copyLen
        for n > 0 {
                m := len(f.hist) - f.hp
                if m > n {
@@ -511,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error {
                m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m])
                f.roffset += int64(m)
                if err != nil {
-                       return &ReadError{f.roffset, err}
+                       f.err = &ReadError{f.roffset, err}
+                       return
                }
                n -= m
                f.hp += m
                if f.hp == len(f.hist) {
-                       if err = f.flush(); err != nil {
-                               return err
-                       }
+                       f.copyLen = n
+                       f.flush((*decompressor).copyData)
+                       return
                }
        }
-       return nil
+       f.step = (*decompressor).nextBlock
 }
 
 func (f *decompressor) setDict(dict []byte) {
@@ -577,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
 }
 
 // Flush any buffered output to the underlying writer.
-func (f *decompressor) flush() os.Error {
-       if f.hw == f.hp {
-               return nil
-       }
-       n, err := f.w.Write(f.hist[f.hw:f.hp])
-       if n != f.hp-f.hw && err == nil {
-               err = io.ErrShortWrite
-       }
-       if err != nil {
-               return &WriteError{f.woffset, err}
-       }
+func (f *decompressor) flush(step func(*decompressor)) {
+       f.toRead = f.hist[f.hw:f.hp]
        f.woffset += int64(f.hp - f.hw)
        f.hw = f.hp
        if f.hp == len(f.hist) {
@@ -595,7 +673,7 @@ func (f *decompressor) flush() os.Error {
                f.hw = 0
                f.hfull = true
        }
-       return nil
+       f.step = step
 }
 
 func makeReader(r io.Reader) Reader {
@@ -605,30 +683,15 @@ func makeReader(r io.Reader) Reader {
        return bufio.NewReader(r)
 }
 
-// decompress reads DEFLATE-compressed data from r and writes
-// the uncompressed data to w.
-func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
-       f.r = makeReader(r)
-       f.w = w
-       f.woffset = 0
-       if err := f.inflate(); err != nil {
-               return err
-       }
-       if err := f.flush(); err != nil {
-               return err
-       }
-       return nil
-}
-
 // NewReader returns a new ReadCloser that can be used
 // to read the uncompressed version of r.  It is the caller's
 // responsibility to call Close on the ReadCloser when
 // finished reading.
 func NewReader(r io.Reader) io.ReadCloser {
        var f decompressor
-       pr, pw := io.Pipe()
-       go func() { pw.CloseWithError(f.decompress(r, pw)) }()
-       return pr
+       f.r = makeReader(r)
+       f.step = (*decompressor).nextBlock
+       return &f
 }
 
 // NewReaderDict is like NewReader but initializes the reader
@@ -639,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser {
 func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
        var f decompressor
        f.setDict(dict)
-       pr, pw := io.Pipe()
-       go func() { pw.CloseWithError(f.decompress(r, pw)) }()
-       return pr
+       f.r = makeReader(r)
+       f.step = (*decompressor).nextBlock
+       return &f
 }