]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.DialTLS hook
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 8 Sep 2014 03:48:40 +0000 (20:48 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 8 Sep 2014 03:48:40 +0000 (20:48 -0700)
Per discussions out of https://golang.org/cl/128930043/
and golang-nuts threads and with agl.

Fixes #8522

LGTM=agl, adg
R=agl, c, adg
CC=c, golang-codereviews
https://golang.org/cl/137940043

src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index 7a229c1b7167b33680f8bdc81171fdcdcdf52bd8..527ed8bdd117490aa8451ba5e9943e2e2aaf3223 100644 (file)
@@ -43,8 +43,8 @@ var DefaultTransport RoundTripper = &Transport{
 // MaxIdleConnsPerHost.
 const DefaultMaxIdleConnsPerHost = 2
 
-// Transport is an implementation of RoundTripper that supports http,
-// https, and http proxies (for either http or https with CONNECT).
+// Transport is an implementation of RoundTripper that supports HTTP,
+// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
 // Transport can also cache connections for future re-use.
 type Transport struct {
        idleMu      sync.Mutex
@@ -61,11 +61,22 @@ type Transport struct {
        // If Proxy is nil or returns a nil *URL, no proxy is used.
        Proxy func(*Request) (*url.URL, error)
 
-       // Dial specifies the dial function for creating TCP
-       // connections.
+       // Dial specifies the dial function for creating unencrypted
+       // TCP connections.
        // If Dial is nil, net.Dial is used.
        Dial func(network, addr string) (net.Conn, error)
 
+       // DialTLS specifies an optional dial function for creating
+       // TLS connections for non-proxied HTTPS requests.
+       //
+       // If DialTLS is nil, Dial and TLSClientConfig are used.
+       //
+       // If DialTLS is set, the Dial hook is not used for HTTPS
+       // requests and the TLSClientConfig and TLSHandshakeTimeout
+       // are ignored. The returned net.Conn is assumed to already be
+       // past the TLS handshake.
+       DialTLS func(network, addr string) (net.Conn, error)
+
        // TLSClientConfig specifies the TLS configuration to use with
        // tls.Client. If nil, the default configuration is used.
        TLSClientConfig *tls.Config
@@ -504,44 +515,56 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
 }
 
 func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
-       conn, err := t.dial("tcp", cm.addr())
-       if err != nil {
-               if cm.proxyURL != nil {
-                       err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
-               }
-               return nil, err
-       }
-
-       pa := cm.proxyAuth()
-
        pconn := &persistConn{
                t:          t,
                cacheKey:   cm.key(),
-               conn:       conn,
                reqch:      make(chan requestAndChan, 1),
                writech:    make(chan writeRequest, 1),
                closech:    make(chan struct{}),
                writeErrCh: make(chan error, 1),
        }
+       tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
+       if tlsDial {
+               var err error
+               pconn.conn, err = t.DialTLS("tcp", cm.addr())
+               if err != nil {
+                       return nil, err
+               }
+               if tc, ok := pconn.conn.(*tls.Conn); ok {
+                       cs := tc.ConnectionState()
+                       pconn.tlsState = &cs
+               }
+       } else {
+               conn, err := t.dial("tcp", cm.addr())
+               if err != nil {
+                       if cm.proxyURL != nil {
+                               err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
+                       }
+                       return nil, err
+               }
+               pconn.conn = conn
+       }
 
+       // Proxy setup.
        switch {
        case cm.proxyURL == nil:
-               // Do nothing.
+               // Do nothing. Not using a proxy.
        case cm.targetScheme == "http":
                pconn.isProxy = true
-               if pa != "" {
+               if pa := cm.proxyAuth(); pa != "" {
                        pconn.mutateHeaderFunc = func(h Header) {
                                h.Set("Proxy-Authorization", pa)
                        }
                }
        case cm.targetScheme == "https":
+               conn := pconn.conn
                connectReq := &Request{
                        Method: "CONNECT",
                        URL:    &url.URL{Opaque: cm.targetAddr},
                        Host:   cm.targetAddr,
                        Header: make(Header),
                }
-               if pa != "" {
+               if pa := cm.proxyAuth(); pa != "" {
                        connectReq.Header.Set("Proxy-Authorization", pa)
                }
                connectReq.Write(conn)
@@ -562,7 +585,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
                }
        }
 
-       if cm.targetScheme == "https" {
+       if cm.targetScheme == "https" && !tlsDial {
                // Initiate TLS and check remote host name against certificate.
                cfg := t.TLSClientConfig
                if cfg == nil || cfg.ServerName == "" {
@@ -575,7 +598,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
                                cfg = &clone
                        }
                }
-               plainConn := conn
+               plainConn := pconn.conn
                tlsConn := tls.Client(plainConn, cfg)
                errc := make(chan error, 2)
                var timer *time.Timer // for canceling TLS handshake
index b55d30ddf957fca5f00231f87758ba3ef503ce1b..3460d690e35262ecd171b1e42c8c1caa2c48a459 100644 (file)
@@ -2096,6 +2096,46 @@ func TestTransportClosesBodyOnError(t *testing.T) {
        }
 }
 
+func TestTransportDialTLS(t *testing.T) {
+       var mu sync.Mutex // guards following
+       var gotReq, didDial bool
+
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               mu.Lock()
+               gotReq = true
+               mu.Unlock()
+       }))
+       defer ts.Close()
+       tr := &Transport{
+               DialTLS: func(netw, addr string) (net.Conn, error) {
+                       mu.Lock()
+                       didDial = true
+                       mu.Unlock()
+                       c, err := tls.Dial(netw, addr, &tls.Config{
+                               InsecureSkipVerify: true,
+                       })
+                       if err != nil {
+                               return nil, err
+                       }
+                       return c, c.Handshake()
+               },
+       }
+       defer tr.CloseIdleConnections()
+       client := &Client{Transport: tr}
+       res, err := client.Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       res.Body.Close()
+       mu.Lock()
+       if !gotReq {
+               t.Error("didn't get request")
+       }
+       if !didDial {
+               t.Error("didn't use dial hook")
+       }
+}
+
 func wantBody(res *http.Response, err error, want string) error {
        if err != nil {
                return err