go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
- // wait until both copy functions have sent on the error channel
+ // Wait until both copy functions have sent on the error channel,
+ // or until one fails.
err := <-errc
if err == nil {
err = <-errc
}
- if err != nil {
+ if err != nil && err != errCopyDone {
p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err))
}
}
+var errCopyDone = errors.New("hijacked connection copy complete")
+
// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
return
}
- errc <- nil
+ errc <- errCopyDone
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
return
}
- errc <- nil
+ errc <- errCopyDone
}
func cleanQueryParams(s string) string {
}
}
+func TestReverseProxyUpgradeNoCloseWrite(t *testing.T) {
+ // The backend hijacks the connection,
+ // reads all data from the client,
+ // and returns.
+ backendDone := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Connection", "upgrade")
+ w.Header().Set("Upgrade", "u")
+ w.WriteHeader(101)
+ conn, _, err := http.NewResponseController(w).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ }
+ io.Copy(io.Discard, conn)
+ close(backendDone)
+ }))
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // The proxy includes a ModifyResponse function which replaces the response body
+ // with its own wrapper, dropping the original body's CloseWrite method.
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ModifyResponse = func(resp *http.Response) error {
+ type readWriteCloserOnly struct {
+ io.ReadWriteCloser
+ }
+ resp.Body = readWriteCloserOnly{resp.Body.(io.ReadWriteCloser)}
+ return nil
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ // The client sends a request and closes the connection.
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Upgrade", "u")
+ resp, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+
+ // We expect that the client's closure of the connection is propagated to the backend.
+ <-backendDone
+}
+
func TestUnannouncedTrailer(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)