"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509",
"golang_org/x/net/http/httpguts",
},
- "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
+ "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal", "golang_org/x/net/http/httpguts"},
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
"net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"},
"net/rpc/jsonrpc": {"L4", "NET", "encoding/json", "net/rpc"},
import (
"context"
+ "fmt"
"io"
"log"
"net"
"strings"
"sync"
"time"
+
+ "golang_org/x/net/http/httpguts"
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
p.Director(outreq)
outreq.Close = false
+ reqUpType := upgradeType(outreq.Header)
removeConnectionHeaders(outreq.Header)
// Remove hop-by-hop headers to the backend. Especially
outreq.Header.Del(h)
}
+ // After stripping all the hop-by-hop connection headers above, add back any
+ // necessary for protocol upgrades, such as for websockets.
+ if reqUpType != "" {
+ outreq.Header.Set("Connection", "Upgrade")
+ outreq.Header.Set("Upgrade", reqUpType)
+ }
+
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
return
}
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ p.handleUpgradeResponse(rw, outreq, res)
+ return
+ }
+
removeConnectionHeaders(res.Header)
for _, h := range hopHeaders {
m.t.Stop()
}
}
+
+func upgradeType(h http.Header) string {
+ if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
+ return ""
+ }
+ return strings.ToLower(h.Get("Upgrade"))
+}
+
+func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if reqUpType != resUpType {
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+ hj, ok := rw.(http.Hijacker)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+ defer backConn.Close()
+ conn, brw, err := hj.Hijack()
+ if err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+ return
+ }
+ defer conn.Close()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+ return
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}
func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
const fakeConnectionToken = "X-Fake-Connection-Token"
const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if c := r.Header.Get(fakeConnectionToken); c != "" {
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
}
- if c := r.Header.Get("Upgrade"); c != "" {
- t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
}
- w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
- w.Header().Set("Upgrade", "should be deleted")
+ w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken)
+ w.Header().Set(someConnHeader, "should be deleted")
w.Header().Set(fakeConnectionToken, "should be deleted")
io.WriteString(w, backendResponse)
}))
proxyHandler := NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyHandler.ServeHTTP(w, r)
- if c := r.Header.Get("Upgrade"); c != "original value" {
- t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
+ if c := r.Header.Get(someConnHeader); c != "original value" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value")
}
}))
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
- getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
- getReq.Header.Set("Upgrade", "original value")
+ getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken)
+ getReq.Header.Set(someConnHeader, "original value")
getReq.Header.Set(fakeConnectionToken, "should be deleted")
res, err := frontend.Client().Do(getReq)
if err != nil {
if got, want := string(bodyBytes), backendResponse; got != want {
t.Errorf("got body %q; want %q", got, want)
}
- if c := res.Header.Get("Upgrade"); c != "" {
- t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
}
if c := res.Header.Get(fakeConnectionToken); c != "" {
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
})
}
}
+
+func TestReverseProxyWebSocket(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if upgradeType(r.Header) != "websocket" {
+ t.Error("unexpected backend request")
+ http.Error(w, "unexpected request", 400)
+ return
+ }
+ c, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+ io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
+ bs := bufio.NewScanner(c)
+ if !bs.Scan() {
+ t.Errorf("backend failed to read line from client: %v", bs.Err())
+ return
+ }
+ fmt.Fprintf(c, "backend got %q\n", bs.Text())
+ }))
+ defer backendServer.Close()
+
+ backURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(backURL)
+ rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ c := frontendProxy.Client()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("status = %v; want 101", res.Status)
+ }
+ if upgradeType(res.Header) != "websocket" {
+ t.Fatalf("not websocket upgrade; got %#v", res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
+ }
+ defer rwc.Close()
+
+ io.WriteString(rwc, "Hello\n")
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("Scan: %v", bs.Err())
+ }
+ got := bs.Text()
+ want := `backend got "Hello"`
+ if got != want {
+ t.Errorf("got %#q, want %#q", got, want)
+ }
+}