]> Cypherpunks repositories - gostls13.git/commitdiff
compress/lzw: do not use background goroutines
authorRuss Cox <rsc@golang.org>
Tue, 7 Jun 2011 18:37:06 +0000 (14:37 -0400)
committerRuss Cox <rsc@golang.org>
Tue, 7 Jun 2011 18:37:06 +0000 (14:37 -0400)
Programs expect that Read and Write are synchronous.
The background goroutines make the implementation
a little easier, but they introduce asynchrony that
trips up calling code.  Remove them.

R=golang-dev, nigeltao
CC=golang-dev
https://golang.org/cl/4548080

src/pkg/compress/lzw/reader.go
src/pkg/compress/lzw/writer_test.go

index a1cd2abc0439e0fea9fb1979eadf73dfc362e54c..ccd882f88b14cfe7f2e85bc31b6717f441c1676c 100644 (file)
@@ -32,13 +32,48 @@ const (
        MSB
 )
 
+const (
+       maxWidth           = 12
+       decoderInvalidCode = 0xffff
+       flushBuffer        = 1 << maxWidth
+)
+
 // decoder is the state from which the readXxx method converts a byte
 // stream into a code stream.
 type decoder struct {
-       r     io.ByteReader
-       bits  uint32
-       nBits uint
-       width uint
+       r        io.ByteReader
+       bits     uint32
+       nBits    uint
+       width    uint
+       read     func(*decoder) (uint16, os.Error) // readLSB or readMSB
+       litWidth int                               // width in bits of literal codes
+       err      os.Error
+
+       // The first 1<<litWidth codes are literal codes.
+       // The next two codes mean clear and EOF.
+       // Other valid codes are in the range [lo, hi] where lo := clear + 2,
+       // with the upper bound incrementing on each code seen.
+       // overflow is the code at which hi overflows the code width.
+       // last is the most recently seen code, or decoderInvalidCode.
+       clear, eof, hi, overflow, last uint16
+
+       // Each code c in [lo, hi] expands to two or more bytes. For c != hi:
+       //   suffix[c] is the last of these bytes.
+       //   prefix[c] is the code for all but the last byte.
+       //   This code can either be a literal code or another code in [lo, c).
+       // The c == hi case is a special case.
+       suffix [1 << maxWidth]uint8
+       prefix [1 << maxWidth]uint16
+       // buf is a scratch buffer for reconstituting the bytes that a code expands to.
+       // Code suffixes are written right-to-left from the end of the buffer.
+       buf [1 << maxWidth]byte
+
+       // output is the temporary output buffer.
+       // It is flushed when it contains >= 1<<maxWidth bytes,
+       // so that there is always room to copy buf into it while decoding.
+       output [2 * 1 << maxWidth]byte
+       o      int    // write index into output
+       toRead []byte // bytes to return from Read
 }
 
 // readLSB returns the next code for "Least Significant Bits first" data.
@@ -73,119 +108,113 @@ func (d *decoder) readMSB() (uint16, os.Error) {
        return code, nil
 }
 
