]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.r57] mime/multipart: fix regression from previous ReadSlice change
authorRuss Cox <rsc@golang.org>
Tue, 3 May 2011 05:15:56 +0000 (01:15 -0400)
committerRuss Cox <rsc@golang.org>
Tue, 3 May 2011 05:15:56 +0000 (01:15 -0400)
««« CL 4432083 / 698b5ea9c782
mime/multipart: fix regression from previous ReadSlice change

The previous change to make multipart use ReadSlice out of
paranoia broke multipart to not deal with large lines in
the bodies.

We should only be paranoid about long lines in the header
sections.

Fixes http://code.google.com/p/camlistore/issues/detail?id=4

R=adg
CC=golang-dev
https://golang.org/cl/4432083
»»»

R=adg
CC=golang-dev
https://golang.org/cl/4452062

src/pkg/mime/multipart/multipart.go
src/pkg/mime/multipart/multipart_test.go

index e0b747c3fb7a74b74ee7ec484ef1073249142b1d..60329fe17b597c0d46f5ffc5af61bf9de8cbcc2d 100644 (file)
@@ -15,13 +15,13 @@ package multipart
 import (
        "bufio"
        "bytes"
+       "fmt"
        "io"
        "io/ioutil"
        "mime"
        "net/textproto"
        "os"
        "regexp"
-       "strings"
 )
 
 var headerRegexp *regexp.Regexp = regexp.MustCompile("^([a-zA-Z0-9\\-]+): *([^\r\n]+)")
