]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: remove more garbage from chunk reading
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Nov 2012 03:50:42 +0000 (19:50 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Nov 2012 03:50:42 +0000 (19:50 -0800)
Noticed this while closing tabs. Yesterday I thought I could
ignore this garbage and hope that a fix for issue 2205 handled
it, but I just realized that's the opposite case,
string->[]byte, whereas this is []byte->string.  I'm having a
hard time convincing myself that an Issue 2205-style fix with
static analysis and faking a string header would be safe in
all cases without violating the memory model (callee assumes
frozen memory; are there non-racy ways it could keep being
modified?)

R=dsymonds
CC=dave, gobot, golang-dev
https://golang.org/cl/6850067

src/pkg/net/http/chunked.go
src/pkg/net/http/chunked_test.go
src/pkg/net/http/httputil/chunked.go
src/pkg/net/http/httputil/chunked_test.go

index 7cf39cfa5fcee812534f5ca96c898ee9066e6707..91db017245656824884e4c4eb294913026d80094 100644 (file)
@@ -11,11 +11,9 @@ package http
 
 import (
        "bufio"
-       "bytes"
        "errors"
        "fmt"
        "io"
-       "strconv"
 )
 
 const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -45,12 +43,12 @@ type chunkedReader struct {
 
 func (cr *chunkedReader) beginChunk() {
        // chunk-size CRLF
-       var line string
+       var line []byte
        line, cr.err = readLine(cr.r)
        if cr.err != nil {
                return
        }
-       cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+       cr.n, cr.err = parseHexUint(line)
        if cr.err != nil {
                return
        }
@@ -89,7 +87,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
 // Give up if the line exceeds maxLineLength.
 // The returned bytes are a pointer into storage in
 // the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
        if p, err = b.ReadSlice('\n'); err != nil {
                // We always know when EOF is coming.
                // If the caller asked for a line, there should be a line.
@@ -103,20 +101,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
        if len(p) >= maxLineLength {
                return nil, ErrLineTooLong
        }
-
-       // Chop off trailing white space.
-       p = bytes.TrimRight(p, " \r\t\n")
-
-       return p, nil
+       return trimTrailingWhitespace(p), nil
 }
 
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
-       p, e := readLineBytes(b)
-       if e != nil {
-               return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+       for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+               b = b[:len(b)-1]
        }
-       return string(p), nil
+       return b
+}
+
+func isASCIISpace(b byte) bool {
+       return b == ' ' || b == '\t' || b == '\n' || b == '\r'
 }
 
 // newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -167,3 +163,21 @@ func (cw *chunkedWriter) Close() error {
        _, err := io.WriteString(cw.Wire, "0\r\n")
        return err
 }
+
+func parseHexUint(v []byte) (n uint64, err error) {
+       for _, b := range v {
+               n <<= 4
+               switch {
+               case '0' <= b && b <= '9':
+                       b = b - '0'
+               case 'a' <= b && b <= 'f':
+                       b = b - 'a' + 10
+               case 'A' <= b && b <= 'F':
+                       b = b - 'A' + 10
+               default:
+                       return 0, errors.New("invalid byte in chunk length")
+               }
+               n |= uint64(b)
+       }
+       return
+}
index b77ee2ff26ce492e83581ad19252b45462fd906c..ad88eb16735accbb3f5b5ac1e1120fb58596e28a 100644 (file)
@@ -9,7 +9,10 @@ package http
 
 import (
        "bytes"
+       "fmt"
+       "io"
        "io/ioutil"
+       "runtime"
        "testing"
 )
 
@@ -37,3 +40,52 @@ func TestChunk(t *testing.T) {
                t.Errorf("chunk reader read %q; want %q", g, e)
        }
 }
