]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make Transport send WebSocket upgrade requests over HTTP/1
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 25 Sep 2018 20:59:52 +0000 (20:59 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 2 Oct 2018 23:33:23 +0000 (23:33 +0000)
WebSockets requires HTTP/1 in practice (no spec or implementations
work over HTTP/2), so if we get an HTTP request that looks like it's
trying to initiate WebSockets, use HTTP/1, like browsers do.

This is part of a series of commits to make WebSockets work over
httputil.ReverseProxy. See #26937.

Updates #26937

Change-Id: I6ad3df9b0a21fddf62fa7d9cacef48e7d5d9585b
Reviewed-on: https://go-review.googlesource.com/c/137437
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
src/net/http/clientserver_test.go
src/net/http/export_test.go
src/net/http/proxy_test.go
src/net/http/request.go
src/net/http/transport.go

index 9a05b648e32bee83b83be5e5eec20d1184ef86fe..3e88c64b6fdede4e5526f52570ff0bb11a269646 100644 (file)
@@ -252,7 +252,7 @@ type slurpResult struct {
 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
 
 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
-       if res.Proto == wantProto {
+       if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
                res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
        } else {
                t.Errorf("got %q response; want %q", res.Proto, wantProto)
@@ -1546,3 +1546,25 @@ func TestBidiStreamReverseProxy(t *testing.T) {
        }
 
 }
+
+// Always use HTTP/1.1 for WebSocket upgrades.
+func TestH12_WebSocketUpgrade(t *testing.T) {
+       h12Compare{
+               Handler: func(w ResponseWriter, r *Request) {
+                       h := w.Header()
+                       h.Set("Foo", "bar")
+               },
+               ReqFunc: func(c *Client, url string) (*Response, error) {
+                       req, _ := NewRequest("GET", url, nil)
+                       req.Header.Set("Connection", "Upgrade")
+                       req.Header.Set("Upgrade", "WebSocket")
+                       return c.Do(req)
+               },
+               EarlyCheckResponse: func(proto string, res *Response) {
+                       if res.Proto != "HTTP/1.1" {
+                               t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
+                       }
+                       res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
+               },
+       }.run(t)
+}
index bc0db53a2c600549bdb884899a4626b144f352f6..716e8ecac70a90c3596240e59b2211f388db70a4 100644 (file)
@@ -155,7 +155,7 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string {
 func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
        t.idleMu.Lock()
        defer t.idleMu.Unlock()
-       key := connectMethodKey{"", scheme, addr}
+       key := connectMethodKey{"", scheme, addr, false}
        cacheKey := key.String()
        for k, conns := range t.idleConn {
                if k.String() == cacheKey {
@@ -178,12 +178,12 @@ func (t *Transport) IsIdleForTesting() bool {
 }
 
 func (t *Transport) RequestIdleConnChForTesting() {
-       t.getIdleConnCh(connectMethod{nil, "http", "example.com"})
+       t.getIdleConnCh(connectMethod{nil, "http", "example.com", false})
 }
 
 func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
        c, _ := net.Pipe()
-       key := connectMethodKey{"", scheme, addr}
+       key := connectMethodKey{"", scheme, addr, false}
        select {
        case <-t.incHostConnCount(key):
        default:
index eef0ca82f8c9a75e9af7df37424c3641525983d1..feb7047a58e55f289b913f730ca65aff7a6b3cdc 100644 (file)
@@ -35,7 +35,7 @@ func TestCacheKeys(t *testing.T) {
                        }
                        proxy = u
                }
-               cm := connectMethod{proxy, tt.scheme, tt.addr}
+               cm := connectMethod{proxy, tt.scheme, tt.addr, false}
                if got := cm.key().String(); got != tt.key {
                        t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key)
                }
index ac3302934fe7d658d531df0bab7462ce7cd8fd25..967de7917f7d3eb4998829e8ec5cf70e8674dd8c 100644 (file)
@@ -1371,3 +1371,10 @@ func requestMethodUsuallyLacksBody(method string) bool {
        }
        return false
 }
+
+// requiresHTTP1 reports whether this request requires being sent on
+// an HTTP/1 connection.
+func (r *Request) requiresHTTP1() bool {
+       return hasToken(r.Header.Get("Connection"), "upgrade") &&
+               strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
+}
index e6493036e8cfdca96d359d86a04fa56c8decf23b..b298ec6d7d28c16817c5f06f4ad305fc8ac11093 100644 (file)
@@ -382,6 +382,19 @@ func (tr *transportRequest) setError(err error) {
        tr.mu.Unlock()
 }
 
+// useRegisteredProtocol reports whether an alternate protocol (as reqistered
+// with Transport.RegisterProtocol) should be respected for this request.
+func (t *Transport) useRegisteredProtocol(req *Request) bool {
+       if req.URL.Scheme == "https" && req.requiresHTTP1() {
+               // If this request requires HTTP/1, don't use the
+               // "https" alternate protocol, which is used by the
+               // HTTP/2 code to take over requests if there's an
+               // existing cached HTTP/2 connection.
+               return false
+       }
+       return true
+}
+
 // roundTrip implements a RoundTripper over HTTP.
 func (t *Transport) roundTrip(req *Request) (*Response, error) {
        t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
@@ -411,10 +424,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                }
        }
 
-       altProto, _ := t.altProto.Load().(map[string]RoundTripper)
-       if altRT := altProto[scheme]; altRT != nil {
-               if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
-                       return resp, err
+       if t.useRegisteredProtocol(req) {
+               altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+               if altRT := altProto[scheme]; altRT != nil {
+                       if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
+                               return resp, err
+                       }
                }
        }
        if !isHTTP {
@@ -653,6 +668,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM
                        }
                }
        }
