import (
"bytes"
+ "io/ioutil"
"net/http"
)
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool
- stagingMap http.Header // map that handlers manipulate to set headers
- trailerMap http.Header // lazily filled when Trailers() is called
-
+ result *http.Response // cache of Result's return value
+ snapHeader http.Header // snapshot of HeaderMap at first Write
wroteHeader bool
}
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
- m := rw.stagingMap
+ m := rw.HeaderMap
if m == nil {
m = make(http.Header)
- rw.stagingMap = m
+ rw.HeaderMap = m
}
return m
}
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
- for k, vv := range rw.stagingMap {
+ rw.snapHeader = cloneHeader(rw.HeaderMap)
+}
+
+func cloneHeader(h http.Header) http.Header {
+ h2 := make(http.Header, len(h))
+ for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
- rw.HeaderMap[k] = vv2
+ h2[k] = vv2
}
+ return h2
}
// Flush sets rw.Flushed to true.
rw.Flushed = true
}
-// Trailers returns any trailers set by the handler. It must be called
-// after the handler finished running.
-func (rw *ResponseRecorder) Trailers() http.Header {
- if rw.trailerMap != nil {
- return rw.trailerMap
- }
- trailers, ok := rw.HeaderMap["Trailer"]
- if !ok {
- rw.trailerMap = make(http.Header)
- return rw.trailerMap
- }
- rw.trailerMap = make(http.Header, len(trailers))
- for _, k := range trailers {
- switch k {
- case "Transfer-Encoding", "Content-Length", "Trailer":
- // Ignore since forbidden by RFC 2616 14.40.
- continue
- }
- k = http.CanonicalHeaderKey(k)
- vv, ok := rw.stagingMap[k]
- if !ok {
- continue
+// Result returns the response generated by the handler.
+//
+// The returned Response will have at least its StatusCode,
+// Header, Body, and optionally Trailer populated.
+// More fields may be populated in the future, so callers should
+// not DeepEqual the result in tests.
+//
+// The Response.Header is a snapshot of the headers at the time of the
+// first write call, or at the time of this call, if the handler never
+// did a write.
+//
+// Result must only be called after the handler has finished running.
+func (rw *ResponseRecorder) Result() *http.Response {
+ if rw.result != nil {
+ return rw.result
+ }
+ if rw.snapHeader == nil {
+ rw.snapHeader = cloneHeader(rw.HeaderMap)
+ }
+ res := &http.Response{
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ StatusCode: rw.Code,
+ Header: rw.snapHeader,
+ }
+ rw.result = res
+ if res.StatusCode == 0 {
+ res.StatusCode = 200
+ }
+ res.Status = http.StatusText(res.StatusCode)
+ if rw.Body != nil {
+ res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
+ }
+
+ if trailers, ok := rw.snapHeader["Trailer"]; ok {
+ res.Trailer = make(http.Header, len(trailers))
+ for _, k := range trailers {
+ // TODO: use http2.ValidTrailerHeader, but we can't
+ // get at it easily because it's bundled into net/http
+ // unexported. This is good enough for now:
+ switch k {
+ case "Transfer-Encoding", "Content-Length", "Trailer":
+ // Ignore since forbidden by RFC 2616 14.40.
+ continue
+ }
+ k = http.CanonicalHeaderKey(k)
+ vv, ok := rw.HeaderMap[k]
+ if !ok {
+ continue
+ }
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ res.Trailer[k] = vv2
}
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- rw.trailerMap[k] = vv2
}
- return rw.trailerMap
+ return res
}
return nil
}
}
+ hasResultStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Result().StatusCode != wantCode {
+ return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
+ }
+ return nil
+ }
+ }
hasContents := func(want string) checkFunc {
return func(rec *ResponseRecorder) error {
if rec.Body.String() != want {
return nil
}
}
- hasHeader := func(key, want string) checkFunc {
+ hasOldHeader := func(key, want string) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.HeaderMap.Get(key); got != want {
- return fmt.Errorf("header %s = %q; want %q", key, got, want)
+ return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().Header.Get(key); got != want {
+ return fmt.Errorf("final header %s = %q; want %q", key, got, want)
}
return nil
}
hasNotHeaders := func(keys ...string) checkFunc {
return func(rec *ResponseRecorder) error {
for _, k := range keys {
- _, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
+ v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
if ok {
- return fmt.Errorf("unexpected header %s", k)
+ return fmt.Errorf("unexpected header %s with value %q", k, v)
}
}
return nil
}
hasTrailer := func(key, want string) checkFunc {
return func(rec *ResponseRecorder) error {
- if got := rec.Trailers().Get(key); got != want {
+ if got := rec.Result().Trailer.Get(key); got != want {
return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
}
return nil
}
hasNotTrailers := func(keys ...string) checkFunc {
return func(rec *ResponseRecorder) error {
- trailers := rec.Trailers()
+ trailers := rec.Result().Trailer
for _, k := range keys {
_, ok := trailers[http.CanonicalHeaderKey(k)]
if ok {
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
),
},
+ {
+ "Header set without any write", // Issue 15560
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Foo", "1")
+
+ // Simulate somebody using
+ // new(ResponseRecorder) instead of
+ // using the constructor which sets
+ // this to 200
+ w.(*ResponseRecorder).Code = 0
+ },
+ check(
+ hasOldHeader("X-Foo", "1"),
+ hasStatus(0),
+ hasHeader("X-Foo", "1"),
+ hasResultStatus(200),
+ ),
+ },
+ {
+ "HeaderMap vs FinalHeaders", // more for Issue 15560
+ func(w http.ResponseWriter, r *http.Request) {
+ h := w.Header()
+ h.Set("X-Foo", "1")
+ w.Write([]byte("hi"))
+ h.Set("X-Foo", "2")
+ h.Set("X-Bar", "2")
+ },
+ check(
+ hasOldHeader("X-Foo", "2"),
+ hasOldHeader("X-Bar", "2"),
+ hasHeader("X-Foo", "1"),
+ hasNotHeaders("X-Bar"),
+ ),
+ },
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {