"bytes"
"io/ioutil"
"net/http"
+ "strconv"
+ "strings"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
}
+ res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
}
return res
}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+//
+// This a modified version of same function found in net/http/transfer.go. This
+// one just ignores an invalid header.
+func parseContentLength(cl string) int64 {
+ cl = strings.TrimSpace(cl)
+ if cl == "" {
+ return -1
+ }
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil {
+ return -1
+ }
+ return n
+}
return nil
}
}
+ hasContentLength := func(length int64) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().ContentLength; got != length {
+ return fmt.Errorf("ContentLength = %d; want %d", got, length)
+ }
+ return nil
+ }
+ }
tests := []struct {
name string
w.(http.Flusher).Flush() // also sends a 200
w.WriteHeader(201)
},
- check(hasStatus(200), hasFlush(true)),
+ check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
},
{
"Content-Type detection",
hasNotHeaders("X-Bar"),
),
},
+ {
+ "setting Content-Length header",
+ func(w http.ResponseWriter, r *http.Request) {
+ body := "Some body"
+ contentLength := fmt.Sprintf("%d", len(body))
+ w.Header().Set("Content-Length", contentLength)
+ io.WriteString(w, body)
+ },
+ check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
+ },
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {