]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add HandshakeContext method to Conn
authorJohan Brandhorst <johan.brandhorst@gmail.com>
Sat, 1 Aug 2020 11:18:31 +0000 (12:18 +0100)
committerJohan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Tue, 16 Mar 2021 14:05:45 +0000 (14:05 +0000)
Adds the (*tls.Conn).HandshakeContext method. This allows
us to pass the context provided down the call stack to
eventually reach the tls.ClientHelloInfo and
tls.CertificateRequestInfo structs.
These contexts are exposed to the user as read-only via Context()
methods.

This allows users of (*tls.Config).GetCertificate and
(*tls.Config).GetClientCertificate to use the context for
request scoped parameters and cancellation.

Replace uses of (*tls.Conn).Handshake with (*tls.Conn).HandshakeContext
where appropriate, to propagate existing contexts.

Fixes #32406

Change-Id: I259939c744bdc9b805bf51a845a8bc462c042483
Reviewed-on: https://go-review.googlesource.com/c/go/+/295370
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
13 files changed:
doc/go1.17.html
src/crypto/tls/common.go
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/tls.go
src/net/http/server.go
src/net/http/transport.go
src/net/http/transport_test.go

index a07290714fa915743c59811d21ec8acbca440ecf..34cfce7a06c8e88467ee370de5b347568a7368e8 100644 (file)
@@ -79,6 +79,23 @@ Do not send CLs removing the interior tags from such phrases.
   TODO: complete this section
 </p>
 
+<h3 id="crypto/tls"><a href="/pkg/crypto/tls">crypto/tls</a></h3>
+
+<p><!-- CL 295370 -->
+  <a href="/pkg/crypto/tls#Conn.HandshakeContext">(*Conn).HandshakeContext</a> was added to
+  allow the user to control cancellation of an in-progress TLS Handshake.
+  The context provided is propagated into the
+  <a href="/pkg/crypto/tls#ClientHelloInfo">ClientHelloInfo</a>
+  and <a href="/pkg/crypto/tls#CertificateRequestInfo">CertificateRequestInfo</a>
+  structs and accessible through the new
+  <a href="/pkg/crypto/tls#ClientHelloInfo.Context">(*ClientHelloInfo).Context</a>
+  and
+  <a href="/pkg/crypto/tls#CertificateRequestInfo.Context">
+      (*CertificateRequestInfo).Context
+  </a> methods respectively. Canceling the context after the handshake has finished
+  has no effect.
+</p>
+
 <h3 id="minor_library_changes">Minor changes to the library</h3>
 
 <p>
@@ -87,6 +104,15 @@ Do not send CLs removing the interior tags from such phrases.
   in mind.
 </p>
 
+<dl id="net/http"><dt><a href="/pkg/net/http/">net/http</a></dt>
+  <p>
+    The <a href="/pkg/net/http/"><code>net/http</code></a> package now uses the new
+    <a href="/pkg/crypto/tls#Conn.HandshakeContext"><code>(*tls.Conn).HandshakeContext</code></a>
+    with the <a href="/pkg/net/http/#Request"><code>Request</code></a> context
+    when performing TLS handshakes in the client or server.
+  </p>
+</dl><!-- net/http -->
+
 <p>
   TODO: complete this section
 </p>
