]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: mimic the normal HTTP server's ResponseWriter more closely
authorBrad Fitzpatrick <bradfitz@golang.org>
Sun, 7 Oct 2012 16:48:14 +0000 (09:48 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sun, 7 Oct 2012 16:48:14 +0000 (09:48 -0700)
Also adds tests, which didn't exist before.

Fixes #4188

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6613062

src/pkg/net/http/httptest/recorder.go
src/pkg/net/http/httptest/recorder_test.go [new file with mode: 0644]

index 9aa0d510bd4d57906ee7d7d5357b545a1b172d70..5451f54234c9e6dfa8a5bbc6796495432d478085 100644 (file)
@@ -17,6 +17,8 @@ type ResponseRecorder struct {
        HeaderMap http.Header   // the HTTP response headers
        Body      *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
        Flushed   bool
+
+       wroteHeader bool
 }
 
 // NewRecorder returns an initialized ResponseRecorder.
@@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder {
        return &ResponseRecorder{
                HeaderMap: make(http.Header),
                Body:      new(bytes.Buffer),
+               Code:      200,
        }
 }
 
@@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4"
 
 // Header returns the response headers.
 func (rw *ResponseRecorder) Header() http.Header {
-       return rw.HeaderMap
+       m := rw.HeaderMap
+       if m == nil {
+               m = make(http.Header)
+               rw.HeaderMap = m
+       }
+       return m
 }
 
 // Write always succeeds and writes to rw.Body, if not nil.
 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+       if !rw.wroteHeader {
+               rw.WriteHeader(200)
+       }
        if rw.Body != nil {
                rw.Body.Write(buf)
        }
-       if rw.Code == 0 {
-               rw.Code = http.StatusOK
-       }
        return len(buf), nil
 }
 
 // WriteHeader sets rw.Code.
 func (rw *ResponseRecorder) WriteHeader(code int) {
-       rw.Code = code
+       if !rw.wroteHeader {
+               rw.Code = code
+       }
+       rw.wroteHeader = true
 }
 
 // Flush sets rw.Flushed to true.
 func (rw *ResponseRecorder) Flush() {
+       if !rw.wroteHeader {
+               rw.WriteHeader(200)
+       }
        rw.Flushed = true
 }
diff --git a/src/pkg/net/http/httptest/recorder_test.go b/src/pkg/net/http/httptest/recorder_test.go
new file mode 100644 (file)
index 0000000..2b56326
--- /dev/null
@@ -0,0 +1,90 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+       "fmt"
+       "net/http"
+       "testing"
+)
+
+func TestRecorder(t *testing.T) {
+       type checkFunc func(*ResponseRecorder) error
+       check := func(fns ...checkFunc) []checkFunc { return fns }
+
+       hasStatus := func(wantCode int) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if rec.Code != wantCode {
+                               return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+                       }
+                       return nil
+               }
+       }
+       hasContents := func(want string) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if rec.Body.String() != want {
+                               return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+                       }
+                       return nil
+               }
+       }
+       hasFlush := func(want bool) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if rec.Flushed != want {
+                               return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+                       }
+                       return nil
+               }
+       }
+
+       tests := []struct {
+               name   string
+               h      func(w http.ResponseWriter, r *http.Request)
+               checks []checkFunc
+       }{
+               {
+                       "200 default",
+                       func(w http.ResponseWriter, r *http.Request) {},
+                       check(hasStatus(200), hasContents("")),
+               },
+               {
+                       "first code only",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               w.WriteHeader(201)
+                               w.WriteHeader(202)
+                               w.Write([]byte("hi"))
+                       },
+                       check(hasStatus(201), hasContents("hi")),
+               },
+               {
+                       "write sends 200",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               w.Write([]byte("hi first"))
+                               w.WriteHeader(201)
+                               w.WriteHeader(202)
+                       },
+                       check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+               },
+               {
+                       "flush",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               w.(http.Flusher).Flush() // also sends a 200
+                               w.WriteHeader(201)
+                       },
+                       check(hasStatus(200), hasFlush(true)),
+               },
+       }
+       r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+       for _, tt := range tests {
+               h := http.HandlerFunc(tt.h)
+               rec := NewRecorder()
+               h.ServeHTTP(rec, r)
+               for _, check := range tt.checks {
+                       if err := check(rec); err != nil {
+                               t.Errorf("%s: %v", tt.name, err)
+                       }
+               }
+       }
+}