]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make Transport return Writable Response.Body on protocol switch
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 24 Aug 2018 19:08:28 +0000 (19:08 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 25 Aug 2018 22:46:42 +0000 (22:46 +0000)
Updates #26937
Updates #17227

Change-Id: I79865938b05c219e1947822e60e4f52bb2604b70
Reviewed-on: https://go-review.googlesource.com/131279
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/response.go
src/net/http/transport.go
src/net/http/transport_test.go

index bf1e13c8ae2f17075a85356a713828ac2afc6779..b3ca56c4190dbe744d2fee7a15c612fc256ba651 100644 (file)
@@ -12,6 +12,7 @@ import (
        "crypto/tls"
        "errors"
        "fmt"
+       "golang_org/x/net/http/httpguts"
        "io"
        "net/textproto"
        "net/url"
@@ -63,6 +64,10 @@ type Response struct {
        //
        // The Body is automatically dechunked if the server replied
        // with a "chunked" Transfer-Encoding.
+       //
+       // As of Go 1.12, the Body will be also implement io.Writer
+       // on a successful "101 Switching Protocols" responses,
+       // as used by WebSockets and HTTP/2's "h2c" mode.
        Body io.ReadCloser
 
        // ContentLength records the length of the associated content. The
@@ -333,3 +338,23 @@ func (r *Response) closeBody() {
                r.Body.Close()
        }
 }
+
+// bodyIsWritable reports whether the Body supports writing. The
+// Transport returns Writable bodies for 101 Switching Protocols
+// responses.
+// The Transport uses this method to determine whether a persistent
+// connection is done being managed from its perspective. Once we
+// return a writable response body to a user, the net/http package is
+// done managing that connection.
+func (r *Response) bodyIsWritable() bool {
+       _, ok := r.Body.(io.Writer)
+       return ok
+}
+
+// isProtocolSwitch reports whether r is a response to a successful
+// protocol upgrade.
+func (r *Response) isProtocolSwitch() bool {
+       return r.StatusCode == StatusSwitchingProtocols &&
+               r.Header.Get("Upgrade") != "" &&
+               httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade")
+}
index 40947baf87a42290fc722b510d1e2e829c97cb5a..ffe4cdc0d6dba647a25f487286a57836fa010785 100644 (file)
@@ -1607,6 +1607,11 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte
        return err
 }
 
+// errCallerOwnsConn is an internal sentinel error used when we hand
+// off a writable response.Body to the caller. We use this to prevent
+// closing a net.Conn that is now owned by the caller.
+var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn")
+
 func (pc *persistConn) readLoop() {
        closeErr := errReadLoopExiting // default value, if not changed below
        defer func() {
@@ -1681,9 +1686,10 @@ func (pc *persistConn) readLoop() {
                pc.numExpectedResponses--
                pc.mu.Unlock()
 
+               bodyWritable := resp.bodyIsWritable()
                hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
 
-               if resp.Close || rc.req.Close || resp.StatusCode <= 199 {
+               if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
                        // Don't do keep-alive on error if either party requested a close
                        // or we get an unexpected informational (1xx) response.
                        // StatusCode 100 is already handled above.
@@ -1704,6 +1710,10 @@ func (pc *persistConn) readLoop() {
                                pc.wroteRequest() &&
                                tryPutIdleConn(trace)
 
+                       if bodyWritable {
+                               closeErr = errCallerOwnsConn
+                       }
+
                        select {
                        case rc.ch <- responseAndError{res: resp}:
                        case <-rc.callerGone:
@@ -1848,6 +1858,10 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
                }
                break
        }
+       if resp.isProtocolSwitch() {
+               resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
+       }
+
        resp.TLS = pc.tlsState
        return
 }
@@ -1874,6 +1888,38 @@ func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
        }
 }
 
+func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser {
+       body := &readWriteCloserBody{ReadWriteCloser: rwc}
+       if br.Buffered() != 0 {
+               body.br = br
+       }
+       return body
+}
+
+// readWriteCloserBody is the Response.Body type used when we want to
+// give users write access to the Body through the underlying
+// connection (TCP, unless using custom dialers). This is then
+// the concrete type for a Response.Body on the 101 Switching
+// Protocols response, as used by WebSockets, h2c, etc.
+type readWriteCloserBody struct {
+       br *bufio.Reader // used until empty
+       io.ReadWriteCloser
+}
+
+func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
+       if b.br != nil {
+               if n := b.br.Buffered(); len(p) > n {
+                       p = p[:n]
+               }
+               n, err = b.br.Read(p)
+               if b.br.Buffered() == 0 {
+                       b.br = nil
+               }
+               return n, err
+       }
+       return b.ReadWriteCloser.Read(p)
+}
+
 // nothingWrittenError wraps a write errors which ended up writing zero bytes.
 type nothingWrittenError struct {
        error
@@ -2193,7 +2239,9 @@ func (pc *persistConn) closeLocked(err error) {
                        // freelist for http2. That's done by the
                        // alternate protocol's RoundTripper.
                } else {
-                       pc.conn.Close()
+                       if err != errCallerOwnsConn {
+                               pc.conn.Close()
+                       }
                        close(pc.closech)
                }
        }
index 73e6e3033177f2c36fd449489c321c40c5aafc66..327b3b4996c902a70b636b9251cd0e951ae39ba2 100644 (file)
@@ -4836,3 +4836,54 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
                t.Fatal("timeout")
        }
 }
+
+func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+       done := make(chan struct{})
+       defer close(done)
+       cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               conn, _, err := w.(Hijacker).Hijack()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               defer conn.Close()
+               io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
+               bs := bufio.NewScanner(conn)
+               bs.Scan()
+               fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
+               <-done
+       }))
+       defer cst.close()
+
+       req, _ := NewRequest("GET", cst.ts.URL, nil)
+       req.Header.Set("Upgrade", "foo")
+       req.Header.Set("Connection", "upgrade")
+       res, err := cst.c.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if res.StatusCode != 101 {
+               t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
+       }
+       rwc, ok := res.Body.(io.ReadWriteCloser)
+       if !ok {
+               t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
+       }
+       defer rwc.Close()
+       bs := bufio.NewScanner(rwc)
+       if !bs.Scan() {
+               t.Fatalf("expected readable input")
+       }
+       if got, want := bs.Text(), "Some buffered data"; got != want {
+               t.Errorf("read %q; want %q", got, want)
+       }
+       io.WriteString(rwc, "echo\n")
+       if !bs.Scan() {
+               t.Fatalf("expected another line")
+       }
+       if got, want := bs.Text(), "ECHO"; got != want {
+               t.Errorf("read %q; want %q", got, want)
+       }
+}