From f5cd3868d52babd106e0509a67295690246a5252 Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Thu, 5 Oct 2017 14:07:55 -0400 Subject: [PATCH] net/http: HTTPS proxies support net/http already supports http proxies. This CL allows it to establish a connection to the http proxy over https. See more at: https://www.chromium.org/developers/design-documents/secure-web-proxy Fixes golang/go#11332 Change-Id: If0e017df0e8f8c2c499a2ddcbbeb625c8fa2bb6b Reviewed-on: https://go-review.googlesource.com/68550 Run-TryBot: Tom Bergan TryBot-Result: Gobot Gobot Reviewed-by: Tom Bergan --- src/net/http/transport.go | 160 ++++++++++++++-------- src/net/http/transport_test.go | 243 +++++++++++++++++++++++++-------- 2 files changed, 283 insertions(+), 120 deletions(-) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 5f2ace7b4b..258b912b0a 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -618,11 +618,6 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if port := cm.proxyURL.Port(); !validPort(port) { return cm, fmt.Errorf("invalid proxy URL port %q", port) } - switch cm.proxyURL.Scheme { - case "http", "socks5": - default: - return cm, fmt.Errorf("invalid proxy URL scheme %q", cm.proxyURL.Scheme) - } } } return cm, err @@ -1021,6 +1016,69 @@ func (d oneConnDialer) Dial(network, addr string) (net.Conn, error) { } } +// The connect method and the transport can both specify a TLS +// Host name. The transport's name takes precedence if present. +func chooseTLSHost(cm connectMethod, t *Transport) string { + tlsHost := "" + if t.TLSClientConfig != nil { + tlsHost = t.TLSClientConfig.ServerName + } + if tlsHost == "" { + tlsHost = cm.tlsHost() + } + return tlsHost +} + +// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS +// tunnel, this function establishes a nested TLS session inside the encrypted channel. +// The remote endpoint's name may be overridden by TLSClientConfig.ServerName. +func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error { + // Initiate TLS and check remote host name against certificate. + cfg := cloneTLSConfig(pconn.t.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = name + } + plainConn := pconn.conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := pconn.t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + plainConn.Close() + return err + } + } + cs := tlsConn.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } + pconn.tlsState = &cs + pconn.conn = tlsConn + return nil +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) { pconn := &persistConn{ t: t, @@ -1032,15 +1090,21 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon writeLoopDone: make(chan struct{}), } trace := httptrace.ContextClientTrace(ctx) - tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil - if tlsDial { + wrapErr := func(err error) error { + if cm.proxyURL != nil { + // Return a typed error, per Issue 16997 + return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} + } + return err + } + if cm.scheme() == "https" && t.DialTLS != nil { var err error pconn.conn, err = t.DialTLS("tcp", cm.addr()) if err != nil { - return nil, err + return nil, wrapErr(err) } if pconn.conn == nil { - return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)") + return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)")) } if tc, ok := pconn.conn.(*tls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below @@ -1064,13 +1128,18 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } else { conn, err := t.dial(ctx, "tcp", cm.addr()) if err != nil { - if cm.proxyURL != nil { - // Return a typed error, per Issue 16997: - err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} - } - return nil, err + return nil, wrapErr(err) } pconn.conn = conn + if cm.scheme() == "https" { + var firstTLSHost string + if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { + return nil, wrapErr(err) + } + if err = pconn.addTLS(firstTLSHost, trace); err != nil { + return nil, wrapErr(err) + } + } } // Proxy setup. @@ -1134,50 +1203,10 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } } - if cm.targetScheme == "https" && !tlsDial { - // Initiate TLS and check remote host name against certificate. - cfg := cloneTLSConfig(t.TLSClientConfig) - if cfg.ServerName == "" { - cfg.ServerName = cm.tlsHost() - } - plainConn := pconn.conn - tlsConn := tls.Client(plainConn, cfg) - errc := make(chan error, 2) - var timer *time.Timer // for canceling TLS handshake - if d := t.TLSHandshakeTimeout; d != 0 { - timer = time.AfterFunc(d, func() { - errc <- tlsHandshakeTimeoutError{} - }) - } - go func() { - if trace != nil && trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := tlsConn.Handshake() - if timer != nil { - timer.Stop() - } - errc <- err - }() - if err := <-errc; err != nil { - plainConn.Close() - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tls.ConnectionState{}, err) - } + if cm.proxyURL != nil && cm.targetScheme == "https" { + if err := pconn.addTLS(cm.tlsHost(), trace); err != nil { return nil, err } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - plainConn.Close() - return nil, err - } - } - cs := tlsConn.ConnectionState() - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(cs, nil) - } - pconn.tlsState = &cs - pconn.conn = tlsConn } if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { @@ -1279,13 +1308,16 @@ func useProxy(addr string) bool { // http://proxy.com|http http to proxy, http to anywhere after that // socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com // socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com -// -// Note: no support to https to the proxy yet. +// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com +// https://proxy.com|http https to proxy, http to anywhere after that // type connectMethod struct { proxyURL *url.URL // nil for no proxy, else full proxy URL targetScheme string // "http" or "https" - targetAddr string // Not used if http proxy + http targetScheme (4th example in table) + // If proxyURL specifies an http or https proxy, and targetScheme is http (not https), + // then targetAddr is not included in the connect method key, because the socket can + // be reused for different targetAddr values. + targetAddr string } func (cm *connectMethod) key() connectMethodKey { @@ -1293,7 +1325,7 @@ func (cm *connectMethod) key() connectMethodKey { targetAddr := cm.targetAddr if cm.proxyURL != nil { proxyStr = cm.proxyURL.String() - if strings.HasPrefix(cm.proxyURL.Scheme, "http") && cm.targetScheme == "http" { + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { targetAddr = "" } } @@ -1304,6 +1336,14 @@ func (cm *connectMethod) key() connectMethodKey { } } +// scheme returns the first hop scheme: http, https, or socks5 +func (cm *connectMethod) scheme() string { + if cm.proxyURL != nil { + return cm.proxyURL.Scheme + } + return cm.targetScheme +} + // addr returns the first hop "host:port" to which we need to TCP connect. func (cm *connectMethod) addr() string { if cm.proxyURL != nil { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 39b5cd358f..ad63cca5fe 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -961,14 +961,10 @@ func TestTransportExpect100Continue(t *testing.T) { func TestSocks5Proxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ch <- "real server" - })) - defer ts.Close() l := newLocalListener(t) defer l.Close() - go func() { - defer close(ch) + defer close(ch) + proxy := func(t *testing.T) { s, err := l.Accept() if err != nil { t.Errorf("socks5 proxy Accept(): %v", err) @@ -1003,7 +999,8 @@ func TestSocks5Proxy(t *testing.T) { case 4: ipLen = 16 default: - t.Fatalf("socks5 proxy second read: unexpected address type %v", buf[4]) + t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) + return } if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { t.Errorf("socks5 proxy address read: %v", err) @@ -1016,71 +1013,197 @@ func TestSocks5Proxy(t *testing.T) { t.Errorf("socks5 proxy connect write: %v", err) return } - done := make(chan struct{}) - srv := &Server{Handler: HandlerFunc(func(w ResponseWriter, r *Request) { - done <- struct{}{} - })} - srv.Serve(&oneConnListener{conn: s}) - <-done - srv.Shutdown(context.Background()) ch <- fmt.Sprintf("proxy for %s:%d", ip, port) - }() - pu, err := url.Parse("socks5://" + l.Addr().String()) - if err != nil { - t.Fatal(err) - } - c := ts.Client() - c.Transport.(*Transport).Proxy = ProxyURL(pu) - if _, err := c.Head(ts.URL); err != nil { - t.Error(err) - } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to socks5 proxy") + // Implement proxying. + targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) + targetConn, err := net.Dial("tcp", targetHost) + if err != nil { + t.Errorf("net.Dial failed") + return + } + go io.Copy(targetConn, s) + io.Copy(s, targetConn) // Wait for the client to close the socket. + targetConn.Close() } - tsu, err := url.Parse(ts.URL) + + pu, err := url.Parse("socks5://" + l.Addr().String()) if err != nil { t.Fatal(err) } - want := "proxy for " + tsu.Host - if got != want { - t.Errorf("got %q, want %q", got, want) + + sentinelHeader := "X-Sentinel" + sentinelValue := "12345" + h := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set(sentinelHeader, sentinelValue) + }) + for _, useTLS := range []bool{false, true} { + t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { + var ts *httptest.Server + if useTLS { + ts = httptest.NewTLSServer(h) + } else { + ts = httptest.NewServer(h) + } + go proxy(t) + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) + r, err := c.Head(ts.URL) + if err != nil { + t.Fatal(err) + } + if r.Header.Get(sentinelHeader) != sentinelValue { + t.Errorf("Failed to retrieve sentinel value") + } + var got string + select { + case got = <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to socks5 proxy") + } + ts.Close() + tsu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + want := "proxy for " + tsu.Host + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) } } func TestTransportProxy(t *testing.T) { defer afterTest(t) - ch := make(chan string, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ch <- "real server" - })) - defer ts.Close() - proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ch <- "proxy for " + r.URL.String() - })) - defer proxy.Close() + testCases := []struct{ httpsSite, httpsProxy bool }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + for _, testCase := range testCases { + httpsSite := testCase.httpsSite + httpsProxy := testCase.httpsProxy + t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { + siteCh := make(chan *Request, 1) + h1 := HandlerFunc(func(w ResponseWriter, r *Request) { + siteCh <- r + }) + proxyCh := make(chan *Request, 1) + h2 := HandlerFunc(func(w ResponseWriter, r *Request) { + proxyCh <- r + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + hijacker, ok := w.(Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &Response{ + StatusCode: StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + } + + log.Printf("Dialing %s", r.URL.Host) + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial failed") + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed") + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + var ts *httptest.Server + if httpsSite { + ts = httptest.NewTLSServer(h1) + } else { + ts = httptest.NewServer(h1) + } + var proxy *httptest.Server + if httpsProxy { + proxy = httptest.NewTLSServer(h2) + } else { + proxy = httptest.NewServer(h2) + } - pu, err := url.Parse(proxy.URL) - if err != nil { - t.Fatal(err) - } - c := ts.Client() - c.Transport.(*Transport).Proxy = ProxyURL(pu) - if _, err := c.Head(ts.URL); err != nil { - t.Error(err) - } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to http proxy") - } - want := "proxy for " + ts.URL + "/" - if got != want { - t.Errorf("got %q, want %q", got, want) + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + // If neither server is HTTPS or both are, then c may be derived from either. + // If only one server is HTTPS, c must be derived from that server in order + // to ensure that it is configured to use the fake root CA from testcert.go. + c := proxy.Client() + if httpsSite { + c = ts.Client() + } + + c.Transport.(*Transport).Proxy = ProxyURL(pu) + if _, err := c.Head(ts.URL); err != nil { + t.Error(err) + } + var got *Request + select { + case got = <-proxyCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to http proxy") + } + c.Transport.(*Transport).CloseIdleConnections() + ts.Close() + proxy.Close() + if httpsSite { + // First message should be a CONNECT, asking for a socket to the real server, + if got.Method != "CONNECT" { + t.Errorf("Wrong method for secure proxying: %q", got.Method) + } + gotHost := got.URL.Host + pu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal("Invalid site URL") + } + if wantHost := pu.Host; gotHost != wantHost { + t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) + } + + // The next message on the channel should be from the site's server. + next := <-siteCh + if next.Method != "HEAD" { + t.Errorf("Wrong method at destination: %s", next.Method) + } + if nextURL := next.URL.String(); nextURL != "/" { + t.Errorf("Wrong URL at destination: %s", nextURL) + } + } else { + if got.Method != "HEAD" { + t.Errorf("Wrong method for destination: %q", got.Method) + } + gotURL := got.URL.String() + wantURL := ts.URL + "/" + if gotURL != wantURL { + t.Errorf("Got URL %q, want %q", gotURL, wantURL) + } + } + }) } } -- 2.48.1