From: Gabriel Rosenhouse Date: Mon, 11 Nov 2019 19:50:55 +0000 (+0000) Subject: net/http: add DialTLSContext hook to Transport X-Git-Tag: go1.14beta1~256 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=eb55a0c86438b815e244c1f00dea273b3122592a;p=gostls13.git net/http: add DialTLSContext hook to Transport Fixes #21526 Change-Id: I2f8215cd671641cddfa8499f8a8c0130db93dbc6 Reviewed-on: https://go-review.googlesource.com/c/go/+/61291 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick --- diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 6fade795ab..bdc767a236 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -142,15 +142,24 @@ type Transport struct { // If both are set, DialContext takes priority. Dial func(network, addr string) (net.Conn, error) - // DialTLS specifies an optional dial function for creating + // DialTLSContext specifies an optional dial function for creating // TLS connections for non-proxied HTTPS requests. // - // If DialTLS is nil, Dial and TLSClientConfig are used. + // If DialTLSContext is nil (and the deprecated DialTLS below is also nil), + // DialContext and TLSClientConfig are used. // - // If DialTLS is set, the Dial hook is not used for HTTPS + // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS // requests and the TLSClientConfig and TLSHandshakeTimeout // are ignored. The returned net.Conn is assumed to already be // past the TLS handshake. + DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // DialTLS specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // Deprecated: Use DialTLSContext instead, which allows the transport + // to cancel dials as soon as they are no longer needed. + // If both are set, DialTLSContext takes priority. DialTLS func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with @@ -286,6 +295,7 @@ func (t *Transport) Clone() *Transport { DialContext: t.DialContext, Dial: t.Dial, DialTLS: t.DialTLS, + DialTLSContext: t.DialTLSContext, TLSHandshakeTimeout: t.TLSHandshakeTimeout, DisableKeepAlives: t.DisableKeepAlives, DisableCompression: t.DisableCompression, @@ -324,6 +334,10 @@ type h2Transport interface { CloseIdleConnections() } +func (t *Transport) hasCustomTLSDialer() bool { + return t.DialTLS != nil || t.DialTLSContext != nil +} + // onceSetNextProtoDefaults initializes TLSNextProto. // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { @@ -352,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() { // Transport. return } - if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) { + if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) { // Be conservative and don't automatically enable // http2 if they've specified a custom TLS config or // custom dialers. Let them opt-in themselves via @@ -1185,6 +1199,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } +func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) { + if t.DialTLSContext != nil { + conn, err = t.DialTLSContext(ctx, network, addr) + } else { + conn, err = t.DialTLS(network, addr) + } + if conn == nil && err == nil { + err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)") + } + return +} + // getConn dials and creates a new persistConn to the target as // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn @@ -1435,15 +1461,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } return err } - if cm.scheme() == "https" && t.DialTLS != nil { + if cm.scheme() == "https" && t.hasCustomTLSDialer() { var err error - pconn.conn, err = t.DialTLS("tcp", cm.addr()) + pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) if err != nil { return nil, wrapErr(err) } - if pconn.conn == nil { - return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)")) - } if tc, ok := pconn.conn.(*tls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 692868094c..27be26cedc 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -3506,6 +3506,90 @@ func TestTransportDialTLS(t *testing.T) { } } +func TestTransportDialContext(t *testing.T) { + setParallel(t) + defer afterTest(t) + var mu sync.Mutex // guards following + var gotReq bool + var receivedContext context.Context + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })) + defer ts.Close() + c := ts.Client() + c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { + mu.Lock() + receivedContext = ctx + mu.Unlock() + return net.Dial(netw, addr) + } + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + ctx := context.WithValue(context.Background(), "some-key", "some-value") + res, err := c.Do(req.WithContext(ctx)) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if receivedContext != ctx { + t.Error("didn't receive correct context") + } +} + +func TestTransportDialTLSContext(t *testing.T) { + setParallel(t) + defer afterTest(t) + var mu sync.Mutex // guards following + var gotReq bool + var receivedContext context.Context + + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })) + defer ts.Close() + c := ts.Client() + c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { + mu.Lock() + receivedContext = ctx + mu.Unlock() + c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) + if err != nil { + return nil, err + } + return c, c.Handshake() + } + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + ctx := context.WithValue(context.Background(), "some-key", "some-value") + res, err := c.Do(req.WithContext(ctx)) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if receivedContext != ctx { + t.Error("didn't receive correct context") + } +} + // Test for issue 8755 // Ensure that if a proxy returns an error, it is exposed by RoundTrip func TestRoundTripReturnsProxyError(t *testing.T) { @@ -5577,6 +5661,7 @@ func TestTransportClone(t *testing.T) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, Dial: func(network, addr string) (net.Conn, error) { panic("") }, DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, TLSClientConfig: new(tls.Config), TLSHandshakeTimeout: time.Second, DisableKeepAlives: true,