]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: don't call WriteHeader after Hijack
authorDamien Neil <dneil@google.com>
Wed, 19 Mar 2025 16:26:31 +0000 (09:26 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 19 Mar 2025 16:57:03 +0000 (09:57 -0700)
CL 637939 changed ReverseProxy to report errors encountered when
copying data on an hijacked connection. This is generally not useful,
and when using the default error handler results in WriteHeader
being called on a hijacked connection.

While this is harmless with standard net/http ResponseWriter
implementations, it can confuse middleware layers.

Fixes #72954

Change-Id: I21f3d3d515e114dc5c298d7dbc3796c505d3c82f
Reviewed-on: https://go-review.googlesource.com/c/go/+/659255
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go

index 5d2788073596f0fa0a9bb5493b023668e960600e..8d3e20c302457c6b8377ce5514e2197c24692b59 100644 (file)
@@ -802,9 +802,6 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
        if err == nil {
                err = <-errc
        }
-       if err != nil && err != errCopyDone {
-               p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err))
-       }
 }
 
 var errCopyDone = errors.New("hijacked connection copy complete")
index 1acbc296c3692c90ef6cba06ad556cf499eaacdc..62c93fb261a6e5b5782501517fe08e4eba9b59b0 100644 (file)
@@ -2104,10 +2104,56 @@ func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool,
        }
 }
 
+// Issue #72954: We should not call WriteHeader on a ResponseWriter after hijacking
+// the connection.
+func TestReverseProxyHijackCopyError(t *testing.T) {
+       backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               w.Header().Set("Upgrade", "someproto")
+               w.WriteHeader(http.StatusSwitchingProtocols)
+       }))
+       defer backend.Close()
+       backendURL, err := url.Parse(backend.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       proxyHandler := &ReverseProxy{
+               Rewrite: func(r *ProxyRequest) {
+                       r.SetURL(backendURL)
+               },
+               ModifyResponse: func(resp *http.Response) error {
+                       resp.Body = &testReadWriteCloser{
+                               read: func([]byte) (int, error) {
+                                       return 0, errors.New("read error")
+                               },
+                       }
+                       return nil
+               },
+       }
+
+       hijacked := false
+       rw := &testResponseWriter{
+               writeHeader: func(statusCode int) {
+                       if hijacked {
+                               t.Errorf("WriteHeader(%v) called after Hijack", statusCode)
+                       }
+               },
+               hijack: func() (net.Conn, *bufio.ReadWriter, error) {
+                       hijacked = true
+                       cli, srv := net.Pipe()
+                       go io.Copy(io.Discard, cli)
+                       return srv, bufio.NewReadWriter(bufio.NewReader(srv), bufio.NewWriter(srv)), nil
+               },
+       }
+       req, _ := http.NewRequest("GET", "http://example.tld/", nil)
+       req.Header.Set("Upgrade", "someproto")
+       proxyHandler.ServeHTTP(rw, req)
+}
+
 type testResponseWriter struct {
        h           http.Header
        writeHeader func(int)
        write       func([]byte) (int, error)
+       hijack      func() (net.Conn, *bufio.ReadWriter, error)
 }
 
 func (rw *testResponseWriter) Header() http.Header {
@@ -2129,3 +2175,37 @@ func (rw *testResponseWriter) Write(p []byte) (int, error) {
        }
        return len(p), nil
 }
+
+func (rw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       if rw.hijack != nil {
+               return rw.hijack()
+       }
+       return nil, nil, errors.ErrUnsupported
+}
+
+type testReadWriteCloser struct {
+       read  func([]byte) (int, error)
+       write func([]byte) (int, error)
+       close func() error
+}
+
+func (rc *testReadWriteCloser) Read(p []byte) (int, error) {
+       if rc.read != nil {
+               return rc.read(p)
+       }
+       return 0, io.EOF
+}
+
+func (rc *testReadWriteCloser) Write(p []byte) (int, error) {
+       if rc.write != nil {
+               return rc.write(p)
+       }
+       return len(p), nil
+}
+
+func (rc *testReadWriteCloser) Close() error {
+       if rc.close != nil {
+               return rc.close()
+       }
+       return nil
+}