import (
        "bufio"
        "image"
+       "image/ycbcr"
        "io"
        "os"
 )
 
+// TODO(nigeltao): fix up the doc comment style so that sentences start with
+// the name of the type or function that they annotate.
+
 // A FormatError reports that the input is not a valid JPEG.
 type FormatError string
 
 
 // Component specification, specified in section B.2.2.
 type component struct {
+       h  int   // Horizontal sampling factor.
+       v  int   // Vertical sampling factor.
        c  uint8 // Component identifier.
-       h  uint8 // Horizontal sampling factor.
-       v  uint8 // Vertical sampling factor.
        tq uint8 // Quantization table destination selector.
 }
 
 type decoder struct {
        r             Reader
        width, height int
-       image         *image.RGBA
+       img           *ycbcr.YCbCr
        ri            int // Restart Interval.
        comps         [nComponent]component
        huff          [maxTc + 1][maxTh + 1]huffman
        }
        for i := 0; i < nComponent; i++ {
                hv := d.tmp[7+3*i]
+               d.comps[i].h = int(hv >> 4)
+               d.comps[i].v = int(hv & 0x0f)
                d.comps[i].c = d.tmp[6+3*i]
-               d.comps[i].h = hv >> 4
-               d.comps[i].v = hv & 0x0f
                d.comps[i].tq = d.tmp[8+3*i]
                // We only support YCbCr images, and 4:4:4, 4:2:2 or 4:2:0 chroma downsampling ratios. This implies that
                // the (h, v) values for the Y component are either (1, 1), (2, 1) or (2, 2), and the
        return nil
 }
 