-// decode decompresses bytes from r and writes them to pw.
-// read specifies how to decode bytes into codes.
-// litWidth is the width in bits of literal codes.
-func decode(r io.Reader, read func(*decoder) (uint16, os.Error), litWidth int, pw *io.PipeWriter) {
-       br, ok := r.(io.ByteReader)
-       if !ok {
-               br = bufio.NewReader(r)
+func (d *decoder) Read(b []byte) (int, os.Error) {
+       for {
+               if len(d.toRead) > 0 {
+                       n := copy(b, d.toRead)
+                       d.toRead = d.toRead[n:]
+                       return n, nil
+               }
+               if d.err != nil {
+                       return 0, d.err
+               }
+               d.decode()
        }
-       pw.CloseWithError(decode1(pw, br, read, uint(litWidth)))
+       panic("unreachable")
 }
 
-func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os.Error), litWidth uint) os.Error {
-       const (
-               maxWidth    = 12
-               invalidCode = 0xffff
-       )
-       d := decoder{r, 0, 0, 1 + litWidth}
-       w := bufio.NewWriter(pw)
-       // The first 1<<litWidth codes are literal codes.
-       // The next two codes mean clear and EOF.
-       // Other valid codes are in the range [lo, hi] where lo := clear + 2,
-       // with the upper bound incrementing on each code seen.
-       clear := uint16(1) << litWidth
-       eof, hi := clear+1, clear+1
-       // overflow is the code at which hi overflows the code width.
-       overflow := uint16(1) << d.width
-       var (
-               // Each code c in [lo, hi] expands to two or more bytes. For c != hi:
-               //   suffix[c] is the last of these bytes.
-               //   prefix[c] is the code for all but the last byte.
-               //   This code can either be a literal code or another code in [lo, c).
-               // The c == hi case is a special case.
-               suffix [1 << maxWidth]uint8
-               prefix [1 << maxWidth]uint16
-               // buf is a scratch buffer for reconstituting the bytes that a code expands to.
-               // Code suffixes are written right-to-left from the end of the buffer.
-               buf [1 << maxWidth]byte
-       )
-
+// decode decompresses bytes from r and leaves them in d.toRead.
+// read specifies how to decode bytes into codes.
+// litWidth is the width in bits of literal codes.
+func (d *decoder) decode() {
        // Loop over the code stream, converting codes into decompressed bytes.
-       last := uint16(invalidCode)
        for {
-               code, err := read(&d)
+               code, err := d.read(d)
                if err != nil {
                        if err == os.EOF {
                                err = io.ErrUnexpectedEOF
                        }
-                       return err
+                       d.err = err
+                       return
                }
                switch {
-               case code < clear:
+               case code < d.clear:
                        // We have a literal code.
-                       if err := w.WriteByte(uint8(code)); err != nil {
-                               return err
-                       }
-                       if last != invalidCode {
+                       d.output[d.o] = uint8(code)
+                       d.o++
+                       if d.last != decoderInvalidCode {
                                // Save what the hi code expands to.
-                               suffix[hi] = uint8(code)
-                               prefix[hi] = last
+                               d.suffix[d.hi] = uint8(code)
+                               d.prefix[d.hi] = d.last
                        }
-               case code == clear:
-                       d.width = 1 + litWidth
-                       hi = eof
-                       overflow = 1 << d.width
-                       last = invalidCode
+               case code == d.clear:
+                       d.width = 1 + uint(d.litWidth)
+                       d.hi = d.eof
+                       d.overflow = 1 << d.width
+                       d.last = decoderInvalidCode
                        continue
-               case code == eof:
-                       return w.Flush()
-               case code <= hi:
-                       c, i := code, len(buf)-1
-                       if code == hi {
+               case code == d.eof:
+                       d.flush()
+                       d.err = os.EOF
+                       return
+               case code <= d.hi:
+                       c, i := code, len(d.buf)-1
+                       if code == d.hi {
                                // code == hi is a special case which expands to the last expansion
                                // followed by the head of the last expansion. To find the head, we walk
                                // the prefix chain until we find a literal code.
-                               c = last
-                               for c >= clear {
-                                       c = prefix[c]
+                               c = d.last
+                               for c >= d.clear {
+                                       c = d.prefix[c]
                                }
-                               buf[i] = uint8(c)
+                               d.buf[i] = uint8(c)
                                i--
-                               c = last
+                               c = d.last
                        }
                        // Copy the suffix chain into buf and then write that to w.
-                       for c >= clear {
-                               buf[i] = suffix[c]
+                       for c >= d.clear {
+                               d.buf[i] = d.suffix[c]
                                i--
-                               c = prefix[c]
+                               c = d.prefix[c]
                        }
-                       buf[i] = uint8(c)
-                       if _, err := w.Write(buf[i:]); err != nil {
-                               return err
-                       }
-                       if last != invalidCode {
+                       d.buf[i] = uint8(c)
+                       d.o += copy(d.output[d.o:], d.buf[i:])
+                       if d.last != decoderInvalidCode {
                                // Save what the hi code expands to.
-                               suffix[hi] = uint8(c)
-                               prefix[hi] = last
+                               d.suffix[d.hi] = uint8(c)
+                               d.prefix[d.hi] = d.last
                        }
                default:
-                       return os.NewError("lzw: invalid code")
+                       d.err = os.NewError("lzw: invalid code")
+                       return
                }
-               last, hi = code, hi+1
-               if hi >= overflow {
+               d.last, d.hi = code, d.hi+1
+               if d.hi >= d.overflow {
                        if d.width == maxWidth {
-                               last = invalidCode
-                               continue
+                               d.last = decoderInvalidCode
+                       } else {
+                               d.width++
+                               d.overflow <<= 1
                        }
-                       d.width++
-                       overflow <<= 1
+               }
+               if d.o >= flushBuffer {
+                       d.flush()
+                       return
                }
        }
        panic("unreachable")
 }
 
+func (d *decoder) flush() {
+       d.toRead = d.output[:d.o]
+       d.o = 0
+}
+
+func (d *decoder) Close() os.Error {
+       d.err = os.EINVAL // in case any Reads come along
+       return nil
+}
+
 // NewReader creates a new io.ReadCloser that satisfies reads by decompressing
 // the data read from r.
 // It is the caller's responsibility to call Close on the ReadCloser when
@@ -193,21 +222,31 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os
 // The number of bits to use for literal codes, litWidth, must be in the
 // range [2,8] and is typically 8.
 func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
-       pr, pw := io.Pipe()
-       var read func(*decoder) (uint16, os.Error)
+       d := new(decoder)
        switch order {
        case LSB:
-               read = (*decoder).readLSB
+               d.read = (*decoder).readLSB
        case MSB:
-               read = (*decoder).readMSB
+               d.read = (*decoder).readMSB
        default:
-               pw.CloseWithError(os.NewError("lzw: unknown order"))
-               return pr
+               d.err = os.NewError("lzw: unknown order")
+               return d
        }
        if litWidth < 2 || 8 < litWidth {
-               pw.CloseWithError(fmt.Errorf("lzw: litWidth %d out of range", litWidth))
-               return pr
+               d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
+               return d
        }
-       go decode(r, read, litWidth, pw)
-       return pr
+       if br, ok := r.(io.ByteReader); ok {
+               d.r = br
+       } else {
+               d.r = bufio.NewReader(r)
+       }
+       d.litWidth = litWidth
+       d.width = 1 + uint(litWidth)
+       d.clear = uint16(1) << uint(litWidth)
+       d.eof, d.hi = d.clear+1, d.clear+1
+       d.overflow = uint16(1) << d.width
+       d.last = decoderInvalidCode
+
+       return d
 }
index 82464ecd1b0811b23eadc5b00cdeca7c9b07ccdb..4c5e522f94316cd5b726c6f68fcc749911a0a37b 100644 (file)
@@ -77,13 +77,13 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
                t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
                return
        }
-       if len(b0) != len(b1) {
-               t.Errorf("%s (order=%d litWidth=%d): length mismatch %d versus %d", fn, order, litWidth, len(b0), len(b1))
+       if len(b1) != len(b0) {
+               t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
                return
        }
        for i := 0; i < len(b0); i++ {
-               if b0[i] != b1[i] {
-                       t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, order, litWidth, i, b0[i], b1[i])
+               if b1[i] != b0[i] {
+                       t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
                        return
                }
        }