]> Cypherpunks repositories - gostls13.git/commitdiff
compress/lzw: fix hi code overflow.
authorNigel Tao <nigeltao@golang.org>
Fri, 28 Apr 2017 01:07:16 +0000 (11:07 +1000)
committerNigel Tao <nigeltao@golang.org>
Fri, 28 Apr 2017 05:59:30 +0000 (05:59 +0000)
Change-Id: I2d3c3c715d857305944cd96c45554a16cb7967e9
Reviewed-on: https://go-review.googlesource.com/42032
Reviewed-by: David Symonds <dsymonds@golang.org>
src/compress/lzw/reader.go
src/compress/lzw/reader_test.go

index 9eef2b2a782b7c763de77116d14c843747507bc9..557955bc3fd03b54d4f9e9d108c5dd80babd1027 100644 (file)
@@ -57,8 +57,14 @@ type decoder struct {
        // The next two codes mean clear and EOF.
        // Other valid codes are in the range [lo, hi] where lo := clear + 2,
        // with the upper bound incrementing on each code seen.
-       // overflow is the code at which hi overflows the code width.
+       //
+       // overflow is the code at which hi overflows the code width. It always
+       // equals 1 << width.
+       //
        // last is the most recently seen code, or decoderInvalidCode.
+       //
+       // An invariant is that
+       // (hi < overflow) || (hi == overflow && last == decoderInvalidCode)
        clear, eof, hi, overflow, last uint16
 
        // Each code c in [lo, hi] expands to two or more bytes. For c != hi:
@@ -196,6 +202,10 @@ loop:
                if d.hi >= d.overflow {
                        if d.width == maxWidth {
                                d.last = decoderInvalidCode
+                               // Undo the d.hi++ a few lines above, so that (1) we maintain
+                               // the invariant that d.hi <= d.overflow, and (2) d.hi does not
+                               // eventually overflow a uint16.
+                               d.hi--
                        } else {
                                d.width++
                                d.overflow <<= 1
index 6b9f9a3da7035ddc80203c5b4730bc0f5d93eda0..53c9cdd865dce0838facdd791ad2fd7773a7711e 100644 (file)
@@ -120,6 +120,32 @@ func TestReader(t *testing.T) {
        }
 }
 
+type devZero struct{}
+
+func (devZero) Read(p []byte) (int, error) {
+       for i := range p {
+               p[i] = 0
+       }
+       return len(p), nil
+}
+
+func TestHiCodeDoesNotOverflow(t *testing.T) {
+       r := NewReader(devZero{}, LSB, 8)
+       d := r.(*decoder)
+       buf := make([]byte, 1024)
+       oldHi := uint16(0)
+       for i := 0; i < 100; i++ {
+               if _, err := io.ReadFull(r, buf); err != nil {
+                       t.Fatalf("i=%d: %v", i, err)
+               }
+               // The hi code should never decrease.
+               if d.hi < oldHi {
+                       t.Fatalf("i=%d: hi=%d decreased from previous value %d", i, d.hi, oldHi)
+               }
+               oldHi = d.hi
+       }
+}
+
 func BenchmarkDecoder(b *testing.B) {
        buf, err := ioutil.ReadFile("../testdata/e.txt")
        if err != nil {