]> Cypherpunks repositories - gostls13.git/commitdiff
compress/zlib: make errors persistent
authorJoe Tsai <joetsai@digital-static.net>
Sun, 6 Mar 2016 23:51:57 +0000 (15:51 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 7 Mar 2016 18:32:04 +0000 (18:32 +0000)
Ensure that all errors (including io.EOF) are persistent across method
calls on zlib.Reader. Furthermore, ensure that these persistent errors
are properly cleared when Reset is called.

Fixes #14675

Change-Id: I15a20c7e25dc38219e7e0ff255d1ba775a86bb47
Reviewed-on: https://go-review.googlesource.com/20292
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/compress/zlib/reader.go
src/compress/zlib/reader_test.go

index 78ea7043bcee353987e41b8ca210634247b75ad9..30535fd980e55bb180d57c05c701c36eec9b0415 100644 (file)
@@ -84,19 +84,17 @@ func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) {
        return z, nil
 }
 
-func (z *reader) Read(p []byte) (n int, err error) {
+func (z *reader) Read(p []byte) (int, error) {
        if z.err != nil {
                return 0, z.err
        }
-       if len(p) == 0 {
-               return 0, nil
-       }
 
-       n, err = z.decompressor.Read(p)
+       var n int
+       n, z.err = z.decompressor.Read(p)
        z.digest.Write(p[0:n])
-       if n != 0 || err != io.EOF {
-               z.err = err
-               return
+       if z.err != io.EOF {
+               // In the normal case we return here.
+               return n, z.err
        }
 
        // Finished file; check checksum.
@@ -105,20 +103,20 @@ func (z *reader) Read(p []byte) (n int, err error) {
                        err = io.ErrUnexpectedEOF
                }
                z.err = err
-               return 0, err
+               return n, z.err
        }
        // ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952).
        checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3])
        if checksum != z.digest.Sum32() {
                z.err = ErrChecksum
-               return 0, z.err
+               return n, z.err
        }
-       return
+       return n, io.EOF
 }
 
 // Calling Close does not close the wrapped io.Reader originally passed to NewReader.
 func (z *reader) Close() error {
-       if z.err != nil {
+       if z.err != nil && z.err != io.EOF {
                return z.err
        }
        z.err = z.decompressor.Close()
@@ -126,36 +124,42 @@ func (z *reader) Close() error {
 }
 
 func (z *reader) Reset(r io.Reader, dict []byte) error {
+       *z = reader{decompressor: z.decompressor}
        if fr, ok := r.(flate.Reader); ok {
                z.r = fr
        } else {
                z.r = bufio.NewReader(r)
        }
-       _, err := io.ReadFull(z.r, z.scratch[0:2])
-       if err != nil {
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
+
+       // Read the header (RFC 1950 section 2.2.).
+       _, z.err = io.ReadFull(z.r, z.scratch[0:2])
+       if z.err != nil {
+               if z.err == io.EOF {
+                       z.err = io.ErrUnexpectedEOF
                }
-               return err
+               return z.err
        }
        h := uint(z.scratch[0])<<8 | uint(z.scratch[1])
        if (z.scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) {
-               return ErrHeader
+               z.err = ErrHeader
+               return z.err
        }
        haveDict := z.scratch[1]&0x20 != 0
        if haveDict {
-               _, err = io.ReadFull(z.r, z.scratch[0:4])
-               if err != nil {
-                       if err == io.EOF {
-                               err = io.ErrUnexpectedEOF
+               _, z.err = io.ReadFull(z.r, z.scratch[0:4])
+               if z.err != nil {
+                       if z.err == io.EOF {
+                               z.err = io.ErrUnexpectedEOF
                        }
-                       return err
+                       return z.err
                }
                checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3])
                if checksum != adler32.Checksum(dict) {
-                       return ErrDictionary
+                       z.err = ErrDictionary
+                       return z.err
                }
        }
+
        if z.decompressor == nil {
                if haveDict {
                        z.decompressor = flate.NewReaderDict(z.r, dict)
index 449f4460bcd3baa0d9b98b7c4f32b974b85b3d86..f74bff1f3cc6e9acce076d8eb8cd595b5b03f9b2 100644 (file)
@@ -127,16 +127,18 @@ func TestDecompressor(t *testing.T) {
        b := new(bytes.Buffer)
        for _, tt := range zlibTests {
                in := bytes.NewReader(tt.compressed)
-               zlib, err := NewReaderDict(in, tt.dict)
+               zr, err := NewReaderDict(in, tt.dict)
                if err != nil {
                        if err != tt.err {
                                t.Errorf("%s: NewReader: %s", tt.desc, err)
                        }
                        continue
                }
-               defer zlib.Close()
+               defer zr.Close()
+
+               // Read and verify correctness of data.
                b.Reset()
-               n, err := io.Copy(b, zlib)
+               n, err := io.Copy(b, zr)
                if err != nil {
                        if err != tt.err {
                                t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
@@ -147,5 +149,13 @@ func TestDecompressor(t *testing.T) {
                if s != tt.raw {
                        t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw)
                }
+
+               // Check for sticky errors.
+               if n, err := zr.Read([]byte{0}); n != 0 || err != io.EOF {
+                       t.Errorf("%s: Read() = (%d, %v), want (0, io.EOF)", tt.desc, n, err)
+               }
+               if err := zr.Close(); err != nil {
+                       t.Errorf("%s: Close() = %v, want nil", tt.desc, err)
+               }
        }
 }