]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add DialTLSContext hook to Transport
authorGabriel Rosenhouse <rosenhouse@gmail.com>
Mon, 11 Nov 2019 19:50:55 +0000 (19:50 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 11 Nov 2019 20:17:03 +0000 (20:17 +0000)
Fixes #21526

Change-Id: I2f8215cd671641cddfa8499f8a8c0130db93dbc6
Reviewed-on: https://go-review.googlesource.com/c/go/+/61291
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>

src/net/http/transport.go
src/net/http/transport_test.go

index 6fade795ab64f13601d37ecb3cd26b024507ca27..bdc767a2363fa0d9c2be19468f5805e8549503fd 100644 (file)
@@ -142,15 +142,24 @@ type Transport struct {
        // If both are set, DialContext takes priority.
        Dial func(network, addr string) (net.Conn, error)
 
-       // DialTLS specifies an optional dial function for creating
+       // DialTLSContext specifies an optional dial function for creating
        // TLS connections for non-proxied HTTPS requests.
        //
-       // If DialTLS is nil, Dial and TLSClientConfig are used.
+       // If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
+       // DialContext and TLSClientConfig are used.
        //
-       // If DialTLS is set, the Dial hook is not used for HTTPS
+       // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
        // requests and the TLSClientConfig and TLSHandshakeTimeout
        // are ignored. The returned net.Conn is assumed to already be
        // past the TLS handshake.
+       DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
+       // DialTLS specifies an optional dial function for creating
+       // TLS connections for non-proxied HTTPS requests.
+       //
+       // Deprecated: Use DialTLSContext instead, which allows the transport
+       // to cancel dials as soon as they are no longer needed.
+       // If both are set, DialTLSContext takes priority.
        DialTLS func(network, addr string) (net.Conn, error)
 
        // TLSClientConfig specifies the TLS configuration to use with
@@ -286,6 +295,7 @@ func (t *Transport) Clone() *Transport {
                DialContext:            t.DialContext,
                Dial:                   t.Dial,
                DialTLS:                t.DialTLS,
+               DialTLSContext:         t.DialTLSContext,
                TLSHandshakeTimeout:    t.TLSHandshakeTimeout,
                DisableKeepAlives:      t.DisableKeepAlives,
                DisableCompression:     t.DisableCompression,
@@ -324,6 +334,10 @@ type h2Transport interface {
        CloseIdleConnections()
 }
 
+func (t *Transport) hasCustomTLSDialer() bool {
+       return t.DialTLS != nil || t.DialTLSContext != nil
+}
+
 // onceSetNextProtoDefaults initializes TLSNextProto.
 // It must be called via t.nextProtoOnce.Do.
 func (t *Transport) onceSetNextProtoDefaults() {
@@ -352,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() {
                // Transport.
                return
        }
-       if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) {
+       if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
                // Be conservative and don't automatically enable
                // http2 if they've specified a custom TLS config or
                // custom dialers. Let them opt-in themselves via
@@ -1185,6 +1199,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
        }
 }
 
+func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
+       if t.DialTLSContext != nil {
+               conn, err = t.DialTLSContext(ctx, network, addr)
+       } else {
+               conn, err = t.DialTLS(network, addr)
+       }
+       if conn == nil && err == nil {
+               err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)")
+       }
+       return
+}
+
 // getConn dials and creates a new persistConn to the target as
 // specified in the connectMethod. This includes doing a proxy CONNECT
 // and/or setting up TLS.  If this doesn't return an error, the persistConn
@@ -1435,15 +1461,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                }
                return err
        }
-       if cm.scheme() == "https" && t.DialTLS != nil {
+       if cm.scheme() == "https" && t.hasCustomTLSDialer() {
                var err error
-               pconn.conn, err = t.DialTLS("tcp", cm.addr())
+               pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
                if err != nil {
                        return nil, wrapErr(err)
                }
-               if pconn.conn == nil {
-                       return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
-               }
                if tc, ok := pconn.conn.(*tls.Conn); ok {
                        // Handshake here, in case DialTLS didn't. TLSNextProto below
                        // depends on it for knowing the connection state.
index 692868094c11234e3a2228aa61409f1588796a1e..27be26cedc5b5d948d9939f479522632850f542c 100644 (file)
@@ -3506,6 +3506,90 @@ func TestTransportDialTLS(t *testing.T) {
        }
 }
 
+func TestTransportDialContext(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+       var mu sync.Mutex // guards following
+       var gotReq bool
+       var receivedContext context.Context
+
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               mu.Lock()
+               gotReq = true
+               mu.Unlock()
+       }))
+       defer ts.Close()
+       c := ts.Client()
+       c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+               mu.Lock()
+               receivedContext = ctx
+               mu.Unlock()
+               return net.Dial(netw, addr)
+       }
+
+       req, err := NewRequest("GET", ts.URL, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       ctx := context.WithValue(context.Background(), "some-key", "some-value")
+       res, err := c.Do(req.WithContext(ctx))
+       if err != nil {
+               t.Fatal(err)
+       }
+       res.Body.Close()
+       mu.Lock()
+       if !gotReq {
+               t.Error("didn't get request")
+       }
+       if receivedContext != ctx {
+               t.Error("didn't receive correct context")
+       }
+}
+
+func TestTransportDialTLSContext(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+       var mu sync.Mutex // guards following
+       var gotReq bool
+       var receivedContext context.Context
+
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               mu.Lock()
+               gotReq = true
+               mu.Unlock()
+       }))
+       defer ts.Close()
+       c := ts.Client()
+       c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+               mu.Lock()
+               receivedContext = ctx
+               mu.Unlock()
+               c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
+               if err != nil {
+                       return nil, err
+               }
+               return c, c.Handshake()
+       }
+
+       req, err := NewRequest("GET", ts.URL, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       ctx := context.WithValue(context.Background(), "some-key", "some-value")
+       res, err := c.Do(req.WithContext(ctx))
+       if err != nil {
+               t.Fatal(err)
+       }
+       res.Body.Close()
+       mu.Lock()
+       if !gotReq {
+               t.Error("didn't get request")
+       }
+       if receivedContext != ctx {
+               t.Error("didn't receive correct context")
+       }
+}
+
 // Test for issue 8755
 // Ensure that if a proxy returns an error, it is exposed by RoundTrip
 func TestRoundTripReturnsProxyError(t *testing.T) {
@@ -5577,6 +5661,7 @@ func TestTransportClone(t *testing.T) {
                DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
                Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
                DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },
+               DialTLSContext:         func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
                TLSClientConfig:        new(tls.Config),
                TLSHandshakeTimeout:    time.Second,
                DisableKeepAlives:      true,