]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: detect Content-Type in ResponseRecorder
authorNodir Turakulov <nodir@google.com>
Mon, 19 Oct 2015 21:36:25 +0000 (14:36 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Oct 2015 01:01:22 +0000 (01:01 +0000)
* detect Content-Type on ReponseRecorder.Write[String] call
  if header wasn't written yet, Content-Type header is not set and
  Transfer-Encoding is not set.
* fix typos in serve_test.go

Updates #12986

Change-Id: Id2ed8b1994e64657370fed71eb3882d611f76b31
Reviewed-on: https://go-review.googlesource.com/16096
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/httptest/recorder.go
src/net/http/httptest/recorder_test.go
src/net/http/serve_test.go

index 30c5140daee2e737c6061d906eea6b57991663d9..c813cf5021ae8198c057fa5f06e5840a28414ab5 100644 (file)
@@ -44,11 +44,36 @@ func (rw *ResponseRecorder) Header() http.Header {
        return m
 }
 
+// writeHeader writes a header if it was not written yet and
+// detects Content-Type if needed.
+//
+// bytes or str are the beginning of the response body.
+// We pass both to avoid unnecessarily generate garbage
+// in rw.WriteString which was created for performance reasons.
+// Non-nil bytes win.
+func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
+       if rw.wroteHeader {
+               return
+       }
+       if len(str) > 512 {
+               str = str[:512]
+       }
+
+       _, hasType := rw.HeaderMap["Content-Type"]
+       hasTE := rw.HeaderMap.Get("Transfer-Encoding") != ""
+       if !hasType && !hasTE {
+               if b == nil {
+                       b = []byte(str)
+               }
+               rw.HeaderMap.Set("Content-Type", http.DetectContentType(b))
+       }
+
+       rw.WriteHeader(200)
+}
+
 // Write always succeeds and writes to rw.Body, if not nil.
 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
-       if !rw.wroteHeader {
-               rw.WriteHeader(200)
-       }
+       rw.writeHeader(buf, "")
        if rw.Body != nil {
                rw.Body.Write(buf)
        }
@@ -57,9 +82,7 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
 
 // WriteString always succeeds and writes to rw.Body, if not nil.
 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
-       if !rw.wroteHeader {
-               rw.WriteHeader(200)
-       }
+       rw.writeHeader(nil, str)
        if rw.Body != nil {
                rw.Body.WriteString(str)
        }
@@ -70,8 +93,8 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) {
 func (rw *ResponseRecorder) WriteHeader(code int) {
        if !rw.wroteHeader {
                rw.Code = code
+               rw.wroteHeader = true
        }
-       rw.wroteHeader = true
 }
 
 // Flush sets rw.Flushed to true.
index bc486e6b63957b6f0a7f69ea8e502edc316a97c8..a5a1725fa98fff5aa1998b6629a9a1b48311c3a1 100644 (file)
@@ -39,6 +39,14 @@ func TestRecorder(t *testing.T) {
                        return nil
                }
        }
+       hasHeader := func(key, want string) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if got := rec.HeaderMap.Get(key); got != want {
+                               return fmt.Errorf("header %s = %q; want %q", key, got, want)
+                       }
+                       return nil
+               }
+       }
 
        tests := []struct {
                name   string
@@ -73,7 +81,12 @@ func TestRecorder(t *testing.T) {
                        func(w http.ResponseWriter, r *http.Request) {
                                io.WriteString(w, "hi first")
                        },
-                       check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+                       check(
+                               hasStatus(200),
+                               hasContents("hi first"),
+                               hasFlush(false),
+                               hasHeader("Content-Type", "text/plain; charset=utf-8"),
+                       ),
                },
                {
                        "flush",
@@ -83,6 +96,29 @@ func TestRecorder(t *testing.T) {
                        },
                        check(hasStatus(200), hasFlush(true)),
                },
+               {
+                       "Content-Type detection",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               io.WriteString(w, "<html>")
+                       },
+                       check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+               },
+               {
+                       "no Content-Type detection with Transfer-Encoding",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               w.Header().Set("Transfer-Encoding", "some encoding")
+                               io.WriteString(w, "<html>")
+                       },
+                       check(hasHeader("Content-Type", "")), // no header
+               },
+               {
+                       "no Content-Type detection if set explicitly",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               w.Header().Set("Content-Type", "some/type")
+                               io.WriteString(w, "<html>")
+                       },
+                       check(hasHeader("Content-Type", "some/type")),
+               },
        }
        r, _ := http.NewRequest("GET", "http://foo.com/", nil)
        for _, tt := range tests {
index 7a008274e72aba047fd9a7d0c9b17763e46a906b..f9c2accc98aefd2996609f9a8a5ac95750d37664 100644 (file)
@@ -2487,7 +2487,7 @@ func TestHeaderToWire(t *testing.T) {
                                if !strings.Contains(got, "404") {
                                        return errors.New("wrong status")
                                }
-                               if strings.Contains(got, "Some-Header") {
+                               if strings.Contains(got, "Too-Late") {
                                        return errors.New("shouldn't have seen Too-Late")
                                }
                                return nil