]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: make ReverseProxy automatically proxy WebSocket requests
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 29 Oct 2018 23:21:40 +0000 (23:21 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 13 Nov 2018 01:42:47 +0000 (01:42 +0000)
Fixes #26937

Change-Id: I6cdc1bad4cf476cd2ea1462b53444eccd8841e14
Reviewed-on: https://go-review.googlesource.com/c/146437
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
src/go/build/deps_test.go
src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go
src/net/http/transport.go

index d632954d0c4b3f2dd020f703c34a023265f134d2..0ecf38c567df10ef44cad8a93e03f337aff1d004 100644 (file)
@@ -436,7 +436,7 @@ var pkgDeps = map[string][]string{
                "L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509",
                "golang_org/x/net/http/httpguts",
        },
-       "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
+       "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal", "golang_org/x/net/http/httpguts"},
        "net/http/pprof":    {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
        "net/rpc":           {"L4", "NET", "encoding/gob", "html/template", "net/http"},
        "net/rpc/jsonrpc":   {"L4", "NET", "encoding/json", "net/rpc"},
index f82d820a43007d614feec03398f06d25c49b9f9a..e9552a225641410a442d753487ce919d7d7fcc52 100644 (file)
@@ -8,6 +8,7 @@ package httputil
 
 import (
        "context"
+       "fmt"
        "io"
        "log"
        "net"
@@ -16,6 +17,8 @@ import (
        "strings"
        "sync"
        "time"
+
+       "golang_org/x/net/http/httpguts"
 )
 
 // ReverseProxy is an HTTP Handler that takes an incoming request and
@@ -199,6 +202,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
        p.Director(outreq)
        outreq.Close = false
 
+       reqUpType := upgradeType(outreq.Header)
        removeConnectionHeaders(outreq.Header)
 
        // Remove hop-by-hop headers to the backend. Especially
@@ -221,6 +225,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
                outreq.Header.Del(h)
        }
 
+       // After stripping all the hop-by-hop connection headers above, add back any
+       // necessary for protocol upgrades, such as for websockets.
+       if reqUpType != "" {
+               outreq.Header.Set("Connection", "Upgrade")
+               outreq.Header.Set("Upgrade", reqUpType)
+       }
+
        if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
                // If we aren't the first proxy retain prior
                // X-Forwarded-For information as a comma+space
@@ -237,6 +248,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
                return
        }
 
+       // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+       if res.StatusCode == http.StatusSwitchingProtocols {
+               p.handleUpgradeResponse(rw, outreq, res)
+               return
+       }
+
        removeConnectionHeaders(res.Header)
 
        for _, h := range hopHeaders {
@@ -463,3 +480,67 @@ func (m *maxLatencyWriter) stop() {
                m.t.Stop()
        }
 }
+
+func upgradeType(h http.Header) string {
+       if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
+               return ""
+       }
+       return strings.ToLower(h.Get("Upgrade"))
+}
+
+func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+       reqUpType := upgradeType(req.Header)
+       resUpType := upgradeType(res.Header)
+       if reqUpType != resUpType {
+               p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+               return
+       }
+       hj, ok := rw.(http.Hijacker)
+       if !ok {
+               p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+               return
+       }
+       backConn, ok := res.Body.(io.ReadWriteCloser)
+       if !ok {
+               p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+               return
+       }
+       defer backConn.Close()
+       conn, brw, err := hj.Hijack()
+       if err != nil {
+               p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+               return
+       }
+       defer conn.Close()
+       res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+       if err := res.Write(brw); err != nil {
+               p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+               return
+       }
+       if err := brw.Flush(); err != nil {
+               p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+               return
+       }
+       errc := make(chan error, 1)
+       spc := switchProtocolCopier{user: conn, backend: backConn}
+       go spc.copyToBackend(errc)
+       go spc.copyFromBackend(errc)
+       <-errc
+       return
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+       user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+       _, err := io.Copy(c.user, c.backend)
+       errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+       _, err := io.Copy(c.backend, c.user)
+       errc <- err
+}
index ddae11b168ce7da8eb20f704c02fd26b0eaf07e5..039273e7c581ec4bc808629f674ea69add44bf98 100644 (file)
@@ -153,15 +153,20 @@ func TestReverseProxy(t *testing.T) {
 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
        const fakeConnectionToken = "X-Fake-Connection-Token"
        const backendResponse = "I am the backend"
+
+       // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+       // in the Request's Connection header.
+       const someConnHeader = "X-Some-Conn-Header"
+
        backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                if c := r.Header.Get(fakeConnectionToken); c != "" {
                        t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
                }
-               if c := r.Header.Get("Upgrade"); c != "" {
-                       t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
+               if c := r.Header.Get(someConnHeader); c != "" {
+                       t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
                }
-               w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
-               w.Header().Set("Upgrade", "should be deleted")
+               w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken)
+               w.Header().Set(someConnHeader, "should be deleted")
                w.Header().Set(fakeConnectionToken, "should be deleted")
                io.WriteString(w, backendResponse)
        }))
