]> Cypherpunks repositories - gostls13.git/commitdiff
compress/flate: return error on closed stream write
authorGregory Man <man.gregory@gmail.com>
Thu, 20 Sep 2018 08:48:17 +0000 (11:48 +0300)
committerGopher Robot <gobot@golang.org>
Mon, 2 May 2022 14:46:44 +0000 (14:46 +0000)
Previously flate.Writer allowed writes after Close, and this behavior
could lead to stream corruption.

Fixes #27741

Change-Id: Iee1ac69f8199232f693dba77b275f7078257b582
Reviewed-on: https://go-review.googlesource.com/c/go/+/136475
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>

src/compress/flate/deflate.go
src/compress/flate/deflate_test.go

index 550032176d76ca119e6b04f4d346b783acded6f9..ccf03d74eb211baccefd4699fce917cc6ed622ab 100644 (file)
@@ -5,6 +5,7 @@
 package flate
 
 import (
+       "errors"
        "fmt"
        "io"
        "math"
@@ -699,17 +700,27 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
        return w.w.Write(b)
 }
 
+var errWriteAfterClose = errors.New("compress/flate: write after close")
+
 // A Writer takes data written to it and writes the compressed
 // form of that data to an underlying writer (see NewWriter).
 type Writer struct {
        d    compressor
        dict []byte
+       err  error
 }
 
 // Write writes data to w, which will eventually write the
 // compressed form of data to its underlying writer.
 func (w *Writer) Write(data []byte) (n int, err error) {
-       return w.d.write(data)
+       if w.err != nil {
+               return 0, w.err
+       }
+       n, err = w.d.write(data)
+       if err != nil {
+               w.err = err
+       }
+       return n, err
 }
 
 // Flush flushes any pending data to the underlying writer.
@@ -724,18 +735,37 @@ func (w *Writer) Write(data []byte) (n int, err error) {
 func (w *Writer) Flush() error {
        // For more about flushing:
        // https://www.bolet.org/~pornin/deflate-flush.html
-       return w.d.syncFlush()
+       if w.err != nil {
+               return w.err
+       }
+       if err := w.d.syncFlush(); err != nil {
+               w.err = err
+               return err
+       }
+       return nil
 }
 
 // Close flushes and closes the writer.
 func (w *Writer) Close() error {
-       return w.d.close()
+       if w.err == errWriteAfterClose {
+               return nil
+       }
+       if w.err != nil {
+               return w.err
+       }
+       if err := w.d.close(); err != nil {
+               w.err = err
+               return err
+       }
+       w.err = errWriteAfterClose
+       return nil
 }
 
 // Reset discards the writer's state and makes it equivalent to
 // the result of NewWriter or NewWriterDict called with dst
 // and w's level and dictionary.
 func (w *Writer) Reset(dst io.Writer) {
+       w.err = nil
        if dw, ok := w.d.w.writer.(*dictWriter); ok {
                // w was created with NewWriterDict
                dw.w = dst
index ff567121236d2d7377ea86e924b4aa75d0d3132a..8c9ee72feb47e0f06e92a715c0e0ac62a4c2e48a 100644 (file)
@@ -125,6 +125,40 @@ func TestDeflate(t *testing.T) {
        }
 }
 
+func TestWriterClose(t *testing.T) {
+       b := new(bytes.Buffer)
+       zw, err := NewWriter(b, 6)
+       if err != nil {
+               t.Fatalf("NewWriter: %v", err)
+       }
+
+       if c, err := zw.Write([]byte("Test")); err != nil || c != 4 {
+               t.Fatalf("Write to not closed writer: %s, %d", err, c)
+       }
+
+       if err := zw.Close(); err != nil {
+               t.Fatalf("Close: %v", err)
+       }
+
+       afterClose := b.Len()
+
+       if c, err := zw.Write([]byte("Test")); err == nil || c != 0 {
+               t.Fatalf("Write to closed writer: %s, %d", err, c)
+       }
+
+       if err := zw.Flush(); err == nil {
+               t.Fatalf("Flush to closed writer: %s", err)
+       }
+
+       if err := zw.Close(); err != nil {
+               t.Fatalf("Close: %v", err)
+       }
+
+       if afterClose != b.Len() {
+               t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len())
+       }
+}
+
 // A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
 // This tests missing hash references in a very large input.
 type sparseReader struct {
@@ -683,7 +717,7 @@ func (w *failWriter) Write(b []byte) (int, error) {
        return len(b), nil
 }
 
-func TestWriterPersistentError(t *testing.T) {
+func TestWriterPersistentWriteError(t *testing.T) {
        t.Parallel()
        d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
        if err != nil {
@@ -706,12 +740,16 @@ func TestWriterPersistentError(t *testing.T) {
 
                _, werr := zw.Write(d)
                cerr := zw.Close()
+               ferr := zw.Flush()
                if werr != errIO && werr != nil {
                        t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
                }
                if cerr != errIO && fw.n < 0 {
                        t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
                }
+               if ferr != errIO && fw.n < 0 {
+                       t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO)
+               }
                if fw.n >= 0 {
                        // At this point, the failure threshold was sufficiently high enough
                        // that we wrote the whole stream without any errors.
@@ -719,6 +757,54 @@ func TestWriterPersistentError(t *testing.T) {
                }
        }
 }
+func TestWriterPersistentFlushError(t *testing.T) {
+       zw, err := NewWriter(&failWriter{0}, DefaultCompression)
+       if err != nil {
+               t.Fatalf("NewWriter: %v", err)
+       }
+       flushErr := zw.Flush()
+       closeErr := zw.Close()
+       _, writeErr := zw.Write([]byte("Test"))
+       checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
+}
+
+func TestWriterPersistentCloseError(t *testing.T) {
+       // If underlying writer return error on closing stream we should persistent this error across all writer calls.
+       zw, err := NewWriter(&failWriter{0}, DefaultCompression)
+       if err != nil {
+               t.Fatalf("NewWriter: %v", err)
+       }
+       closeErr := zw.Close()
+       flushErr := zw.Flush()
+       _, writeErr := zw.Write([]byte("Test"))
+       checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
+
+       // After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil
+       // on next Close calls.
+       var b bytes.Buffer
+       zw.Reset(&b)
+       err = zw.Close()
+       if err != nil {
+               t.Fatalf("First call to close returned error: %s", err)
+       }
+       err = zw.Close()
+       if err != nil {
+               t.Fatalf("Second call to close returned error: %s", err)
+       }
+
+       flushErr = zw.Flush()
+       _, writeErr = zw.Write([]byte("Test"))
+       checkErrors([]error{flushErr, writeErr}, errWriteAfterClose, t)
+}
+
+func checkErrors(got []error, want error, t *testing.T) {
+       t.Helper()
+       for _, err := range got {
+               if err != want {
+                       t.Errorf("Errors dosn't match\nWant: %s\nGot: %s", want, got)
+               }
+       }
+}
 
 func TestBestSpeedMatch(t *testing.T) {
        t.Parallel()