+
+func TestChunkReaderAllocs(t *testing.T) {
+       var buf bytes.Buffer
+       w := newChunkedWriter(&buf)
+       a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+       w.Write(a)
+       w.Write(b)
+       w.Write(c)
+       w.Close()
+
+       r := newChunkedReader(&buf)
+       readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+       var ms runtime.MemStats
+       runtime.ReadMemStats(&ms)
+       m0 := ms.Mallocs
+
+       n, err := io.ReadFull(r, readBuf)
+
+       runtime.ReadMemStats(&ms)
+       mallocs := ms.Mallocs - m0
+       if mallocs > 1 {
+               t.Errorf("%d mallocs; want <= 1", mallocs)
+       }
+
+       if n != len(readBuf)-1 {
+               t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+       }
+       if err != io.ErrUnexpectedEOF {
+               t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+       }
+}
+
+func TestParseHexUint(t *testing.T) {
+       for i := uint64(0); i <= 1234; i++ {
+               line := []byte(fmt.Sprintf("%x", i))
+               got, err := parseHexUint(line)
+               if err != nil {
+                       t.Fatalf("on %d: %v", i, err)
+               }
+               if got != i {
+                       t.Errorf("for input %q = %d; want %d", line, got, i)
+               }
+       }
+       _, err := parseHexUint([]byte("bogus"))
+       if err == nil {
+               t.Error("expected error on bogus input")
+       }
+}
index 26daee5f2c74f80ebb822c24072fe3917cb24396..b66d4095153f91fe18343317ee92f20ead9ff849 100644 (file)
@@ -13,11 +13,9 @@ package httputil
 
 import (
        "bufio"
-       "bytes"
        "errors"
        "fmt"
        "io"
-       "strconv"
 )
 
 const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -47,12 +45,12 @@ type chunkedReader struct {
 
 func (cr *chunkedReader) beginChunk() {
        // chunk-size CRLF
-       var line string
+       var line []byte
        line, cr.err = readLine(cr.r)
        if cr.err != nil {
                return
        }
-       cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+       cr.n, cr.err = parseHexUint(line)
        if cr.err != nil {
                return
        }
@@ -91,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
 // Give up if the line exceeds maxLineLength.
 // The returned bytes are a pointer into storage in
 // the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
        if p, err = b.ReadSlice('\n'); err != nil {
                // We always know when EOF is coming.
                // If the caller asked for a line, there should be a line.
@@ -105,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
        if len(p) >= maxLineLength {
                return nil, ErrLineTooLong
        }
-
-       // Chop off trailing white space.
-       p = bytes.TrimRight(p, " \r\t\n")
-
-       return p, nil
+       return trimTrailingWhitespace(p), nil
 }
 
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
-       p, e := readLineBytes(b)
-       if e != nil {
-               return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+       for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+               b = b[:len(b)-1]
        }
-       return string(p), nil
+       return b
+}
+
+func isASCIISpace(b byte) bool {
+       return b == ' ' || b == '\t' || b == '\n' || b == '\r'
 }
 
 // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -169,3 +165,21 @@ func (cw *chunkedWriter) Close() error {
        _, err := io.WriteString(cw.Wire, "0\r\n")
        return err
 }
+
+func parseHexUint(v []byte) (n uint64, err error) {
+       for _, b := range v {
+               n <<= 4
+               switch {
+               case '0' <= b && b <= '9':
+                       b = b - '0'
+               case 'a' <= b && b <= 'f':
+                       b = b - 'a' + 10
+               case 'A' <= b && b <= 'F':
+                       b = b - 'A' + 10
+               default:
+                       return 0, errors.New("invalid byte in chunk length")
+               }
+               n |= uint64(b)
+       }
+       return
+}
index 155a32bdf9a2aca3a371e503699cd5f5165ac186..22c1bb75489d7ec2c6028ff622d668b8e7af8a12 100644 (file)
@@ -11,7 +11,10 @@ package httputil
 
 import (
        "bytes"
+       "fmt"
+       "io"
        "io/ioutil"
+       "runtime"
        "testing"
 )
 
@@ -39,3 +42,52 @@ func TestChunk(t *testing.T) {
                t.Errorf("chunk reader read %q; want %q", g, e)
        }
 }
+
+func TestChunkReaderAllocs(t *testing.T) {
+       var buf bytes.Buffer
+       w := NewChunkedWriter(&buf)
+       a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+       w.Write(a)
+       w.Write(b)
+       w.Write(c)
+       w.Close()
+
+       r := NewChunkedReader(&buf)
+       readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+       var ms runtime.MemStats
+       runtime.ReadMemStats(&ms)
+       m0 := ms.Mallocs
+
+       n, err := io.ReadFull(r, readBuf)
+
+       runtime.ReadMemStats(&ms)
+       mallocs := ms.Mallocs - m0
+       if mallocs > 1 {
+               t.Errorf("%d mallocs; want <= 1", mallocs)
+       }
+
+       if n != len(readBuf)-1 {
+               t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+       }
+       if err != io.ErrUnexpectedEOF {
+               t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+       }
+}
+
+func TestParseHexUint(t *testing.T) {
+       for i := uint64(0); i <= 1234; i++ {
+               line := []byte(fmt.Sprintf("%x", i))
+               got, err := parseHexUint(line)
+               if err != nil {
+                       t.Fatalf("on %d: %v", i, err)
+               }
+               if got != i {
+                       t.Errorf("for input %q = %d; want %d", line, got, i)
+               }
+       }
+       _, err := parseHexUint([]byte("bogus"))
+       if err == nil {
+               t.Error("expected error on bogus input")
+       }
+}