]> Cypherpunks repositories - gostls13.git/commitdiff
image/gif: tighten the checks for when the amount of an image's pixel
authorNigel Tao <nigeltao@golang.org>
Fri, 22 Mar 2013 03:42:02 +0000 (14:42 +1100)
committerNigel Tao <nigeltao@golang.org>
Fri, 22 Mar 2013 03:42:02 +0000 (14:42 +1100)
data does not agree with its bounds.

R=r, jeff.allen
CC=golang-dev
https://golang.org/cl/7938043

src/pkg/image/gif/reader.go
src/pkg/image/gif/reader_test.go [new file with mode: 0644]

index ed493eac2f95928c95f24225bf6c084265ac65f5..2e0fed5e59af3066df872807d7cb94f7f4767513 100644 (file)
@@ -17,6 +17,11 @@ import (
        "io"
 )
 
+var (
+       errNotEnough = errors.New("gif: not enough image data")
+       errTooMuch   = errors.New("gif: too much image data")
+)
+
 // If the io.Reader does not also have ReadByte, then decode will introduce its own buffering.
 type reader interface {
        io.Reader
@@ -89,29 +94,35 @@ type decoder struct {
 // comprises (n, (n bytes)) blocks, with 1 <= n <= 255.  It is the
 // reader given to the LZW decoder, which is thus immune to the
 // blocking.  After the LZW decoder completes, there will be a 0-byte
-// block remaining (0, ()), but under normal execution blockReader
-// doesn't consume it, so it is handled in decode.
+// block remaining (0, ()), which is consumed when checking that the
+// blockReader is exhausted.
 type blockReader struct {
        r     reader
        slice []byte
+       err   error
        tmp   [256]byte
 }
 
 func (b *blockReader) Read(p []byte) (int, error) {
+       if b.err != nil {
+               return 0, b.err
+       }
        if len(p) == 0 {
                return 0, nil
        }
        if len(b.slice) == 0 {
-               blockLen, err := b.r.ReadByte()
-               if err != nil {
-                       return 0, err
+               var blockLen uint8
+               blockLen, b.err = b.r.ReadByte()
+               if b.err != nil {
+                       return 0, b.err
                }
                if blockLen == 0 {
-                       return 0, io.EOF
+                       b.err = io.EOF
+                       return 0, b.err
                }
                b.slice = b.tmp[0:blockLen]
-               if _, err = io.ReadFull(b.r, b.slice); err != nil {
-                       return 0, err
+               if _, b.err = io.ReadFull(b.r, b.slice); b.err != nil {
+                       return 0, b.err
                }
        }
        n := copy(p, b.slice)
@@ -142,35 +153,33 @@ func (d *decoder) decode(r io.Reader, configOnly bool) error {
                }
        }
 
-Loop:
-       for err == nil {
-               var c byte
-               c, err = d.r.ReadByte()
-               if err == io.EOF {
-                       break
+       for {
+               c, err := d.r.ReadByte()
+               if err != nil {
+                       return err
                }
                switch c {
                case sExtension:
-                       err = d.readExtension()
+                       if err = d.readExtension(); err != nil {
+                               return err
+                       }
 
                case sImageDescriptor:
-                       var m *image.Paletted
-                       m, err = d.newImageFromDescriptor()
+                       m, err := d.newImageFromDescriptor()
                        if err != nil {
-                               break
+                               return err
                        }
                        if d.imageFields&fColorMapFollows != 0 {
                                m.Palette, err = d.readColorMap()
                                if err != nil {
-                                       break
+                                       return err
                                }
                                // TODO: do we set transparency in this map too? That would be
                                // d.setTransparency(m.Palette)
                        } else {
                                m.Palette = d.globalColorMap
                        }
-                       var litWidth uint8
-                       litWidth, err = d.r.ReadByte()
+                       litWidth, err := d.r.ReadByte()
                        if err != nil {
                                return err
                        }
@@ -178,18 +187,27 @@ Loop:
                                return fmt.Errorf("gif: pixel size in decode out of range: %d", litWidth)
                        }
                        // A wonderfully Go-like piece of magic.
-                       lzwr := lzw.NewReader(&blockReader{r: d.r}, lzw.LSB, int(litWidth))
+                       br := &blockReader{r: d.r}
+                       lzwr := lzw.NewReader(br, lzw.LSB, int(litWidth))
                        if _, err = io.ReadFull(lzwr, m.Pix); err != nil {
-                               break
+                               if err != io.ErrUnexpectedEOF {
+                                       return err
+                               }
+                               return errNotEnough
                        }
-
-                       // There should be a "0" block remaining; drain that.
-                       c, err = d.r.ReadByte()
-                       if err != nil {
-                               return err
+                       // Both lzwr and br should be exhausted. Reading from them
+                       // should yield (0, io.EOF).
+                       if n, err := lzwr.Read(d.tmp[:1]); n != 0 || err != io.EOF {
+                               if err != nil {
+                                       return err
+                               }
+                               return errTooMuch
                        }
-                       if c != 0 {
-                               return errors.New("gif: extra data after image")
+                       if n, err := br.Read(d.tmp[:1]); n != 0 || err != io.EOF {
+                               if err != nil {
+                                       return err
+                               }
+                               return errTooMuch
                        }
 
                        // Undo the interlacing if necessary.
@@ -202,19 +220,15 @@ Loop:
                        d.delayTime = 0 // TODO: is this correct, or should we hold on to the value?
 
                case sTrailer:
-                       break Loop
+                       if len(d.image) == 0 {
+                               return io.ErrUnexpectedEOF
+                       }
+                       return nil
 
                default:
-                       err = fmt.Errorf("gif: unknown block type: 0x%.2x", c)
+                       return fmt.Errorf("gif: unknown block type: 0x%.2x", c)
                }
        }
-       if err != nil {
-               return err
-       }
-       if len(d.image) == 0 {
-               return io.ErrUnexpectedEOF
-       }
-       return nil
 }
 
 func (d *decoder) readHeaderAndScreenDescriptor() error {
diff --git a/src/pkg/image/gif/reader_test.go b/src/pkg/image/gif/reader_test.go
new file mode 100644 (file)
index 0000000..77810f8
--- /dev/null
@@ -0,0 +1,86 @@
+package gif
+
+import (
+       "bytes"
+       "compress/lzw"
+       "image"
+       "image/color"
+       "reflect"
+       "testing"
+)
+
+func TestDecode(t *testing.T) {
+       // header and trailer are parts of a valid 2x1 GIF image.
+       const (
+               header = "GIF89a" +
+                       "\x02\x00\x01\x00" + // width=2, height=1
+                       "\x80\x00\x00" + // headerFields=(a color map of 2 pixels), backgroundIndex, aspect
+                       "\x10\x20\x30\x40\x50\x60" // the color map, also known as a palette
+               trailer = "\x3b"
+       )
+
+       // lzwEncode returns an LZW encoding (with 2-bit literals) of n zeroes.
+       lzwEncode := func(n int) []byte {
+               b := &bytes.Buffer{}
+               w := lzw.NewWriter(b, lzw.LSB, 2)
+               w.Write(make([]byte, n))
+               w.Close()
+               return b.Bytes()
+       }
+
+       testCases := []struct {
+               nPix    int  // The number of pixels in the image data.
+               extra   bool // Whether to write an extra block after the LZW-encoded data.
+               wantErr error
+       }{
+               {0, false, errNotEnough},
+               {1, false, errNotEnough},
+               {2, false, nil},
+               {2, true, errTooMuch},
+               {3, false, errTooMuch},
+       }
+       for _, tc := range testCases {
+               b := &bytes.Buffer{}
+               b.WriteString(header)
+               // Write an image with bounds 2x1 but tc.nPix pixels. If tc.nPix != 2
+               // then this should result in an invalid GIF image. First, write a
+               // magic 0x2c (image descriptor) byte, bounds=(0,0)-(2,1), a flags
+               // byte, and 2-bit LZW literals.
+               b.WriteString("\x2c\x00\x00\x00\x00\x02\x00\x01\x00\x00\x02")
+               if tc.nPix > 0 {
+                       enc := lzwEncode(tc.nPix)
+                       if len(enc) > 0xff {
+                               t.Errorf("nPix=%d, extra=%t: compressed length %d is too large", tc.nPix, tc.extra, len(enc))
+                               continue
+                       }
+                       b.WriteByte(byte(len(enc)))
+                       b.Write(enc)
+               }
+               if tc.extra {
+                       b.WriteString("\x01\x02") // A 1-byte payload with an 0x02 byte.
+               }
+               b.WriteByte(0x00) // An empty block signifies the end of the image data.
+               b.WriteString(trailer)
+
+               got, err := Decode(b)
+               if err != tc.wantErr {
+                       t.Errorf("nPix=%d, extra=%t\ngot  %v\nwant %v", tc.nPix, tc.extra, err, tc.wantErr)
+               }
+
+               if tc.wantErr != nil {
+                       continue
+               }
+               want := &image.Paletted{
+                       Pix:    []uint8{0, 0},
+                       Stride: 2,
+                       Rect:   image.Rect(0, 0, 2, 1),
+                       Palette: color.Palette{
+                               color.RGBA{0x10, 0x20, 0x30, 0xff},
+                               color.RGBA{0x40, 0x50, 0x60, 0xff},
+                       },
+               }
+               if !reflect.DeepEqual(got, want) {
+                       t.Errorf("nPix=%d, extra=%t\ngot  %v\nwant %v", tc.nPix, tc.extra, got, want)
+               }
+       }
+}