]> Cypherpunks repositories - gostls13.git/commitdiff
archive/zip: avoid overflow in record count and byte offset fields
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 15 Nov 2016 19:33:10 +0000 (11:33 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 17 Nov 2016 18:54:33 +0000 (18:54 +0000)
This is Quentin's https://golang.org/cl/33012 with updated tests.

Fixes #14186

Change-Id: Ib51deaab0368c6bad32ce9d6345119ff44f3c2d6
Reviewed-on: https://go-review.googlesource.com/33291
Reviewed-by: Quentin Smith <quentin@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/archive/zip/writer.go
src/archive/zip/zip_test.go

index 4ab993d9497c31268330ed533e3622dd46d90122..ea4559e698428acd02aba0d5992a729debe2e7ee 100644 (file)
@@ -22,6 +22,10 @@ type Writer struct {
        last        *fileWriter
        closed      bool
        compressors map[uint16]Compressor
+
+       // testHookCloseSizeOffset if non-nil is called with the size
+       // of offset of the central directory at Close.
+       testHookCloseSizeOffset func(size, offset uint64)
 }
 
 type header struct {
@@ -140,7 +144,11 @@ func (w *Writer) Close() error {
        size := uint64(end - start)
        offset := uint64(start)
 
-       if records > uint16max || size > uint32max || offset > uint32max {
+       if f := w.testHookCloseSizeOffset; f != nil {
+               f(size, offset)
+       }
+
+       if records >= uint16max || size >= uint32max || offset >= uint32max {
                var buf [directory64EndLen + directory64LocLen]byte
                b := writeBuf(buf[:])
 
index f166b76e3f97590ede186515751579875628abcb..e1e67e535703db415f75455f5df95aa4fbcc157b 100644 (file)
@@ -8,6 +8,7 @@ package zip
 
 import (
        "bytes"
+       "errors"
        "fmt"
        "hash"
        "internal/testenv"
@@ -271,6 +272,7 @@ func TestZip64(t *testing.T) {
        if testing.Short() {
                t.Skip("slow test; skipping")
        }
+       t.Parallel()
        const size = 1 << 32 // before the "END\n" part
        buf := testZip64(t, size)
        testZip64DirectoryRecordLength(buf, t)
@@ -280,6 +282,7 @@ func TestZip64EdgeCase(t *testing.T) {
        if testing.Short() {
                t.Skip("slow test; skipping")
        }
+       t.Parallel()
        // Test a zip file with uncompressed size 0xFFFFFFFF.
        // That's the magic marker for a 64-bit file, so even though
        // it fits in a 32-bit field we must use the 64-bit field.
@@ -290,6 +293,250 @@ func TestZip64EdgeCase(t *testing.T) {
        testZip64DirectoryRecordLength(buf, t)
 }
 
+// Tests that we generate a zip64 file if the the directory at offset
+// 0xFFFFFFFF, but not before.
+func TestZip64DirectoryOffset(t *testing.T) {
+       t.Parallel()
+       const filename = "huge.txt"
+       gen := func(wantOff uint64) func(*Writer) {
+               return func(w *Writer) {
+                       w.testHookCloseSizeOffset = func(size, off uint64) {
+                               if off != wantOff {
+                                       t.Errorf("central directory offset = %d (%x); want %d", off, off, wantOff)
+                               }
+                       }
+                       f, err := w.CreateHeader(&FileHeader{
+                               Name:   filename,
+                               Method: Store,
+                       })
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       f.(*fileWriter).crc32 = fakeHash32{}
+                       size := wantOff - fileHeaderLen - uint64(len(filename)) - dataDescriptorLen
+                       if _, err := io.CopyN(f, zeros{}, int64(size)); err != nil {
+                               t.Fatal(err)
+                       }
+                       if err := w.Close(); err != nil {
+                               t.Fatal(err)
+                       }
+               }
+       }
+       t.Run("uint32max-2_NoZip64", func(t *testing.T) {
+               t.Parallel()
+               if generatesZip64(t, gen(0xfffffffe)) {
+                       t.Error("unexpected zip64")
+               }
+       })
+       t.Run("uint32max-1_Zip64", func(t *testing.T) {
+               t.Parallel()
+               if !generatesZip64(t, gen(0xffffffff)) {
+                       t.Error("expected zip64")
+               }
+       })
+}
+
+// At 16k records, we need to generate a zip64 file.
+func TestZip64ManyRecords(t *testing.T) {
+       t.Parallel()
+       gen := func(numRec int) func(*Writer) {
+               return func(w *Writer) {
+                       for i := 0; i < numRec; i++ {
+                               _, err := w.CreateHeader(&FileHeader{
+                                       Name:   "a.txt",
+                                       Method: Store,
+                               })
+                               if err != nil {
+                                       t.Fatal(err)
+                               }
+                       }
+                       if err := w.Close(); err != nil {
+                               t.Fatal(err)
+                       }
+               }
+       }
+       // 16k-1 records shouldn't make a zip64:
+       t.Run("uint16max-1_NoZip64", func(t *testing.T) {
+               t.Parallel()
+               if generatesZip64(t, gen(0xfffe)) {
+                       t.Error("unexpected zip64")
+               }
+       })
+       // 16k records should make a zip64:
+       t.Run("uint16max_Zip64", func(t *testing.T) {
+               t.Parallel()
+               if !generatesZip64(t, gen(0xffff)) {
+                       t.Error("expected zip64")
+               }
+       })
+}
+
+// suffixSaver is an io.Writer & io.ReaderAt that remembers the last 0
+// to 'keep' bytes of data written to it. Call Suffix to get the
+// suffix bytes.
+type suffixSaver struct {
+       keep  int
+       buf   []byte
+       start int
+       size  int64
+}
+
+func (ss *suffixSaver) Size() int64 { return ss.size }
+
+var errDiscardedBytes = errors.New("ReadAt of discarded bytes")
+
+func (ss *suffixSaver) ReadAt(p []byte, off int64) (n int, err error) {
+       back := ss.size - off
+       if back > int64(ss.keep) {
+               return 0, errDiscardedBytes
+       }
+       suf := ss.Suffix()
+       n = copy(p, suf[len(suf)-int(back):])
+       if n != len(p) {
+               err = io.EOF
+       }
+       return
+}
+
+func (ss *suffixSaver) Suffix() []byte {
+       if len(ss.buf) < ss.keep {
+               return ss.buf
+       }
+       buf := make([]byte, ss.keep)
+       n := copy(buf, ss.buf[ss.start:])
+       copy(buf[n:], ss.buf[:])
+       return buf
+}
+
+func (ss *suffixSaver) Write(p []byte) (n int, err error) {
+       n = len(p)
+       ss.size += int64(len(p))
+       if len(ss.buf) < ss.keep {
+               space := ss.keep - len(ss.buf)
+               add := len(p)
+               if add > space {
+                       add = space
+               }
+               ss.buf = append(ss.buf, p[:add]...)
+               p = p[add:]
+       }
+       for len(p) > 0 {
+               n := copy(ss.buf[ss.start:], p)
+               p = p[n:]
+               ss.start += n
+               if ss.start == ss.keep {
+                       ss.start = 0
+               }
+       }
+       return
+}
+
+// generatesZip64 reports whether f wrote a zip64 file.
+// f is also responsible for closing w.
+func generatesZip64(t *testing.T, f func(w *Writer)) bool {
+       ss := &suffixSaver{keep: 10 << 20}
+       w := NewWriter(ss)
+       f(w)
+       return suffixIsZip64(t, ss)
+}
+
+type sizedReaderAt interface {
+       io.ReaderAt
+       Size() int64
+}
+
+func suffixIsZip64(t *testing.T, zip sizedReaderAt) bool {
+       d := make([]byte, 1024)
+       if _, err := zip.ReadAt(d, zip.Size()-int64(len(d))); err != nil {
+               t.Fatalf("ReadAt: %v", err)
+       }
+
+       sigOff := findSignatureInBlock(d)
+       if sigOff == -1 {
+               t.Errorf("failed to find signature in block")
+               return false
+       }
+
+       dirOff, err := findDirectory64End(zip, zip.Size()-int64(len(d))+int64(sigOff))
+       if err != nil {
+               t.Fatalf("findDirectory64End: %v", err)
+       }
+       if dirOff == -1 {
+               return false
+       }
+
+       d = make([]byte, directory64EndLen)
+       if _, err := zip.ReadAt(d, dirOff); err != nil {
+               t.Fatalf("ReadAt(off=%d): %v", dirOff, err)
+       }
+
+       b := readBuf(d)
+       if sig := b.uint32(); sig != directory64EndSignature {
+               return false
+       }
+
+       size := b.uint64()
+       if size != directory64EndLen-12 {
+               t.Errorf("expected length of %d, got %d", directory64EndLen-12, size)
+       }
+       return true
+}
+
+// Zip64 is required if the total size of the records is uint32max.
+func TestZip64LargeDirectory(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+       t.Parallel()
+       // gen returns a func that writes a zip with a wantLen bytes
+       // of central directory.
+       gen := func(wantLen int64) func(*Writer) {
+               return func(w *Writer) {
+                       w.testHookCloseSizeOffset = func(size, off uint64) {
+                               if size != uint64(wantLen) {
+                                       t.Errorf("Close central directory size = %d; want %d", size, wantLen)
+                               }
+                       }
+
+                       uint16string := strings.Repeat(".", uint16max)
+                       remain := wantLen
+                       for remain > 0 {
+                               commentLen := int(uint16max) - directoryHeaderLen - 1
+                               thisRecLen := directoryHeaderLen + int(uint16max) + commentLen
+                               if int64(thisRecLen) > remain {
+                                       remove := thisRecLen - int(remain)
+                                       commentLen -= remove
+                                       thisRecLen -= remove
+                               }
+                               remain -= int64(thisRecLen)
+                               f, err := w.CreateHeader(&FileHeader{
+                                       Name:    uint16string,
+                                       Comment: uint16string[:commentLen],
+                               })
+                               if err != nil {
+                                       t.Fatalf("CreateHeader: %v", err)
+                               }
+                               f.(*fileWriter).crc32 = fakeHash32{}
+                       }
+                       if err := w.Close(); err != nil {
+                               t.Fatalf("Close: %v", err)
+                       }
+               }
+       }
+       t.Run("uint32max-1_NoZip64", func(t *testing.T) {
+               t.Parallel()
+               if generatesZip64(t, gen(uint32max-1)) {
+                       t.Error("unexpected zip64")
+               }
+       })
+       t.Run("uint32max_HasZip64", func(t *testing.T) {
+               t.Parallel()
+               if !generatesZip64(t, gen(uint32max)) {
+                       t.Error("expected zip64")
+               }
+       })
+}
+
 func testZip64(t testing.TB, size int64) *rleBuffer {
        const chunkSize = 1024
        chunks := int(size / chunkSize)
@@ -378,30 +625,8 @@ func testZip64(t testing.TB, size int64) *rleBuffer {
 
 // Issue 9857
 func testZip64DirectoryRecordLength(buf *rleBuffer, t *testing.T) {
-       d := make([]byte, 1024)
-       if _, err := buf.ReadAt(d, buf.Size()-int64(len(d))); err != nil {
-               t.Fatal("read:", err)
-       }
-
-       sigOff := findSignatureInBlock(d)
-       dirOff, err := findDirectory64End(buf, buf.Size()-int64(len(d))+int64(sigOff))
-       if err != nil {
-               t.Fatal("findDirectory64End:", err)
-       }
-
-       d = make([]byte, directory64EndLen)
-       if _, err := buf.ReadAt(d, dirOff); err != nil {
-               t.Fatal("read:", err)
-       }
-
-       b := readBuf(d)
-       if sig := b.uint32(); sig != directory64EndSignature {
-               t.Fatalf("Expected directory64EndSignature (%d), got %d", directory64EndSignature, sig)
-       }
-
-       size := b.uint64()
-       if size != directory64EndLen-12 {
-               t.Fatalf("Expected length of %d, got %d", directory64EndLen-12, size)
+       if !suffixIsZip64(t, buf) {
+               t.Fatal("not a zip64")
        }
 }
 
@@ -487,3 +712,47 @@ func BenchmarkZip64Test(b *testing.B) {
                testZip64(b, 1<<26)
        }
 }
+
+func TestSuffixSaver(t *testing.T) {
+       const keep = 10
+       ss := &suffixSaver{keep: keep}
+       ss.Write([]byte("abc"))
+       if got := string(ss.Suffix()); got != "abc" {
+               t.Errorf("got = %q; want abc", got)
+       }
+       ss.Write([]byte("defghijklmno"))
+       if got := string(ss.Suffix()); got != "fghijklmno" {
+               t.Errorf("got = %q; want fghijklmno", got)
+       }
+       if got, want := ss.Size(), int64(len("abc")+len("defghijklmno")); got != want {
+               t.Errorf("Size = %d; want %d", got, want)
+       }
+       buf := make([]byte, ss.Size())
+       for off := int64(0); off < ss.Size(); off++ {
+               for size := 1; size <= int(ss.Size()-off); size++ {
+                       readBuf := buf[:size]
+                       n, err := ss.ReadAt(readBuf, off)
+                       if off < ss.Size()-keep {
+                               if err != errDiscardedBytes {
+                                       t.Errorf("off %d, size %d = %v, %v (%q); want errDiscardedBytes", off, size, n, err, readBuf[:n])
+                               }
+                               continue
+                       }
+                       want := "abcdefghijklmno"[off : off+int64(size)]
+                       got := string(readBuf[:n])
+                       if err != nil || got != want {
+                               t.Errorf("off %d, size %d = %v, %v (%q); want %q", off, size, n, err, got, want)
+                       }
+               }
+       }
+
+}
+
+type zeros struct{}
+
+func (zeros) Read(p []byte) (int, error) {
+       for i := range p {
+               p[i] = 0
+       }
+       return len(p), nil
+}