]> Cypherpunks repositories - gostls13.git/commitdiff
compress/flate: forward upstream Writer errors
authorKlaus Post <klauspost@gmail.com>
Thu, 10 Mar 2016 15:46:25 +0000 (16:46 +0100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 10 Mar 2016 17:46:46 +0000 (17:46 +0000)
If the upstream writer has returned an error, it may not
be returned by subsequent calls.

This makes sure that if an error has been returned, the
Writer will keep returning an error on all subsequent calls,
and not silently "swallow" them.

Change-Id: I2c9f614df72e1f4786705bf94e119b66c62abe5e
Reviewed-on: https://go-review.googlesource.com/20515
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
src/compress/flate/deflate.go
src/compress/flate/writer_test.go

index 8bcd61ac2ab3002abae4d475e5463e3604415745..199fc4cf3c78f3930c4ecfd42a2b8eedbd80c479 100644 (file)
@@ -373,16 +373,25 @@ func (d *compressor) store() {
 }
 
 func (d *compressor) write(b []byte) (n int, err error) {
+       if d.err != nil {
+               return 0, d.err
+       }
        n = len(b)
        b = b[d.fill(d, b):]
        for len(b) > 0 {
                d.step(d)
                b = b[d.fill(d, b):]
+               if d.err != nil {
+                       return 0, d.err
+               }
        }
-       return n, d.err
+       return n, nil
 }
 
 func (d *compressor) syncFlush() error {
+       if d.err != nil {
+               return d.err
+       }
        d.sync = true
        d.step(d)
        if d.err == nil {
@@ -461,6 +470,9 @@ func (d *compressor) reset(w io.Writer) {
 }
 
 func (d *compressor) close() error {
+       if d.err != nil {
+               return d.err
+       }
        d.sync = true
        d.step(d)
        if d.err != nil {
index 85101afafbed4157f69ea8a5f08a9aa247e75e77..5c18ba346ca605257f8d434ad3e0fcbfc0221136 100644 (file)
@@ -5,6 +5,9 @@
 package flate
 
 import (
+       "bytes"
+       "fmt"
+       "io"
        "io/ioutil"
        "runtime"
        "testing"
@@ -59,3 +62,68 @@ func BenchmarkEncodeTwainDefault1e6(b *testing.B)   { benchmarkEncoder(b, twain,
 func BenchmarkEncodeTwainCompress1e4(b *testing.B)  { benchmarkEncoder(b, twain, compress, 1e4) }
 func BenchmarkEncodeTwainCompress1e5(b *testing.B)  { benchmarkEncoder(b, twain, compress, 1e5) }
 func BenchmarkEncodeTwainCompress1e6(b *testing.B)  { benchmarkEncoder(b, twain, compress, 1e6) }
+
+// errorWriter is a writer that fails after N writes.
+type errorWriter struct {
+       N int
+}
+
+func (e *errorWriter) Write(b []byte) (int, error) {
+       if e.N <= 0 {
+               return 0, io.ErrClosedPipe
+       }
+       e.N--
+       return len(b), nil
+}
+
+// Test if errors from the underlying writer is passed upwards.
+func TestWriteError(t *testing.T) {
+       buf := new(bytes.Buffer)
+       for i := 0; i < 1024*1024; i++ {
+               buf.WriteString(fmt.Sprintf("asdasfasf%d%dfghfgujyut%dyutyu\n", i, i, i))
+       }
+       in := buf.Bytes()
+       // We create our own buffer to control number of writes.
+       copyBuffer := make([]byte, 1024)
+       for l := 0; l < 10; l++ {
+               for fail := 1; fail <= 512; fail *= 2 {
+                       // Fail after 'fail' writes
+                       ew := &errorWriter{N: fail}
+                       w, err := NewWriter(ew, l)
+                       if err != nil {
+                               t.Fatalf("NewWriter: level %d: %v", l, err)
+                       }
+                       n, err := io.CopyBuffer(w, bytes.NewBuffer(in), copyBuffer)
+                       if err == nil {
+                               t.Fatalf("Level %d: Expected an error, writer was %#v", l, ew)
+                       }
+                       n2, err := w.Write([]byte{1, 2, 2, 3, 4, 5})
+                       if n2 != 0 {
+                               t.Fatal("Level", l, "Expected 0 length write, got", n)
+                       }
+                       if err == nil {
+                               t.Fatal("Level", l, "Expected an error")
+                       }
+                       err = w.Flush()
+                       if err == nil {
+                               t.Fatal("Level", l, "Expected an error on flush")
+                       }
+                       err = w.Close()
+                       if err == nil {
+                               t.Fatal("Level", l, "Expected an error on close")
+                       }
+
+                       w.Reset(ioutil.Discard)
+                       n2, err = w.Write([]byte{1, 2, 3, 4, 5, 6})
+                       if err != nil {
+                               t.Fatal("Level", l, "Got unexpected error after reset:", err)
+                       }
+                       if n2 == 0 {
+                               t.Fatal("Level", l, "Got 0 length write, expected > 0")
+                       }
+                       if testing.Short() {
+                               return
+                       }
+               }
+       }
+}