]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make NewRequest pick a ContentLength from a *bytes.Reader too
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 2 Jan 2013 22:40:27 +0000 (14:40 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 2 Jan 2013 22:40:27 +0000 (14:40 -0800)
It already did so for its sibling, *strings.Reader, as well as *bytes.Buffer.

R=edsrzf, dave, adg, kevlar, remyoudompheng, adg, rsc
CC=golang-dev
https://golang.org/cl/7031045

src/pkg/net/http/request.go
src/pkg/net/http/request_test.go

index f50e254fb28ce2ce7fcc5d5fa1976609fa5e9375..3b799108ac00fcaab115eabb7c107a049603e628 100644 (file)
@@ -433,10 +433,12 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
        }
        if body != nil {
                switch v := body.(type) {
-               case *strings.Reader:
-                       req.ContentLength = int64(v.Len())
                case *bytes.Buffer:
                        req.ContentLength = int64(v.Len())
+               case *bytes.Reader:
+                       req.ContentLength = int64(v.Len())
+               case *strings.Reader:
+                       req.ContentLength = int64(v.Len())
                }
        }
 
index 2f34d124128d47de91cd9f73d2f9fc8782e92c7b..fc485fcdf864e101db698b67b6bf4ca069e11ca5 100644 (file)
@@ -238,6 +238,35 @@ func TestNewRequestHost(t *testing.T) {
        }
 }
 
+func TestNewRequestContentLength(t *testing.T) {
+       readByte := func(r io.Reader) io.Reader {
+               var b [1]byte
+               r.Read(b[:])
+               return r
+       }
+       tests := []struct {
+               r    io.Reader
+               want int64
+       }{
+               {bytes.NewReader([]byte("123")), 3},
+               {bytes.NewBuffer([]byte("1234")), 4},
+               {strings.NewReader("12345"), 5},
+               // Not detected:
+               {struct{ io.Reader }{strings.NewReader("xyz")}, 0},
+               {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
+               {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
+       }
+       for _, tt := range tests {
+               req, err := NewRequest("POST", "http://localhost/", tt.r)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if req.ContentLength != tt.want {
+                       t.Errorf("ContentLength(%#T) = %d; want %d", tt.r, req.ContentLength, tt.want)
+               }
+       }
+}
+
 func testMissingFile(t *testing.T, req *Request) {
        f, fh, err := req.FormFile("missing")
        if f != nil {