if port := cm.proxyURL.Port(); !validPort(port) {
return cm, fmt.Errorf("invalid proxy URL port %q", port)
}
- switch cm.proxyURL.Scheme {
- case "http", "socks5":
- default:
- return cm, fmt.Errorf("invalid proxy URL scheme %q", cm.proxyURL.Scheme)
- }
}
}
return cm, err
}
}
+// The connect method and the transport can both specify a TLS
+// Host name. The transport's name takes precedence if present.
+func chooseTLSHost(cm connectMethod, t *Transport) string {
+ tlsHost := ""
+ if t.TLSClientConfig != nil {
+ tlsHost = t.TLSClientConfig.ServerName
+ }
+ if tlsHost == "" {
+ tlsHost = cm.tlsHost()
+ }
+ return tlsHost
+}
+
+// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
+// tunnel, this function establishes a nested TLS session inside the encrypted channel.
+// The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
+func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error {
+ // Initiate TLS and check remote host name against certificate.
+ cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
+ if cfg.ServerName == "" {
+ cfg.ServerName = name
+ }
+ plainConn := pconn.conn
+ tlsConn := tls.Client(plainConn, cfg)
+ errc := make(chan error, 2)
+ var timer *time.Timer // for canceling TLS handshake
+ if d := pconn.t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := tlsConn.Handshake()
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+ }
+ return err
+ }
+ if !cfg.InsecureSkipVerify {
+ if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+ plainConn.Close()
+ return err
+ }
+ }
+ cs := tlsConn.ConnectionState()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(cs, nil)
+ }
+ pconn.tlsState = &cs
+ pconn.conn = tlsConn
+ return nil
+}
+
func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) {
pconn := &persistConn{
t: t,
writeLoopDone: make(chan struct{}),
}
trace := httptrace.ContextClientTrace(ctx)
- tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
- if tlsDial {
+ wrapErr := func(err error) error {
+ if cm.proxyURL != nil {
+ // Return a typed error, per Issue 16997
+ return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
+ }
+ return err
+ }
+ if cm.scheme() == "https" && t.DialTLS != nil {
var err error
pconn.conn, err = t.DialTLS("tcp", cm.addr())
if err != nil {
- return nil, err
+ return nil, wrapErr(err)
}
if pconn.conn == nil {
- return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
+ return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
}
if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below
} else {
conn, err := t.dial(ctx, "tcp", cm.addr())
if err != nil {
- if cm.proxyURL != nil {
- // Return a typed error, per Issue 16997:
- err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
- }
- return nil, err
+ return nil, wrapErr(err)
}
pconn.conn = conn
+ if cm.scheme() == "https" {
+ var firstTLSHost string
+ if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
+ return nil, wrapErr(err)
+ }
+ if err = pconn.addTLS(firstTLSHost, trace); err != nil {
+ return nil, wrapErr(err)
+ }
+ }
}
// Proxy setup.
}
}
- if cm.targetScheme == "https" && !tlsDial {
- // Initiate TLS and check remote host name against certificate.
- cfg := cloneTLSConfig(t.TLSClientConfig)
- if cfg.ServerName == "" {
- cfg.ServerName = cm.tlsHost()
- }
- plainConn := pconn.conn
- tlsConn := tls.Client(plainConn, cfg)
- errc := make(chan error, 2)
- var timer *time.Timer // for canceling TLS handshake
- if d := t.TLSHandshakeTimeout; d != 0 {
- timer = time.AfterFunc(d, func() {
- errc <- tlsHandshakeTimeoutError{}
- })
- }
- go func() {
- if trace != nil && trace.TLSHandshakeStart != nil {
- trace.TLSHandshakeStart()
- }
- err := tlsConn.Handshake()
- if timer != nil {
- timer.Stop()
- }
- errc <- err
- }()
- if err := <-errc; err != nil {
- plainConn.Close()
- if trace != nil && trace.TLSHandshakeDone != nil {
- trace.TLSHandshakeDone(tls.ConnectionState{}, err)
- }
+ if cm.proxyURL != nil && cm.targetScheme == "https" {
+ if err := pconn.addTLS(cm.tlsHost(), trace); err != nil {
return nil, err
}
- if !cfg.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
- plainConn.Close()
- return nil, err
- }
- }
- cs := tlsConn.ConnectionState()
- if trace != nil && trace.TLSHandshakeDone != nil {
- trace.TLSHandshakeDone(cs, nil)
- }
- pconn.tlsState = &cs
- pconn.conn = tlsConn
}
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
// http://proxy.com|http http to proxy, http to anywhere after that
// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com
// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com
-//
-// Note: no support to https to the proxy yet.
+// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com
+// https://proxy.com|http https to proxy, http to anywhere after that
//
type connectMethod struct {
proxyURL *url.URL // nil for no proxy, else full proxy URL
targetScheme string // "http" or "https"
- targetAddr string // Not used if http proxy + http targetScheme (4th example in table)
+ // If proxyURL specifies an http or https proxy, and targetScheme is http (not https),
+ // then targetAddr is not included in the connect method key, because the socket can
+ // be reused for different targetAddr values.
+ targetAddr string
}
func (cm *connectMethod) key() connectMethodKey {
targetAddr := cm.targetAddr
if cm.proxyURL != nil {
proxyStr = cm.proxyURL.String()
- if strings.HasPrefix(cm.proxyURL.Scheme, "http") && cm.targetScheme == "http" {
+ if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" {
targetAddr = ""
}
}
}
}
+// scheme returns the first hop scheme: http, https, or socks5
+func (cm *connectMethod) scheme() string {
+ if cm.proxyURL != nil {
+ return cm.proxyURL.Scheme
+ }
+ return cm.targetScheme
+}
+
// addr returns the first hop "host:port" to which we need to TCP connect.
func (cm *connectMethod) addr() string {
if cm.proxyURL != nil {
func TestSocks5Proxy(t *testing.T) {
defer afterTest(t)
ch := make(chan string, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- ch <- "real server"
- }))
- defer ts.Close()
l := newLocalListener(t)
defer l.Close()
- go func() {
- defer close(ch)
+ defer close(ch)
+ proxy := func(t *testing.T) {
s, err := l.Accept()
if err != nil {
t.Errorf("socks5 proxy Accept(): %v", err)
case 4:
ipLen = 16
default:
- t.Fatalf("socks5 proxy second read: unexpected address type %v", buf[4])
+ t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
+ return
}
if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
t.Errorf("socks5 proxy address read: %v", err)
t.Errorf("socks5 proxy connect write: %v", err)
return
}
- done := make(chan struct{})
- srv := &Server{Handler: HandlerFunc(func(w ResponseWriter, r *Request) {
- done <- struct{}{}
- })}
- srv.Serve(&oneConnListener{conn: s})
- <-done
- srv.Shutdown(context.Background())
ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
- }()
- pu, err := url.Parse("socks5://" + l.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- c := ts.Client()
- c.Transport.(*Transport).Proxy = ProxyURL(pu)
- if _, err := c.Head(ts.URL); err != nil {
- t.Error(err)
- }
- var got string
- select {
- case got = <-ch:
- case <-time.After(5 * time.Second):
- t.Fatal("timeout connecting to socks5 proxy")
+ // Implement proxying.
+ targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
+ targetConn, err := net.Dial("tcp", targetHost)
+ if err != nil {
+ t.Errorf("net.Dial failed")
+ return
+ }
+ go io.Copy(targetConn, s)
+ io.Copy(s, targetConn) // Wait for the client to close the socket.
+ targetConn.Close()
}
- tsu, err := url.Parse(ts.URL)
+
+ pu, err := url.Parse("socks5://" + l.Addr().String())
if err != nil {
t.Fatal(err)
}
- want := "proxy for " + tsu.Host
- if got != want {
- t.Errorf("got %q, want %q", got, want)
+
+ sentinelHeader := "X-Sentinel"
+ sentinelValue := "12345"
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set(sentinelHeader, sentinelValue)
+ })
+ for _, useTLS := range []bool{false, true} {
+ t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
+ var ts *httptest.Server
+ if useTLS {
+ ts = httptest.NewTLSServer(h)
+ } else {
+ ts = httptest.NewServer(h)
+ }
+ go proxy(t)
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ r, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if r.Header.Get(sentinelHeader) != sentinelValue {
+ t.Errorf("Failed to retrieve sentinel value")
+ }
+ var got string
+ select {
+ case got = <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout connecting to socks5 proxy")
+ }
+ ts.Close()
+ tsu, err := url.Parse(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "proxy for " + tsu.Host
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ })
}
}
func TestTransportProxy(t *testing.T) {
defer afterTest(t)
- ch := make(chan string, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- ch <- "real server"
- }))
- defer ts.Close()
- proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- ch <- "proxy for " + r.URL.String()
- }))
- defer proxy.Close()
+ testCases := []struct{ httpsSite, httpsProxy bool }{
+ {false, false},
+ {false, true},
+ {true, false},
+ {true, true},
+ }
+ for _, testCase := range testCases {
+ httpsSite := testCase.httpsSite
+ httpsProxy := testCase.httpsProxy
+ t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) {
+ siteCh := make(chan *Request, 1)
+ h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ siteCh <- r
+ })
+ proxyCh := make(chan *Request, 1)
+ h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ proxyCh <- r
+ // Implement an entire CONNECT proxy
+ if r.Method == "CONNECT" {
+ hijacker, ok := w.(Hijacker)
+ if !ok {
+ t.Errorf("hijack not allowed")
+ return
+ }
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ t.Errorf("hijacking failed")
+ return
+ }
+ res := &Response{
+ StatusCode: StatusOK,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ }
+
+ log.Printf("Dialing %s", r.URL.Host)
+ targetConn, err := net.Dial("tcp", r.URL.Host)
+ if err != nil {
+ t.Errorf("net.Dial failed")
+ return
+ }
+
+ if err := res.Write(clientConn); err != nil {
+ t.Errorf("Writing 200 OK failed")
+ return
+ }
+
+ go io.Copy(targetConn, clientConn)
+ go func() {
+ io.Copy(clientConn, targetConn)
+ targetConn.Close()
+ }()
+ }
+ })
+ var ts *httptest.Server
+ if httpsSite {
+ ts = httptest.NewTLSServer(h1)
+ } else {
+ ts = httptest.NewServer(h1)
+ }
+ var proxy *httptest.Server
+ if httpsProxy {
+ proxy = httptest.NewTLSServer(h2)
+ } else {
+ proxy = httptest.NewServer(h2)
+ }
- pu, err := url.Parse(proxy.URL)
- if err != nil {
- t.Fatal(err)
- }
- c := ts.Client()
- c.Transport.(*Transport).Proxy = ProxyURL(pu)
- if _, err := c.Head(ts.URL); err != nil {
- t.Error(err)
- }
- var got string
- select {
- case got = <-ch:
- case <-time.After(5 * time.Second):
- t.Fatal("timeout connecting to http proxy")
- }
- want := "proxy for " + ts.URL + "/"
- if got != want {
- t.Errorf("got %q, want %q", got, want)
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If neither server is HTTPS or both are, then c may be derived from either.
+ // If only one server is HTTPS, c must be derived from that server in order
+ // to ensure that it is configured to use the fake root CA from testcert.go.
+ c := proxy.Client()
+ if httpsSite {
+ c = ts.Client()
+ }
+
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ if _, err := c.Head(ts.URL); err != nil {
+ t.Error(err)
+ }
+ var got *Request
+ select {
+ case got = <-proxyCh:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout connecting to http proxy")
+ }
+ c.Transport.(*Transport).CloseIdleConnections()
+ ts.Close()
+ proxy.Close()
+ if httpsSite {
+ // First message should be a CONNECT, asking for a socket to the real server,
+ if got.Method != "CONNECT" {
+ t.Errorf("Wrong method for secure proxying: %q", got.Method)
+ }
+ gotHost := got.URL.Host
+ pu, err := url.Parse(ts.URL)
+ if err != nil {
+ t.Fatal("Invalid site URL")
+ }
+ if wantHost := pu.Host; gotHost != wantHost {
+ t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
+ }
+
+ // The next message on the channel should be from the site's server.
+ next := <-siteCh
+ if next.Method != "HEAD" {
+ t.Errorf("Wrong method at destination: %s", next.Method)
+ }
+ if nextURL := next.URL.String(); nextURL != "/" {
+ t.Errorf("Wrong URL at destination: %s", nextURL)
+ }
+ } else {
+ if got.Method != "HEAD" {
+ t.Errorf("Wrong method for destination: %q", got.Method)
+ }
+ gotURL := got.URL.String()
+ wantURL := ts.URL + "/"
+ if gotURL != wantURL {
+ t.Errorf("Got URL %q, want %q", gotURL, wantURL)
+ }
+ }
+ })
}
}