]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: support TCP half-close when HTTP is upgraded in ReverseProxy
authorMark Wakefield <mark.a.wakefield@gmail.com>
Thu, 2 Jan 2025 19:18:01 +0000 (19:18 +0000)
committerGopher Robot <gobot@golang.org>
Tue, 4 Mar 2025 12:59:32 +0000 (04:59 -0800)
This CL propagates closing the write stream from either side of the
reverse proxy and ensures the proxy waits for both copy-to and the
copy-from the backend to complete.

The new unit test checks communication through the reverse proxy when
the backend or frontend closes either the read or write streams.
That closing the write stream is propagated through the proxy from
either the backend or the frontend. That closing the read stream is
not propagated through the proxy.

Fixes #35892

Change-Id: I83ce377df66a0f17b9ba2b53caf9e4991a95f6a0
Reviewed-on: https://go-review.googlesource.com/c/go/+/637939
Reviewed-by: Michael Pratt <mpratt@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
Auto-Submit: Sean Liao <sean@liao.dev>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Matej Kramny <matejkramny@gmail.com>
src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go
src/net/http/transport.go

index 15e968470815d475f309675e1f3c6af0c8e9f463..bbb7c13d415c3ce0c0398a108ce8d503fea6f199 100644 (file)
@@ -793,7 +793,15 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
        spc := switchProtocolCopier{user: conn, backend: backConn}
        go spc.copyToBackend(errc)
        go spc.copyFromBackend(errc)
-       <-errc
+
+       // wait until both copy functions have sent on the error channel
+       err := <-errc
+       if err == nil {
+               err = <-errc
+       }
+       if err != nil {
+               p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err))
+       }
 }
 
 // switchProtocolCopier exists so goroutines proxying data back and
@@ -803,13 +811,33 @@ type switchProtocolCopier struct {
 }
 
 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
-       _, err := io.Copy(c.user, c.backend)
-       errc <- err
+       if _, err := io.Copy(c.user, c.backend); err != nil {
+               errc <- err
+               return
+       }
+
+       // backend conn has reached EOF so propogate close write to user conn
+       if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
+               errc <- wc.CloseWrite()
+               return
+       }
+
+       errc <- nil
 }
 
 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
