]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: Client support for Expect: 100-continue
authorYasuharu Goto <matope.ono@gmail.com>
Thu, 14 May 2015 15:44:34 +0000 (00:44 +0900)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 17 Oct 2015 00:44:46 +0000 (00:44 +0000)
Current http client doesn't support Expect: 100-continue request
header(RFC2616-8/RFC7231-5.1.1). So even if the client have the header,
the head of the request body is consumed prematurely.

Those are my intentions to avoid premature consuming body in this change.
- If http.Request header contains body and Expect: 100-continue
  header, it blocks sending body until it gets the first response.
- If the first status code to the request were 100, the request
  starts sending body. Otherwise, sending body will be cancelled.
- Tranport.ExpectContinueTimeout specifies the amount of the time to
  wait for the first response.

Fixes #3665

Change-Id: I4c04f7d88573b08cabd146c4e822061764a7cd1f
Reviewed-on: https://go-review.googlesource.com/10091
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/request.go
src/net/http/transport.go
src/net/http/transport_test.go

index 31fe45a4edb7d51522156bb226daf0a58391fb15..8467decc182f8db5f57c3fd45cd4d68a8a699928 100644 (file)
@@ -354,7 +354,7 @@ const defaultUserAgent = "Go-http-client/1.1"
 // hasn't been set to "identity", Write adds "Transfer-Encoding:
 // chunked" to the header. Body is closed after it is sent.
 func (r *Request) Write(w io.Writer) error {
-       return r.write(w, false, nil)
+       return r.write(w, false, nil, nil)
 }
 
 // WriteProxy is like Write but writes the request in the form
@@ -364,11 +364,12 @@ func (r *Request) Write(w io.Writer) error {
 // In either case, WriteProxy also writes a Host header, using
 // either r.Host or r.URL.Host.
 func (r *Request) WriteProxy(w io.Writer) error {
-       return r.write(w, true, nil)
+       return r.write(w, true, nil, nil)
 }
 
 // extraHeaders may be nil
-func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error {
+// waitForContinue may be nil
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) error {
        // Find the target host. Prefer the Host: header, but if that
        // is not given, use the host from the request URL.
        //
@@ -458,6 +459,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
                return err
        }
 
+       // Flush and wait for 100-continue if expected.
+       if waitForContinue != nil {
+               if bw, ok := w.(*bufio.Writer); ok {
+                       err = bw.Flush()
+                       if err != nil {
+                               return err
+                       }
+               }
+
+               if !waitForContinue() {
+                       req.closeBody()
+                       return nil
+               }
+       }
+
        // Write body and trailer
        err = tw.WriteBody(w)
        if err != nil {
index 70d1864605993f5a46ab908251da112581d070ed..31599237e059a8c0487877fda2a15b0a30d56af1 100644 (file)
@@ -36,7 +36,8 @@ var DefaultTransport RoundTripper = &Transport{
                Timeout:   30 * time.Second,
                KeepAlive: 30 * time.Second,
        }).Dial,
-       TLSHandshakeTimeout: 10 * time.Second,
+       TLSHandshakeTimeout:   10 * time.Second,
+       ExpectContinueTimeout: 1 * time.Second,
 }
 
 // DefaultMaxIdleConnsPerHost is the default value of Transport's
@@ -113,6 +114,13 @@ type Transport struct {
        // time does not include the time to read the response body.
        ResponseHeaderTimeout time.Duration
 
+       // ExpectContinueTimeout, if non-zero, specifies the amount of
+       // time to wait for a server's first response headers after fully
+       // writing the request headers if the request has an
+       // "Expect: 100-continue" header. Zero means no timeout.
+       // This time does not include the time to send the request header.
+       ExpectContinueTimeout time.Duration
+
        // TODO: tunable on global max cached connections
        // TODO: tunable on timeout on cached connections
 }
