]> Cypherpunks repositories - gostls13.git/commitdiff
compress/lzw: implement an encoder.
authorNigel Tao <nigeltao@golang.org>
Thu, 24 Feb 2011 22:20:04 +0000 (09:20 +1100)
committerNigel Tao <nigeltao@golang.org>
Thu, 24 Feb 2011 22:20:04 +0000 (09:20 +1100)
R=rsc, nigeltao_gnome
CC=golang-dev
https://golang.org/cl/4209043

src/pkg/compress/lzw/Makefile
src/pkg/compress/lzw/reader.go
src/pkg/compress/lzw/writer.go [new file with mode: 0644]
src/pkg/compress/lzw/writer_test.go [new file with mode: 0644]

index 8f2a376f4b9e1d4ba4daf6ede303c619fa714eb0..28f5e6abcb403d87b6fa7d89a26ee9c5056d3fdf 100644 (file)
@@ -7,5 +7,6 @@ include ../../../Make.inc
 TARG=compress/lzw
 GOFILES=\
        reader.go\
+       writer.go\
 
 include ../../../Make.pkg
index 47b10a8cbd15e6ba029064428f8a6a5871e83be8..505b24bb5dc6cfc4390d470ef362b4692a13ccd6 100644 (file)
@@ -14,8 +14,6 @@ package lzw
 // TODO(nigeltao): check that TIFF and PDF use LZW in the same way as GIF,
 // modulo LSB/MSB packing order.
 
