]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: make Switching Protocol requests (e.g. Websockets) cancelable
authorPierre Carru <pierre.carru@eshard.com>
Sun, 26 Apr 2020 09:11:35 +0000 (09:11 +0000)
committerEmmanuel Odeke <emm.odeke@gmail.com>
Sun, 26 Apr 2020 09:26:10 +0000 (09:26 +0000)
Ensures that a canceled client request for Switching Protocols
(e.g. h2c, Websockets) will cause the underlying connection to
be terminated.

Adds a goroutine in handleUpgradeResponse in order to select on
the incoming client request's context and appropriately cancel it.

Fixes #35559

Change-Id: I1238e18fd4cce457f034f78d9cdce0e7f93b8bf6
GitHub-Last-Rev: 3629c78493f667703ea99f9f4db5e63ddaaa0e6b
GitHub-Pull-Request: golang/go#38021
Reviewed-on: https://go-review.googlesource.com/c/go/+/224897
Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go

index 4d6a085f60ae3f38e0f506a7da38117e57c523e3..eb17bef979812e4dfae4729e214cd819c1c00ada 100644 (file)
@@ -526,7 +526,20 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
                p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
                return
        }
-       defer backConn.Close()
+
+       backConnCloseCh := make(chan bool)
+       go func() {
+               // Ensure that the cancelation of a request closes the backend.
+               // See issue https://golang.org/issue/35559.
+               select {
+               case <-req.Context().Done():
+               case <-backConnCloseCh:
+               }
+               backConn.Close()
+       }()
+
+       defer close(backConnCloseCh)
+
        conn, brw, err := hj.Hijack()
        if err != nil {
                p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
index 08cccb7d92fd5e81c3dd31437a2cbc9e51193ebb..6fb9ba60a9c221219300eec41e0ce22cb2bd0b07 100644 (file)
@@ -1158,6 +1158,128 @@ func TestReverseProxyWebSocket(t *testing.T) {
        }
 }
 
+func TestReverseProxyWebSocketCancelation(t *testing.T) {
+       n := 5
+       triggerCancelCh := make(chan bool, n)
+       nthResponse := func(i int) string {
+               return fmt.Sprintf("backend response #%d\n", i)
+       }
+       terminalMsg := "final message"
+
+       cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               if g, ws := upgradeType(r.Header), "websocket"; g != ws {
+                       t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
+                       http.Error(w, "Unexpected request", 400)
+                       return
+               }
+               conn, bufrw, err := w.(http.Hijacker).Hijack()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               defer conn.Close()
+
+               upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
+               if _, err := io.WriteString(conn, upgradeMsg); err != nil {
+                       t.Error(err)
+                       return
+               }
+               if _, _, err := bufrw.ReadLine(); err != nil {
+                       t.Errorf("Failed to read line from client: %v", err)
+                       return
+               }
+
+               for i := 0; i < n; i++ {
+                       if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
+                               t.Errorf("Writing response #%d failed: %v", i, err)
+                       }
+                       bufrw.Flush()
+                       time.Sleep(time.Second)
+               }
+               if _, err := bufrw.WriteString(terminalMsg); err != nil {
+                       t.Errorf("Failed to write terminal message: %v", err)
+               }
+               bufrw.Flush()
+       }))
+       defer cst.Close()
+
+       backendURL, _ := url.Parse(cst.URL)
+       rproxy := NewSingleHostReverseProxy(backendURL)
+       rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+       rproxy.ModifyResponse = func(res *http.Response) error {
+               res.Header.Add("X-Modified", "true")
+               return nil
+       }
+
+       handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+               rw.Header().Set("X-Header", "X-Value")
+               ctx, cancel := context.WithCancel(req.Context())
+               go func() {
+                       <-triggerCancelCh
+                       cancel()
+               }()
+               rproxy.ServeHTTP(rw, req.WithContext(ctx))
+       })
+
+       frontendProxy := httptest.NewServer(handler)
+       defer frontendProxy.Close()
+
+       req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+       req.Header.Set("Connection", "Upgrade")
+       req.Header.Set("Upgrade", "websocket")
+
+       res, err := frontendProxy.Client().Do(req)
+       if err != nil {
+               t.Fatalf("Dialing to frontend proxy: %v", err)
+       }
+       defer res.Body.Close()
+       if g, w := res.StatusCode, 101; g != w {
+               t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
+       }
+
+       if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
+               t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
+       }
+
+       if g, w := upgradeType(res.Header), "websocket"; g != w {
+               t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
+       }
+
+       rwc, ok := res.Body.(io.ReadWriteCloser)
+       if !ok {
+               t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
+       }
+
+       if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+               t.Errorf("response X-Modified header = %q; want %q", got, want)
+       }
+
+       if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
+               t.Fatalf("Failed to write first message: %v", err)
+       }
+
+       // Read loop.
+
+       br := bufio.NewReader(rwc)
+       for {
+               line, err := br.ReadString('\n')
+               switch {
+               case line == terminalMsg: // this case before "err == io.EOF"
+                       t.Fatalf("The websocket request was not canceled, unfortunately!")
+
+               case err == io.EOF:
+                       return
+
+               case err != nil:
+                       t.Fatalf("Unexpected error: %v", err)
+
+               case line == nthResponse(0): // We've gotten the first response back
+                       // Let's trigger a cancel.
+                       close(triggerCancelCh)
+               }
+       }
+}
+
 func TestUnannouncedTrailer(t *testing.T) {
        backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                w.WriteHeader(http.StatusOK)