]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix sniffing when using ReadFrom.
authorDavid Symonds <dsymonds@golang.org>
Wed, 9 Nov 2011 04:48:05 +0000 (15:48 +1100)
committerDavid Symonds <dsymonds@golang.org>
Wed, 9 Nov 2011 04:48:05 +0000 (15:48 +1100)
R=golang-dev, rsc, bradfitz
CC=golang-dev
https://golang.org/cl/5362046

src/pkg/net/http/server.go
src/pkg/net/http/sniff_test.go

index 8c4889436f1c6006dd8d52f31bf00305719de9f6..7221d2508bb4e4856e8f61080160190fbf295f5e 100644 (file)
@@ -149,11 +149,13 @@ type writerOnly struct {
 }
 
 func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
-       // Flush before checking w.chunking, as Flush will call
-       // WriteHeader if it hasn't been called yet, and WriteHeader
-       // is what sets w.chunking.
-       w.Flush()
+       // Call WriteHeader before checking w.chunking if it hasn't
+       // been called yet, since WriteHeader is what sets w.chunking.
+       if !w.wroteHeader {
+               w.WriteHeader(StatusOK)
+       }
        if !w.chunking && w.bodyAllowed() && !w.needSniff {
+               w.Flush()
                if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
                        n, err = rf.ReadFrom(src)
                        w.written += n
index a414e6420db88c39058cfd6b2f59d1faeb6214b7..56d589a150161078115b88aab751c86c4ebed44a 100644 (file)
@@ -6,6 +6,7 @@ package http_test
 
 import (
        "bytes"
+       "io"
        "io/ioutil"
        "log"
        . "net/http"
@@ -79,3 +80,35 @@ func TestServerContentType(t *testing.T) {
                resp.Body.Close()
        }
 }
+
+func TestContentTypeWithCopy(t *testing.T) {
+       const (
+               input    = "\n<html>\n\t<head>\n"
+               expected = "text/html; charset=utf-8"
+       )
+
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               // Use io.Copy from a bytes.Buffer to trigger ReadFrom.
+               buf := bytes.NewBuffer([]byte(input))
+               n, err := io.Copy(w, buf)
+               if int(n) != len(input) || err != nil {
+                       t.Fatalf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+               }
+       }))
+       defer ts.Close()
+
+       resp, err := Get(ts.URL)
+       if err != nil {
+               t.Fatalf("Get: %v", err)
+       }
+       if ct := resp.Header.Get("Content-Type"); ct != expected {
+               t.Errorf("Content-Type = %q, want %q", ct, expected)
+       }
+       data, err := ioutil.ReadAll(resp.Body)
+       if err != nil {
+               t.Errorf("reading body: %v", err)
+       } else if !bytes.Equal(data, []byte(input)) {
+               t.Errorf("data is %q, want %q", data, input)
+       }
+       resp.Body.Close()
+}