Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool
+ stagingMap http.Header // map that handlers manipulate to set headers
+ trailerMap http.Header // lazily filled when Trailers() is called
+
wroteHeader bool
}
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
- m := rw.HeaderMap
+ m := rw.stagingMap
if m == nil {
m = make(http.Header)
- rw.HeaderMap = m
+ rw.stagingMap = m
}
return m
}
str = str[:512]
}
- _, hasType := rw.HeaderMap["Content-Type"]
- hasTE := rw.HeaderMap.Get("Transfer-Encoding") != ""
+ m := rw.Header()
+
+ _, hasType := m["Content-Type"]
+ hasTE := m.Get("Transfer-Encoding") != ""
if !hasType && !hasTE {
if b == nil {
b = []byte(str)
}
- if rw.HeaderMap == nil {
- rw.HeaderMap = make(http.Header)
- }
- rw.HeaderMap.Set("Content-Type", http.DetectContentType(b))
+ m.Set("Content-Type", http.DetectContentType(b))
}
rw.WriteHeader(200)
return len(str), nil
}
-// WriteHeader sets rw.Code.
+// WriteHeader sets rw.Code. After it is called, changing rw.Header
+// will not affect rw.HeaderMap.
func (rw *ResponseRecorder) WriteHeader(code int) {
- if !rw.wroteHeader {
- rw.Code = code
- rw.wroteHeader = true
+ if rw.wroteHeader {
+ return
+ }
+ rw.Code = code
+ rw.wroteHeader = true
+ if rw.HeaderMap == nil {
+ rw.HeaderMap = make(http.Header)
+ }
+ for k, vv := range rw.stagingMap {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ rw.HeaderMap[k] = vv2
}
}
}
rw.Flushed = true
}
+
+// Trailers returns any trailers set by the handler. It must be called
+// after the handler finished running.
+func (rw *ResponseRecorder) Trailers() http.Header {
+ if rw.trailerMap != nil {
+ return rw.trailerMap
+ }
+ trailers, ok := rw.HeaderMap["Trailer"]
+ if !ok {
+ rw.trailerMap = make(http.Header)
+ return rw.trailerMap
+ }
+ rw.trailerMap = make(http.Header, len(trailers))
+ for _, k := range trailers {
+ switch k {
+ case "Transfer-Encoding", "Content-Length", "Trailer":
+ // Ignore since forbidden by RFC 2616 14.40.
+ continue
+ }
+ k = http.CanonicalHeaderKey(k)
+ vv, ok := rw.stagingMap[k]
+ if !ok {
+ continue
+ }
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ rw.trailerMap[k] = vv2
+ }
+ return rw.trailerMap
+}
return nil
}
}
+ hasNotHeaders := func(keys ...string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ for _, k := range keys {
+ _, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
+ if ok {
+ return fmt.Errorf("unexpected header %s", k)
+ }
+ }
+ return nil
+ }
+ }
+ hasTrailer := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Trailers().Get(key); got != want {
+ return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasNotTrailers := func(keys ...string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ trailers := rec.Trailers()
+ for _, k := range keys {
+ _, ok := trailers[http.CanonicalHeaderKey(k)]
+ if ok {
+ return fmt.Errorf("unexpected trailer %s", k)
+ }
+ }
+ return nil
+ }
+ }
tests := []struct {
name string
},
check(hasHeader("Content-Type", "text/html; charset=utf-8")),
},
+ {
+ "Header is not changed after write",
+ func(w http.ResponseWriter, r *http.Request) {
+ hdr := w.Header()
+ hdr.Set("Key", "correct")
+ w.WriteHeader(200)
+ hdr.Set("Key", "incorrect")
+ },
+ check(hasHeader("Key", "correct")),
+ },
+ {
+ "Trailer headers are correctly recorded",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Non-Trailer", "correct")
+ w.Header().Set("Trailer", "Trailer-A")
+ w.Header().Add("Trailer", "Trailer-B")
+ w.Header().Add("Trailer", "Trailer-C")
+ io.WriteString(w, "<html>")
+ w.Header().Set("Non-Trailer", "incorrect")
+ w.Header().Set("Trailer-A", "valuea")
+ w.Header().Set("Trailer-C", "valuec")
+ w.Header().Set("Trailer-NotDeclared", "should be omitted")
+ },
+ check(
+ hasStatus(200),
+ hasHeader("Content-Type", "text/html; charset=utf-8"),
+ hasHeader("Non-Trailer", "correct"),
+ hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
+ hasTrailer("Trailer-A", "valuea"),
+ hasTrailer("Trailer-C", "valuec"),
+ hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
+ ),
+ },
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {