]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't buffer request writing if dest is already buffered
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 9 Jan 2013 18:33:46 +0000 (10:33 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 9 Jan 2013 18:33:46 +0000 (10:33 -0800)
The old code made it impossible to implement a reverse proxy
with anything less than 4k write granularity to the backends.

R=golang-dev, adg
CC=golang-dev
https://golang.org/cl/7060059

src/pkg/net/http/request.go
src/pkg/net/http/request_test.go

index 3b799108ac00fcaab115eabb7c107a049603e628..217f35b48331ee8bf06c3885ad58ec74e37bf550 100644 (file)
@@ -331,11 +331,20 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
        }
        // TODO(bradfitz): escape at least newlines in ruri?
 
-       bw := bufio.NewWriter(w)
-       fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+       // Wrap the writer in a bufio Writer if it's not already buffered.
+       // Don't always call NewWriter, as that forces a bytes.Buffer
+       // and other small bufio Writers to have a minimum 4k buffer
+       // size.
+       var bw *bufio.Writer
+       if _, ok := w.(io.ByteWriter); !ok {
+               bw = bufio.NewWriter(w)
+               w = bw
+       }
+
+       fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
 
        // Header lines
-       fmt.Fprintf(bw, "Host: %s\r\n", host)
+       fmt.Fprintf(w, "Host: %s\r\n", host)
 
        // Use the defaultUserAgent unless the Header contains one, which
        // may be blank to not send the header.
@@ -346,7 +355,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
                }
        }
        if userAgent != "" {
-               fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent)
+               fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
        }
 
        // Process Body,ContentLength,Close,Trailer
@@ -354,33 +363,36 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
        if err != nil {
                return err
        }
-       err = tw.WriteHeader(bw)
+       err = tw.WriteHeader(w)
        if err != nil {
                return err
        }
 
        // TODO: split long values?  (If so, should share code with Conn.Write)
-       err = req.Header.WriteSubset(bw, reqWriteExcludeHeader)
+       err = req.Header.WriteSubset(w, reqWriteExcludeHeader)
        if err != nil {
                return err
        }
 
        if extraHeaders != nil {
-               err = extraHeaders.Write(bw)
+               err = extraHeaders.Write(w)
                if err != nil {
                        return err
                }
        }
 
-       io.WriteString(bw, "\r\n")
+       io.WriteString(w, "\r\n")
 
        // Write body and trailer
-       err = tw.WriteBody(bw)
+       err = tw.WriteBody(w)
        if err != nil {
                return err
        }
 
-       return bw.Flush()
+       if bw != nil {
+               return bw.Flush()
+       }
+       return nil
 }
 
 // ParseHTTPVersion parses a HTTP version string.
index fc485fcdf864e101db698b67b6bf4ca069e11ca5..bd757920b7808fc96b24923afe18d7730ad819cb 100644 (file)
@@ -267,6 +267,36 @@ func TestNewRequestContentLength(t *testing.T) {
        }
 }
 
+type logWrites struct {
+       t   *testing.T
+       dst *[]string
+}
+
+func (l logWrites) WriteByte(c byte) error {
+       l.t.Fatalf("unexpected WriteByte call")
+       return nil
+}
+
+func (l logWrites) Write(p []byte) (n int, err error) {
+       *l.dst = append(*l.dst, string(p))
+       return len(p), nil
+}
+
+func TestRequestWriteBufferedWriter(t *testing.T) {
+       got := []string{}
+       req, _ := NewRequest("GET", "http://foo.com/", nil)
+       req.Write(logWrites{t, &got})
+       want := []string{
+               "GET / HTTP/1.1\r\n",
+               "Host: foo.com\r\n",
+               "User-Agent: Go http package\r\n",
+               "\r\n",
+       }
+       if !reflect.DeepEqual(got, want) {
+               t.Errorf("Writes = %q\n  Want = %q", got, want)
+       }
+}
+
 func testMissingFile(t *testing.T, req *Request) {
        f, fh, err := req.FormFile("missing")
        if f != nil {