]> Cypherpunks repositories - gostls13.git/commitdiff
net/textproto: always copy the data from bufio to avoid corruption
authorAndrew Gerrand <adg@golang.org>
Thu, 12 Jan 2012 03:15:58 +0000 (14:15 +1100)
committerAndrew Gerrand <adg@golang.org>
Thu, 12 Jan 2012 03:15:58 +0000 (14:15 +1100)
Fixes #2621.

R=rsc, rsc
CC=golang-dev
https://golang.org/cl/5498104

src/pkg/net/textproto/reader.go
src/pkg/net/textproto/reader_test.go

index 793c6c2c83e5fb9b8ef1ec3bb7a6b72b533711bf..862cd536c467e237ec6f56d56195ea5fa255444c 100644 (file)
@@ -22,6 +22,7 @@ import (
 type Reader struct {
        R   *bufio.Reader
        dot *dotReader
+       buf []byte // a re-usable buffer for readContinuedLineSlice
 }
 
 // NewReader returns a new Reader reading from r.
@@ -121,74 +122,44 @@ func (r *Reader) readContinuedLineSlice() ([]byte, error) {
        // Read the first line.
        line, err := r.readLineSlice()
        if err != nil {
-               return line, err
+               return nil, err
        }
        if len(line) == 0 { // blank line - no continuation
                return line, nil
        }
-       line = trim(line)
 
-       copied := false
-       if r.R.Buffered() < 1 {
-               // ReadByte will flush the buffer; make a copy of the slice.
-               copied = true
-               line = append([]byte(nil), line...)
-       }
-
-       // Look for a continuation line.
-       c, err := r.R.ReadByte()
-       if err != nil {
-               // Delay err until we read the byte next time.
-               return line, nil
-       }
-       if c != ' ' && c != '\t' {
-               // Not a continuation.
-               r.R.UnreadByte()
-               return line, nil
-       }
-
-       if !copied {
-               // The next readLineSlice will invalidate the previous one.
-               line = append(make([]byte, 0, len(line)*2), line...)
-       }
+       // ReadByte or the next readLineSlice will flush the read buffer;
+       // copy the slice into buf.
+       r.buf = append(r.buf[:0], trim(line)...)
 
        // Read continuation lines.
-       for {
-               // Consume leading spaces; one already gone.
-               for {
-                       c, err = r.R.ReadByte()
-                       if err != nil {
-                               break
-                       }
-                       if c != ' ' && c != '\t' {
-                               r.R.UnreadByte()
-                               break
-                       }
-               }
-               var cont []byte
-               cont, err = r.readLineSlice()
-               cont = trim(cont)
-               line = append(line, ' ')
-               line = append(line, cont...)
+       for r.skipSpace() > 0 {
+               line, err := r.readLineSlice()
                if err != nil {
                        break
                }
+               r.buf = append(r.buf, ' ')
+               r.buf = append(r.buf, line...)
+       }
+       return r.buf, nil
+}
 
-               // Check for leading space on next line.
-               if c, err = r.R.ReadByte(); err != nil {
+// skipSpace skips R over all spaces and returns the number of bytes skipped.
+func (r *Reader) skipSpace() int {
+       n := 0
+       for {
+               c, err := r.R.ReadByte()
+               if err != nil {
+                       // Bufio will keep err until next read.
                        break
                }
                if c != ' ' && c != '\t' {
                        r.R.UnreadByte()
                        break
                }
+               n++
        }
-
-       // Delay error until next call.
-       if len(line) > 0 {
-               err = nil
-       }
-       return line, err
+       return n
 }
 
 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
index 0460c1c8deeb5557c07524ebed9a39f1ee9e7dd3..4d036914801f9b8fb30fe51fa16be01fead1dbaa 100644 (file)
@@ -138,6 +138,15 @@ func TestReadMIMEHeader(t *testing.T) {
        }
 }
 
+func TestReadMIMEHeaderSingle(t *testing.T) {
+       r := reader("Foo: bar\n\n")
+       m, err := r.ReadMIMEHeader()
+       want := MIMEHeader{"Foo": {"bar"}}
+       if !reflect.DeepEqual(m, want) || err != nil {
+               t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want)
+       }
+}
+
 func TestLargeReadMIMEHeader(t *testing.T) {
        data := make([]byte, 16*1024)
        for i := 0; i < len(data); i++ {