]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: restore historic ResponseRecorder.HeaderMap behavior
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 19 May 2016 18:05:10 +0000 (18:05 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 19 May 2016 23:02:34 +0000 (23:02 +0000)
In Go versions 1 up to and including Go 1.6,
ResponseRecorder.HeaderMap was both the map that handlers got access
to, and was the map tests checked their results against. That did not
mimic the behavior of the real HTTP server (Issue #8857), so HeaderMap
was changed to be a snapshot at the first write in
https://golang.org/cl/20047. But that broke cases where the Handler
never did a write (#15560), so revert the behavior.

Instead, introduce the ResponseWriter.Result method, returning an
*http.Response. It subsumes ResponseWriter.Trailers which was added
for Go 1.7 in CL 20047. Result().Header now contains the correct
answer, and HeaderMap is unchanged in behavior from previous Go
releases, so we don't break people's tests. People wanting the correct
behavior can use ResponseWriter.Result.

Fixes #15560
Updates #8857

Change-Id: I7ea9b56a6b843103784553d67f67847b5315b3d2
Reviewed-on: https://go-review.googlesource.com/23257
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/httptest/recorder.go
src/net/http/httptest/recorder_test.go

index b1f49541d576e556dee47786b33627da77097552..0ad26a3d418005c4f8ed92cba932c5fc9cbfef9e 100644 (file)
@@ -6,6 +6,7 @@ package httptest
 
 import (
        "bytes"
+       "io/ioutil"
        "net/http"
 )
 
@@ -17,9 +18,8 @@ type ResponseRecorder struct {
        Body      *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
        Flushed   bool
 
-       stagingMap http.Header // map that handlers manipulate to set headers
-       trailerMap http.Header // lazily filled when Trailers() is called
-
+       result      *http.Response // cache of Result's return value
+       snapHeader  http.Header    // snapshot of HeaderMap at first Write
        wroteHeader bool
 }
 
@@ -38,10 +38,10 @@ const DefaultRemoteAddr = "1.2.3.4"
 
 // Header returns the response headers.
 func (rw *ResponseRecorder) Header() http.Header {
-       m := rw.stagingMap
+       m := rw.HeaderMap
        if m == nil {
                m = make(http.Header)
-               rw.stagingMap = m
+               rw.HeaderMap = m
        }
        return m
 }
@@ -104,11 +104,17 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
        if rw.HeaderMap == nil {
                rw.HeaderMap = make(http.Header)
        }
-       for k, vv := range rw.stagingMap {
+       rw.snapHeader = cloneHeader(rw.HeaderMap)
+}
+
+func cloneHeader(h http.Header) http.Header {
+       h2 := make(http.Header, len(h))
+       for k, vv := range h {
                vv2 := make([]string, len(vv))
                copy(vv2, vv)
-               rw.HeaderMap[k] = vv2
+               h2[k] = vv2
        }
+       return h2
 }
 
 // Flush sets rw.Flushed to true.
@@ -119,32 +125,61 @@ func (rw *ResponseRecorder) Flush() {
        rw.Flushed = true
 }
 
-// Trailers returns any trailers set by the handler. It must be called
-// after the handler finished running.
-func (rw *ResponseRecorder) Trailers() http.Header {
-       if rw.trailerMap != nil {
-               return rw.trailerMap
-       }
-       trailers, ok := rw.HeaderMap["Trailer"]
-       if !ok {
-               rw.trailerMap = make(http.Header)
-               return rw.trailerMap
-       }
-       rw.trailerMap = make(http.Header, len(trailers))
-       for _, k := range trailers {
-               switch k {
-               case "Transfer-Encoding", "Content-Length", "Trailer":
-                       // Ignore since forbidden by RFC 2616 14.40.
-                       continue
-               }
-               k = http.CanonicalHeaderKey(k)
-               vv, ok := rw.stagingMap[k]
-               if !ok {
-                       continue
+// Result returns the response generated by the handler.
+//
+// The returned Response will have at least its StatusCode,
+// Header, Body, and optionally Trailer populated.
+// More fields may be populated in the future, so callers should
+// not DeepEqual the result in tests.
+//
+// The Response.Header is a snapshot of the headers at the time of the
+// first write call, or at the time of this call, if the handler never
+// did a write.
+//
+// Result must only be called after the handler has finished running.
+func (rw *ResponseRecorder) Result() *http.Response {
+       if rw.result != nil {
+               return rw.result
+       }
+       if rw.snapHeader == nil {
+               rw.snapHeader = cloneHeader(rw.HeaderMap)
+       }
+       res := &http.Response{
+               Proto:      "HTTP/1.1",
+               ProtoMajor: 1,
+               ProtoMinor: 1,
+               StatusCode: rw.Code,
+               Header:     rw.snapHeader,
+       }
+       rw.result = res
+       if res.StatusCode == 0 {
+               res.StatusCode = 200
+       }
+       res.Status = http.StatusText(res.StatusCode)
+       if rw.Body != nil {
+               res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
+       }
+
+       if trailers, ok := rw.snapHeader["Trailer"]; ok {
+               res.Trailer = make(http.Header, len(trailers))
+               for _, k := range trailers {
+                       // TODO: use http2.ValidTrailerHeader, but we can't
+                       // get at it easily because it's bundled into net/http
+                       // unexported. This is good enough for now:
+                       switch k {
+                       case "Transfer-Encoding", "Content-Length", "Trailer":
+                               // Ignore since forbidden by RFC 2616 14.40.
+                               continue
+                       }
+                       k = http.CanonicalHeaderKey(k)
+                       vv, ok := rw.HeaderMap[k]
+                       if !ok {
+                               continue
+                       }
+                       vv2 := make([]string, len(vv))
+                       copy(vv2, vv)
+                       res.Trailer[k] = vv2
                }
-               vv2 := make([]string, len(vv))
-               copy(vv2, vv)
-               rw.trailerMap[k] = vv2
        }
-       return rw.trailerMap
+       return res
 }
index 19a37b6c54d2d656141c465228f3296c1c16808e..d4e7137913e4924f0b1f4db5ad4e666961c5e98d 100644 (file)
@@ -23,6 +23,14 @@ func TestRecorder(t *testing.T) {
                        return nil
                }
        }
+       hasResultStatus := func(wantCode int) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if rec.Result().StatusCode != wantCode {
+                               return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
+                       }
+                       return nil
+               }
+       }
        hasContents := func(want string) checkFunc {
                return func(rec *ResponseRecorder) error {
                        if rec.Body.String() != want {
@@ -39,10 +47,18 @@ func TestRecorder(t *testing.T) {
                        return nil
                }
        }
-       hasHeader := func(key, want string) checkFunc {
+       hasOldHeader := func(key, want string) checkFunc {
                return func(rec *ResponseRecorder) error {
                        if got := rec.HeaderMap.Get(key); got != want {
-                               return fmt.Errorf("header %s = %q; want %q", key, got, want)
+                               return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
+                       }
+                       return nil
+               }
+       }
+       hasHeader := func(key, want string) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if got := rec.Result().Header.Get(key); got != want {
+                               return fmt.Errorf("final header %s = %q; want %q", key, got, want)
                        }
                        return nil
                }
@@ -50,9 +66,9 @@ func TestRecorder(t *testing.T) {
        hasNotHeaders := func(keys ...string) checkFunc {
                return func(rec *ResponseRecorder) error {
                        for _, k := range keys {
-                               _, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
+                               v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
                                if ok {
-                                       return fmt.Errorf("unexpected header %s", k)
+                                       return fmt.Errorf("unexpected header %s with value %q", k, v)
                                }
                        }
                        return nil
@@ -60,7 +76,7 @@ func TestRecorder(t *testing.T) {
        }
        hasTrailer := func(key, want string) checkFunc {
                return func(rec *ResponseRecorder) error {
-                       if got := rec.Trailers().Get(key); got != want {
+                       if got := rec.Result().Trailer.Get(key); got != want {
                                return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
                        }
                        return nil
@@ -68,7 +84,7 @@ func TestRecorder(t *testing.T) {
        }
        hasNotTrailers := func(keys ...string) checkFunc {
                return func(rec *ResponseRecorder) error {
-                       trailers := rec.Trailers()
+                       trailers := rec.Result().Trailer
                        for _, k := range keys {
                                _, ok := trailers[http.CanonicalHeaderKey(k)]
                                if ok {
@@ -194,6 +210,40 @@ func TestRecorder(t *testing.T) {
                                hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
                        ),
                },
+               {
+                       "Header set without any write", // Issue 15560
+                       func(w http.ResponseWriter, r *http.Request) {
+                               w.Header().Set("X-Foo", "1")
+
+                               // Simulate somebody using
+                               // new(ResponseRecorder) instead of
+                               // using the constructor which sets
+                               // this to 200
+                               w.(*ResponseRecorder).Code = 0
+                       },
+                       check(
+                               hasOldHeader("X-Foo", "1"),
+                               hasStatus(0),
+                               hasHeader("X-Foo", "1"),
+                               hasResultStatus(200),
+                       ),
+               },
+               {
+                       "HeaderMap vs FinalHeaders", // more for Issue 15560
+                       func(w http.ResponseWriter, r *http.Request) {
+                               h := w.Header()
+                               h.Set("X-Foo", "1")
+                               w.Write([]byte("hi"))
+                               h.Set("X-Foo", "2")
+                               h.Set("X-Bar", "2")
+                       },
+                       check(
+                               hasOldHeader("X-Foo", "2"),
+                               hasOldHeader("X-Bar", "2"),
+                               hasHeader("X-Foo", "1"),
+                               hasNotHeaders("X-Bar"),
+                       ),
+               },
        }
        r, _ := http.NewRequest("GET", "http://foo.com/", nil)
        for _, tt := range tests {