From 541620409dee210c5498cc38433dcf690f58f888 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 24 Aug 2018 19:08:28 +0000 Subject: [PATCH] net/http: make Transport return Writable Response.Body on protocol switch Updates #26937 Updates #17227 Change-Id: I79865938b05c219e1947822e60e4f52bb2604b70 Reviewed-on: https://go-review.googlesource.com/131279 Reviewed-by: Brad Fitzpatrick --- src/net/http/response.go | 25 ++++++++++++++++ src/net/http/transport.go | 52 ++++++++++++++++++++++++++++++++-- src/net/http/transport_test.go | 51 +++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 2 deletions(-) diff --git a/src/net/http/response.go b/src/net/http/response.go index bf1e13c8ae..b3ca56c419 100644 --- a/src/net/http/response.go +++ b/src/net/http/response.go @@ -12,6 +12,7 @@ import ( "crypto/tls" "errors" "fmt" + "golang_org/x/net/http/httpguts" "io" "net/textproto" "net/url" @@ -63,6 +64,10 @@ type Response struct { // // The Body is automatically dechunked if the server replied // with a "chunked" Transfer-Encoding. + // + // As of Go 1.12, the Body will be also implement io.Writer + // on a successful "101 Switching Protocols" responses, + // as used by WebSockets and HTTP/2's "h2c" mode. Body io.ReadCloser // ContentLength records the length of the associated content. The @@ -333,3 +338,23 @@ func (r *Response) closeBody() { r.Body.Close() } } + +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} + +// isProtocolSwitch reports whether r is a response to a successful +// protocol upgrade. +func (r *Response) isProtocolSwitch() bool { + return r.StatusCode == StatusSwitchingProtocols && + r.Header.Get("Upgrade") != "" && + httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 40947baf87..ffe4cdc0d6 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1607,6 +1607,11 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return err } +// errCallerOwnsConn is an internal sentinel error used when we hand +// off a writable response.Body to the caller. We use this to prevent +// closing a net.Conn that is now owned by the caller. +var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn") + func (pc *persistConn) readLoop() { closeErr := errReadLoopExiting // default value, if not changed below defer func() { @@ -1681,9 +1686,10 @@ func (pc *persistConn) readLoop() { pc.numExpectedResponses-- pc.mu.Unlock() + bodyWritable := resp.bodyIsWritable() hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 - if resp.Close || rc.req.Close || resp.StatusCode <= 199 { + if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { // Don't do keep-alive on error if either party requested a close // or we get an unexpected informational (1xx) response. // StatusCode 100 is already handled above. @@ -1704,6 +1710,10 @@ func (pc *persistConn) readLoop() { pc.wroteRequest() && tryPutIdleConn(trace) + if bodyWritable { + closeErr = errCallerOwnsConn + } + select { case rc.ch <- responseAndError{res: resp}: case <-rc.callerGone: @@ -1848,6 +1858,10 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr } break } + if resp.isProtocolSwitch() { + resp.Body = newReadWriteCloserBody(pc.br, pc.conn) + } + resp.TLS = pc.tlsState return } @@ -1874,6 +1888,38 @@ func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool { } } +func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser { + body := &readWriteCloserBody{ReadWriteCloser: rwc} + if br.Buffered() != 0 { + body.br = br + } + return body +} + +// readWriteCloserBody is the Response.Body type used when we want to +// give users write access to the Body through the underlying +// connection (TCP, unless using custom dialers). This is then +// the concrete type for a Response.Body on the 101 Switching +// Protocols response, as used by WebSockets, h2c, etc. +type readWriteCloserBody struct { + br *bufio.Reader // used until empty + io.ReadWriteCloser +} + +func (b *readWriteCloserBody) Read(p []byte) (n int, err error) { + if b.br != nil { + if n := b.br.Buffered(); len(p) > n { + p = p[:n] + } + n, err = b.br.Read(p) + if b.br.Buffered() == 0 { + b.br = nil + } + return n, err + } + return b.ReadWriteCloser.Read(p) +} + // nothingWrittenError wraps a write errors which ended up writing zero bytes. type nothingWrittenError struct { error @@ -2193,7 +2239,9 @@ func (pc *persistConn) closeLocked(err error) { // freelist for http2. That's done by the // alternate protocol's RoundTripper. } else { - pc.conn.Close() + if err != errCallerOwnsConn { + pc.conn.Close() + } close(pc.closech) } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 73e6e30331..327b3b4996 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -4836,3 +4836,54 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { t.Fatal("timeout") } } + +func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { + setParallel(t) + defer afterTest(t) + done := make(chan struct{}) + defer close(done) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") + bs := bufio.NewScanner(conn) + bs.Scan() + fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) + <-done + })) + defer cst.close() + + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Header.Set("Upgrade", "foo") + req.Header.Set("Connection", "upgrade") + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) + } + defer rwc.Close() + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("expected readable input") + } + if got, want := bs.Text(), "Some buffered data"; got != want { + t.Errorf("read %q; want %q", got, want) + } + io.WriteString(rwc, "echo\n") + if !bs.Scan() { + t.Fatalf("expected another line") + } + if got, want := bs.Text(), "ECHO"; got != want { + t.Errorf("read %q; want %q", got, want) + } +} -- 2.50.0