]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: HTTPS proxies support
authorBen Schwartz <bemasc@google.com>
Thu, 5 Oct 2017 18:07:55 +0000 (14:07 -0400)
committerTom Bergan <tombergan@google.com>
Fri, 13 Oct 2017 18:20:45 +0000 (18:20 +0000)
net/http already supports http proxies. This CL allows it to establish
a connection to the http proxy over https. See more at:
https://www.chromium.org/developers/design-documents/secure-web-proxy

Fixes golang/go#11332

Change-Id: If0e017df0e8f8c2c499a2ddcbbeb625c8fa2bb6b
Reviewed-on: https://go-review.googlesource.com/68550
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Tom Bergan <tombergan@google.com>
src/net/http/transport.go
src/net/http/transport_test.go

index 5f2ace7b4b4fcf703f00c10a987720f35f0592ea..258b912b0a536bf4724bdccff9821db3917b27ce 100644 (file)
@@ -618,11 +618,6 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM
                        if port := cm.proxyURL.Port(); !validPort(port) {
                                return cm, fmt.Errorf("invalid proxy URL port %q", port)
                        }
-                       switch cm.proxyURL.Scheme {
-                       case "http", "socks5":
-                       default:
-                               return cm, fmt.Errorf("invalid proxy URL scheme %q", cm.proxyURL.Scheme)
-                       }
                }
        }
        return cm, err
@@ -1021,6 +1016,69 @@ func (d oneConnDialer) Dial(network, addr string) (net.Conn, error) {
        }
 }
 
+// The connect method and the transport can both specify a TLS
+// Host name.  The transport's name takes precedence if present.
+func chooseTLSHost(cm connectMethod, t *Transport) string {
+       tlsHost := ""
+       if t.TLSClientConfig != nil {
+               tlsHost = t.TLSClientConfig.ServerName
+       }
+       if tlsHost == "" {
+               tlsHost = cm.tlsHost()
+       }
+       return tlsHost
+}
+
+// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
+// tunnel, this function establishes a nested TLS session inside the encrypted channel.
+// The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
+func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error {
+       // Initiate TLS and check remote host name against certificate.
+       cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
+       if cfg.ServerName == "" {
+               cfg.ServerName = name
+       }
+       plainConn := pconn.conn
+       tlsConn := tls.Client(plainConn, cfg)
+       errc := make(chan error, 2)
+       var timer *time.Timer // for canceling TLS handshake
+       if d := pconn.t.TLSHandshakeTimeout; d != 0 {
+               timer = time.AfterFunc(d, func() {
+                       errc <- tlsHandshakeTimeoutError{}
+               })
+       }
+       go func() {
+               if trace != nil && trace.TLSHandshakeStart != nil {
+                       trace.TLSHandshakeStart()
+               }
+               err := tlsConn.Handshake()
+               if timer != nil {
+                       timer.Stop()
+               }
+               errc <- err
+       }()
+       if err := <-errc; err != nil {
+               plainConn.Close()
+               if trace != nil && trace.TLSHandshakeDone != nil {
+                       trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+               }
+               return err
+       }
+       if !cfg.InsecureSkipVerify {
+               if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+                       plainConn.Close()
+                       return err
+               }
+       }
+       cs := tlsConn.ConnectionState()
+       if trace != nil && trace.TLSHandshakeDone != nil {
+               trace.TLSHandshakeDone(cs, nil)
+       }
+       pconn.tlsState = &cs
+       pconn.conn = tlsConn
+       return nil
+}
+
 func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) {
        pconn := &persistConn{
                t:             t,
@@ -1032,15 +1090,21 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                writeLoopDone: make(chan struct{}),
        }
        trace := httptrace.ContextClientTrace(ctx)
-       tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
-       if tlsDial {
+       wrapErr := func(err error) error {
+               if cm.proxyURL != nil {
+                       // Return a typed error, per Issue 16997
+                       return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
+               }
+               return err
+       }
+       if cm.scheme() == "https" && t.DialTLS != nil {
                var err error
                pconn.conn, err = t.DialTLS("tcp", cm.addr())
                if err != nil {
-                       return nil, err
+                       return nil, wrapErr(err)
                }
                if pconn.conn == nil {
-                       return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
+                       return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
                }
                if tc, ok := pconn.conn.(*tls.Conn); ok {
                        // Handshake here, in case DialTLS didn't. TLSNextProto below
@@ -1064,13 +1128,18 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
        } else {
                conn, err := t.dial(ctx, "tcp", cm.addr())
                if err != nil {
-                       if cm.proxyURL != nil {
-                               // Return a typed error, per Issue 16997:
-                               err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
-                       }
-                       return nil, err
+                       return nil, wrapErr(err)
                }
                pconn.conn = conn
+               if cm.scheme() == "https" {
+                       var firstTLSHost string
+                       if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
+                               return nil, wrapErr(err)
+                       }
+                       if err = pconn.addTLS(firstTLSHost, trace); err != nil {
+                               return nil, wrapErr(err)
+                       }
+               }
        }
 
        // Proxy setup.
@@ -1134,50 +1203,10 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                }
        }
 
