From: Joe Tsai Date: Sun, 6 Mar 2016 23:51:57 +0000 (-0800) Subject: compress/zlib: make errors persistent X-Git-Tag: go1.7beta1~1514 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=5a966cf2da1f054d92c36b0f7f407d3ee467ef34;p=gostls13.git compress/zlib: make errors persistent Ensure that all errors (including io.EOF) are persistent across method calls on zlib.Reader. Furthermore, ensure that these persistent errors are properly cleared when Reset is called. Fixes #14675 Change-Id: I15a20c7e25dc38219e7e0ff255d1ba775a86bb47 Reviewed-on: https://go-review.googlesource.com/20292 Reviewed-by: Brad Fitzpatrick --- diff --git a/src/compress/zlib/reader.go b/src/compress/zlib/reader.go index 78ea7043bc..30535fd980 100644 --- a/src/compress/zlib/reader.go +++ b/src/compress/zlib/reader.go @@ -84,19 +84,17 @@ func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) { return z, nil } -func (z *reader) Read(p []byte) (n int, err error) { +func (z *reader) Read(p []byte) (int, error) { if z.err != nil { return 0, z.err } - if len(p) == 0 { - return 0, nil - } - n, err = z.decompressor.Read(p) + var n int + n, z.err = z.decompressor.Read(p) z.digest.Write(p[0:n]) - if n != 0 || err != io.EOF { - z.err = err - return + if z.err != io.EOF { + // In the normal case we return here. + return n, z.err } // Finished file; check checksum. @@ -105,20 +103,20 @@ func (z *reader) Read(p []byte) (n int, err error) { err = io.ErrUnexpectedEOF } z.err = err - return 0, err + return n, z.err } // ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952). checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) if checksum != z.digest.Sum32() { z.err = ErrChecksum - return 0, z.err + return n, z.err } - return + return n, io.EOF } // Calling Close does not close the wrapped io.Reader originally passed to NewReader. func (z *reader) Close() error { - if z.err != nil { + if z.err != nil && z.err != io.EOF { return z.err } z.err = z.decompressor.Close() @@ -126,36 +124,42 @@ func (z *reader) Close() error { } func (z *reader) Reset(r io.Reader, dict []byte) error { + *z = reader{decompressor: z.decompressor} if fr, ok := r.(flate.Reader); ok { z.r = fr } else { z.r = bufio.NewReader(r) } - _, err := io.ReadFull(z.r, z.scratch[0:2]) - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF + + // Read the header (RFC 1950 section 2.2.). + _, z.err = io.ReadFull(z.r, z.scratch[0:2]) + if z.err != nil { + if z.err == io.EOF { + z.err = io.ErrUnexpectedEOF } - return err + return z.err } h := uint(z.scratch[0])<<8 | uint(z.scratch[1]) if (z.scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) { - return ErrHeader + z.err = ErrHeader + return z.err } haveDict := z.scratch[1]&0x20 != 0 if haveDict { - _, err = io.ReadFull(z.r, z.scratch[0:4]) - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF + _, z.err = io.ReadFull(z.r, z.scratch[0:4]) + if z.err != nil { + if z.err == io.EOF { + z.err = io.ErrUnexpectedEOF } - return err + return z.err } checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) if checksum != adler32.Checksum(dict) { - return ErrDictionary + z.err = ErrDictionary + return z.err } } + if z.decompressor == nil { if haveDict { z.decompressor = flate.NewReaderDict(z.r, dict) diff --git a/src/compress/zlib/reader_test.go b/src/compress/zlib/reader_test.go index 449f4460bc..f74bff1f3c 100644 --- a/src/compress/zlib/reader_test.go +++ b/src/compress/zlib/reader_test.go @@ -127,16 +127,18 @@ func TestDecompressor(t *testing.T) { b := new(bytes.Buffer) for _, tt := range zlibTests { in := bytes.NewReader(tt.compressed) - zlib, err := NewReaderDict(in, tt.dict) + zr, err := NewReaderDict(in, tt.dict) if err != nil { if err != tt.err { t.Errorf("%s: NewReader: %s", tt.desc, err) } continue } - defer zlib.Close() + defer zr.Close() + + // Read and verify correctness of data. b.Reset() - n, err := io.Copy(b, zlib) + n, err := io.Copy(b, zr) if err != nil { if err != tt.err { t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err) @@ -147,5 +149,13 @@ func TestDecompressor(t *testing.T) { if s != tt.raw { t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw) } + + // Check for sticky errors. + if n, err := zr.Read([]byte{0}); n != 0 || err != io.EOF { + t.Errorf("%s: Read() = (%d, %v), want (0, io.EOF)", tt.desc, n, err) + } + if err := zr.Close(); err != nil { + t.Errorf("%s: Close() = %v, want nil", tt.desc, err) + } } }