]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: fill ContentLength in recorded Response
authorThomas de Zeeuw <thomasdezeeuw@gmail.com>
Thu, 1 Sep 2016 12:54:08 +0000 (14:54 +0200)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 21 Sep 2016 17:34:01 +0000 (17:34 +0000)
This change fills the ContentLength field in the http.Response returned by
ResponseRecorder.Result.

Fixes #16952.

Change-Id: I9c49b1bf83e3719b5275b03a43aff5033156637d
Reviewed-on: https://go-review.googlesource.com/28302
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/httptest/recorder.go
src/net/http/httptest/recorder_test.go

index 725ba0b70a9dca8ea1e7a6b8c3ee7dab4aee83c4..bc99797b3328ca88ab73df894169c1b08359606b 100644 (file)
@@ -8,6 +8,8 @@ import (
        "bytes"
        "io/ioutil"
        "net/http"
+       "strconv"
+       "strings"
 )
 
 // ResponseRecorder is an implementation of http.ResponseWriter that
@@ -162,6 +164,7 @@ func (rw *ResponseRecorder) Result() *http.Response {
        if rw.Body != nil {
                res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
        }
+       res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
 
        if trailers, ok := rw.snapHeader["Trailer"]; ok {
                res.Trailer = make(http.Header, len(trailers))
@@ -186,3 +189,20 @@ func (rw *ResponseRecorder) Result() *http.Response {
        }
        return res
 }
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+//
+// This a modified version of same function found in net/http/transfer.go. This
+// one just ignores an invalid header.
+func parseContentLength(cl string) int64 {
+       cl = strings.TrimSpace(cl)
+       if cl == "" {
+               return -1
+       }
+       n, err := strconv.ParseInt(cl, 10, 64)
+       if err != nil {
+               return -1
+       }
+       return n
+}
index d4e7137913e4924f0b1f4db5ad4e666961c5e98d..ff9b9911a8636e4a4ae54d3f5672b9365c099341 100644 (file)
@@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) {
                        return nil
                }
        }
+       hasContentLength := func(length int64) checkFunc {
+               return func(rec *ResponseRecorder) error {
+                       if got := rec.Result().ContentLength; got != length {
+                               return fmt.Errorf("ContentLength = %d; want %d", got, length)
+                       }
+                       return nil
+               }
+       }
 
        tests := []struct {
                name   string
@@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) {
                                w.(http.Flusher).Flush() // also sends a 200
                                w.WriteHeader(201)
                        },
-                       check(hasStatus(200), hasFlush(true)),
+                       check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
                },
                {
                        "Content-Type detection",
@@ -244,6 +252,16 @@ func TestRecorder(t *testing.T) {
                                hasNotHeaders("X-Bar"),
                        ),
                },
+               {
+                       "setting Content-Length header",
+                       func(w http.ResponseWriter, r *http.Request) {
+                               body := "Some body"
+                               contentLength := fmt.Sprintf("%d", len(body))
+                               w.Header().Set("Content-Length", contentLength)
+                               io.WriteString(w, body)
+                       },
+                       check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
+               },
        }
        r, _ := http.NewRequest("GET", "http://foo.com/", nil)
        for _, tt := range tests {