]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: populate Request.Close in ReadRequest
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 25 Aug 2014 18:44:08 +0000 (11:44 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 25 Aug 2014 18:44:08 +0000 (11:44 -0700)
Fixes #8261

LGTM=r
R=r
CC=golang-codereviews
https://golang.org/cl/126620043

src/pkg/net/http/readrequest_test.go
src/pkg/net/http/request.go
src/pkg/net/http/transfer.go

index ffdd6a892dab7f2f1265c87ea983fff293b53176..e930d99af62e8720f5ac240a39c2ddc790ec4be0 100644 (file)
@@ -11,6 +11,7 @@ import (
        "io"
        "net/url"
        "reflect"
+       "strings"
        "testing"
 )
 
@@ -295,14 +296,39 @@ var reqTests = []reqTest{
                noTrailer,
                noError,
        },
+
+       // Connection: close. golang.org/issue/8261
+       {
+               "GET / HTTP/1.1\r\nHost: issue8261.com\r\nConnection: close\r\n\r\n",
+               &Request{
+                       Method: "GET",
+                       URL: &url.URL{
+                               Path: "/",
+                       },
+                       Header: Header{
+                               // This wasn't removed from Go 1.0 to
+                               // Go 1.3, so locking it in that we
+                               // keep this:
+                               "Connection": []string{"close"},
+                       },
+                       Host:       "issue8261.com",
+                       Proto:      "HTTP/1.1",
+                       ProtoMajor: 1,
+                       ProtoMinor: 1,
+                       Close:      true,
+                       RequestURI: "/",
+               },
+
+               noBody,
+               noTrailer,
+               noError,
+       },
 }
 
 func TestReadRequest(t *testing.T) {
        for i := range reqTests {
                tt := &reqTests[i]
-               var braw bytes.Buffer
-               braw.WriteString(tt.Raw)
-               req, err := ReadRequest(bufio.NewReader(&braw))
+               req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.Raw)))
                if err != nil {
                        if err.Error() != tt.Error {
                                t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error)
@@ -311,21 +337,22 @@ func TestReadRequest(t *testing.T) {
                }
                rbody := req.Body
                req.Body = nil
-               diff(t, fmt.Sprintf("#%d Request", i), req, tt.Req)
+               testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw)
+               diff(t, testName, req, tt.Req)
                var bout bytes.Buffer
                if rbody != nil {
                        _, err := io.Copy(&bout, rbody)
                        if err != nil {
-                               t.Fatalf("#%d. copying body: %v", i, err)
+                               t.Fatalf("%s: copying body: %v", testName, err)
                        }
                        rbody.Close()
                }
                body := bout.String()
                if body != tt.Body {
-                       t.Errorf("#%d: Body = %q want %q", i, body, tt.Body)
+                       t.Errorf("%s: Body = %q want %q", testName, body, tt.Body)
                }
                if !reflect.DeepEqual(tt.Trailer, req.Trailer) {
-                       t.Errorf("#%d. Trailers differ.\n got: %v\nwant: %v", i, req.Trailer, tt.Trailer)
+                       t.Errorf("%s: Trailers differ.\n got: %v\nwant: %v", testName, req.Trailer, tt.Trailer)
                }
        }
 }
index 131cb6d67eef5bc1a1b744f019ed39360c1e89cb..63729431887a126cfce57e09595b63e8a988e735 100644 (file)
@@ -635,6 +635,7 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
                return nil, err
        }
 
+       req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false)
        return req, nil
 }
 
index 51b1dcb30bb7e068e8b5776cfde62aa018493cf9..520500330bc8892ddac7110016b3aeabf022896e 100644 (file)
@@ -303,7 +303,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
                t.StatusCode = rr.StatusCode
                t.ProtoMajor = rr.ProtoMajor
                t.ProtoMinor = rr.ProtoMinor
-               t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header)
+               t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true)
                isResponse = true
                if rr.Request != nil {
                        t.RequestMethod = rr.Request.Method
@@ -502,7 +502,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
 // Determine whether to hang up after sending a request and body, or
 // receiving a response and body
 // 'header' is the request headers
-func shouldClose(major, minor int, header Header) bool {
+func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
        if major < 1 {
                return true
        } else if major == 1 && minor == 0 {
@@ -514,7 +514,9 @@ func shouldClose(major, minor int, header Header) bool {
                // TODO: Should split on commas, toss surrounding white space,
                // and check each field.
                if strings.ToLower(header.get("Connection")) == "close" {
-                       header.Del("Connection")
+                       if removeCloseHeader {
+                               header.Del("Connection")
+                       }
                        return true
                }
        }