-// TODO(nigeltao): write an encoder.
-
 import (
        "bufio"
        "fmt"
@@ -46,11 +44,11 @@ type decoder struct {
 // readLSB returns the next code for "Least Significant Bits first" data.
 func (d *decoder) readLSB() (uint16, os.Error) {
        for d.nBits < d.width {
-               c, err := d.r.ReadByte()
+               x, err := d.r.ReadByte()
                if err != nil {
                        return 0, err
                }
-               d.bits |= uint32(c) << d.nBits
+               d.bits |= uint32(x) << d.nBits
                d.nBits += 8
        }
        code := uint16(d.bits & (1<<d.width - 1))
@@ -62,11 +60,11 @@ func (d *decoder) readLSB() (uint16, os.Error) {
 // readMSB returns the next code for "Most Significant Bits first" data.
 func (d *decoder) readMSB() (uint16, os.Error) {
        for d.nBits < d.width {
-               c, err := d.r.ReadByte()
+               x, err := d.r.ReadByte()
                if err != nil {
                        return 0, err
                }
-               d.bits |= uint32(c) << (24 - d.nBits)
+               d.bits |= uint32(x) << (24 - d.nBits)
                d.nBits += 8
        }
        code := uint16(d.bits >> (32 - d.width))
@@ -177,13 +175,12 @@ func decode(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os.
        panic("unreachable")
 }
 
-// NewReader returns a new ReadCloser that can be used to read the uncompressed
-// version of r. It is the caller's responsibility to call Close on the
-// ReadCloser when finished reading.
-// order is either LSB or MSB for Least or Most Significant Bits first packing
-// order. GIF uses LSB. TIFF and PDF use MSB.
-// litWidth is the width in bits for literal codes. Valid values range from
-// 2 to 8 inclusive.
+// NewReader creates a new io.ReadCloser that satisfies reads by decompressing
+// the data read from r.
+// It is the caller's responsibility to call Close on the ReadCloser when
+// finished reading.
+// The number of bits to use for literal codes, litWidth, must be in the
+// range [2,8] and is typically 8.
 func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
        pr, pw := io.Pipe()
        var read func(*decoder) (uint16, os.Error)
diff --git a/src/pkg/compress/lzw/writer.go b/src/pkg/compress/lzw/writer.go
new file mode 100644 (file)
index 0000000..87143b7
--- /dev/null
@@ -0,0 +1,259 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+       "bufio"
+       "fmt"
+       "io"
+       "os"
+)
+
+// A writer is a buffered, flushable writer.
+type writer interface {
+       WriteByte(byte) os.Error
+       Flush() os.Error
+}
+
+// An errWriteCloser is an io.WriteCloser that always returns a given error.
+type errWriteCloser struct {
+       err os.Error
+}
+
+func (e *errWriteCloser) Write([]byte) (int, os.Error) {
+       return 0, e.err
+}
+
+func (e *errWriteCloser) Close() os.Error {
+       return e.err
+}
+
+const (
+       // A code is a 12 bit value, stored as a uint32 when encoding to avoid
+       // type conversions when shifting bits.
+       maxCode     = 1<<12 - 1
+       invalidCode = 1<<32 - 1
+       // There are 1<<12 possible codes, which is an upper bound on the number of
+       // valid hash table entries at any given point in time. tableSize is 4x that.
+       tableSize = 4 * 1 << 12
+       tableMask = tableSize - 1
+       // A hash table entry is a uint32. Zero is an invalid entry since the
+       // lower 12 bits of a valid entry must be a non-literal code.
+       invalidEntry = 0
+)
+
+// encoder is LZW compressor.
+type encoder struct {
+       // w is the writer that compressed bytes are written to.
+       w writer
+       // write, bits, nBits and width are the state for converting a code stream
+       // into a byte stream.
+       write func(*encoder, uint32) os.Error
+       bits  uint32
+       nBits uint
+       width uint
+       // litWidth is the width in bits of literal codes.
+       litWidth uint
+       // hi is the code implied by the next code emission.
+       // overflow is the code at which hi overflows the code width.
+       hi, overflow uint32
+       // savedCode is the accumulated code at the end of the most recent Write
+       // call. It is equal to invalidCode if there was no such call.
+       savedCode uint32
+       // err is the first error encountered during writing. Closing the encoder
+       // will make any future Write calls return os.EINVAL.
+       err os.Error
+       // table is the hash table from 20-bit keys to 12-bit values. Each table
+       // entry contains key<<12|val and collisions resolve by linear probing.
+       // The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
+       // The values are a 12-bit code.
+       table [tableSize]uint32
+}
+
+// writeLSB writes the code c for "Least Significant Bits first" data.
+func (e *encoder) writeLSB(c uint32) os.Error {
+       e.bits |= c << e.nBits
+       e.nBits += e.width
+       for e.nBits >= 8 {
+               if err := e.w.WriteByte(uint8(e.bits)); err != nil {
+                       return err
+               }
+               e.bits >>= 8
+               e.nBits -= 8
+       }
+       return nil
+}
+
+// writeMSB writes the code c for "Most Significant Bits first" data.
+func (e *encoder) writeMSB(c uint32) os.Error {
+       e.bits |= c << (32 - e.width - e.nBits)
+       e.nBits += e.width
+       for e.nBits >= 8 {
+               if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil {
+                       return err
+               }
+               e.bits <<= 8
+               e.nBits -= 8
+       }
+       return nil
+}
+
+// errOutOfCodes is an internal error that means that the encoder has run out
+// of unused codes and a clear code needs to be sent next.
+var errOutOfCodes = os.NewError("lzw: out of codes")
+
+// incHi increments e.hi and checks for both overflow and running out of
+// unused codes. In the latter case, incHi sends a clear code, resets the
+// encoder state and returns errOutOfCodes.
+func (e *encoder) incHi() os.Error {
+       e.hi++
+       if e.hi == e.overflow {
+               e.width++
+               e.overflow <<= 1
+       }
+       if e.hi == maxCode {
+               clear := uint32(1) << e.litWidth
+               if err := e.write(e, clear); err != nil {
+                       return err
+               }
+               e.width = uint(e.litWidth) + 1
+               e.hi = clear + 1
+               e.overflow = clear << 1
+               for i := range e.table {
+                       e.table[i] = invalidEntry
+               }
+               return errOutOfCodes
+       }
+       return nil
+}
+
+// Write writes a compressed representation of p to e's underlying writer.
+func (e *encoder) Write(p []byte) (int, os.Error) {
+       if e.err != nil {
+               return 0, e.err
+       }
+       if len(p) == 0 {
+               return 0, nil
+       }
+       litMask := uint32(1<<e.litWidth - 1)
+       code := e.savedCode
+       if code == invalidCode {
+               // The first code sent is always a literal code.
+               code, p = uint32(p[0])&litMask, p[1:]
+       }
+loop:
+       for _, x := range p {
+               literal := uint32(x) & litMask
+               key := code<<8 | literal
+               // If there is a hash table hit for this key then we continue the loop
+               // and do not emit a code yet.
+               hash := (key>>12 ^ key) & tableMask
+               for h, t := hash, e.table[hash]; t != invalidEntry; {
+                       if key == t>>12 {
+                               code = t & maxCode
+                               continue loop
+                       }
+                       h = (h + 1) & tableMask
+                       t = e.table[h]
+               }
+               // Otherwise, write the current code, and literal becomes the start of
+               // the next emitted code.
+               if e.err = e.write(e, code); e.err != nil {
+                       return 0, e.err
+               }
+               code = literal
+               // Increment e.hi, the next implied code. If we run out of codes, reset
+               // the encoder state (including clearing the hash table) and continue.
+               if err := e.incHi(); err != nil {
+                       if err == errOutOfCodes {
+                               continue
+                       }
+                       e.err = err
+                       return 0, e.err
+               }
+               // Otherwise, insert key -> e.hi into the map that e.table represents.
+               for {
+                       if e.table[hash] == invalidEntry {
+                               e.table[hash] = (key << 12) | e.hi
+                               break
+                       }
+                       hash = (hash + 1) & tableMask
+               }
+       }
+       e.savedCode = code
+       return len(p), nil
+}
+
+// Close closes the encoder, flushing any pending output. It does not close or
+// flush e's underlying writer.
+func (e *encoder) Close() os.Error {
+       if e.err != nil {
+               if e.err == os.EINVAL {
+                       return nil
+               }
+               return e.err
+       }
+       // Make any future calls to Write return os.EINVAL.
+       e.err = os.EINVAL
+       // Write the savedCode if valid.
+       if e.savedCode != invalidCode {
+               if err := e.write(e, e.savedCode); err != nil {
+                       return err
+               }
+               if err := e.incHi(); err != nil && err != errOutOfCodes {
+                       return err
+               }
+       }
+       // Write the eof code.
+       eof := uint32(1)<<e.litWidth + 1
+       if err := e.write(e, eof); err != nil {
+               return err
+       }
+       // Write the final bits.
+       if e.nBits > 0 {
+               if e.write == (*encoder).writeMSB {
+                       e.bits >>= 24
+               }
+               if err := e.w.WriteByte(uint8(e.bits)); err != nil {
+                       return err
+               }
+       }
+       return e.w.Flush()
+}
+
+// NewWriter creates a new io.WriteCloser that satisfies writes by compressing
+// the data and writing it to w.
+// It is the caller's responsibility to call Close on the WriteCloser when
+// finished writing.
+// The number of bits to use for literal codes, litWidth, must be in the
+// range [2,8] and is typically 8.
+func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
+       var write func(*encoder, uint32) os.Error
+       switch order {
+       case LSB:
+               write = (*encoder).writeLSB
+       case MSB:
+               write = (*encoder).writeMSB
+       default:
+               return &errWriteCloser{os.NewError("lzw: unknown order")}
+       }
+       if litWidth < 2 || 8 < litWidth {
+               return &errWriteCloser{fmt.Errorf("lzw: litWidth %d out of range", litWidth)}
+       }
+       bw, ok := w.(writer)
+       if !ok {
+               bw = bufio.NewWriter(w)
+       }
+       lw := uint(litWidth)
+       return &encoder{
+               w:         bw,
+               write:     write,
+               width:     1 + lw,
+               litWidth:  lw,
+               hi:        1<<lw + 1,
+               overflow:  1 << (lw + 1),
+               savedCode: invalidCode,
+       }
+}
diff --git a/src/pkg/compress/lzw/writer_test.go b/src/pkg/compress/lzw/writer_test.go
new file mode 100644 (file)
index 0000000..2199522
--- /dev/null
@@ -0,0 +1,100 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+       "io"
+       "io/ioutil"
+       "os"
+       "testing"
+)
+
+var filenames = []string{
+       "../testdata/e.txt",
+       "../testdata/pi.txt",
+}
+
+// testFile tests that compressing and then decompressing the given file with
+// the given options yields equivalent bytes to the original file.
+func testFile(t *testing.T, fn string, order Order, litWidth int) {
+       // Read the file, as golden output.
+       golden, err := os.Open(fn, os.O_RDONLY, 0400)
+       if err != nil {
+               t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err)
+               return
+       }
+       defer golden.Close()
+
+       // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
+       raw, err := os.Open(fn, os.O_RDONLY, 0400)
+       if err != nil {
+               t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err)
+               return
+       }
+
+       piper, pipew := io.Pipe()
+       defer piper.Close()
+       go func() {
+               defer raw.Close()
+               defer pipew.Close()
+               lzww := NewWriter(pipew, order, litWidth)
+               defer lzww.Close()
+               var b [4096]byte
+               for {
+                       n, err0 := raw.Read(b[:])
+                       if err0 != nil && err0 != os.EOF {
+                               t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err0)
+                               return
+                       }
+                       _, err1 := lzww.Write(b[:n])
+                       if err1 == os.EPIPE {
+                               // Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
+                               return
+                       }
+                       if err1 != nil {
+                               t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
+                               return
+                       }
+                       if err0 == os.EOF {
+                               break
+                       }
+               }
+       }()
+       lzwr := NewReader(piper, order, litWidth)
+       defer lzwr.Close()
+
+       // Compare the two.
+       b0, err0 := ioutil.ReadAll(golden)
+       b1, err1 := ioutil.ReadAll(lzwr)
+       if err0 != nil {
+               t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err0)
+               return
+       }
+       if err1 != nil {
+               t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
+               return
+       }
+       if len(b0) != len(b1) {
+               t.Errorf("%s (order=%d litWidth=%d): length mismatch %d versus %d", fn, order, litWidth, len(b0), len(b1))
+               return
+       }
+       for i := 0; i < len(b0); i++ {
+               if b0[i] != b1[i] {
+                       t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, order, litWidth, i, b0[i], b1[i])
+                       return
+               }
+       }
+}
+
+func TestWriter(t *testing.T) {
+       for _, filename := range filenames {
+               for _, order := range [...]Order{LSB, MSB} {
+                       // The test data "2.71828 etcetera" is ASCII text requiring at least 6 bits.
+                       for _, litWidth := range [...]int{6, 7, 8} {
+                               testFile(t, filename, order, litWidth)
+                       }
+               }
+       }
+}