From 9c6a73e478e6e46859c68057144b8c3297e7a881 Mon Sep 17 00:00:00 2001 From: David Symonds Date: Wed, 9 Nov 2011 15:48:05 +1100 Subject: [PATCH] net/http: fix sniffing when using ReadFrom. R=golang-dev, rsc, bradfitz CC=golang-dev https://golang.org/cl/5362046 --- src/pkg/net/http/server.go | 10 ++++++---- src/pkg/net/http/sniff_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 8c4889436f..7221d2508b 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -149,11 +149,13 @@ type writerOnly struct { } func (w *response) ReadFrom(src io.Reader) (n int64, err error) { - // Flush before checking w.chunking, as Flush will call - // WriteHeader if it hasn't been called yet, and WriteHeader - // is what sets w.chunking. - w.Flush() + // Call WriteHeader before checking w.chunking if it hasn't + // been called yet, since WriteHeader is what sets w.chunking. + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } if !w.chunking && w.bodyAllowed() && !w.needSniff { + w.Flush() if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { n, err = rf.ReadFrom(src) w.written += n diff --git a/src/pkg/net/http/sniff_test.go b/src/pkg/net/http/sniff_test.go index a414e6420d..56d589a150 100644 --- a/src/pkg/net/http/sniff_test.go +++ b/src/pkg/net/http/sniff_test.go @@ -6,6 +6,7 @@ package http_test import ( "bytes" + "io" "io/ioutil" "log" . "net/http" @@ -79,3 +80,35 @@ func TestServerContentType(t *testing.T) { resp.Body.Close() } } + +func TestContentTypeWithCopy(t *testing.T) { + const ( + input = "\n\n\t\n" + expected = "text/html; charset=utf-8" + ) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Use io.Copy from a bytes.Buffer to trigger ReadFrom. + buf := bytes.NewBuffer([]byte(input)) + n, err := io.Copy(w, buf) + if int(n) != len(input) || err != nil { + t.Fatalf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input)) + } + })) + defer ts.Close() + + resp, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if ct := resp.Header.Get("Content-Type"); ct != expected { + t.Errorf("Content-Type = %q, want %q", ct, expected) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("reading body: %v", err) + } else if !bytes.Equal(data, []byte(input)) { + t.Errorf("data is %q, want %q", data, input) + } + resp.Body.Close() +} -- 2.50.0