-// Set the Pixel (px, py)'s RGB value, based on its YCbCr value.
-func (d *decoder) calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex int) {
-       y, cb, cr := d.blocks[0][lumaBlock][lumaIndex], d.blocks[1][0][chromaIndex], d.blocks[2][0][chromaIndex]
-       // The JFIF specification (http://www.w3.org/Graphics/JPEG/jfif3.pdf, page 3) gives the formula
-       // for translating YCbCr to RGB as:
-       //   R = Y + 1.402 (Cr-128)
-       //   G = Y - 0.34414 (Cb-128) - 0.71414 (Cr-128)
-       //   B = Y + 1.772 (Cb-128)
-       yPlusHalf := 100000*y + 50000
-       cb -= 128
-       cr -= 128
-       r := (yPlusHalf + 140200*cr) / 100000
-       g := (yPlusHalf - 34414*cb - 71414*cr) / 100000
-       b := (yPlusHalf + 177200*cb) / 100000
-       if r < 0 {
-               r = 0
-       } else if r > 255 {
-               r = 255
-       }
-       if g < 0 {
-               g = 0
-       } else if g > 255 {
-               g = 255
+// Clip x to the range [0, 255] inclusive.
+func clip(x int) uint8 {
+       if x < 0 {
+               return 0
        }
-       if b < 0 {
-               b = 0
-       } else if b > 255 {
-               b = 255
+       if x > 255 {
+               return 255
        }
-       d.image.Pix[py*d.image.Stride+px] = image.RGBAColor{uint8(r), uint8(g), uint8(b), 0xff}
+       return uint8(x)
 }
 
-// Convert the MCU from YCbCr to RGB.
-func (d *decoder) convertMCU(mx, my, h0, v0 int) {
-       lumaBlock := 0
+// Store the MCU to the image.
+func (d *decoder) storeMCU(mx, my int) {
+       h0, v0 := d.comps[0].h, d.comps[0].v
+       // Store the luma blocks.
        for v := 0; v < v0; v++ {
                for h := 0; h < h0; h++ {
-                       chromaBase := 8*4*v + 4*h
-                       py := 8 * (v0*my + v)
-                       for y := 0; y < 8 && py < d.height; y++ {
-                               px := 8 * (h0*mx + h)
-                               lumaIndex := 8 * y
-                               chromaIndex := chromaBase + 8*(y/v0)
-                               for x := 0; x < 8 && px < d.width; x++ {
-                                       d.calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex)
-                                       if h0 == 1 {
-                                               chromaIndex += 1
-                                       } else {
-                                               chromaIndex += x % 2
-                                       }
-                                       lumaIndex++
-                                       px++
+                       p := 8 * ((v0*my+v)*d.img.YStride + (h0*mx + h))
+                       for y := 0; y < 8; y++ {
+                               for x := 0; x < 8; x++ {
+                                       d.img.Y[p] = clip(d.blocks[0][h0*v+h][8*y+x])
+                                       p++
                                }
-                               py++
+                               p += d.img.YStride - 8
                        }
-                       lumaBlock++
                }
        }
+       // Store the chroma blocks.
+       p := 8 * (my*d.img.CStride + mx)
+       for y := 0; y < 8; y++ {
+               for x := 0; x < 8; x++ {
+                       d.img.Cb[p] = clip(d.blocks[1][0][8*y+x])
+                       d.img.Cr[p] = clip(d.blocks[2][0][8*y+x])
+                       p++
+               }
+               p += d.img.CStride - 8
+       }
 }
 
 // Specified in section B.2.3.
 func (d *decoder) processSOS(n int) os.Error {
-       if d.image == nil {
-               d.image = image.NewRGBA(d.width, d.height)
-       }
        if n != 4+2*nComponent {
                return UnsupportedError("SOS has wrong length")
        }
                td uint8 // DC table selector.
                ta uint8 // AC table selector.
        }
-       h0, v0 := int(d.comps[0].h), int(d.comps[0].v) // The h and v values from the Y components.
        for i := 0; i < nComponent; i++ {
                cs := d.tmp[1+2*i] // Component selector.
                if cs != d.comps[i].c {
                scanComps[i].ta = d.tmp[2+2*i] & 0x0f
        }
        // mxx and myy are the number of MCUs (Minimum Coded Units) in the image.
-       mxx := (d.width + 8*int(h0) - 1) / (8 * int(h0))
-       myy := (d.height + 8*int(v0) - 1) / (8 * int(v0))
+       h0, v0 := d.comps[0].h, d.comps[0].v // The h and v values from the Y components.
+       mxx := (d.width + 8*h0 - 1) / (8 * h0)
+       myy := (d.height + 8*v0 - 1) / (8 * v0)
+       if d.img == nil {
+               var subsampleRatio ycbcr.SubsampleRatio
+               n := h0 * v0
+               switch n {
+               case 1:
+                       subsampleRatio = ycbcr.SubsampleRatio444
+               case 2:
+                       subsampleRatio = ycbcr.SubsampleRatio422
+               case 4:
+                       subsampleRatio = ycbcr.SubsampleRatio420
+               default:
+                       panic("unreachable")
+               }
+               b := make([]byte, mxx*myy*(1*8*8*n+2*8*8))
+               d.img = &ycbcr.YCbCr{
+                       Y:              b[mxx*myy*(0*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+0*8*8)],
+                       Cb:             b[mxx*myy*(1*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+1*8*8)],
+                       Cr:             b[mxx*myy*(1*8*8*n+1*8*8) : mxx*myy*(1*8*8*n+2*8*8)],
+                       SubsampleRatio: subsampleRatio,
+                       YStride:        mxx * 8 * h0,
+                       CStride:        mxx * 8,
+                       Rect:           image.Rect(0, 0, d.width, d.height),
+               }
+       }
 
        mcu, expectedRST := 0, uint8(rst0Marker)
        var allZeroes block
                for mx := 0; mx < mxx; mx++ {
                        for i := 0; i < nComponent; i++ {
                                qt := &d.quant[d.comps[i].tq]
-                               for j := 0; j < int(d.comps[i].h*d.comps[i].v); j++ {
+                               for j := 0; j < d.comps[i].h*d.comps[i].v; j++ {
                                        d.blocks[i][j] = allZeroes
 
                                        // Decode the DC coefficient, as specified in section F.2.2.1.
                                                if err != nil {
                                                        return err
                                                }
-                                               v0 := value >> 4
-                                               v1 := value & 0x0f
-                                               if v1 != 0 {
-                                                       k += int(v0)
+                                               val0 := value >> 4
+                                               val1 := value & 0x0f
+                                               if val1 != 0 {
+                                                       k += int(val0)
                                                        if k > blockSize {
                                                                return FormatError("bad DCT index")
                                                        }
-                                                       ac, err := d.receiveExtend(v1)
+                                                       ac, err := d.receiveExtend(val1)
                                                        if err != nil {
                                                                return err
                                                        }
                                                        d.blocks[i][j][unzig[k]] = ac * qt[k]
                                                } else {
-                                                       if v0 != 0x0f {
+                                                       if val0 != 0x0f {
                                                                break
                                                        }
                                                        k += 0x0f
                                        idct(&d.blocks[i][j])
                                } // for j
                        } // for i
-                       d.convertMCU(mx, my, int(d.comps[0].h), int(d.comps[0].v))
+                       d.storeMCU(mx, my)
                        mcu++
                        if d.ri > 0 && mcu%d.ri == 0 && mcu < mxx*myy {
                                // A more sophisticated decoder could use RST[0-7] markers to resynchronize from corrupt input,
                        return nil, err
                }
        }
-       return d.image, nil
+       return d.img, nil
 }
 
 // Decode reads a JPEG image from r and returns it as an image.Image.