-       if cm.targetScheme == "https" && !tlsDial {
-               // Initiate TLS and check remote host name against certificate.
-               cfg := cloneTLSConfig(t.TLSClientConfig)
-               if cfg.ServerName == "" {
-                       cfg.ServerName = cm.tlsHost()
-               }
-               plainConn := pconn.conn
-               tlsConn := tls.Client(plainConn, cfg)
-               errc := make(chan error, 2)
-               var timer *time.Timer // for canceling TLS handshake
-               if d := t.TLSHandshakeTimeout; d != 0 {
-                       timer = time.AfterFunc(d, func() {
-                               errc <- tlsHandshakeTimeoutError{}
-                       })
-               }
-               go func() {
-                       if trace != nil && trace.TLSHandshakeStart != nil {
-                               trace.TLSHandshakeStart()
-                       }
-                       err := tlsConn.Handshake()
-                       if timer != nil {
-                               timer.Stop()
-                       }
-                       errc <- err
-               }()
-               if err := <-errc; err != nil {
-                       plainConn.Close()
-                       if trace != nil && trace.TLSHandshakeDone != nil {
-                               trace.TLSHandshakeDone(tls.ConnectionState{}, err)
-                       }
+       if cm.proxyURL != nil && cm.targetScheme == "https" {
+               if err := pconn.addTLS(cm.tlsHost(), trace); err != nil {
                        return nil, err
                }
-               if !cfg.InsecureSkipVerify {
-                       if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
-                               plainConn.Close()
-                               return nil, err
-                       }
-               }
-               cs := tlsConn.ConnectionState()
-               if trace != nil && trace.TLSHandshakeDone != nil {
-                       trace.TLSHandshakeDone(cs, nil)
-               }
-               pconn.tlsState = &cs
-               pconn.conn = tlsConn
        }
 
        if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
@@ -1279,13 +1308,16 @@ func useProxy(addr string) bool {
 // http://proxy.com|http             http to proxy, http to anywhere after that
 // socks5://proxy.com|http|foo.com   socks5 to proxy, then http to foo.com
 // socks5://proxy.com|https|foo.com  socks5 to proxy, then https to foo.com
-//
-// Note: no support to https to the proxy yet.
+// https://proxy.com|https|foo.com   https to proxy, then CONNECT to foo.com
+// https://proxy.com|http            https to proxy, http to anywhere after that
 //
 type connectMethod struct {
        proxyURL     *url.URL // nil for no proxy, else full proxy URL
        targetScheme string   // "http" or "https"
-       targetAddr   string   // Not used if http proxy + http targetScheme (4th example in table)
+       // If proxyURL specifies an http or https proxy, and targetScheme is http (not https),
+       // then targetAddr is not included in the connect method key, because the socket can
+       // be reused for different targetAddr values.
+       targetAddr string
 }
 
 func (cm *connectMethod) key() connectMethodKey {
@@ -1293,7 +1325,7 @@ func (cm *connectMethod) key() connectMethodKey {
        targetAddr := cm.targetAddr
        if cm.proxyURL != nil {
                proxyStr = cm.proxyURL.String()
-               if strings.HasPrefix(cm.proxyURL.Scheme, "http") && cm.targetScheme == "http" {
+               if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" {
                        targetAddr = ""
                }
        }
@@ -1304,6 +1336,14 @@ func (cm *connectMethod) key() connectMethodKey {
        }
 }
 
+// scheme returns the first hop scheme: http, https, or socks5
+func (cm *connectMethod) scheme() string {
+       if cm.proxyURL != nil {
+               return cm.proxyURL.Scheme
+       }
+       return cm.targetScheme
+}
+
 // addr returns the first hop "host:port" to which we need to TCP connect.
 func (cm *connectMethod) addr() string {
        if cm.proxyURL != nil {
index 39b5cd358f810a006007ba5f4d86200cccbe9fd1..ad63cca5fe9274cb877e0129f2ef3aec0536a7d2 100644 (file)
@@ -961,14 +961,10 @@ func TestTransportExpect100Continue(t *testing.T) {
 func TestSocks5Proxy(t *testing.T) {
        defer afterTest(t)
        ch := make(chan string, 1)
-       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
-               ch <- "real server"
-       }))
-       defer ts.Close()
        l := newLocalListener(t)
        defer l.Close()
-       go func() {
-               defer close(ch)
+       defer close(ch)
+       proxy := func(t *testing.T) {
                s, err := l.Accept()
                if err != nil {
                        t.Errorf("socks5 proxy Accept(): %v", err)
@@ -1003,7 +999,8 @@ func TestSocks5Proxy(t *testing.T) {
                case 4:
                        ipLen = 16
                default:
-                       t.Fatalf("socks5 proxy second read: unexpected address type %v", buf[4])
+                       t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
+                       return
                }
                if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
                        t.Errorf("socks5 proxy address read: %v", err)
@@ -1016,71 +1013,197 @@ func TestSocks5Proxy(t *testing.T) {
                        t.Errorf("socks5 proxy connect write: %v", err)
                        return
                }
-               done := make(chan struct{})
-               srv := &Server{Handler: HandlerFunc(func(w ResponseWriter, r *Request) {
-                       done <- struct{}{}
-               })}
-               srv.Serve(&oneConnListener{conn: s})
-               <-done
-               srv.Shutdown(context.Background())
                ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
-       }()
 
-       pu, err := url.Parse("socks5://" + l.Addr().String())
-       if err != nil {
-               t.Fatal(err)
-       }
-       c := ts.Client()
-       c.Transport.(*Transport).Proxy = ProxyURL(pu)
-       if _, err := c.Head(ts.URL); err != nil {
-               t.Error(err)
-       }
-       var got string
-       select {
-       case got = <-ch:
-       case <-time.After(5 * time.Second):
-               t.Fatal("timeout connecting to socks5 proxy")
+               // Implement proxying.
+               targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
+               targetConn, err := net.Dial("tcp", targetHost)
+               if err != nil {
+                       t.Errorf("net.Dial failed")
+                       return
+               }
+               go io.Copy(targetConn, s)
+               io.Copy(s, targetConn) // Wait for the client to close the socket.
+               targetConn.Close()
        }
-       tsu, err := url.Parse(ts.URL)
+
+       pu, err := url.Parse("socks5://" + l.Addr().String())
        if err != nil {
                t.Fatal(err)
        }
-       want := "proxy for " + tsu.Host
-       if got != want {
-               t.Errorf("got %q, want %q", got, want)
+
+       sentinelHeader := "X-Sentinel"
+       sentinelValue := "12345"
+       h := HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set(sentinelHeader, sentinelValue)
+       })
+       for _, useTLS := range []bool{false, true} {
+               t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
+                       var ts *httptest.Server
+                       if useTLS {
+                               ts = httptest.NewTLSServer(h)
+                       } else {
+                               ts = httptest.NewServer(h)
+                       }
+                       go proxy(t)
+                       c := ts.Client()
+                       c.Transport.(*Transport).Proxy = ProxyURL(pu)
+                       r, err := c.Head(ts.URL)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if r.Header.Get(sentinelHeader) != sentinelValue {
+                               t.Errorf("Failed to retrieve sentinel value")
+                       }
+                       var got string
+                       select {
+                       case got = <-ch:
+                       case <-time.After(5 * time.Second):
+                               t.Fatal("timeout connecting to socks5 proxy")
+                       }
+                       ts.Close()
+                       tsu, err := url.Parse(ts.URL)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       want := "proxy for " + tsu.Host
+                       if got != want {
+                               t.Errorf("got %q, want %q", got, want)
+                       }
+               })
        }
 }
 
 func TestTransportProxy(t *testing.T) {
        defer afterTest(t)
-       ch := make(chan string, 1)
-       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
-               ch <- "real server"
-       }))
-       defer ts.Close()
-       proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
-               ch <- "proxy for " + r.URL.String()
-       }))
-       defer proxy.Close()
+       testCases := []struct{ httpsSite, httpsProxy bool }{
+               {false, false},
+               {false, true},
+               {true, false},
+               {true, true},
+       }
+       for _, testCase := range testCases {
+               httpsSite := testCase.httpsSite
+               httpsProxy := testCase.httpsProxy
+               t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) {
+                       siteCh := make(chan *Request, 1)
+                       h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+                               siteCh <- r
+                       })
+                       proxyCh := make(chan *Request, 1)
+                       h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+                               proxyCh <- r
+                               // Implement an entire CONNECT proxy
+                               if r.Method == "CONNECT" {
+                                       hijacker, ok := w.(Hijacker)
+                                       if !ok {
+                                               t.Errorf("hijack not allowed")
+                                               return
+                                       }
+                                       clientConn, _, err := hijacker.Hijack()
+                                       if err != nil {
+                                               t.Errorf("hijacking failed")
+                                               return
+                                       }
+                                       res := &Response{
+                                               StatusCode: StatusOK,
+                                               Proto:      "HTTP/1.1",
+                                               ProtoMajor: 1,
+                                               ProtoMinor: 1,
+                                               Header:     make(Header),
+                                       }
+
+                                       log.Printf("Dialing %s", r.URL.Host)
+                                       targetConn, err := net.Dial("tcp", r.URL.Host)
+                                       if err != nil {
+                                               t.Errorf("net.Dial failed")
+                                               return
+                                       }
+
+                                       if err := res.Write(clientConn); err != nil {
+                                               t.Errorf("Writing 200 OK failed")
+                                               return
+                                       }
+
+                                       go io.Copy(targetConn, clientConn)
+                                       go func() {
+                                               io.Copy(clientConn, targetConn)
+                                               targetConn.Close()
+                                       }()
+                               }
+                       })
+                       var ts *httptest.Server
+                       if httpsSite {
+                               ts = httptest.NewTLSServer(h1)
+                       } else {
+                               ts = httptest.NewServer(h1)
+                       }
+                       var proxy *httptest.Server
+                       if httpsProxy {
+                               proxy = httptest.NewTLSServer(h2)
+                       } else {
+                               proxy = httptest.NewServer(h2)
+                       }
 
