}
if body != nil {
switch v := body.(type) {
- case *strings.Reader:
- req.ContentLength = int64(v.Len())
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
}
}
}
}
+func TestNewRequestContentLength(t *testing.T) {
+ readByte := func(r io.Reader) io.Reader {
+ var b [1]byte
+ r.Read(b[:])
+ return r
+ }
+ tests := []struct {
+ r io.Reader
+ want int64
+ }{
+ {bytes.NewReader([]byte("123")), 3},
+ {bytes.NewBuffer([]byte("1234")), 4},
+ {strings.NewReader("12345"), 5},
+ // Not detected:
+ {struct{ io.Reader }{strings.NewReader("xyz")}, 0},
+ {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
+ {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
+ }
+ for _, tt := range tests {
+ req, err := NewRequest("POST", "http://localhost/", tt.r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.ContentLength != tt.want {
+ t.Errorf("ContentLength(%#T) = %d; want %d", tt.r, req.ContentLength, tt.want)
+ }
+ }
+}
+
func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing")
if f != nil {