-       _, err := io.Copy(c.backend, c.user)
-       errc <- err
+       if _, err := io.Copy(c.backend, c.user); err != nil {
+               errc <- err
+               return
+       }
+
+       // user conn has reached EOF so propogate close write to backend conn
+       if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
+               errc <- wc.CloseWrite()
+               return
+       }
+
+       errc <- nil
 }
 
 func cleanQueryParams(s string) string {
index c618f6f19e37065c30e9d079b1ecaf8a0218fa12..f089ce0574be7156f26018abca9c753461f776e9 100644 (file)
@@ -14,6 +14,7 @@ import (
        "fmt"
        "io"
        "log"
+       "net"
        "net/http"
        "net/http/httptest"
        "net/http/httptrace"
@@ -1551,6 +1552,149 @@ func TestReverseProxyWebSocketCancellation(t *testing.T) {
        }
 }
 
+func TestReverseProxyWebSocketHalfTCP(t *testing.T) {
+       // Issue #35892: support TCP half-close when HTTP is upgraded in the ReverseProxy.
+       // Specifically testing:
+       // - the communication through the reverse proxy when the client or server closes
+       //   either the read or write streams
+       // - that closing the write stream is propagated through the proxy and results in reading
+       //   EOF at the other end of the connection
+
+       mustRead := func(t *testing.T, conn *net.TCPConn, msg string) {
+               b := make([]byte, len(msg))
+               if _, err := conn.Read(b); err != nil {
+                       t.Errorf("failed to read: %v", err)
+               }
+
+               if got, want := string(b), msg; got != want {
+                       t.Errorf("got %#q, want %#q", got, want)
+               }
+       }
+
+       mustReadError := func(t *testing.T, conn *net.TCPConn, e error) {
+               b := make([]byte, 1)
+               if _, err := conn.Read(b); !errors.Is(err, e) {
+                       t.Errorf("failed to read error: %v", err)
+               }
+       }
+
+       mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) {
+               if _, err := conn.Write([]byte(msg)); err != nil {
+                       t.Errorf("failed to write: %v", err)
+               }
+       }
+
+       mustCloseRead := func(t *testing.T, conn *net.TCPConn) {
+               if err := conn.CloseRead(); err != nil {
+                       t.Errorf("failed to CloseRead: %v", err)
+               }
+       }
+
+       mustCloseWrite := func(t *testing.T, conn *net.TCPConn) {
+               if err := conn.CloseWrite(); err != nil {
+                       t.Errorf("failed to CloseWrite: %v", err)
+               }
+       }
+
+       tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){
+               "server close read": func(t *testing.T, cli, srv *net.TCPConn) {
+                       mustCloseRead(t, srv)
+                       mustWrite(t, srv, "server sends")
+                       mustRead(t, cli, "server sends")
+               },
+               "server close write": func(t *testing.T, cli, srv *net.TCPConn) {
+                       mustCloseWrite(t, srv)
+                       mustWrite(t, cli, "client sends")
+                       mustRead(t, srv, "client sends")
+                       mustReadError(t, cli, io.EOF)
+               },
+               "client close read": func(t *testing.T, cli, srv *net.TCPConn) {
+                       mustCloseRead(t, cli)
+                       mustWrite(t, cli, "client sends")
+                       mustRead(t, srv, "client sends")
+               },
+               "client close write": func(t *testing.T, cli, srv *net.TCPConn) {
+                       mustCloseWrite(t, cli)
+                       mustWrite(t, srv, "server sends")
+                       mustRead(t, cli, "server sends")
+                       mustReadError(t, srv, io.EOF)
+               },
+       }
+
+       for name, test := range tests {
+               t.Run(name, func(t *testing.T) {
+                       var srv *net.TCPConn
+
+                       backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+                               if g, ws := upgradeType(r.Header), "websocket"; g != ws {
+                                       t.Fatalf("Unexpected upgrade type %q, want %q", g, ws)
+                               }
+
+                               conn, _, err := w.(http.Hijacker).Hijack()
+                               if err != nil {
+                                       conn.Close()
+                                       t.Fatalf("hijack failed: %v", err)
+                               }
+
+                               var ok bool
+                               if srv, ok = conn.(*net.TCPConn); !ok {
+                                       conn.Close()
+                                       t.Fatal("conn is not a TCPConn")
+                               }
+
+                               upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
+                               if _, err := io.WriteString(srv, upgradeMsg); err != nil {
+                                       srv.Close()
+                                       t.Fatalf("backend upgrade failed: %v", err)
+                               }
+                       }))
+                       defer backendServer.Close()
+
+                       backendURL, _ := url.Parse(backendServer.URL)
+                       rproxy := NewSingleHostReverseProxy(backendURL)
+                       rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+                       frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+                               rproxy.ServeHTTP(rw, req)
+                       }))
+                       defer frontendProxy.Close()
+
+                       frontendURL, _ := url.Parse(frontendProxy.URL)
+                       addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host)
+                       if err != nil {
+                               t.Fatalf("failed to resolve TCP address: %v", err)
+                       }
+                       cli, err := net.DialTCP("tcp", nil, addr)
+                       if err != nil {
+                               t.Fatalf("failed to dial TCP address: %v", err)
+                       }
+                       defer cli.Close()
+
+                       req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+                       req.Header.Set("Connection", "Upgrade")
+                       req.Header.Set("Upgrade", "websocket")
+                       if err := req.Write(cli); err != nil {
+                               t.Fatalf("failed to write request: %v", err)
+                       }
+
+                       br := bufio.NewReader(cli)
+                       resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
+                       if err != nil {
+                               t.Fatalf("failed to read response: %v", err)
+                       }
+                       if resp.StatusCode != 101 {
+                               t.Fatalf("status code not 101: %v", resp.StatusCode)
+                       }
+                       if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
+                               strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
+                               t.Fatalf("frontend upgrade failed")
+                       }
+                       defer srv.Close()
+
+                       test(t, cli, srv)
+               })
+       }
+}
+
 func TestUnannouncedTrailer(t *testing.T) {
        backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                w.WriteHeader(http.StatusOK)
index 4a6c9288270b6c8021d0fc29d454b11e9abab5c9..59a125cbc78a54aadfecb6358691659dbd8096d9 100644 (file)
@@ -2575,6 +2575,13 @@ func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
        return b.ReadWriteCloser.Read(p)
 }
 
+func (b *readWriteCloserBody) CloseWrite() error {
+       if cw, ok := b.ReadWriteCloser.(interface{ CloseWrite() error }); ok {
+               return cw.CloseWrite()
+       }
+       return fmt.Errorf("CloseWrite: %w", ErrNotSupported)
+}
+
 // nothingWrittenError wraps a write errors which ended up writing zero bytes.
 type nothingWrittenError struct {
        error