+       cm.onlyH1 = treq.requiresHTTP1()
        return cm, err
 }
 
@@ -1155,6 +1171,9 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
        if cfg.ServerName == "" {
                cfg.ServerName = name
        }
+       if pconn.cacheKey.onlyH1 {
+               cfg.NextProtos = nil
+       }
        plainConn := pconn.conn
        tlsConn := tls.Client(plainConn, cfg)
        errc := make(chan error, 2)
@@ -1361,10 +1380,11 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
 //
 // A connect method may be of the following types:
 //
-//     Cache key form                    Description
-//     -----------------                 -------------------------
+//     connectMethod.key().String()      Description
+//     ------------------------------    -------------------------
 //     |http|foo.com                     http directly to server, no proxy
 //     |https|foo.com                    https directly to server, no proxy
+//     |https,h1|foo.com                 https directly to server w/o HTTP/2, no proxy
 //     http://proxy.com|https|foo.com    http to proxy, then CONNECT to foo.com
 //     http://proxy.com|http             http to proxy, http to anywhere after that
 //     socks5://proxy.com|http|foo.com   socks5 to proxy, then http to foo.com
@@ -1379,6 +1399,7 @@ type connectMethod struct {
        // then targetAddr is not included in the connect method key, because the socket can
        // be reused for different targetAddr values.
        targetAddr string
+       onlyH1     bool // whether to disable HTTP/2 and force HTTP/1
 }
 
 func (cm *connectMethod) key() connectMethodKey {
@@ -1394,6 +1415,7 @@ func (cm *connectMethod) key() connectMethodKey {
                proxy:  proxyStr,
                scheme: cm.targetScheme,
                addr:   targetAddr,
+               onlyH1: cm.onlyH1,
        }
 }
 
@@ -1428,11 +1450,16 @@ func (cm *connectMethod) tlsHost() string {
 // a URL.
 type connectMethodKey struct {
        proxy, scheme, addr string
+       onlyH1              bool
 }
 
 func (k connectMethodKey) String() string {
        // Only used by tests.
-       return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr)
+       var h1 string
+       if k.onlyH1 {
+               h1 = ",h1"
+       }
+       return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr)
 }
 
 // persistConn wraps a connection, usually a persistent one