From: Brad Fitzpatrick Date: Mon, 8 Sep 2014 03:48:40 +0000 (-0700) Subject: net/http: add Transport.DialTLS hook X-Git-Tag: go1.4beta1~491 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=ae47e044a8c2f1a4cf21ba2b02eba7304c2d157d;p=gostls13.git net/http: add Transport.DialTLS hook Per discussions out of https://golang.org/cl/128930043/ and golang-nuts threads and with agl. Fixes #8522 LGTM=agl, adg R=agl, c, adg CC=c, golang-codereviews https://golang.org/cl/137940043 --- diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 7a229c1b71..527ed8bdd1 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -43,8 +43,8 @@ var DefaultTransport RoundTripper = &Transport{ // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -// Transport is an implementation of RoundTripper that supports http, -// https, and http proxies (for either http or https with CONNECT). +// Transport is an implementation of RoundTripper that supports HTTP, +// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { idleMu sync.Mutex @@ -61,11 +61,22 @@ type Transport struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*Request) (*url.URL, error) - // Dial specifies the dial function for creating TCP - // connections. + // Dial specifies the dial function for creating unencrypted + // TCP connections. // If Dial is nil, net.Dial is used. Dial func(network, addr string) (net.Conn, error) + // DialTLS specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // If DialTLS is nil, Dial and TLSClientConfig are used. + // + // If DialTLS is set, the Dial hook is not used for HTTPS + // requests and the TLSClientConfig and TLSHandshakeTimeout + // are ignored. The returned net.Conn is assumed to already be + // past the TLS handshake. + DialTLS func(network, addr string) (net.Conn, error) + // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config @@ -504,44 +515,56 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error } func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { - conn, err := t.dial("tcp", cm.addr()) - if err != nil { - if cm.proxyURL != nil { - err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) - } - return nil, err - } - - pa := cm.proxyAuth() - pconn := &persistConn{ t: t, cacheKey: cm.key(), - conn: conn, reqch: make(chan requestAndChan, 1), writech: make(chan writeRequest, 1), closech: make(chan struct{}), writeErrCh: make(chan error, 1), } + tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil + if tlsDial { + var err error + pconn.conn, err = t.DialTLS("tcp", cm.addr()) + if err != nil { + return nil, err + } + if tc, ok := pconn.conn.(*tls.Conn); ok { + cs := tc.ConnectionState() + pconn.tlsState = &cs + } + } else { + conn, err := t.dial("tcp", cm.addr()) + if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } + return nil, err + } + pconn.conn = conn + } + // Proxy setup. switch { case cm.proxyURL == nil: - // Do nothing. + // Do nothing. Not using a proxy. case cm.targetScheme == "http": pconn.isProxy = true - if pa != "" { + if pa := cm.proxyAuth(); pa != "" { pconn.mutateHeaderFunc = func(h Header) { h.Set("Proxy-Authorization", pa) } } case cm.targetScheme == "https": + conn := pconn.conn connectReq := &Request{ Method: "CONNECT", URL: &url.URL{Opaque: cm.targetAddr}, Host: cm.targetAddr, Header: make(Header), } - if pa != "" { + if pa := cm.proxyAuth(); pa != "" { connectReq.Header.Set("Proxy-Authorization", pa) } connectReq.Write(conn) @@ -562,7 +585,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { } } - if cm.targetScheme == "https" { + if cm.targetScheme == "https" && !tlsDial { // Initiate TLS and check remote host name against certificate. cfg := t.TLSClientConfig if cfg == nil || cfg.ServerName == "" { @@ -575,7 +598,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { cfg = &clone } } - plainConn := conn + plainConn := pconn.conn tlsConn := tls.Client(plainConn, cfg) errc := make(chan error, 2) var timer *time.Timer // for canceling TLS handshake diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index b55d30ddf9..3460d690e3 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -2096,6 +2096,46 @@ func TestTransportClosesBodyOnError(t *testing.T) { } } +func TestTransportDialTLS(t *testing.T) { + var mu sync.Mutex // guards following + var gotReq, didDial bool + + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })) + defer ts.Close() + tr := &Transport{ + DialTLS: func(netw, addr string) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + c, err := tls.Dial(netw, addr, &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + return nil, err + } + return c, c.Handshake() + }, + } + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if !didDial { + t.Error("didn't use dial hook") + } +} + func wantBody(res *http.Response, err error, want string) error { if err != nil { return err