@@ -894,13 +902,17 @@ func (pc *persistConn) readLoop() {
                var resp *Response
                if err == nil {
                        resp, err = ReadResponse(pc.br, rc.req)
-                       if err == nil && resp.StatusCode == 100 {
-                               // Skip any 100-continue for now.
-                               // TODO(bradfitz): if rc.req had "Expect: 100-continue",
-                               // actually block the request body write and signal the
-                               // writeLoop now to begin sending it. (Issue 2184) For now we
-                               // eat it, since we're never expecting one.
-                               resp, err = ReadResponse(pc.br, rc.req)
+                       if err == nil {
+                               if rc.continueCh != nil {
+                                       if resp.StatusCode == 100 {
+                                               rc.continueCh <- struct{}{}
+                                       } else {
+                                               close(rc.continueCh)
+                                       }
+                               }
+                               if resp.StatusCode == 100 {
+                                       resp, err = ReadResponse(pc.br, rc.req)
+                               }
                        }
                }
 
@@ -1004,6 +1016,28 @@ func (pc *persistConn) readLoop() {
        pc.close()
 }
 
+// waitForContinue returns the function to block until
+// any response, timeout or connection close. After any of them,
+// the function returns a bool which indicates if the body should be sent.
+func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
+       if continueCh == nil {
+               return nil
+       }
+       return func() bool {
+               timer := time.NewTimer(pc.t.ExpectContinueTimeout)
+               defer timer.Stop()
+
+               select {
+               case _, ok := <-continueCh:
+                       return ok
+               case <-timer.C:
+                       return true
+               case <-pc.closech:
+                       return false
+               }
+       }
+}
+
 func (pc *persistConn) writeLoop() {
        for {
                select {
@@ -1012,7 +1046,7 @@ func (pc *persistConn) writeLoop() {
                                wr.ch <- errors.New("http: can't write HTTP request on broken connection")
                                continue
                        }
-                       err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra)
+                       err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
                        if err == nil {
                                err = pc.bw.Flush()
                        }
@@ -1069,6 +1103,12 @@ type requestAndChan struct {
        // Accept-Encoding gzip header? only if it we set it do
        // we transparently decode the gzip.
        addedGzip bool
+
+       // Optional blocking chan for Expect: 100-continue (for send).
+       // If the request has an "Expect: 100-continue" header and
+       // the server responds 100 Continue, readLoop send a value
+       // to writeLoop via this chan.
+       continueCh chan<- struct{}
 }
 
 // A writeRequest is sent by the readLoop's goroutine to the
@@ -1078,6 +1118,11 @@ type requestAndChan struct {
 type writeRequest struct {
        req *transportRequest
        ch  chan<- error
+
+       // Optional blocking chan for Expect: 100-continue (for recieve).
+       // If not nil, writeLoop blocks sending request body until
+       // it receives from this chan.
+       continueCh <-chan struct{}
 }
 
 type httpError struct {
@@ -1143,6 +1188,11 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
                req.extraHeaders().Set("Accept-Encoding", "gzip")
        }
 
+       var continueCh chan struct{}
+       if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() {
+               continueCh = make(chan struct{}, 1)
+       }
+
        if pc.t.DisableKeepAlives {
                req.extraHeaders().Set("Connection", "close")
        }
@@ -1151,10 +1201,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        // in case the server decides to reply before reading our full
        // request body.
        writeErrCh := make(chan error, 1)
-       pc.writech <- writeRequest{req, writeErrCh}
+       pc.writech <- writeRequest{req, writeErrCh, continueCh}
 
        resc := make(chan responseAndError, 1)
-       pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
+       pc.reqch <- requestAndChan{req.Request, resc, requestedGzip, continueCh}
 
        var re responseAndError
        var respHeaderTimer <-chan time.Time
index 5811650b0e46722455322ea3fdc3d82a6ed3c855..f721fd555810f3c357bc157c8e84f42ab2330f43 100644 (file)
@@ -790,6 +790,94 @@ func TestTransportGzip(t *testing.T) {
        }
 }
 
+// If a request has Expect:100-continue header, the request blocks sending body until the first response.
+// Premature consumption of the request body should not be occurred.
+func TestTransportExpect100Continue(t *testing.T) {
+       defer afterTest(t)
+
+       ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+               switch req.URL.Path {
+               case "/100":
+                       // This endpoint implicitly responds 100 Continue and reads body.
+                       if _, err := io.Copy(ioutil.Discard, req.Body); err != nil {
+                               t.Error("Failed to read Body", err)
+                       }
+                       rw.WriteHeader(StatusOK)
+               case "/200":
+                       // Go 1.5 adds Connection: close header if the client expect
+                       // continue but not entire request body is consumed.
+                       rw.WriteHeader(StatusOK)
+               case "/500":
+                       rw.WriteHeader(StatusInternalServerError)
+               case "/keepalive":
+                       // This hijacked endpoint responds error without Connection:close.
+                       _, bufrw, err := rw.(Hijacker).Hijack()
+                       if err != nil {
+                               log.Fatal(err)
+                       }
+                       bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
+                       bufrw.WriteString("Content-Length: 0\r\n\r\n")
+                       bufrw.Flush()
+               case "/timeout":
+                       // This endpoint tries to read body without 100 (Continue) response.
+                       // After ExpectContinueTimeout, the reading will be started.
+                       conn, bufrw, err := rw.(Hijacker).Hijack()
+                       if err != nil {
+                               log.Fatal(err)
+                       }
+                       if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil {
+                               t.Error("Failed to read Body", err)
+                       }
+                       bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
+                       bufrw.Flush()
+                       conn.Close()
+               }
+
+       }))
+       defer ts.Close()
+
+       tests := []struct {
+               path   string
+               body   []byte
+               sent   int
+               status int
+       }{
+               {path: "/100", body: []byte("hello"), sent: 5, status: 200},       // Got 100 followed by 200, entire body is sent.
+               {path: "/200", body: []byte("hello"), sent: 0, status: 200},       // Got 200 without 100. body isn't sent.
+               {path: "/500", body: []byte("hello"), sent: 0, status: 500},       // Got 500 without 100. body isn't sent.
+               {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Althogh without Connection:close, body isn't sent.
+               {path: "/timeout", body: []byte("hello"), sent: 5, status: 200},   // Timeout exceeded and entire body is sent.
+       }
+
+       for i, v := range tests {
+               tr := &Transport{ExpectContinueTimeout: 2 * time.Second}
+               defer tr.CloseIdleConnections()
+               c := &Client{Transport: tr}
+
+               body := bytes.NewReader(v.body)
+               req, err := NewRequest("PUT", ts.URL+v.path, body)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               req.Header.Set("Expect", "100-continue")
+               req.ContentLength = int64(len(v.body))
+
+               resp, err := c.Do(req)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               resp.Body.Close()
+
+               sent := len(v.body) - body.Len()
+               if v.status != resp.StatusCode {
+                       t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
+               }
+               if v.sent != sent {
+                       t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
+               }
+       }
+}
+
 func TestTransportProxy(t *testing.T) {
        defer afterTest(t)
        ch := make(chan string, 1)