@@ -173,15 +178,15 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
        proxyHandler := NewSingleHostReverseProxy(backendURL)
        frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                proxyHandler.ServeHTTP(w, r)
-               if c := r.Header.Get("Upgrade"); c != "original value" {
-                       t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
+               if c := r.Header.Get(someConnHeader); c != "original value" {
+                       t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value")
                }
        }))
        defer frontend.Close()
 
        getReq, _ := http.NewRequest("GET", frontend.URL, nil)
-       getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
-       getReq.Header.Set("Upgrade", "original value")
+       getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken)
+       getReq.Header.Set(someConnHeader, "original value")
        getReq.Header.Set(fakeConnectionToken, "should be deleted")
        res, err := frontend.Client().Do(getReq)
        if err != nil {
@@ -195,8 +200,8 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
        if got, want := string(bodyBytes), backendResponse; got != want {
                t.Errorf("got body %q; want %q", got, want)
        }
-       if c := res.Header.Get("Upgrade"); c != "" {
-               t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
+       if c := res.Header.Get(someConnHeader); c != "" {
+               t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
        }
        if c := res.Header.Get(fakeConnectionToken); c != "" {
                t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
@@ -980,3 +985,66 @@ func TestSelectFlushInterval(t *testing.T) {
                })
        }
 }
+
+func TestReverseProxyWebSocket(t *testing.T) {
+       backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               if upgradeType(r.Header) != "websocket" {
+                       t.Error("unexpected backend request")
+                       http.Error(w, "unexpected request", 400)
+                       return
+               }
+               c, _, err := w.(http.Hijacker).Hijack()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               defer c.Close()
+               io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
+               bs := bufio.NewScanner(c)
+               if !bs.Scan() {
+                       t.Errorf("backend failed to read line from client: %v", bs.Err())
+                       return
+               }
+               fmt.Fprintf(c, "backend got %q\n", bs.Text())
+       }))
+       defer backendServer.Close()
+
+       backURL, _ := url.Parse(backendServer.URL)
+       rproxy := NewSingleHostReverseProxy(backURL)
+       rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+
+       frontendProxy := httptest.NewServer(rproxy)
+       defer frontendProxy.Close()
+
+       req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+       req.Header.Set("Connection", "Upgrade")
+       req.Header.Set("Upgrade", "websocket")
+
+       c := frontendProxy.Client()
+       res, err := c.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if res.StatusCode != 101 {
+               t.Fatalf("status = %v; want 101", res.Status)
+       }
+       if upgradeType(res.Header) != "websocket" {
+               t.Fatalf("not websocket upgrade; got %#v", res.Header)
+       }
+       rwc, ok := res.Body.(io.ReadWriteCloser)
+       if !ok {
+               t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
+       }
+       defer rwc.Close()
+
+       io.WriteString(rwc, "Hello\n")
+       bs := bufio.NewScanner(rwc)
+       if !bs.Scan() {
+               t.Fatalf("Scan: %v", bs.Err())
+       }
+       got := bs.Text()
+       want := `backend got "Hello"`
+       if got != want {
+               t.Errorf("got %#q, want %#q", got, want)
+       }
+}
index c459092cb89872d787794dfbcc54c8ecfb8c5074..7ef414ba53a2f52928b0df79f3024c4c54aaa02c 100644 (file)
@@ -1714,7 +1714,7 @@ func (pc *persistConn) readLoop() {
                        alive = false
                }
 
-               if !hasBody {
+               if !hasBody || bodyWritable {
                        pc.t.setReqCanceler(rc.req, nil)
 
                        // Put the idle conn back into the pool before we send the response