From: Yasuharu Goto Date: Thu, 14 May 2015 15:44:34 +0000 (+0900) Subject: net/http: Client support for Expect: 100-continue X-Git-Tag: go1.6beta1~822 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=dab143c8820151538fea908efe54e9625d1bc795;p=gostls13.git net/http: Client support for Expect: 100-continue Current http client doesn't support Expect: 100-continue request header(RFC2616-8/RFC7231-5.1.1). So even if the client have the header, the head of the request body is consumed prematurely. Those are my intentions to avoid premature consuming body in this change. - If http.Request header contains body and Expect: 100-continue header, it blocks sending body until it gets the first response. - If the first status code to the request were 100, the request starts sending body. Otherwise, sending body will be cancelled. - Tranport.ExpectContinueTimeout specifies the amount of the time to wait for the first response. Fixes #3665 Change-Id: I4c04f7d88573b08cabd146c4e822061764a7cd1f Reviewed-on: https://go-review.googlesource.com/10091 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- diff --git a/src/net/http/request.go b/src/net/http/request.go index 31fe45a4ed..8467decc18 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -354,7 +354,7 @@ const defaultUserAgent = "Go-http-client/1.1" // hasn't been set to "identity", Write adds "Transfer-Encoding: // chunked" to the header. Body is closed after it is sent. func (r *Request) Write(w io.Writer) error { - return r.write(w, false, nil) + return r.write(w, false, nil, nil) } // WriteProxy is like Write but writes the request in the form @@ -364,11 +364,12 @@ func (r *Request) Write(w io.Writer) error { // In either case, WriteProxy also writes a Host header, using // either r.Host or r.URL.Host. func (r *Request) WriteProxy(w io.Writer) error { - return r.write(w, true, nil) + return r.write(w, true, nil, nil) } // extraHeaders may be nil -func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error { +// waitForContinue may be nil +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) error { // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -458,6 +459,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err return err } + // Flush and wait for 100-continue if expected. + if waitForContinue != nil { + if bw, ok := w.(*bufio.Writer); ok { + err = bw.Flush() + if err != nil { + return err + } + } + + if !waitForContinue() { + req.closeBody() + return nil + } + } + // Write body and trailer err = tw.WriteBody(w) if err != nil { diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 70d1864605..31599237e0 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -36,7 +36,8 @@ var DefaultTransport RoundTripper = &Transport{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).Dial, - TLSHandshakeTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, } // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -113,6 +114,13 @@ type Transport struct { // time does not include the time to read the response body. ResponseHeaderTimeout time.Duration + // ExpectContinueTimeout, if non-zero, specifies the amount of + // time to wait for a server's first response headers after fully + // writing the request headers if the request has an + // "Expect: 100-continue" header. Zero means no timeout. + // This time does not include the time to send the request header. + ExpectContinueTimeout time.Duration + // TODO: tunable on global max cached connections // TODO: tunable on timeout on cached connections } @@ -894,13 +902,17 @@ func (pc *persistConn) readLoop() { var resp *Response if err == nil { resp, err = ReadResponse(pc.br, rc.req) - if err == nil && resp.StatusCode == 100 { - // Skip any 100-continue for now. - // TODO(bradfitz): if rc.req had "Expect: 100-continue", - // actually block the request body write and signal the - // writeLoop now to begin sending it. (Issue 2184) For now we - // eat it, since we're never expecting one. - resp, err = ReadResponse(pc.br, rc.req) + if err == nil { + if rc.continueCh != nil { + if resp.StatusCode == 100 { + rc.continueCh <- struct{}{} + } else { + close(rc.continueCh) + } + } + if resp.StatusCode == 100 { + resp, err = ReadResponse(pc.br, rc.req) + } } } @@ -1004,6 +1016,28 @@ func (pc *persistConn) readLoop() { pc.close() } +// waitForContinue returns the function to block until +// any response, timeout or connection close. After any of them, +// the function returns a bool which indicates if the body should be sent. +func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool { + if continueCh == nil { + return nil + } + return func() bool { + timer := time.NewTimer(pc.t.ExpectContinueTimeout) + defer timer.Stop() + + select { + case _, ok := <-continueCh: + return ok + case <-timer.C: + return true + case <-pc.closech: + return false + } + } +} + func (pc *persistConn) writeLoop() { for { select { @@ -1012,7 +1046,7 @@ func (pc *persistConn) writeLoop() { wr.ch <- errors.New("http: can't write HTTP request on broken connection") continue } - err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) if err == nil { err = pc.bw.Flush() } @@ -1069,6 +1103,12 @@ type requestAndChan struct { // Accept-Encoding gzip header? only if it we set it do // we transparently decode the gzip. addedGzip bool + + // Optional blocking chan for Expect: 100-continue (for send). + // If the request has an "Expect: 100-continue" header and + // the server responds 100 Continue, readLoop send a value + // to writeLoop via this chan. + continueCh chan<- struct{} } // A writeRequest is sent by the readLoop's goroutine to the @@ -1078,6 +1118,11 @@ type requestAndChan struct { type writeRequest struct { req *transportRequest ch chan<- error + + // Optional blocking chan for Expect: 100-continue (for recieve). + // If not nil, writeLoop blocks sending request body until + // it receives from this chan. + continueCh <-chan struct{} } type httpError struct { @@ -1143,6 +1188,11 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.extraHeaders().Set("Accept-Encoding", "gzip") } + var continueCh chan struct{} + if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() { + continueCh = make(chan struct{}, 1) + } + if pc.t.DisableKeepAlives { req.extraHeaders().Set("Connection", "close") } @@ -1151,10 +1201,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // in case the server decides to reply before reading our full // request body. writeErrCh := make(chan error, 1) - pc.writech <- writeRequest{req, writeErrCh} + pc.writech <- writeRequest{req, writeErrCh, continueCh} resc := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + pc.reqch <- requestAndChan{req.Request, resc, requestedGzip, continueCh} var re responseAndError var respHeaderTimer <-chan time.Time diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 5811650b0e..f721fd5558 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -790,6 +790,94 @@ func TestTransportGzip(t *testing.T) { } } +// If a request has Expect:100-continue header, the request blocks sending body until the first response. +// Premature consumption of the request body should not be occurred. +func TestTransportExpect100Continue(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + switch req.URL.Path { + case "/100": + // This endpoint implicitly responds 100 Continue and reads body. + if _, err := io.Copy(ioutil.Discard, req.Body); err != nil { + t.Error("Failed to read Body", err) + } + rw.WriteHeader(StatusOK) + case "/200": + // Go 1.5 adds Connection: close header if the client expect + // continue but not entire request body is consumed. + rw.WriteHeader(StatusOK) + case "/500": + rw.WriteHeader(StatusInternalServerError) + case "/keepalive": + // This hijacked endpoint responds error without Connection:close. + _, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") + bufrw.WriteString("Content-Length: 0\r\n\r\n") + bufrw.Flush() + case "/timeout": + // This endpoint tries to read body without 100 (Continue) response. + // After ExpectContinueTimeout, the reading will be started. + conn, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil { + t.Error("Failed to read Body", err) + } + bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") + bufrw.Flush() + conn.Close() + } + + })) + defer ts.Close() + + tests := []struct { + path string + body []byte + sent int + status int + }{ + {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. + {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. + {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. + {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Althogh without Connection:close, body isn't sent. + {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. + } + + for i, v := range tests { + tr := &Transport{ExpectContinueTimeout: 2 * time.Second} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + body := bytes.NewReader(v.body) + req, err := NewRequest("PUT", ts.URL+v.path, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Expect", "100-continue") + req.ContentLength = int64(len(v.body)) + + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + sent := len(v.body) - body.Len() + if v.status != resp.StatusCode { + t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) + } + if v.sent != sent { + t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) + } + } +} + func TestTransportProxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1)