]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: populate ContentLength in HEAD responses
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 6 Dec 2012 06:36:23 +0000 (22:36 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 6 Dec 2012 06:36:23 +0000 (22:36 -0800)
Also fixes a necessary TODO in the process.

Fixes #4126

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6869053

src/pkg/net/http/client_test.go
src/pkg/net/http/response.go
src/pkg/net/http/response_test.go
src/pkg/net/http/server.go
src/pkg/net/http/transfer.go
src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index 9a45b147ef1894f1f5405395ddab6e23cf241838..f4ba6a9e652ac7d2d611c4c16b8c15f2a811930d 100644 (file)
@@ -527,3 +527,38 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
                t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
        }
 }
+
+// Verify Response.ContentLength is populated. http://golang.org/issue/4126
+func TestClientHeadContentLength(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if v := r.FormValue("cl"); v != "" {
+                       w.Header().Set("Content-Length", v)
+               }
+       }))
+       defer ts.Close()
+       tests := []struct {
+               suffix string
+               want   int64
+       }{
+               {"/?cl=1234", 1234},
+               {"/?cl=0", 0},
+               {"", -1},
+       }
+       for _, tt := range tests {
+               req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
+               res, err := DefaultClient.Do(req)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if res.ContentLength != tt.want {
+                       t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
+               }
+               bs, err := ioutil.ReadAll(res.Body)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if len(bs) != 0 {
+                       t.Errorf("Unexpected content: %q", bs)
+               }
+       }
+}
index 92d2f499839946b63a4d728aedbd7ec9d7651f4a..7901c49f5a50db51866afd31b7f264e8bf02c4bd 100644 (file)
@@ -49,7 +49,7 @@ type Response struct {
        Body io.ReadCloser
 
        // ContentLength records the length of the associated content.  The
-       // value -1 indicates that the length is unknown.  Unless RequestMethod
+       // value -1 indicates that the length is unknown.  Unless Request.Method
        // is "HEAD", values >= 0 indicate that the given number of bytes may
        // be read from Body.
        ContentLength int64
@@ -178,7 +178,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
 //  StatusCode
 //  ProtoMajor
 //  ProtoMinor
-//  RequestMethod
+//  Request.Method
 //  TransferEncoding
 //  Trailer
 //  Body
index 6eed4887ddceb8620cc434fd9be18f34413d480e..f31e5d09fe5d9133edc10d7fe299255bc79319e1 100644 (file)
@@ -193,7 +193,7 @@ var respTests = []respTest{
                        Request:       dummyReq("HEAD"),
                        Header:        Header{},
                        Close:         true,
-                       ContentLength: 0,
+                       ContentLength: -1,
                },
 
                "",
index 21480458b683afa4583a5b7037b7ad7428dcd19f..53879c770fb70120195fbcb64646ca6aed757061 100644 (file)
@@ -614,7 +614,7 @@ func (w *response) finishRequest() {
        // HTTP/1.0 clients keep their "keep-alive" connections alive, and for
        // HTTP/1.1 clients is just as good as the alternative: sending a
        // chunked response and immediately sending the zero-length EOF chunk.
-       if w.written == 0 && w.header.get("Content-Length") == "" {
+       if w.written == 0 && w.header.get("Content-Length") == "" && w.req.Method != "HEAD" {
                w.header.Set("Content-Length", "0")
        }
        // If this was an HTTP/1.0 request with keep-alive and we sent a
index 9833dddf2b65b4dce25267112955e92c053011ef..757a0ec4623b0ec0505fc23a134f20a40d7614c5 100644 (file)
@@ -294,10 +294,19 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
                return err
        }
 
-       t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
+       realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
        if err != nil {
                return err
        }
+       if isResponse && t.RequestMethod == "HEAD" {
+               if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+                       return err
+               } else {
+                       t.ContentLength = n
+               }
+       } else {
+               t.ContentLength = realLength
+       }
 
        // Trailer
        t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
@@ -310,7 +319,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
        // See RFC2616, section 4.4.
        switch msg.(type) {
        case *Response:
-               if t.ContentLength == -1 &&
+               if realLength == -1 &&
                        !chunked(t.TransferEncoding) &&
                        bodyAllowedForStatus(t.StatusCode) {
                        // Unbounded body.
@@ -323,11 +332,11 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
        switch {
        case chunked(t.TransferEncoding):
                t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
-       case t.ContentLength >= 0:
+       case realLength >= 0:
                // TODO: limit the Content-Length. This is an easy DoS vector.
-               t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close}
+               t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
        default:
-               // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header
+               // realLength < 0, i.e. "Content-Length" not mentioned in header
                if t.Close {
                        // Close semantics (i.e. HTTP/1.0)
                        t.Body = &body{Reader: r, closing: t.Close}
@@ -434,9 +443,9 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
        // Logic based on Content-Length
        cl := strings.TrimSpace(header.get("Content-Length"))
        if cl != "" {
-               n, err := strconv.ParseInt(cl, 10, 64)
-               if err != nil || n < 0 {
-                       return -1, &badStringError{"bad Content-Length", cl}
+               n, err := parseContentLength(cl)
+               if err != nil {
+                       return -1, err
                }
                return n, nil
        } else {
@@ -641,3 +650,18 @@ func (b *body) Close() error {
        }
        return nil
 }
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+func parseContentLength(cl string) (int64, error) {
+       cl = strings.TrimSpace(cl)
+       if cl == "" {
+               return -1, nil
+       }
+       n, err := strconv.ParseInt(cl, 10, 64)
+       if err != nil || n < 0 {
+               return 0, &badStringError{"bad Content-Length", cl}
+       }
+       return n, nil
+
+}
index 068c50ff0ce48460a12285a3ba1878eb7dbf1a52..7b4afeb8efca88dfc7cf4896bdefb64e7340af60 100644 (file)
@@ -604,10 +604,7 @@ func (pc *persistConn) readLoop() {
                        alive = false
                }
 
-               // TODO(bradfitz): this hasBody conflicts with the defition
-               // above which excludes HEAD requests.  Is this one
-               // incomplete?
-               hasBody := resp != nil && resp.ContentLength != 0
+               hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
                var waitForBodyRead chan bool
                if hasBody {
                        lastbody = resp.Body
index 0e6cf85281f44a15f58b1926d25bd9d1799ab132..f1d415888ca870e407a9d519b2aa0fab88933c7a 100644 (file)
@@ -464,7 +464,7 @@ func TestTransportHeadResponses(t *testing.T) {
                if e, g := "123", res.Header.Get("Content-Length"); e != g {
                        t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
                }
-               if e, g := int64(0), res.ContentLength; e != g {
+               if e, g := int64(123), res.ContentLength; e != g {
                        t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
                }
        }