From a053e79024f56a2a64728b1287509e880fad203e Mon Sep 17 00:00:00 2001 From: Mark Wakefield Date: Thu, 2 Jan 2025 19:18:01 +0000 Subject: [PATCH] net/http: support TCP half-close when HTTP is upgraded in ReverseProxy This CL propagates closing the write stream from either side of the reverse proxy and ensures the proxy waits for both copy-to and the copy-from the backend to complete. The new unit test checks communication through the reverse proxy when the backend or frontend closes either the read or write streams. That closing the write stream is propagated through the proxy from either the backend or the frontend. That closing the read stream is not propagated through the proxy. Fixes #35892 Change-Id: I83ce377df66a0f17b9ba2b53caf9e4991a95f6a0 Reviewed-on: https://go-review.googlesource.com/c/go/+/637939 Reviewed-by: Michael Pratt Reviewed-by: Sean Liao Auto-Submit: Sean Liao LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Reviewed-by: Matej Kramny --- src/net/http/httputil/reverseproxy.go | 38 +++++- src/net/http/httputil/reverseproxy_test.go | 144 +++++++++++++++++++++ src/net/http/transport.go | 7 + 3 files changed, 184 insertions(+), 5 deletions(-) diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 15e9684708..bbb7c13d41 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -793,7 +793,15 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R spc := switchProtocolCopier{user: conn, backend: backConn} go spc.copyToBackend(errc) go spc.copyFromBackend(errc) - <-errc + + // wait until both copy functions have sent on the error channel + err := <-errc + if err == nil { + err = <-errc + } + if err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err)) + } } // switchProtocolCopier exists so goroutines proxying data back and @@ -803,13 +811,33 @@ type switchProtocolCopier struct { } func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.Copy(c.user, c.backend) - errc <- err + if _, err := io.Copy(c.user, c.backend); err != nil { + errc <- err + return + } + + // backend conn has reached EOF so propogate close write to user conn + if wc, ok := c.user.(interface{ CloseWrite() error }); ok { + errc <- wc.CloseWrite() + return + } + + errc <- nil } func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.Copy(c.backend, c.user) - errc <- err + if _, err := io.Copy(c.backend, c.user); err != nil { + errc <- err + return + } + + // user conn has reached EOF so propogate close write to backend conn + if wc, ok := c.backend.(interface{ CloseWrite() error }); ok { + errc <- wc.CloseWrite() + return + } + + errc <- nil } func cleanQueryParams(s string) string { diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index c618f6f19e..f089ce0574 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/http/httptest" "net/http/httptrace" @@ -1551,6 +1552,149 @@ func TestReverseProxyWebSocketCancellation(t *testing.T) { } } +func TestReverseProxyWebSocketHalfTCP(t *testing.T) { + // Issue #35892: support TCP half-close when HTTP is upgraded in the ReverseProxy. + // Specifically testing: + // - the communication through the reverse proxy when the client or server closes + // either the read or write streams + // - that closing the write stream is propagated through the proxy and results in reading + // EOF at the other end of the connection + + mustRead := func(t *testing.T, conn *net.TCPConn, msg string) { + b := make([]byte, len(msg)) + if _, err := conn.Read(b); err != nil { + t.Errorf("failed to read: %v", err) + } + + if got, want := string(b), msg; got != want { + t.Errorf("got %#q, want %#q", got, want) + } + } + + mustReadError := func(t *testing.T, conn *net.TCPConn, e error) { + b := make([]byte, 1) + if _, err := conn.Read(b); !errors.Is(err, e) { + t.Errorf("failed to read error: %v", err) + } + } + + mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) { + if _, err := conn.Write([]byte(msg)); err != nil { + t.Errorf("failed to write: %v", err) + } + } + + mustCloseRead := func(t *testing.T, conn *net.TCPConn) { + if err := conn.CloseRead(); err != nil { + t.Errorf("failed to CloseRead: %v", err) + } + } + + mustCloseWrite := func(t *testing.T, conn *net.TCPConn) { + if err := conn.CloseWrite(); err != nil { + t.Errorf("failed to CloseWrite: %v", err) + } + } + + tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){ + "server close read": func(t *testing.T, cli, srv *net.TCPConn) { + mustCloseRead(t, srv) + mustWrite(t, srv, "server sends") + mustRead(t, cli, "server sends") + }, + "server close write": func(t *testing.T, cli, srv *net.TCPConn) { + mustCloseWrite(t, srv) + mustWrite(t, cli, "client sends") + mustRead(t, srv, "client sends") + mustReadError(t, cli, io.EOF) + }, + "client close read": func(t *testing.T, cli, srv *net.TCPConn) { + mustCloseRead(t, cli) + mustWrite(t, cli, "client sends") + mustRead(t, srv, "client sends") + }, + "client close write": func(t *testing.T, cli, srv *net.TCPConn) { + mustCloseWrite(t, cli) + mustWrite(t, srv, "server sends") + mustRead(t, cli, "server sends") + mustReadError(t, srv, io.EOF) + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + var srv *net.TCPConn + + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if g, ws := upgradeType(r.Header), "websocket"; g != ws { + t.Fatalf("Unexpected upgrade type %q, want %q", g, ws) + } + + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + conn.Close() + t.Fatalf("hijack failed: %v", err) + } + + var ok bool + if srv, ok = conn.(*net.TCPConn); !ok { + conn.Close() + t.Fatal("conn is not a TCPConn") + } + + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" + if _, err := io.WriteString(srv, upgradeMsg); err != nil { + srv.Close() + t.Fatalf("backend upgrade failed: %v", err) + } + })) + defer backendServer.Close() + + backendURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(backendURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rproxy.ServeHTTP(rw, req) + })) + defer frontendProxy.Close() + + frontendURL, _ := url.Parse(frontendProxy.URL) + addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host) + if err != nil { + t.Fatalf("failed to resolve TCP address: %v", err) + } + cli, err := net.DialTCP("tcp", nil, addr) + if err != nil { + t.Fatalf("failed to dial TCP address: %v", err) + } + defer cli.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + if err := req.Write(cli); err != nil { + t.Fatalf("failed to write request: %v", err) + } + + br := bufio.NewReader(cli) + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + if resp.StatusCode != 101 { + t.Fatalf("status code not 101: %v", resp.StatusCode) + } + if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { + t.Fatalf("frontend upgrade failed") + } + defer srv.Close() + + test(t, cli, srv) + }) + } +} + func TestUnannouncedTrailer(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 4a6c928827..59a125cbc7 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -2575,6 +2575,13 @@ func (b *readWriteCloserBody) Read(p []byte) (n int, err error) { return b.ReadWriteCloser.Read(p) } +func (b *readWriteCloserBody) CloseWrite() error { + if cw, ok := b.ReadWriteCloser.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return fmt.Errorf("CloseWrite: %w", ErrNotSupported) +} + // nothingWrittenError wraps a write errors which ended up writing zero bytes. type nothingWrittenError struct { error -- 2.50.0