From: Brad Fitzpatrick Date: Tue, 25 Sep 2018 20:59:52 +0000 (+0000) Subject: net/http: make Transport send WebSocket upgrade requests over HTTP/1 X-Git-Tag: go1.12beta1~934 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=a73d8f5a86185aeb39e398d0226d56be7d9247ca;p=gostls13.git net/http: make Transport send WebSocket upgrade requests over HTTP/1 WebSockets requires HTTP/1 in practice (no spec or implementations work over HTTP/2), so if we get an HTTP request that looks like it's trying to initiate WebSockets, use HTTP/1, like browsers do. This is part of a series of commits to make WebSockets work over httputil.ReverseProxy. See #26937. Updates #26937 Change-Id: I6ad3df9b0a21fddf62fa7d9cacef48e7d5d9585b Reviewed-on: https://go-review.googlesource.com/c/137437 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Dmitri Shuralyov --- diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 9a05b648e3..3e88c64b6f 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -252,7 +252,7 @@ type slurpResult struct { func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { - if res.Proto == wantProto { + if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 } else { t.Errorf("got %q response; want %q", res.Proto, wantProto) @@ -1546,3 +1546,25 @@ func TestBidiStreamReverseProxy(t *testing.T) { } } + +// Always use HTTP/1.1 for WebSocket upgrades. +func TestH12_WebSocketUpgrade(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + h := w.Header() + h.Set("Foo", "bar") + }, + ReqFunc: func(c *Client, url string) (*Response, error) { + req, _ := NewRequest("GET", url, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "WebSocket") + return c.Do(req) + }, + EarlyCheckResponse: func(proto string, res *Response) { + if res.Proto != "HTTP/1.1" { + t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) + } + res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 + }, + }.run(t) +} diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index bc0db53a2c..716e8ecac7 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -155,7 +155,7 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string { func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { t.idleMu.Lock() defer t.idleMu.Unlock() - key := connectMethodKey{"", scheme, addr} + key := connectMethodKey{"", scheme, addr, false} cacheKey := key.String() for k, conns := range t.idleConn { if k.String() == cacheKey { @@ -178,12 +178,12 @@ func (t *Transport) IsIdleForTesting() bool { } func (t *Transport) RequestIdleConnChForTesting() { - t.getIdleConnCh(connectMethod{nil, "http", "example.com"}) + t.getIdleConnCh(connectMethod{nil, "http", "example.com", false}) } func (t *Transport) PutIdleTestConn(scheme, addr string) bool { c, _ := net.Pipe() - key := connectMethodKey{"", scheme, addr} + key := connectMethodKey{"", scheme, addr, false} select { case <-t.incHostConnCount(key): default: diff --git a/src/net/http/proxy_test.go b/src/net/http/proxy_test.go index eef0ca82f8..feb7047a58 100644 --- a/src/net/http/proxy_test.go +++ b/src/net/http/proxy_test.go @@ -35,7 +35,7 @@ func TestCacheKeys(t *testing.T) { } proxy = u } - cm := connectMethod{proxy, tt.scheme, tt.addr} + cm := connectMethod{proxy, tt.scheme, tt.addr, false} if got := cm.key().String(); got != tt.key { t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key) } diff --git a/src/net/http/request.go b/src/net/http/request.go index ac3302934f..967de7917f 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -1371,3 +1371,10 @@ func requestMethodUsuallyLacksBody(method string) bool { } return false } + +// requiresHTTP1 reports whether this request requires being sent on +// an HTTP/1 connection. +func (r *Request) requiresHTTP1() bool { + return hasToken(r.Header.Get("Connection"), "upgrade") && + strings.EqualFold(r.Header.Get("Upgrade"), "websocket") +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go index e6493036e8..b298ec6d7d 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -382,6 +382,19 @@ func (tr *transportRequest) setError(err error) { tr.mu.Unlock() } +// useRegisteredProtocol reports whether an alternate protocol (as reqistered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} + // roundTrip implements a RoundTripper over HTTP. func (t *Transport) roundTrip(req *Request) (*Response, error) { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) @@ -411,10 +424,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } } - altProto, _ := t.altProto.Load().(map[string]RoundTripper) - if altRT := altProto[scheme]; altRT != nil { - if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { - return resp, err + if t.useRegisteredProtocol(req) { + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + if altRT := altProto[scheme]; altRT != nil { + if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { + return resp, err + } } } if !isHTTP { @@ -653,6 +668,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM } } } + cm.onlyH1 = treq.requiresHTTP1() return cm, err } @@ -1155,6 +1171,9 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro if cfg.ServerName == "" { cfg.ServerName = name } + if pconn.cacheKey.onlyH1 { + cfg.NextProtos = nil + } plainConn := pconn.conn tlsConn := tls.Client(plainConn, cfg) errc := make(chan error, 2) @@ -1361,10 +1380,11 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) { // // A connect method may be of the following types: // -// Cache key form Description -// ----------------- ------------------------- +// connectMethod.key().String() Description +// ------------------------------ ------------------------- // |http|foo.com http directly to server, no proxy // |https|foo.com https directly to server, no proxy +// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy // http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com // http://proxy.com|http http to proxy, http to anywhere after that // socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com @@ -1379,6 +1399,7 @@ type connectMethod struct { // then targetAddr is not included in the connect method key, because the socket can // be reused for different targetAddr values. targetAddr string + onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } func (cm *connectMethod) key() connectMethodKey { @@ -1394,6 +1415,7 @@ func (cm *connectMethod) key() connectMethodKey { proxy: proxyStr, scheme: cm.targetScheme, addr: targetAddr, + onlyH1: cm.onlyH1, } } @@ -1428,11 +1450,16 @@ func (cm *connectMethod) tlsHost() string { // a URL. type connectMethodKey struct { proxy, scheme, addr string + onlyH1 bool } func (k connectMethodKey) String() string { // Only used by tests. - return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr) + var h1 string + if k.onlyH1 { + h1 = ",h1" + } + return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr) } // persistConn wraps a connection, usually a persistent one