return m
}
+// writeHeader writes a header if it was not written yet and
+// detects Content-Type if needed.
+//
+// bytes or str are the beginning of the response body.
+// We pass both to avoid unnecessarily generate garbage
+// in rw.WriteString which was created for performance reasons.
+// Non-nil bytes win.
+func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
+ if rw.wroteHeader {
+ return
+ }
+ if len(str) > 512 {
+ str = str[:512]
+ }
+
+ _, hasType := rw.HeaderMap["Content-Type"]
+ hasTE := rw.HeaderMap.Get("Transfer-Encoding") != ""
+ if !hasType && !hasTE {
+ if b == nil {
+ b = []byte(str)
+ }
+ rw.HeaderMap.Set("Content-Type", http.DetectContentType(b))
+ }
+
+ rw.WriteHeader(200)
+}
+
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
- if !rw.wroteHeader {
- rw.WriteHeader(200)
- }
+ rw.writeHeader(buf, "")
if rw.Body != nil {
rw.Body.Write(buf)
}
// WriteString always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
- if !rw.wroteHeader {
- rw.WriteHeader(200)
- }
+ rw.writeHeader(nil, str)
if rw.Body != nil {
rw.Body.WriteString(str)
}
func (rw *ResponseRecorder) WriteHeader(code int) {
if !rw.wroteHeader {
rw.Code = code
+ rw.wroteHeader = true
}
- rw.wroteHeader = true
}
// Flush sets rw.Flushed to true.
return nil
}
}
+ hasHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.HeaderMap.Get(key); got != want {
+ return fmt.Errorf("header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
tests := []struct {
name string
func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "hi first")
},
- check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ check(
+ hasStatus(200),
+ hasContents("hi first"),
+ hasFlush(false),
+ hasHeader("Content-Type", "text/plain; charset=utf-8"),
+ ),
},
{
"flush",
},
check(hasStatus(200), hasFlush(true)),
},
+ {
+ "Content-Type detection",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
+ {
+ "no Content-Type detection with Transfer-Encoding",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Transfer-Encoding", "some encoding")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "")), // no header
+ },
+ {
+ "no Content-Type detection if set explicitly",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "some/type")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "some/type")),
+ },
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {