"crypto/tls"
"errors"
"fmt"
+ "golang_org/x/net/http/httpguts"
"io"
"net/textproto"
"net/url"
//
// The Body is automatically dechunked if the server replied
// with a "chunked" Transfer-Encoding.
+ //
+ // As of Go 1.12, the Body will be also implement io.Writer
+ // on a successful "101 Switching Protocols" responses,
+ // as used by WebSockets and HTTP/2's "h2c" mode.
Body io.ReadCloser
// ContentLength records the length of the associated content. The
r.Body.Close()
}
}
+
+// bodyIsWritable reports whether the Body supports writing. The
+// Transport returns Writable bodies for 101 Switching Protocols
+// responses.
+// The Transport uses this method to determine whether a persistent
+// connection is done being managed from its perspective. Once we
+// return a writable response body to a user, the net/http package is
+// done managing that connection.
+func (r *Response) bodyIsWritable() bool {
+ _, ok := r.Body.(io.Writer)
+ return ok
+}
+
+// isProtocolSwitch reports whether r is a response to a successful
+// protocol upgrade.
+func (r *Response) isProtocolSwitch() bool {
+ return r.StatusCode == StatusSwitchingProtocols &&
+ r.Header.Get("Upgrade") != "" &&
+ httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade")
+}
return err
}
+// errCallerOwnsConn is an internal sentinel error used when we hand
+// off a writable response.Body to the caller. We use this to prevent
+// closing a net.Conn that is now owned by the caller.
+var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn")
+
func (pc *persistConn) readLoop() {
closeErr := errReadLoopExiting // default value, if not changed below
defer func() {
pc.numExpectedResponses--
pc.mu.Unlock()
+ bodyWritable := resp.bodyIsWritable()
hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
- if resp.Close || rc.req.Close || resp.StatusCode <= 199 {
+ if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
// Don't do keep-alive on error if either party requested a close
// or we get an unexpected informational (1xx) response.
// StatusCode 100 is already handled above.
pc.wroteRequest() &&
tryPutIdleConn(trace)
+ if bodyWritable {
+ closeErr = errCallerOwnsConn
+ }
+
select {
case rc.ch <- responseAndError{res: resp}:
case <-rc.callerGone:
}
break
}
+ if resp.isProtocolSwitch() {
+ resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
+ }
+
resp.TLS = pc.tlsState
return
}
}
}
+func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser {
+ body := &readWriteCloserBody{ReadWriteCloser: rwc}
+ if br.Buffered() != 0 {
+ body.br = br
+ }
+ return body
+}
+
+// readWriteCloserBody is the Response.Body type used when we want to
+// give users write access to the Body through the underlying
+// connection (TCP, unless using custom dialers). This is then
+// the concrete type for a Response.Body on the 101 Switching
+// Protocols response, as used by WebSockets, h2c, etc.
+type readWriteCloserBody struct {
+ br *bufio.Reader // used until empty
+ io.ReadWriteCloser
+}
+
+func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
+ if b.br != nil {
+ if n := b.br.Buffered(); len(p) > n {
+ p = p[:n]
+ }
+ n, err = b.br.Read(p)
+ if b.br.Buffered() == 0 {
+ b.br = nil
+ }
+ return n, err
+ }
+ return b.ReadWriteCloser.Read(p)
+}
+
// nothingWrittenError wraps a write errors which ended up writing zero bytes.
type nothingWrittenError struct {
error
// freelist for http2. That's done by the
// alternate protocol's RoundTripper.
} else {
- pc.conn.Close()
+ if err != errCallerOwnsConn {
+ pc.conn.Close()
+ }
close(pc.closech)
}
}
t.Fatal("timeout")
}
}
+
+func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ done := make(chan struct{})
+ defer close(done)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
+ bs := bufio.NewScanner(conn)
+ bs.Scan()
+ fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
+ <-done
+ }))
+ defer cst.close()
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Header.Set("Upgrade", "foo")
+ req.Header.Set("Connection", "upgrade")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
+ }
+ defer rwc.Close()
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("expected readable input")
+ }
+ if got, want := bs.Text(), "Some buffered data"; got != want {
+ t.Errorf("read %q; want %q", got, want)
+ }
+ io.WriteString(rwc, "echo\n")
+ if !bs.Scan() {
+ t.Fatalf("expected another line")
+ }
+ if got, want := bs.Text(), "ECHO"; got != want {
+ t.Errorf("read %q; want %q", got, want)
+ }
+}