]> Cypherpunks repositories - gostls13.git/commitdiff
archive/zip: add Writer
authorAndrew Gerrand <adg@golang.org>
Sun, 10 Jul 2011 01:30:16 +0000 (11:30 +1000)
committerAndrew Gerrand <adg@golang.org>
Sun, 10 Jul 2011 01:30:16 +0000 (11:30 +1000)
R=bradfitz, dchest, r, rsc
CC=golang-dev
https://golang.org/cl/4523077

src/pkg/archive/zip/Makefile
src/pkg/archive/zip/reader.go
src/pkg/archive/zip/struct.go
src/pkg/archive/zip/writer.go [new file with mode: 0644]
src/pkg/archive/zip/writer_test.go [new file with mode: 0644]

index 32a543133c8fae40e6aeaed7ed1d90f18a8f5ca2..9071690f0a9b8ca33bd108a297b02ad6eb8ecaec 100644 (file)
@@ -8,5 +8,6 @@ TARG=archive/zip
 GOFILES=\
        reader.go\
        struct.go\
+       writer.go\
 
 include ../../../Make.pkg
index 17464c5d8e4129d65dc4e7a6ac01d66012d9c49e..7deff117cb35c09487a5b3d1d325eaff4c94de2a 100644 (file)
@@ -2,13 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-/*
-Package zip provides support for reading ZIP archives.
-
-See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
-
-This package does not support ZIP64 or disk spanning.
-*/
 package zip
 
 import (
@@ -24,9 +17,9 @@ import (
 )
 
 var (
-       FormatError       = os.NewError("not a valid zip file")
-       UnsupportedMethod = os.NewError("unsupported compression algorithm")
-       ChecksumError     = os.NewError("checksum error")
+       FormatError       = os.NewError("zip: not a valid zip file")
+       UnsupportedMethod = os.NewError("zip: unsupported compression algorithm")
+       ChecksumError     = os.NewError("zip: checksum error")
 )
 
 type Reader struct {
@@ -52,7 +45,7 @@ func (f *File) hasDataDescriptor() bool {
        return f.Flags&0x8 != 0
 }
 
-// OpenReader will open the Zip file specified by name and return a ReaderCloser.
+// OpenReader will open the Zip file specified by name and return a ReadCloser.
 func OpenReader(name string) (*ReadCloser, os.Error) {
        f, err := os.Open(name)
        if err != nil {
@@ -111,6 +104,7 @@ func (rc *ReadCloser) Close() os.Error {
 // Open returns a ReadCloser that provides access to the File's contents.
 func (f *File) Open() (rc io.ReadCloser, err os.Error) {
        off := int64(f.headerOffset)
+       size := int64(f.CompressedSize)
        if f.bodyOffset == 0 {
                r := io.NewSectionReader(f.zipr, off, f.zipsize-off)
                if err = readFileHeader(f, r); err != nil {
@@ -119,21 +113,19 @@ func (f *File) Open() (rc io.ReadCloser, err os.Error) {
                if f.bodyOffset, err = r.Seek(0, os.SEEK_CUR); err != nil {
                        return
                }
-       }
-       size := int64(f.CompressedSize)
-       if f.hasDataDescriptor() {
                if size == 0 {
-                       // permit SectionReader to see the rest of the file
-                       size = f.zipsize - (off + f.bodyOffset)
-               } else {
-                       size += dataDescriptorLen
+                       size = int64(f.CompressedSize)
                }
        }
+       if f.hasDataDescriptor() && size == 0 {
+               // permit SectionReader to see the rest of the file
+               size = f.zipsize - (off + f.bodyOffset)
+       }
        r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size)
        switch f.Method {
-       case 0: // store (no compression)
+       case Store: // (no compression)
                rc = ioutil.NopCloser(r)
-       case 8: // DEFLATE
+       case Deflate:
                rc = flate.NewReader(r)
        default:
                err = UnsupportedMethod
@@ -171,11 +163,7 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
 func (r *checksumReader) Close() os.Error { return r.rc.Close() }
 
 func readFileHeader(f *File, r io.Reader) (err os.Error) {
-       defer func() {
-               if rerr, ok := recover().(os.Error); ok {
-                       err = rerr
-               }
-       }()
+       defer recoverError(&err)
        var (
                signature      uint32
                filenameLength uint16
@@ -201,11 +189,7 @@ func readFileHeader(f *File, r io.Reader) (err os.Error) {
 }
 
 func readDirectoryHeader(f *File, r io.Reader) (err os.Error) {
-       defer func() {
-               if rerr, ok := recover().(os.Error); ok {
-                       err = rerr
-               }
-       }()
+       defer recoverError(&err)
        var (
                signature          uint32
                filenameLength     uint16
@@ -242,18 +226,14 @@ func readDirectoryHeader(f *File, r io.Reader) (err os.Error) {
 }
 
 func readDataDescriptor(r io.Reader, f *File) (err os.Error) {
-       defer func() {
-               if rerr, ok := recover().(os.Error); ok {
-                       err = rerr
-               }
-       }()
+       defer recoverError(&err)
        read(r, &f.CRC32)
        read(r, &f.CompressedSize)
        read(r, &f.UncompressedSize)
        return
 }
 
-func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error) {
+func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, err os.Error) {
        // look for directoryEndSignature in the last 1k, then in the last 65k
        var b []byte
        for i, bLen := range []int64{1024, 65 * 1024} {
@@ -274,14 +254,9 @@ func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error)
        }
 
        // read header into struct
-       defer func() {
-               if rerr, ok := recover().(os.Error); ok {
-                       err = rerr
-                       d = nil
-               }
-       }()
+       defer recoverError(&err)
        br := bytes.NewBuffer(b[4:]) // skip over signature
-       d = new(directoryEnd)
+       d := new(directoryEnd)
        read(br, &d.diskNbr)
        read(br, &d.dirDiskNbr)
        read(br, &d.dirRecordsThisDisk)
index bfe0aae2e9a2ac8ac8fefdc0a242fd59fb4aea6f..3092314c9c1e8eee64922e89dfab4fd1b1a9f15c 100644 (file)
@@ -1,5 +1,24 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package zip provides support for reading and writing ZIP archives.
+
+See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
+
+This package does not support ZIP64 or disk spanning.
+*/
 package zip
 
+import "os"
+
+// Compression methods.
+const (
+       Store   uint16 = 0
+       Deflate uint16 = 8
+)
+
 const (
        fileHeaderSignature      = 0x04034b50
        directoryHeaderSignature = 0x02014b50
@@ -32,3 +51,13 @@ type directoryEnd struct {
        commentLen         uint16
        comment            string
 }
+
+func recoverError(err *os.Error) {
+       if e := recover(); e != nil {
+               if osErr, ok := e.(os.Error); ok {
+                       *err = osErr
+                       return
+               }
+               panic(e)
+       }
+}
diff --git a/src/pkg/archive/zip/writer.go b/src/pkg/archive/zip/writer.go
new file mode 100644 (file)
index 0000000..2065b06
--- /dev/null
@@ -0,0 +1,244 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zip
+
+import (
+       "bufio"
+       "compress/flate"
+       "encoding/binary"
+       "hash"
+       "hash/crc32"
+       "io"
+       "os"
+)
+
+// TODO(adg): support zip file comments
+// TODO(adg): support specifying deflate level
+
+// Writer implements a zip file writer.
+type Writer struct {
+       *countWriter
+       dir    []*header
+       last   *fileWriter
+       closed bool
+}
+
+type header struct {
+       *FileHeader
+       offset uint32
+}
+
+// NewWriter returns a new Writer writing a zip file to w.
+func NewWriter(w io.Writer) *Writer {
+       return &Writer{countWriter: &countWriter{w: bufio.NewWriter(w)}}
+}
+
+// Close finishes writing the zip file by writing the central directory.
+// It does not (and can not) close the underlying writer.
+func (w *Writer) Close() (err os.Error) {
+       if w.last != nil && !w.last.closed {
+               if err = w.last.close(); err != nil {
+                       return
+               }
+               w.last = nil
+       }
+       if w.closed {
+               return os.NewError("zip: writer closed twice")
+       }
+       w.closed = true
+
+       defer recoverError(&err)
+
+       // write central directory
+       start := w.count
+       for _, h := range w.dir {
+               write(w, uint32(directoryHeaderSignature))
+               write(w, h.CreatorVersion)
+               write(w, h.ReaderVersion)
+               write(w, h.Flags)
+               write(w, h.Method)
+               write(w, h.ModifiedTime)
+               write(w, h.ModifiedDate)
+               write(w, h.CRC32)
+               write(w, h.CompressedSize)
+               write(w, h.UncompressedSize)
+               write(w, uint16(len(h.Name)))
+               write(w, uint16(len(h.Extra)))
+               write(w, uint16(len(h.Comment)))
+               write(w, uint16(0)) // disk number start
+               write(w, uint16(0)) // internal file attributes
+               write(w, uint32(0)) // external file attributes
+               write(w, h.offset)
+               writeBytes(w, []byte(h.Name))
+               writeBytes(w, h.Extra)
+               writeBytes(w, []byte(h.Comment))
+       }
+       end := w.count
+
+       // write end record
+       write(w, uint32(directoryEndSignature))
+       write(w, uint16(0))          // disk number
+       write(w, uint16(0))          // disk number where directory starts
+       write(w, uint16(len(w.dir))) // number of entries this disk
+       write(w, uint16(len(w.dir))) // number of entries total
+       write(w, uint32(end-start))  // size of directory
+       write(w, uint32(start))      // start of directory
+       write(w, uint16(0))          // size of comment
+
+       return w.w.(*bufio.Writer).Flush()
+}
+
+// Create adds a file to the zip file using the provided name.
+// It returns a Writer to which the file contents should be written.
+// The file's contents must be written to the io.Writer before the next
+// call to Create, CreateHeader, or Close.
+func (w *Writer) Create(name string) (io.Writer, os.Error) {
+       header := &FileHeader{
+               Name:   name,
+               Method: Deflate,
+       }
+       return w.CreateHeader(header)
+}
+
+// CreateHeader adds a file to the zip file using the provided FileHeader
+// for the file metadata. 
+// It returns a Writer to which the file contents should be written.
+// The file's contents must be written to the io.Writer before the next
+// call to Create, CreateHeader, or Close.
+func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, os.Error) {
+       if w.last != nil && !w.last.closed {
+               if err := w.last.close(); err != nil {
+                       return nil, err
+               }
+       }
+
+       fh.Flags |= 0x8 // we will write a data descriptor
+       fh.CreatorVersion = 0x14
+       fh.ReaderVersion = 0x14
+
+       fw := &fileWriter{
+               zipw:      w,
+               compCount: &countWriter{w: w},
+               crc32:     crc32.NewIEEE(),
+       }
+       switch fh.Method {
+       case Store:
+               fw.comp = nopCloser{fw.compCount}
+       case Deflate:
+               fw.comp = flate.NewWriter(fw.compCount, 5)
+       default:
+               return nil, UnsupportedMethod
+       }
+       fw.rawCount = &countWriter{w: fw.comp}
+
+       h := &header{
+               FileHeader: fh,
+               offset:     uint32(w.count),
+       }
+       w.dir = append(w.dir, h)
+       fw.header = h
+
+       if err := writeHeader(w, fh); err != nil {
+               return nil, err
+       }
+
+       w.last = fw
+       return fw, nil
+}
+
+func writeHeader(w io.Writer, h *FileHeader) (err os.Error) {
+       defer recoverError(&err)
+       write(w, uint32(fileHeaderSignature))
+       write(w, h.ReaderVersion)
+       write(w, h.Flags)
+       write(w, h.Method)
+       write(w, h.ModifiedTime)
+       write(w, h.ModifiedDate)
+       write(w, h.CRC32)
+       write(w, h.CompressedSize)
+       write(w, h.UncompressedSize)
+       write(w, uint16(len(h.Name)))
+       write(w, uint16(len(h.Extra)))
+       writeBytes(w, []byte(h.Name))
+       writeBytes(w, h.Extra)
+       return nil
+}
+
+type fileWriter struct {
+       *header
+       zipw      io.Writer
+       rawCount  *countWriter
+       comp      io.WriteCloser
+       compCount *countWriter
+       crc32     hash.Hash32
+       closed    bool
+}
+
+func (w *fileWriter) Write(p []byte) (int, os.Error) {
+       if w.closed {
+               return 0, os.NewError("zip: write to closed file")
+       }
+       w.crc32.Write(p)
+       return w.rawCount.Write(p)
+}
+
+func (w *fileWriter) close() (err os.Error) {
+       if w.closed {
+               return os.NewError("zip: file closed twice")
+       }
+       w.closed = true
+       if err = w.comp.Close(); err != nil {
+               return
+       }
+
+       // update FileHeader
+       fh := w.header.FileHeader
+       fh.CRC32 = w.crc32.Sum32()
+       fh.CompressedSize = uint32(w.compCount.count)
+       fh.UncompressedSize = uint32(w.rawCount.count)
+
+       // write data descriptor
+       defer recoverError(&err)
+       write(w.zipw, fh.CRC32)
+       write(w.zipw, fh.CompressedSize)
+       write(w.zipw, fh.UncompressedSize)
+
+       return nil
+}
+
+type countWriter struct {
+       w     io.Writer
+       count int64
+}
+
+func (w *countWriter) Write(p []byte) (int, os.Error) {
+       n, err := w.w.Write(p)
+       w.count += int64(n)
+       return n, err
+}
+
+type nopCloser struct {
+       io.Writer
+}
+
+func (w nopCloser) Close() os.Error {
+       return nil
+}
+
+func write(w io.Writer, data interface{}) {
+       if err := binary.Write(w, binary.LittleEndian, data); err != nil {
+               panic(err)
+       }
+}
+
+func writeBytes(w io.Writer, b []byte) {
+       n, err := w.Write(b)
+       if err != nil {
+               panic(err)
+       }
+       if n != len(b) {
+               panic(io.ErrShortWrite)
+       }
+}
diff --git a/src/pkg/archive/zip/writer_test.go b/src/pkg/archive/zip/writer_test.go
new file mode 100644 (file)
index 0000000..eb2a80c
--- /dev/null
@@ -0,0 +1,73 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zip
+
+import (
+       "bytes"
+       "io/ioutil"
+       "rand"
+       "testing"
+)
+
+// TODO(adg): a more sophisticated test suite
+
+const testString = "Rabbits, guinea pigs, gophers, marsupial rats, and quolls."
+
+func TestWriter(t *testing.T) {
+       largeData := make([]byte, 1<<17)
+       for i := range largeData {
+               largeData[i] = byte(rand.Int())
+       }
+
+       // write a zip file
+       buf := new(bytes.Buffer)
+       w := NewWriter(buf)
+       testCreate(t, w, "foo", []byte(testString), Store)
+       testCreate(t, w, "bar", largeData, Deflate)
+       if err := w.Close(); err != nil {
+               t.Fatal(err)
+       }
+
+       // read it back
+       r, err := NewReader(sliceReaderAt(buf.Bytes()), int64(buf.Len()))
+       if err != nil {
+               t.Fatal(err)
+       }
+       testReadFile(t, r.File[0], []byte(testString))
+       testReadFile(t, r.File[1], largeData)
+}
+
+func testCreate(t *testing.T, w *Writer, name string, data []byte, method uint16) {
+       header := &FileHeader{
+               Name:   name,
+               Method: method,
+       }
+       f, err := w.CreateHeader(header)
+       if err != nil {
+               t.Fatal(err)
+       }
+       _, err = f.Write(data)
+       if err != nil {
+               t.Fatal(err)
+       }
+}
+
+func testReadFile(t *testing.T, f *File, data []byte) {
+       rc, err := f.Open()
+       if err != nil {
+               t.Fatal("opening:", err)
+       }
+       b, err := ioutil.ReadAll(rc)
+       if err != nil {
+               t.Fatal("reading:", err)
+       }
+       err = rc.Close()
+       if err != nil {
+               t.Fatal("closing:", err)
+       }
+       if !bytes.Equal(b, data) {
+               t.Errorf("File contents %q, want %q", b, data)
+       }
+}