t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
}
}
+
+// Verify Response.ContentLength is populated. http://golang.org/issue/4126
+func TestClientHeadContentLength(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if v := r.FormValue("cl"); v != "" {
+ w.Header().Set("Content-Length", v)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int64
+ }{
+ {"/?cl=1234", 1234},
+ {"/?cl=0", 0},
+ {"", -1},
+ }
+ for _, tt := range tests {
+ req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != tt.want {
+ t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
+ }
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != 0 {
+ t.Errorf("Unexpected content: %q", bs)
+ }
+ }
+}
return err
}
- t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
+ realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
if err != nil {
return err
}
+ if isResponse && t.RequestMethod == "HEAD" {
+ if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+ return err
+ } else {
+ t.ContentLength = n
+ }
+ } else {
+ t.ContentLength = realLength
+ }
// Trailer
t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
// See RFC2616, section 4.4.
switch msg.(type) {
case *Response:
- if t.ContentLength == -1 &&
+ if realLength == -1 &&
!chunked(t.TransferEncoding) &&
bodyAllowedForStatus(t.StatusCode) {
// Unbounded body.
switch {
case chunked(t.TransferEncoding):
t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
- case t.ContentLength >= 0:
+ case realLength >= 0:
// TODO: limit the Content-Length. This is an easy DoS vector.
- t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close}
+ t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
default:
- // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header
+ // realLength < 0, i.e. "Content-Length" not mentioned in header
if t.Close {
// Close semantics (i.e. HTTP/1.0)
t.Body = &body{Reader: r, closing: t.Close}
// Logic based on Content-Length
cl := strings.TrimSpace(header.get("Content-Length"))
if cl != "" {
- n, err := strconv.ParseInt(cl, 10, 64)
- if err != nil || n < 0 {
- return -1, &badStringError{"bad Content-Length", cl}
+ n, err := parseContentLength(cl)
+ if err != nil {
+ return -1, err
}
return n, nil
} else {
}
return nil
}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+func parseContentLength(cl string) (int64, error) {
+ cl = strings.TrimSpace(cl)
+ if cl == "" {
+ return -1, nil
+ }
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil || n < 0 {
+ return 0, &badStringError{"bad Content-Length", cl}
+ }
+ return n, nil
+
+}
if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
}
- if e, g := int64(0), res.ContentLength; e != g {
+ if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
}
}