index eec6e1ebbd9060559637c9d2476046b2665a4bed..5b68742975cd079d25b3b8e4e56111e441cb2f8e 100644 (file)
@@ -7,6 +7,7 @@ package tls
 import (
        "bytes"
        "container/list"
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -443,6 +444,16 @@ type ClientHelloInfo struct {
        // config is embedded by the GetCertificate or GetConfigForClient caller,
        // for use with SupportsCertificate.
        config *Config
+
+       // ctx is the context of the handshake that is in progress.
+       ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *ClientHelloInfo) Context() context.Context {
+       return c.ctx
 }
 
 // CertificateRequestInfo contains information from a server's
@@ -461,6 +472,16 @@ type CertificateRequestInfo struct {
 
        // Version is the TLS version that was negotiated for this connection.
        Version uint16
+
+       // ctx is the context of the handshake that is in progress.
+       ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *CertificateRequestInfo) Context() context.Context {
+       return c.ctx
 }
 
 // RenegotiationSupport enumerates the different levels of support for TLS
index 72ad52c194baf49c988baf83340573ed0e54cd01..969f357834c0a5c1733b7b939248ebf07d514510 100644 (file)
@@ -8,6 +8,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto/cipher"
        "crypto/subtle"
        "crypto/x509"
@@ -27,7 +28,7 @@ type Conn struct {
        // constant
        conn        net.Conn
        isClient    bool
-       handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
+       handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
 
        // handshakeStatus is 1 if the connection is currently transferring
        // application data (i.e. is not currently processing a handshake).
@@ -1190,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error {
        defer c.handshakeMutex.Unlock()
 
        atomic.StoreUint32(&c.handshakeStatus, 0)
-       if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
+       if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
                c.handshakes++
        }
        return c.handshakeErr
@@ -1373,8 +1374,61 @@ func (c *Conn) closeNotify() error {
 // first Read or Write will call it automatically.
 //
 // For control over canceling or setting a timeout on a handshake, use
-// the Dialer's DialContext method.
+// HandshakeContext or the Dialer's DialContext method instead.
 func (c *Conn) Handshake() error {
+       return c.HandshakeContext(context.Background())
+}
+
+// HandshakeContext runs the client or server handshake
+// protocol if it has not yet been run.
+//
+// The provided Context must be non-nil. If the context is canceled before
+// the handshake is complete, the handshake is interrupted and an error is returned.
+// Once the handshake has completed, cancellation of the context will not affect the
+// connection.
+//
+// Most uses of this package need not call HandshakeContext explicitly: the
+// first Read or Write will call it automatically.
+func (c *Conn) HandshakeContext(ctx context.Context) error {
+       // Delegate to unexported method for named return
+       // without confusing documented signature.
+       return c.handshakeContext(ctx)
+}
+
+func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
+       handshakeCtx, cancel := context.WithCancel(ctx)
+       // Note: defer this before starting the "interrupter" goroutine
+       // so that we can tell the difference between the input being canceled and
+       // this cancellation. In the former case, we need to close the connection.
+       defer cancel()
+
+       // Start the "interrupter" goroutine, if this context might be canceled.
+       // (The background context cannot).
+       //
+       // The interrupter goroutine waits for the input context to be done and
+       // closes the connection if this happens before the function returns.
+       if ctx.Done() != nil {
+               done := make(chan struct{})
+               interruptRes := make(chan error, 1)
+               defer func() {
+                       close(done)
+                       if ctxErr := <-interruptRes; ctxErr != nil {
+                               // Return context error to user.
+                               ret = ctxErr
+                       }
+               }()
+               go func() {
+                       select {
+                       case <-handshakeCtx.Done():
+                               // Close the connection, discarding the error
+                               _ = c.conn.Close()
+                               interruptRes <- handshakeCtx.Err()
+                       case <-done:
+                               interruptRes <- nil
+                       }
+               }()
+       }
+
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
 
@@ -1388,7 +1442,7 @@ func (c *Conn) Handshake() error {
        c.in.Lock()
        defer c.in.Unlock()
 
-       c.handshakeErr = c.handshakeFn()
+       c.handshakeErr = c.handshakeFn(handshakeCtx)
        if c.handshakeErr == nil {
                c.handshakes++
        } else {
index e684b21d527223e850e8a74d802e9ac0521e9753..92e33e716903a0ea25e3f50ddc5e0d71319a0be7 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -24,6 +25,7 @@ import (
 
 type clientHandshakeState struct {
        c            *Conn
+       ctx          context.Context
        serverHello  *serverHelloMsg
        hello        *clientHelloMsg
        suite        *cipherSuite
@@ -134,7 +136,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
        return hello, params, nil
 }
 
-func (c *Conn) clientHandshake() (err error) {
+func (c *Conn) clientHandshake(ctx context.Context) (err error) {
        if c.config == nil {
                c.config = defaultConfig()
        }
@@ -198,6 +200,7 @@ func (c *Conn) clientHandshake() (err error) {
        if c.vers == VersionTLS13 {
                hs := &clientHandshakeStateTLS13{
                        c:           c,
+                       ctx:         ctx,
                        serverHello: serverHello,
                        hello:       hello,
                        ecdheParams: ecdheParams,
@@ -212,6 +215,7 @@ func (c *Conn) clientHandshake() (err error) {
 
        hs := &clientHandshakeState{
                c:           c,
+               ctx:         ctx,
                serverHello: serverHello,
                hello:       hello,
                session:     session,
@@ -540,7 +544,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                certRequested = true
                hs.finishedHash.Write(certReq.marshal())
 
-               cri := certificateRequestInfoFromMsg(c.vers, certReq)
+               cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
                if chainToSend, err = c.getClientCertificate(cri); err != nil {
                        c.sendAlert(alertInternalError)
                        return err
@@ -880,10 +884,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
 
 // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
 // <= 1.2 CertificateRequest, making an effort to fill in missing information.
-func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
+func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
        cri := &CertificateRequestInfo{
                AcceptableCAs: certReq.certificateAuthorities,
                Version:       vers,
+               ctx:           ctx,
        }
 
        var rsaAvail, ecAvail bool
index 693f9686a790ecd6199f97ad7beebf5d8a98d5bc..2f30b3008d87937b2faabd1152c51d0325243beb 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto/rsa"
        "crypto/x509"
        "encoding/base64"
@@ -20,6 +21,7 @@ import (
        "os/exec"
        "path/filepath"
        "reflect"
+       "runtime"
        "strconv"
        "strings"
        "testing"
@@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
                        serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
        }
 }
+
+func TestClientHandshakeContextCancellation(t *testing.T) {
+       c, s := localPipe(t)
+       serverConfig := testConfig.Clone()
+       serverErr := make(chan error, 1)
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       go func() {
+               defer close(serverErr)
+               defer s.Close()
+               conn := Server(s, serverConfig)
+               _, err := conn.readClientHello(ctx)
+               cancel()
+               serverErr <- err
+       }()
+       cli := Client(c, testConfig)
+       err := cli.HandshakeContext(ctx)
+       if err == nil {
+               t.Fatal("Client handshake did not error when the context was canceled")
+       }
+       if err != context.Canceled {
+               t.Errorf("Unexpected client handshake error: %v", err)
+       }
+       if err := <-serverErr; err != nil {
+               t.Errorf("Unexpected server error: %v", err)
+       }
+       if runtime.GOARCH == "wasm" {
+               t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+       }
+       err = cli.Close()
+       if err == nil {
+               t.Error("Client connection was not closed when the context was canceled")
+       }
+}
index daa5d97fd35548c4871a2f7f7ffa3b6e5da013a7..be37c681c6dee92a73e15945c9a7c3063a6dedf9 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/hmac"
        "crypto/rsa"
@@ -17,6 +18,7 @@ import (
 
 type clientHandshakeStateTLS13 struct {
        c           *Conn
+       ctx         context.Context
        serverHello *serverHelloMsg
        hello       *clientHelloMsg
        ecdheParams ecdheParameters
@@ -555,6 +557,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
                AcceptableCAs:    hs.certReq.certificateAuthorities,
                SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
                Version:          c.vers,
+               ctx:              hs.ctx,
        })
        if err != nil {
                return err
index 9c3e0f636ea7b548776ad75bf7ff112c65e7c571..5a572a9db10f62d5da87cbd5e66f99f8286518a0 100644 (file)
@@ -5,6 +5,7 @@
 package tls
 
 import (
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -23,6 +24,7 @@ import (
 // It's discarded once the handshake has completed.
 type serverHandshakeState struct {
        c            *Conn
+       ctx          context.Context
        clientHello  *clientHelloMsg
        hello        *serverHelloMsg
        suite        *cipherSuite
@@ -37,8 +39,8 @@ type serverHandshakeState struct {
 }
 
 // serverHandshake performs a TLS handshake as a server.
-func (c *Conn) serverHandshake() error {
-       clientHello, err := c.readClientHello()
+func (c *Conn) serverHandshake(ctx context.Context) error {
+       clientHello, err := c.readClientHello(ctx)
        if err != nil {
                return err
        }
@@ -46,6 +48,7 @@ func (c *Conn) serverHandshake() error {
        if c.vers == VersionTLS13 {
                hs := serverHandshakeStateTLS13{
                        c:           c,
+                       ctx:         ctx,
                        clientHello: clientHello,
                }
                return hs.handshake()
@@ -53,6 +56,7 @@ func (c *Conn) serverHandshake() error {
 
        hs := serverHandshakeState{
                c:           c,
+               ctx:         ctx,
                clientHello: clientHello,
        }
        return hs.handshake()
@@ -124,7 +128,7 @@ func (hs *serverHandshakeState) handshake() error {
 }
 
 // readClientHello reads a ClientHello message and selects the protocol version.
-func (c *Conn) readClientHello() (*clientHelloMsg, error) {
+func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
        msg, err := c.readHandshake()
        if err != nil {
                return nil, err
@@ -138,7 +142,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) {
        var configForClient *Config
        originalConfig := c.config
        if c.config.GetConfigForClient != nil {
-               chi := clientHelloInfo(c, clientHello)
+               chi := clientHelloInfo(ctx, c, clientHello)
                if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
                        c.sendAlert(alertInternalError)
                        return nil, err
@@ -220,7 +224,7 @@ func (hs *serverHandshakeState) processClientHello() error {
                }
        }
 
-       hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+       hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
        if err != nil {
                if err == errNoCertificates {
                        c.sendAlert(alertUnrecognizedName)
@@ -828,7 +832,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
        return nil
 }
 
-func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
+func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
        supportedVersions := clientHello.supportedVersions
        if len(clientHello.supportedVersions) == 0 {
                supportedVersions = supportedVersionsFromMax(clientHello.vers)
@@ -844,5 +848,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
                SupportedVersions: supportedVersions,
                Conn:              c.conn,
                config:            c.config,
+               ctx:               ctx,
        }
 }
index d6bf9e439b01c3e50b509085675ddf2a5953e60a..432b4cfe35010ae1ca10e638be83ab9d464164c1 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/elliptic"
        "crypto/x509"
@@ -17,6 +18,7 @@ import (
        "os"
        "os/exec"
        "path/filepath"
+       "runtime"
        "strings"
        "testing"
        "time"
@@ -38,10 +40,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
                cli.writeRecord(recordTypeHandshake, m.marshal())
                c.Close()
        }()
+       ctx := context.Background()
        conn := Server(s, serverConfig)
-       ch, err := conn.readClientHello()
+       ch, err := conn.readClientHello(ctx)
        hs := serverHandshakeState{
                c:           conn,
+               ctx:         ctx,
                clientHello: ch,
        }
        if err == nil {
@@ -1421,9 +1425,11 @@ func TestSNIGivenOnFailure(t *testing.T) {
                c.Close()
        }()
        conn := Server(s, serverConfig)
-       ch, err := conn.readClientHello()
+       ctx := context.Background()
+       ch, err := conn.readClientHello(ctx)
        hs := serverHandshakeState{
                c:           conn,
+               ctx:         ctx,
                clientHello: ch,
        }
        if err == nil {
@@ -1939,3 +1945,112 @@ func TestAESCipherReordering13(t *testing.T) {
                })
        }
 }
+
+func TestServerHandshakeContextCancellation(t *testing.T) {
+       c, s := localPipe(t)
+       clientConfig := testConfig.Clone()
+       clientErr := make(chan error, 1)
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       go func() {
+               defer close(clientErr)
+               defer c.Close()
+               clientHello := &clientHelloMsg{
+                       vers:               VersionTLS10,
+                       random:             make([]byte, 32),
+                       cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
+                       compressionMethods: []uint8{compressionNone},
+               }
+               cli := Client(c, clientConfig)
+               _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+               cancel()
+               clientErr <- err
+       }()
+       conn := Server(s, testConfig)
+       err := conn.HandshakeContext(ctx)
+       if err == nil {
+               t.Fatal("Server handshake did not error when the context was canceled")
+       }
+       if err != context.Canceled {
+               t.Errorf("Unexpected server handshake error: %v", err)
+       }
+       if err := <-clientErr; err != nil {
+               t.Errorf("Unexpected client error: %v", err)
+       }
+       if runtime.GOARCH == "wasm" {
+               t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+       }
+       err = conn.Close()
+       if err == nil {
+               t.Error("Server connection was not closed when the context was canceled")
+       }
+}
+
+// TestHandshakeContextHierarchy tests whether the contexts
+// available to GetClientCertificate and GetCertificate are
+// derived from the context provided to HandshakeContext, and
+// that those contexts are cancelled after HandshakeContext has
+// returned.
+func TestHandshakeContextHierarchy(t *testing.T) {
+       c, s := localPipe(t)
+       clientErr := make(chan error, 1)
+       clientConfig := testConfig.Clone()
+       serverConfig := testConfig.Clone()
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       key := struct{}{}
+       ctx = context.WithValue(ctx, key, true)
+       go func() {
+               defer close(clientErr)
+               defer c.Close()
+               var innerCtx context.Context
+               clientConfig.Certificates = nil
+               clientConfig.GetClientCertificate = func(certificateRequest *CertificateRequestInfo) (*Certificate, error) {
+                       if val, ok := certificateRequest.Context().Value(key).(bool); !ok || !val {
+                               t.Errorf("GetClientCertificate context was not child of HandshakeContext")
+                       }
+                       innerCtx = certificateRequest.Context()
+                       return &Certificate{
+                               Certificate: [][]byte{testRSACertificate},
+                               PrivateKey:  testRSAPrivateKey,
+                       }, nil
+               }
+               cli := Client(c, clientConfig)
+               err := cli.HandshakeContext(ctx)
+               if err != nil {
+                       clientErr <- err
+                       return
+               }
+               select {
+               case <-innerCtx.Done():
+               default:
+                       t.Errorf("GetClientCertificate context was not cancelled after HandshakeContext returned.")
+               }
+       }()
+       var innerCtx context.Context
+       serverConfig.Certificates = nil
+       serverConfig.ClientAuth = RequestClientCert
+       serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+               if val, ok := clientHello.Context().Value(key).(bool); !ok || !val {
+                       t.Errorf("GetClientCertificate context was not child of HandshakeContext")
+               }
+               innerCtx = clientHello.Context()
+               return &Certificate{
+                       Certificate: [][]byte{testRSACertificate},
+                       PrivateKey:  testRSAPrivateKey,
+               }, nil
+       }
+       conn := Server(s, serverConfig)
+       err := conn.HandshakeContext(ctx)
+       if err != nil {
+               t.Errorf("Unexpected server handshake error: %v", err)
+       }
+       select {
+       case <-innerCtx.Done():
+       default:
+               t.Errorf("GetCertificate context was not cancelled after HandshakeContext returned.")
+       }
+       if err := <-clientErr; err != nil {
+               t.Errorf("Unexpected client error: %v", err)
+       }
+}
index c2c288aed4325b280f435fcb26239d0d9f49780e..c7837d2955d5aa7cda93d4e3840fca075f2fd4ea 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/hmac"
        "crypto/rsa"
@@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5
 
 type serverHandshakeStateTLS13 struct {
        c               *Conn
+       ctx             context.Context
        clientHello     *clientHelloMsg
        hello           *serverHelloMsg
        sentDummyCCS    bool
@@ -374,7 +376,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
                return c.sendAlert(alertMissingExtension)
        }
 
-       certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+       certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
        if err != nil {
                if err == errNoCertificates {
                        c.sendAlert(alertUnrecognizedName)
index 4989364958ee3e18566ea8865955a0e788ebc450..b529c705237bbd8df526d6549583dda34b6232d0 100644 (file)
@@ -25,7 +25,6 @@ import (
        "net"
        "os"
        "strings"
-       "time"
 )
 
 // Server returns a new TLS server side connection
@@ -119,28 +118,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
 }
 
 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
-       // We want the Timeout and Deadline values from dialer to cover the
-       // whole process: TCP connection and TLS handshake. This means that we
-       // also need to start our own timers now.
-       timeout := netDialer.Timeout
-
-       if !netDialer.Deadline.IsZero() {
-               deadlineTimeout := time.Until(netDialer.Deadline)
-               if timeout == 0 || deadlineTimeout < timeout {
-                       timeout = deadlineTimeout
-               }
+       if netDialer.Timeout != 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
+               defer cancel()
        }
 
-       // hsErrCh is non-nil if we might not wait for Handshake to complete.
-       var hsErrCh chan error
-       if timeout != 0 || ctx.Done() != nil {
-               hsErrCh = make(chan error, 2)
-       }
-       if timeout != 0 {
-               timer := time.AfterFunc(timeout, func() {
-                       hsErrCh <- timeoutError{}
-               })
-               defer timer.Stop()
+       if !netDialer.Deadline.IsZero() {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
+               defer cancel()
        }
 
        rawConn, err := netDialer.DialContext(ctx, network, addr)
@@ -167,34 +154,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
        }
 
        conn := Client(rawConn, config)
-
-       if hsErrCh == nil {
-               err = conn.Handshake()
-       } else {
-               go func() {
-                       hsErrCh <- conn.Handshake()
-               }()
-
-               select {
-               case <-ctx.Done():
-                       err = ctx.Err()
-               case err = <-hsErrCh:
-                       if err != nil {
-                               // If the error was due to the context
-                               // closing, prefer the context's error, rather
-                               // than some random network teardown error.
-                               if e := ctx.Err(); e != nil {
-                                       err = e
-                               }
-                       }
-               }
-       }
-
-       if err != nil {
+       if err := conn.HandshakeContext(ctx); err != nil {
                rawConn.Close()
                return nil, err
        }
-
        return conn, nil
 }
 
index ea3486289a3ae5a9852034df3b7808f84a3087aa..f095b7edd270133641ebe143b2278ae56d5e2cd9 100644 (file)
@@ -1837,7 +1837,7 @@ func (c *conn) serve(ctx context.Context) {
                if d := c.server.WriteTimeout; d != 0 {
                        c.rwc.SetWriteDeadline(time.Now().Add(d))
                }
-               if err := tlsConn.Handshake(); err != nil {
+               if err := tlsConn.HandshakeContext(ctx); err != nil {
                        // If the handshake failed due to the client not speaking
                        // TLS, assume they're speaking plaintext HTTP and write a
                        // 400 response on the TLS conn's underlying net.Conn.
index 0aa48273dd385848bc04166027a5d4548494f47b..6358c3897ec4774c20e496b6f71ee3cd60e71942 100644 (file)
@@ -1505,7 +1505,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) {
 // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
 // tunnel, this function establishes a nested TLS session inside the encrypted channel.
 // The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
-func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error {
+func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error {
        // Initiate TLS and check remote host name against certificate.
        cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
        if cfg.ServerName == "" {
@@ -1527,7 +1527,7 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
                if trace != nil && trace.TLSHandshakeStart != nil {
                        trace.TLSHandshakeStart()
                }
-               err := tlsConn.Handshake()
+               err := tlsConn.HandshakeContext(ctx)
                if timer != nil {
                        timer.Stop()
                }
@@ -1583,7 +1583,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                        if trace != nil && trace.TLSHandshakeStart != nil {
                                trace.TLSHandshakeStart()
                        }
-                       if err := tc.Handshake(); err != nil {
+                       if err := tc.HandshakeContext(ctx); err != nil {
                                go pconn.conn.Close()
                                if trace != nil && trace.TLSHandshakeDone != nil {
                                        trace.TLSHandshakeDone(tls.ConnectionState{}, err)
@@ -1607,7 +1607,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                        if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
                                return nil, wrapErr(err)
                        }
-                       if err = pconn.addTLS(firstTLSHost, trace); err != nil {
+                       if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil {
                                return nil, wrapErr(err)
                        }
                }
@@ -1721,7 +1721,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
        }
 
        if cm.proxyURL != nil && cm.targetScheme == "https" {
-               if err := pconn.addTLS(cm.tlsHost(), trace); err != nil {
+               if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil {
                        return nil, err
                }
        }
index ba85a61683add300f309a9cb5a166bb1f65e5ba9..7f6e0938c20fa92276120d51fdc8189e1afd7cd2 100644 (file)
@@ -3734,7 +3734,7 @@ func TestTransportDialTLSContext(t *testing.T) {
                if err != nil {
                        return nil, err
                }
-               return c, c.Handshake()
+               return c, c.HandshakeContext(ctx)
        }
 
        req, err := NewRequest("GET", ts.URL, nil)