]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.7] compress/flate: make huffmanBitWriter errors persistent
authorJoe Tsai <joetsai@digital-static.net>
Tue, 16 Aug 2016 23:03:00 +0000 (16:03 -0700)
committerChris Broadfoot <cbro@golang.org>
Wed, 7 Sep 2016 17:45:45 +0000 (17:45 +0000)
For persistent error handling, the methods of huffmanBitWriter have to be
consistent about how they check errors. It must either consistently
check error *before* every operation OR immediately *after* every
operation. Since most of the current logic uses the previous approach,
we apply the same style of error checking to writeBits and all calls
to Write such that they only operate if w.err is already nil going
into them.

The error handling approach is brittle and easily broken by future commits to
the code. In the near future, we should switch the logic to use panic at the
lowest levels and a recover at the edge of the public API to ensure
that errors are always persistent.

Fixes #16749

Change-Id: Ie1d83e4ed8842f6911a31e23311cd3cbf38abe8c
Reviewed-on: https://go-review.googlesource.com/27200
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-on: https://go-review.googlesource.com/28634
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>

src/compress/flate/deflate.go
src/compress/flate/deflate_test.go
src/compress/flate/huffman_bit_writer.go

index 3e4dc7b57e1e052233407b5200335f1729fc9c92..9f53d51a6e751ffe7157f54efeb5d34afd1e92ed 100644 (file)
@@ -724,7 +724,7 @@ func (w *Writer) Close() error {
 // the result of NewWriter or NewWriterDict called with dst
 // and w's level and dictionary.
 func (w *Writer) Reset(dst io.Writer) {
-       if dw, ok := w.d.w.w.(*dictWriter); ok {
+       if dw, ok := w.d.w.writer.(*dictWriter); ok {
                // w was created with NewWriterDict
                dw.w = dst
                w.d.reset(dw)
index 27a3b3823a649401df3c58f8be497328feb6e7a5..3322c40845d0678d29b51f7326d4cd045c0c593e 100644 (file)
@@ -6,6 +6,7 @@ package flate
 
 import (
        "bytes"
+       "errors"
        "fmt"
        "internal/testenv"
        "io"
@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) {
                }
        }
 }
+
+var errIO = errors.New("IO error")
+
+// failWriter fails with errIO exactly at the nth call to Write.
+type failWriter struct{ n int }
+
+func (w *failWriter) Write(b []byte) (int, error) {
+       w.n--
+       if w.n == -1 {
+               return 0, errIO
+       }
+       return len(b), nil
+}
+
+func TestWriterPersistentError(t *testing.T) {
+       d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
+       if err != nil {
+               t.Fatalf("ReadFile: %v", err)
+       }
+       d = d[:10000] // Keep this test short
+
+       zw, err := NewWriter(nil, DefaultCompression)
+       if err != nil {
+               t.Fatalf("NewWriter: %v", err)
+       }
+
+       // Sweep over the threshold at which an error is returned.
+       // The variable i makes it such that the ith call to failWriter.Write will
+       // return errIO. Since failWriter errors are not persistent, we must ensure
+       // that flate.Writer errors are persistent.
+       for i := 0; i < 1000; i++ {
+               fw := &failWriter{i}
+               zw.Reset(fw)
+
+               _, werr := zw.Write(d)
+               cerr := zw.Close()
+               if werr != errIO && werr != nil {
+                       t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
+               }
+               if cerr != errIO && fw.n < 0 {
+                       t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
+               }
+               if fw.n >= 0 {
+                       // At this point, the failure threshold was sufficiently high enough
+                       // that we wrote the whole stream without any errors.
+                       return
+               }
+       }
+}
index c4adef9ff53010b6db8d86f824bf235dc8b87ffd..d8b5a3ebd7b50c7da64274dca161241e1bbb4d60 100644 (file)
@@ -77,7 +77,11 @@ var offsetBase = []uint32{
 var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
 
 type huffmanBitWriter struct {
-       w io.Writer
+       // writer is the underlying writer.
+       // Do not use it directly; use the write method, which ensures
+       // that Write errors are sticky.
+       writer io.Writer
+
        // Data waiting to be written is bytes[0:nbytes]
        // and then the low nbits of bits.
        bits            uint64
@@ -96,7 +100,7 @@ type huffmanBitWriter struct {
 
 func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
        return &huffmanBitWriter{
-               w:               w,
+               writer:          w,
                literalFreq:     make([]int32, maxNumLit),
                offsetFreq:      make([]int32, offsetCodeCount),
                codegen:         make([]uint8, maxNumLit+offsetCodeCount+1),
@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
 }
 
 func (w *huffmanBitWriter) reset(writer io.Writer) {
-       w.w = writer
+       w.writer = writer
        w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
        w.bytes = [bufferSize]byte{}
 }
@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() {
                n++
        }
        w.bits = 0
-       _, w.err = w.w.Write(w.bytes[:n])
+       w.write(w.bytes[:n])
        w.nbytes = 0
 }
 
+func (w *huffmanBitWriter) write(b []byte) {
+       if w.err != nil {
+               return
+       }
+       _, w.err = w.writer.Write(b)
+}
+
 func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
+       if w.err != nil {
+               return
+       }
        w.bits |= uint64(b) << w.nbits
        w.nbits += nb
        if w.nbits >= 48 {
@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
                bytes[5] = byte(bits >> 40)
                n += 6
                if n >= bufferFlushSize {
-                       _, w.err = w.w.Write(w.bytes[:n])
+                       w.write(w.bytes[:n])
                        n = 0
                }
                w.nbytes = n
@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
                n++
        }
        if n != 0 {
-               _, w.err = w.w.Write(w.bytes[:n])
-               if w.err != nil {
-                       return
-               }
+               w.write(w.bytes[:n])
        }
        w.nbytes = 0
-       _, w.err = w.w.Write(bytes)
+       w.write(bytes)
 }
 
 // RFC 1951 3.2.7 specifies a special run-length encoding for specifying
@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) {
                bytes[5] = byte(bits >> 40)
                n += 6
                if n >= bufferFlushSize {
-                       _, w.err = w.w.Write(w.bytes[:n])
+                       w.write(w.bytes[:n])
                        n = 0
                }
                w.nbytes = n
@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets
 // writeTokens writes a slice of tokens to the output.
 // codes for literal and offset encoding must be supplied.
 func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
+       if w.err != nil {
+               return
+       }
        for _, t := range tokens {
                if t < matchType {
                        w.writeCode(leCodes[t.literal()])
@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
                if n < bufferFlushSize {
                        continue
                }
-               _, w.err = w.w.Write(w.bytes[:n])
+               w.write(w.bytes[:n])
                if w.err != nil {
-                       return
+                       return // Return early in the event of write failures
                }
                n = 0
        }