-       pu, err := url.Parse(proxy.URL)
-       if err != nil {
-               t.Fatal(err)
-       }
-       c := ts.Client()
-       c.Transport.(*Transport).Proxy = ProxyURL(pu)
-       if _, err := c.Head(ts.URL); err != nil {
-               t.Error(err)
-       }
-       var got string
-       select {
-       case got = <-ch:
-       case <-time.After(5 * time.Second):
-               t.Fatal("timeout connecting to http proxy")
-       }
-       want := "proxy for " + ts.URL + "/"
-       if got != want {
-               t.Errorf("got %q, want %q", got, want)
+                       pu, err := url.Parse(proxy.URL)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+
+                       // If neither server is HTTPS or both are, then c may be derived from either.
+                       // If only one server is HTTPS, c must be derived from that server in order
+                       // to ensure that it is configured to use the fake root CA from testcert.go.
+                       c := proxy.Client()
+                       if httpsSite {
+                               c = ts.Client()
+                       }
+
+                       c.Transport.(*Transport).Proxy = ProxyURL(pu)
+                       if _, err := c.Head(ts.URL); err != nil {
+                               t.Error(err)
+                       }
+                       var got *Request
+                       select {
+                       case got = <-proxyCh:
+                       case <-time.After(5 * time.Second):
+                               t.Fatal("timeout connecting to http proxy")
+                       }
+                       c.Transport.(*Transport).CloseIdleConnections()
+                       ts.Close()
+                       proxy.Close()
+                       if httpsSite {
+                               // First message should be a CONNECT, asking for a socket to the real server,
+                               if got.Method != "CONNECT" {
+                                       t.Errorf("Wrong method for secure proxying: %q", got.Method)
+                               }
+                               gotHost := got.URL.Host
+                               pu, err := url.Parse(ts.URL)
+                               if err != nil {
+                                       t.Fatal("Invalid site URL")
+                               }
+                               if wantHost := pu.Host; gotHost != wantHost {
+                                       t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
+                               }
+
+                               // The next message on the channel should be from the site's server.
+                               next := <-siteCh
+                               if next.Method != "HEAD" {
+                                       t.Errorf("Wrong method at destination: %s", next.Method)
+                               }
+                               if nextURL := next.URL.String(); nextURL != "/" {
+                                       t.Errorf("Wrong URL at destination: %s", nextURL)
+                               }
+                       } else {
+                               if got.Method != "HEAD" {
+                                       t.Errorf("Wrong method for destination: %q", got.Method)
+                               }
+                               gotURL := got.URL.String()
+                               wantURL := ts.URL + "/"
+                               if gotURL != wantURL {
+                                       t.Errorf("Got URL %q, want %q", gotURL, wantURL)
+                               }
+                       }
+               })
        }
 }