@@ -79,25 +79,28 @@ func (p *Part) FormName() string {
 // NewReader creates a new multipart Reader reading from r using the
 // given MIME boundary.
 func NewReader(reader io.Reader, boundary string) Reader {
+       b := []byte("\r\n--" + boundary + "--")
        return &multiReader{
-               boundary:     boundary,
-               dashBoundary: "--" + boundary,
-               endLine:      "--" + boundary + "--",
-               bufReader:    bufio.NewReader(reader),
+               bufReader: bufio.NewReader(reader),
+
+               nlDashBoundary:   b[:len(b)-2],
+               dashBoundaryDash: b[2:],
+               dashBoundary:     b[2 : len(b)-2],
        }
 }
 
 // Implementation ....
 
-func newPart(mr *multiReader) (bp *Part, err os.Error) {
-       bp = new(Part)
-       bp.Header = make(map[string][]string)
-       bp.mr = mr
-       bp.buffer = new(bytes.Buffer)
-       if err = bp.populateHeaders(); err != nil {
-               bp = nil
+func newPart(mr *multiReader) (*Part, os.Error) {
+       bp := &Part{
+               Header: make(map[string][]string),
+               mr:     mr,
+               buffer: new(bytes.Buffer),
        }
-       return
+       if err := bp.populateHeaders(); err != nil {
+               return nil, err
+       }
+       return bp, nil
 }
 
 func (bp *Part) populateHeaders() os.Error {
@@ -122,44 +125,49 @@ func (bp *Part) populateHeaders() os.Error {
 // Read reads the body of a part, after its headers and before the
 // next part (if any) begins.
 func (bp *Part) Read(p []byte) (n int, err os.Error) {
-       for {
-               if bp.buffer.Len() >= len(p) {
-                       // Internal buffer of unconsumed data is large enough for
-                       // the read request.  No need to parse more at the moment.
-                       break
-               }
-               if !bp.mr.ensureBufferedLine() {
-                       return 0, io.ErrUnexpectedEOF
-               }
-               if bp.mr.bufferedLineIsBoundary() {
-                       // Don't consume this line
-                       break
-               }
+       if bp.buffer.Len() >= len(p) {
+               // Internal buffer of unconsumed data is large enough for
+               // the read request.  No need to parse more at the moment.
+               return bp.buffer.Read(p)
+       }
+       peek, err := bp.mr.bufReader.Peek(4096) // TODO(bradfitz): add buffer size accessor
+       unexpectedEof := err == os.EOF
+       if err != nil && !unexpectedEof {
+               return 0, fmt.Errorf("multipart: Part Read: %v", err)
+       }
+       if peek == nil {
+               panic("nil peek buf")
+       }
 
-               // Write all of this line, except the final CRLF
-               s := *bp.mr.bufferedLine
-               if strings.HasSuffix(s, "\r\n") {
-                       bp.mr.consumeLine()
-                       if !bp.mr.ensureBufferedLine() {
-                               return 0, io.ErrUnexpectedEOF
-                       }
-                       if bp.mr.bufferedLineIsBoundary() {
-                               // The final \r\n isn't ours.  It logically belongs
-                               // to the boundary line which follows.
-                               bp.buffer.WriteString(s[0 : len(s)-2])
-                       } else {
-                               bp.buffer.WriteString(s)
-                       }
-                       break
-               }
-               if strings.HasSuffix(s, "\n") {
-                       bp.buffer.WriteString(s)
-                       bp.mr.consumeLine()
-                       continue
+       // Search the peek buffer for "\r\n--boundary". If found,
+       // consume everything up to the boundary. If not, consume only
+       // as much of the peek buffer as cannot hold the boundary
+       // string.
+       nCopy := 0
+       foundBoundary := false
+       if idx := bytes.Index(peek, bp.mr.nlDashBoundary); idx != -1 {
+               nCopy = idx
+               foundBoundary = true
+       } else if safeCount := len(peek) - len(bp.mr.nlDashBoundary); safeCount > 0 {
+               nCopy = safeCount
+       } else if unexpectedEof {
+               // If we've run out of peek buffer and the boundary
+               // wasn't found (and can't possibly fit), we must have
+               // hit the end of the file unexpectedly.
+               return 0, io.ErrUnexpectedEOF
+       }
+       if nCopy > 0 {
+               if _, err := io.Copyn(bp.buffer, bp.mr.bufReader, int64(nCopy)); err != nil {
+                       return 0, err
                }
-               return 0, os.NewError("multipart parse error during Read; unexpected line: " + s)
        }
-       return bp.buffer.Read(p)
+       n, err = bp.buffer.Read(p)
+       if err == os.EOF && !foundBoundary {
+               // If the boundary hasn't been reached there's more to
+               // read, so don't pass through an EOF from the buffer
+               err = nil
+       }
+       return
 }
 
 func (bp *Part) Close() os.Error {
@@ -168,46 +176,12 @@ func (bp *Part) Close() os.Error {
 }
 
 type multiReader struct {
-       boundary     string
-       dashBoundary string // --boundary
-       endLine      string // --boundary--
+       bufReader *bufio.Reader
 
-       bufferedLine *string
-
-       bufReader   *bufio.Reader
        currentPart *Part
        partsRead   int
-}
 
-func (mr *multiReader) eof() bool {
-       return mr.bufferedLine == nil &&
-               !mr.readLine()
-}
-
-func (mr *multiReader) readLine() bool {
-       lineBytes, err := mr.bufReader.ReadSlice('\n')
-       if err != nil {
-               // TODO: care about err being EOF or not?
-               return false
-       }
-       line := string(lineBytes)
-       mr.bufferedLine = &line
-       return true
-}
-
-func (mr *multiReader) bufferedLineIsBoundary() bool {
-       return strings.HasPrefix(*mr.bufferedLine, mr.dashBoundary)
-}
-
-func (mr *multiReader) ensureBufferedLine() bool {
-       if mr.bufferedLine == nil {
-               return mr.readLine()
-       }
-       return true
-}
-
-func (mr *multiReader) consumeLine() {
-       mr.bufferedLine = nil
+       nlDashBoundary, dashBoundaryDash, dashBoundary []byte
 }
 
 func (mr *multiReader) NextPart() (*Part, os.Error) {
@@ -215,13 +189,14 @@ func (mr *multiReader) NextPart() (*Part, os.Error) {
                mr.currentPart.Close()
        }
 
+       expectNewPart := false
        for {
-               if mr.eof() {
-                       return nil, io.ErrUnexpectedEOF
+               line, err := mr.bufReader.ReadSlice('\n')
+               if err != nil {
+                       return nil, fmt.Errorf("multipart: NextPart: %v", err)
                }
 
-               if isBoundaryDelimiterLine(*mr.bufferedLine, mr.dashBoundary) {
-                       mr.consumeLine()
+               if mr.isBoundaryDelimiterLine(line) {
                        mr.partsRead++
                        bp, err := newPart(mr)
                        if err != nil {
@@ -231,55 +206,67 @@ func (mr *multiReader) NextPart() (*Part, os.Error) {
                        return bp, nil
                }
 
-               if hasPrefixThenNewline(*mr.bufferedLine, mr.endLine) {
-                       mr.consumeLine()
+               if hasPrefixThenNewline(line, mr.dashBoundaryDash) {
                        // Expected EOF (no error)
+                       // TODO(bradfitz): should return an os.EOF error here, not using nil for errors
                        return nil, nil
                }
 
+               if expectNewPart {
+                       return nil, fmt.Errorf("multipart: expecting a new Part; got line %q", string(line))
+               }
+
                if mr.partsRead == 0 {
                        // skip line
-                       mr.consumeLine()
                        continue
                }
 
-               return nil, os.NewError("Unexpected line in Next().")
+               if bytes.Equal(line, []byte("\r\n")) {
+                       // Consume the "\r\n" separator between the
+                       // body of the previous part and the boundary
+                       // line we now expect will follow. (either a
+                       // new part or the end boundary)
+                       expectNewPart = true
+                       continue
+               }
+
+               return nil, fmt.Errorf("multipart: unexpected line in Next(): %q", line)
        }
        panic("unreachable")
 }
 
-func isBoundaryDelimiterLine(line, dashPrefix string) bool {
+func (mr *multiReader) isBoundaryDelimiterLine(line []byte) bool {
        // http://tools.ietf.org/html/rfc2046#section-5.1
        //   The boundary delimiter line is then defined as a line
        //   consisting entirely of two hyphen characters ("-",
        //   decimal value 45) followed by the boundary parameter
        //   value from the Content-Type header field, optional linear
        //   whitespace, and a terminating CRLF.
-       if !strings.HasPrefix(line, dashPrefix) {
+       if !bytes.HasPrefix(line, mr.dashBoundary) {
                return false
        }
-       if strings.HasSuffix(line, "\r\n") {
-               return onlyHorizontalWhitespace(line[len(dashPrefix) : len(line)-2])
+       if bytes.HasSuffix(line, []byte("\r\n")) {
+               return onlyHorizontalWhitespace(line[len(mr.dashBoundary) : len(line)-2])
        }
        // Violate the spec and also support newlines without the
        // carriage return...
-       if strings.HasSuffix(line, "\n") {
-               return onlyHorizontalWhitespace(line[len(dashPrefix) : len(line)-1])
+       if bytes.HasSuffix(line, []byte("\n")) {
+               return onlyHorizontalWhitespace(line[len(mr.dashBoundary) : len(line)-1])
        }
        return false
 }
 
-func onlyHorizontalWhitespace(s string) bool {
-       for i := 0; i < len(s); i++ {
-               if s[i] != ' ' && s[i] != '\t' {
+func onlyHorizontalWhitespace(s []byte) bool {
+       for _, b := range s {
+               if b != ' ' && b != '\t' {
                        return false
                }
        }
        return true
 }
 
-func hasPrefixThenNewline(s, prefix string) bool {
-       return strings.HasPrefix(s, prefix) &&
-               (len(s) == len(prefix)+1 && strings.HasSuffix(s, "\n") ||
-                       len(s) == len(prefix)+2 && strings.HasSuffix(s, "\r\n"))
+func hasPrefixThenNewline(s, prefix []byte) bool {
+       return bytes.HasPrefix(s, prefix) &&
+               (len(s) == len(prefix)+1 && s[len(s)-1] == '\n' ||
+                       len(s) == len(prefix)+2 && bytes.HasSuffix(s, []byte("\r\n")))
 }
index f8f10f3e162b2448df0a2db1d2082d97978836eb..16249146c9d0d09fc70b2da15cab4b5fdfa55d2f 100644 (file)
@@ -8,38 +8,37 @@ import (
        "bytes"
        "fmt"
        "io"
+       "io/ioutil"
        "json"
        "os"
-       "regexp"
        "strings"
        "testing"
 )
 
 func TestHorizontalWhitespace(t *testing.T) {
-       if !onlyHorizontalWhitespace(" \t") {
+       if !onlyHorizontalWhitespace([]byte(" \t")) {
                t.Error("expected pass")
        }
-       if onlyHorizontalWhitespace("foo bar") {
+       if onlyHorizontalWhitespace([]byte("foo bar")) {
                t.Error("expected failure")
        }
 }
 
 func TestBoundaryLine(t *testing.T) {
-       boundary := "myBoundary"
-       prefix := "--" + boundary
-       if !isBoundaryDelimiterLine("--myBoundary\r\n", prefix) {
+       mr := NewReader(strings.NewReader(""), "myBoundary").(*multiReader)
+       if !mr.isBoundaryDelimiterLine([]byte("--myBoundary\r\n")) {
                t.Error("expected")
        }
-       if !isBoundaryDelimiterLine("--myBoundary \r\n", prefix) {
+       if !mr.isBoundaryDelimiterLine([]byte("--myBoundary \r\n")) {
                t.Error("expected")
        }
-       if !isBoundaryDelimiterLine("--myBoundary \n", prefix) {
+       if !mr.isBoundaryDelimiterLine([]byte("--myBoundary \n")) {
                t.Error("expected")
        }
-       if isBoundaryDelimiterLine("--myBoundary bogus \n", prefix) {
+       if mr.isBoundaryDelimiterLine([]byte("--myBoundary bogus \n")) {
                t.Error("expected fail")
        }
-       if isBoundaryDelimiterLine("--myBoundary bogus--", prefix) {
+       if mr.isBoundaryDelimiterLine([]byte("--myBoundary bogus--")) {
                t.Error("expected fail")
        }
 }
@@ -79,7 +78,9 @@ func TestFormName(t *testing.T) {
        }
 }
 
-func TestMultipart(t *testing.T) {
+var longLine = strings.Repeat("\n\n\r\r\r\n\r\000", (1<<20)/8)
+
+func testMultipartBody() string {
        testBody := `
 This is a multi-part message.  This line is ignored.
 --MyBoundary
@@ -90,6 +91,10 @@ foo-bar: baz
 My value
 The end.
 --MyBoundary
+name: bigsection
+
+[longline]
+--MyBoundary
 Header1: value1b
 HEADER2: value2b
 foo-bar: bazb
@@ -102,11 +107,26 @@ Line 3 ends in a newline, but just one.
 
 never read data
 --MyBoundary--
+
+
+useless trailer
 `
-       testBody = regexp.MustCompile("\n").ReplaceAllString(testBody, "\r\n")
-       bodyReader := strings.NewReader(testBody)
+       testBody = strings.Replace(testBody, "\n", "\r\n", -1)
+       return strings.Replace(testBody, "[longline]", longLine, 1)
+}
+
+func TestMultipart(t *testing.T) {
+       bodyReader := strings.NewReader(testMultipartBody())
+       testMultipart(t, bodyReader)
+}
+
+func TestMultipartSlowInput(t *testing.T) {
+       bodyReader := strings.NewReader(testMultipartBody())
+       testMultipart(t, &slowReader{bodyReader})
+}
 
-       reader := NewReader(bodyReader, "MyBoundary")
+func testMultipart(t *testing.T, r io.Reader) {
+       reader := NewReader(r, "MyBoundary")
        buf := new(bytes.Buffer)
 
        // Part1
@@ -125,38 +145,64 @@ never read data
                t.Error("Expected Foo-Bar: baz")
        }
        buf.Reset()
-       io.Copy(buf, part)
+       if _, err := io.Copy(buf, part); err != nil {
+               t.Errorf("part 1 copy: %v", err)
+       }
        expectEq(t, "My value\r\nThe end.",
                buf.String(), "Value of first part")
 
        // Part2
        part, err = reader.NextPart()
+       if err != nil {
+               t.Fatalf("Expected part2; got: %v", err)
+               return
+       }
+       if e, g := "bigsection", part.Header.Get("name"); e != g {
+               t.Errorf("part2's name header: expected %q, got %q", e, g)
+       }
+       buf.Reset()
+       if _, err := io.Copy(buf, part); err != nil {
+               t.Errorf("part 2 copy: %v", err)
+       }
+       s := buf.String()
+       if len(s) != len(longLine) {
+               t.Errorf("part2 body expected long line of length %d; got length %d",
+                       len(longLine), len(s))
+       }
+       if s != longLine {
+               t.Errorf("part2 long body didn't match")
+       }
+
+       // Part3
+       part, err = reader.NextPart()
        if part == nil || err != nil {
-               t.Error("Expected part2")
+               t.Error("Expected part3")
                return
        }
        if part.Header.Get("foo-bar") != "bazb" {
                t.Error("Expected foo-bar: bazb")
        }
        buf.Reset()
-       io.Copy(buf, part)
+       if _, err := io.Copy(buf, part); err != nil {
+               t.Errorf("part 3 copy: %v", err)
+       }
        expectEq(t, "Line 1\r\nLine 2\r\nLine 3 ends in a newline, but just one.\r\n",
-               buf.String(), "Value of second part")
+               buf.String(), "body of part 3")
 
-       // Part3
+       // Part4
        part, err = reader.NextPart()
        if part == nil || err != nil {
-               t.Error("Expected part3 without errors")
+               t.Error("Expected part 4 without errors")
                return
        }
 
-       // Non-existent part4
+       // Non-existent part5
        part, err = reader.NextPart()
        if part != nil {
-               t.Error("Didn't expect a third part.")
+               t.Error("Didn't expect a fifth part.")
        }
        if err != nil {
-               t.Errorf("Unexpected error getting third part: %v", err)
+               t.Errorf("Unexpected error getting fifth part: %v", err)
        }
 }
 
@@ -237,3 +283,36 @@ func TestLineLimit(t *testing.T) {
                t.Errorf("expected to read < %d bytes; read %d", maxReadThreshold, mr.n)
        }
 }
+
+func TestMultipartTruncated(t *testing.T) {
+       testBody := `
+This is a multi-part message.  This line is ignored.
+--MyBoundary
+foo-bar: baz
+
+Oh no, premature EOF!
+`
+       body := strings.Replace(testBody, "\n", "\r\n", -1)
+       bodyReader := strings.NewReader(body)
+       r := NewReader(bodyReader, "MyBoundary")
+
+       part, err := r.NextPart()
+       if err != nil {
+               t.Fatalf("didn't get a part")
+       }
+       _, err = io.Copy(ioutil.Discard, part)
+       if err != io.ErrUnexpectedEOF {
+               t.Fatalf("expected error io.ErrUnexpectedEOF; got %v", err)
+       }
+}
+
+type slowReader struct {
+       r io.Reader
+}
+
+func (s *slowReader) Read(p []byte) (int, os.Error) {
+       if len(p) == 0 {
+               return s.r.Read(p)
+       }
+       return s.r.Read(p[:1])
+}