From: Brad Fitzpatrick Date: Mon, 29 Oct 2018 23:21:40 +0000 (+0000) Subject: net/http/httputil: make ReverseProxy automatically proxy WebSocket requests X-Git-Tag: go1.12beta1~378 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=ee55f0856a3f1fed5d8c15af54c40e4799c2d32f;p=gostls13.git net/http/httputil: make ReverseProxy automatically proxy WebSocket requests Fixes #26937 Change-Id: I6cdc1bad4cf476cd2ea1462b53444eccd8841e14 Reviewed-on: https://go-review.googlesource.com/c/146437 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Dmitri Shuralyov --- diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index d632954d0c..0ecf38c567 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -436,7 +436,7 @@ var pkgDeps = map[string][]string{ "L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509", "golang_org/x/net/http/httpguts", }, - "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"}, + "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal", "golang_org/x/net/http/httpguts"}, "net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"}, "net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"}, "net/rpc/jsonrpc": {"L4", "NET", "encoding/json", "net/rpc"}, diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index f82d820a43..e9552a2256 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -8,6 +8,7 @@ package httputil import ( "context" + "fmt" "io" "log" "net" @@ -16,6 +17,8 @@ import ( "strings" "sync" "time" + + "golang_org/x/net/http/httpguts" ) // ReverseProxy is an HTTP Handler that takes an incoming request and @@ -199,6 +202,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.Director(outreq) outreq.Close = false + reqUpType := upgradeType(outreq.Header) removeConnectionHeaders(outreq.Header) // Remove hop-by-hop headers to the backend. Especially @@ -221,6 +225,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del(h) } + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space @@ -237,6 +248,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + p.handleUpgradeResponse(rw, outreq, res) + return + } + removeConnectionHeaders(res.Header) for _, h := range hopHeaders { @@ -463,3 +480,67 @@ func (m *maxLatencyWriter) stop() { m.t.Stop() } } + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return strings.ToLower(h.Get("Upgrade")) +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + hj, ok := rw.(http.Hijacker) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index ddae11b168..039273e7c5 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -153,15 +153,20 @@ func TestReverseProxy(t *testing.T) { func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { const fakeConnectionToken = "X-Fake-Connection-Token" const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if c := r.Header.Get(fakeConnectionToken); c != "" { t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) } - if c := r.Header.Get("Upgrade"); c != "" { - t.Errorf("handler got header %q = %q; want empty", "Upgrade", c) + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) } - w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken) - w.Header().Set("Upgrade", "should be deleted") + w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken) + w.Header().Set(someConnHeader, "should be deleted") w.Header().Set(fakeConnectionToken, "should be deleted") io.WriteString(w, backendResponse) })) @@ -173,15 +178,15 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { proxyHandler := NewSingleHostReverseProxy(backendURL) frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxyHandler.ServeHTTP(w, r) - if c := r.Header.Get("Upgrade"); c != "original value" { - t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value") + if c := r.Header.Get(someConnHeader); c != "original value" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value") } })) defer frontend.Close() getReq, _ := http.NewRequest("GET", frontend.URL, nil) - getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) - getReq.Header.Set("Upgrade", "original value") + getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken) + getReq.Header.Set(someConnHeader, "original value") getReq.Header.Set(fakeConnectionToken, "should be deleted") res, err := frontend.Client().Do(getReq) if err != nil { @@ -195,8 +200,8 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { if got, want := string(bodyBytes), backendResponse; got != want { t.Errorf("got body %q; want %q", got, want) } - if c := res.Header.Get("Upgrade"); c != "" { - t.Errorf("handler got header %q = %q; want empty", "Upgrade", c) + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) } if c := res.Header.Get(fakeConnectionToken); c != "" { t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) @@ -980,3 +985,66 @@ func TestSelectFlushInterval(t *testing.T) { }) } } + +func TestReverseProxyWebSocket(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upgradeType(r.Header) != "websocket" { + t.Error("unexpected backend request") + http.Error(w, "unexpected request", 400) + return + } + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer c.Close() + io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") + bs := bufio.NewScanner(c) + if !bs.Scan() { + t.Errorf("backend failed to read line from client: %v", bs.Err()) + return + } + fmt.Fprintf(c, "backend got %q\n", bs.Text()) + })) + defer backendServer.Close() + + backURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(backURL) + rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + c := frontendProxy.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("status = %v; want 101", res.Status) + } + if upgradeType(res.Header) != "websocket" { + t.Fatalf("not websocket upgrade; got %#v", res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) + } + defer rwc.Close() + + io.WriteString(rwc, "Hello\n") + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("Scan: %v", bs.Err()) + } + got := bs.Text() + want := `backend got "Hello"` + if got != want { + t.Errorf("got %#q, want %#q", got, want) + } +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go index c459092cb8..7ef414ba53 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1714,7 +1714,7 @@ func (pc *persistConn) readLoop() { alive = false } - if !hasBody { + if !hasBody || bodyWritable { pc.t.setReqCanceler(rc.req, nil) // Put the idle conn back into the pool before we send the response