spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
- <-errc
+
+ // wait until both copy functions have sent on the error channel
+ err := <-errc
+ if err == nil {
+ err = <-errc
+ }
+ if err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err))
+ }
}
// switchProtocolCopier exists so goroutines proxying data back and
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
- _, err := io.Copy(c.user, c.backend)
- errc <- err
+ if _, err := io.Copy(c.user, c.backend); err != nil {
+ errc <- err
+ return
+ }
+
+ // backend conn has reached EOF so propogate close write to user conn
+ if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
+ errc <- wc.CloseWrite()
+ return
+ }
+
+ errc <- nil
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
- _, err := io.Copy(c.backend, c.user)
- errc <- err
+ if _, err := io.Copy(c.backend, c.user); err != nil {
+ errc <- err
+ return
+ }
+
+ // user conn has reached EOF so propogate close write to backend conn
+ if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
+ errc <- wc.CloseWrite()
+ return
+ }
+
+ errc <- nil
}
func cleanQueryParams(s string) string {
"fmt"
"io"
"log"
+ "net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
}
}
+func TestReverseProxyWebSocketHalfTCP(t *testing.T) {
+ // Issue #35892: support TCP half-close when HTTP is upgraded in the ReverseProxy.
+ // Specifically testing:
+ // - the communication through the reverse proxy when the client or server closes
+ // either the read or write streams
+ // - that closing the write stream is propagated through the proxy and results in reading
+ // EOF at the other end of the connection
+
+ mustRead := func(t *testing.T, conn *net.TCPConn, msg string) {
+ b := make([]byte, len(msg))
+ if _, err := conn.Read(b); err != nil {
+ t.Errorf("failed to read: %v", err)
+ }
+
+ if got, want := string(b), msg; got != want {
+ t.Errorf("got %#q, want %#q", got, want)
+ }
+ }
+
+ mustReadError := func(t *testing.T, conn *net.TCPConn, e error) {
+ b := make([]byte, 1)
+ if _, err := conn.Read(b); !errors.Is(err, e) {
+ t.Errorf("failed to read error: %v", err)
+ }
+ }
+
+ mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) {
+ if _, err := conn.Write([]byte(msg)); err != nil {
+ t.Errorf("failed to write: %v", err)
+ }
+ }
+
+ mustCloseRead := func(t *testing.T, conn *net.TCPConn) {
+ if err := conn.CloseRead(); err != nil {
+ t.Errorf("failed to CloseRead: %v", err)
+ }
+ }
+
+ mustCloseWrite := func(t *testing.T, conn *net.TCPConn) {
+ if err := conn.CloseWrite(); err != nil {
+ t.Errorf("failed to CloseWrite: %v", err)
+ }
+ }
+
+ tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){
+ "server close read": func(t *testing.T, cli, srv *net.TCPConn) {
+ mustCloseRead(t, srv)
+ mustWrite(t, srv, "server sends")
+ mustRead(t, cli, "server sends")
+ },
+ "server close write": func(t *testing.T, cli, srv *net.TCPConn) {
+ mustCloseWrite(t, srv)
+ mustWrite(t, cli, "client sends")
+ mustRead(t, srv, "client sends")
+ mustReadError(t, cli, io.EOF)
+ },
+ "client close read": func(t *testing.T, cli, srv *net.TCPConn) {
+ mustCloseRead(t, cli)
+ mustWrite(t, cli, "client sends")
+ mustRead(t, srv, "client sends")
+ },
+ "client close write": func(t *testing.T, cli, srv *net.TCPConn) {
+ mustCloseWrite(t, cli)
+ mustWrite(t, srv, "server sends")
+ mustRead(t, cli, "server sends")
+ mustReadError(t, srv, io.EOF)
+ },
+ }
+
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) {
+ var srv *net.TCPConn
+
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if g, ws := upgradeType(r.Header), "websocket"; g != ws {
+ t.Fatalf("Unexpected upgrade type %q, want %q", g, ws)
+ }
+
+ conn, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ conn.Close()
+ t.Fatalf("hijack failed: %v", err)
+ }
+
+ var ok bool
+ if srv, ok = conn.(*net.TCPConn); !ok {
+ conn.Close()
+ t.Fatal("conn is not a TCPConn")
+ }
+
+ upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
+ if _, err := io.WriteString(srv, upgradeMsg); err != nil {
+ srv.Close()
+ t.Fatalf("backend upgrade failed: %v", err)
+ }
+ }))
+ defer backendServer.Close()
+
+ backendURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(backendURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rproxy.ServeHTTP(rw, req)
+ }))
+ defer frontendProxy.Close()
+
+ frontendURL, _ := url.Parse(frontendProxy.URL)
+ addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host)
+ if err != nil {
+ t.Fatalf("failed to resolve TCP address: %v", err)
+ }
+ cli, err := net.DialTCP("tcp", nil, addr)
+ if err != nil {
+ t.Fatalf("failed to dial TCP address: %v", err)
+ }
+ defer cli.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+ if err := req.Write(cli); err != nil {
+ t.Fatalf("failed to write request: %v", err)
+ }
+
+ br := bufio.NewReader(cli)
+ resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ }
+ if resp.StatusCode != 101 {
+ t.Fatalf("status code not 101: %v", resp.StatusCode)
+ }
+ if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
+ strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
+ t.Fatalf("frontend upgrade failed")
+ }
+ defer srv.Close()
+
+ test(t, cli, srv)
+ })
+ }
+}
+
func TestUnannouncedTrailer(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)