]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add HandshakeContext method to Conn
authorJohan Brandhorst <johan.brandhorst@gmail.com>
Sat, 1 Aug 2020 11:18:31 +0000 (12:18 +0100)
committerFilippo Valsorda <filippo@golang.org>
Mon, 9 Nov 2020 18:34:47 +0000 (18:34 +0000)
Adds the (*tls.Conn).HandshakeContext method. This allows
us to pass the context provided down the call stack to
eventually reach the tls.ClientHelloInfo and
tls.CertificateRequestInfo structs.
These contexts are exposed to the user as read-only via Context()
methods.

This allows users of (*tls.Config).GetCertificate and
(*tls.Config).GetClientCertificate to use the context for
request scoped parameters and cancellation.

Replace uses of (*tls.Conn).Handshake with (*tls.Conn).HandshakeContext
where appropriate, to propagate existing contexts.

Fixes #32406

Change-Id: I33c228904fe82dcf57683b63627497d3eb841ff2
Reviewed-on: https://go-review.googlesource.com/c/go/+/246338
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
13 files changed:
doc/go1.16.html
src/crypto/tls/common.go
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/tls.go
src/net/http/server.go
src/net/http/transport.go
src/net/http/transport_test.go

index 6c4d076d502f40101d73ae4b4e47d899838721a2..bb920a0cb8abd3e38a52442af1a7a0fc638d5978 100644 (file)
@@ -271,6 +271,21 @@ Do not send CLs removing the interior tags from such phrases.
   indefinitely.
 </p>
 
+<p><!-- CL 246338 -->
+  <a href="/pkg/crypto/tls#Conn.HandshakeContext">(*Conn).HandshakeContext</a> was added to
+  allow the user to control cancellation of an in-progress TLS Handshake.
+  The context provided is propagated into the
+  <a href="/pkg/crypto/tls#ClientHelloInfo">ClientHelloInfo</a>
+  and <a href="/pkg/crypto/tls#CertificateRequestInfo">CertificateRequestInfo</a>
+  structs and accessible through the new
+  <a href="/pkg/crypto/tls#ClientHelloInfo.Context">(*ClientHelloInfo).Context</a>
+  and
+  <a href="/pkg/crypto/tls#CertificateRequestInfo.Context">
+    (*CertificateRequestInfo).Context
+  </a> methods respectively. Canceling the context after the handshake has finished
+  has no effect.
+</p>
+
 <h3 id="crypto/x509"><a href="/pkg/crypto/x509">crypto/x509</a></h3>
 
 <p><!-- CL 235078 -->
@@ -405,6 +420,13 @@ Do not send CLs removing the interior tags from such phrases.
     Cookies set with <code>SameSiteDefaultMode</code> now behave according to the current
     spec (no attribute is set) instead of generating a SameSite key without a value.
     </p>
+
+    <p><!-- CL 246338 -->
+      The <a href="/pkg/net/http/"><code>net/http</code></a> package now uses the new
+      <a href="/pkg/crypto/tls#Conn.HandshakeContext"><code>(*tls.Conn).HandshakeContext</code></a>
+      with the <a href="/pkg/net/http/#Request"><code>Request</code></a> context
+      when performing TLS handshakes in the client or server.
+    </p>
   </dd>
 </dl><!-- net/http -->
 
index 86dc0dd3b2e8c0d370511f7d1c726c3b33538f49..1370d26fe2cdf2e097b18a029f0852e961033e19 100644 (file)
@@ -7,6 +7,7 @@ package tls
 import (
        "bytes"
        "container/list"
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -444,6 +445,16 @@ type ClientHelloInfo struct {
        // config is embedded by the GetCertificate or GetConfigForClient caller,
        // for use with SupportsCertificate.
        config *Config
+
+       // ctx is the context of the handshake that is in progress.
+       ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *ClientHelloInfo) Context() context.Context {
+       return c.ctx
 }
 
 // CertificateRequestInfo contains information from a server's
@@ -462,6 +473,16 @@ type CertificateRequestInfo struct {
 
        // Version is the TLS version that was negotiated for this connection.
        Version uint16
+
+       // ctx is the context of the handshake that is in progress.
+       ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *CertificateRequestInfo) Context() context.Context {
+       return c.ctx
 }
 
 // RenegotiationSupport enumerates the different levels of support for TLS
index b9a1095862a7754e32ce567262b29610594f1ed7..2f5d4303c251672ac4b6c7ebddc6714a54fe02cf 100644 (file)
@@ -8,6 +8,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto/cipher"
        "crypto/subtle"
        "crypto/x509"
@@ -26,7 +27,7 @@ type Conn struct {
        // constant
        conn        net.Conn
        isClient    bool
-       handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
+       handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
 
        // handshakeStatus is 1 if the connection is currently transferring
        // application data (i.e. is not currently processing a handshake).
@@ -1192,7 +1193,7 @@ func (c *Conn) handleRenegotiation() error {
        defer c.handshakeMutex.Unlock()
 
        atomic.StoreUint32(&c.handshakeStatus, 0)
-       if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
+       if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
                c.handshakes++
        }
        return c.handshakeErr
@@ -1375,8 +1376,61 @@ func (c *Conn) closeNotify() error {
 // first Read or Write will call it automatically.
 //
 // For control over canceling or setting a timeout on a handshake, use
-// the Dialer's DialContext method.
+// HandshakeContext or the Dialer's DialContext method instead.
 func (c *Conn) Handshake() error {
+       return c.HandshakeContext(context.Background())
+}
+
+// HandshakeContext runs the client or server handshake
+// protocol if it has not yet been run.
+//
+// The provided Context must be non-nil. If the context is canceled before
+// the handshake is complete, the handshake is interrupted and an error is returned.
+// Once the handshake has completed, cancellation of the context will not affect the
+// connection.
+//
+// Most uses of this package need not call HandshakeContext explicitly: the
+// first Read or Write will call it automatically.
+func (c *Conn) HandshakeContext(ctx context.Context) error {
+       // Delegate to unexported method for named return
+       // without confusing documented signature.
+       return c.handshakeContext(ctx)
+}
+
+func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
+       handshakeCtx, cancel := context.WithCancel(ctx)
+       // Note: defer this before starting the "interrupter" goroutine
+       // so that we can tell the difference between the input being canceled and
+       // this cancellation. In the former case, we need to close the connection.
+       defer cancel()
+
+       // Start the "interrupter" goroutine, if this context might be canceled.
+       // (The background context cannot).
+       //
+       // The interrupter goroutine waits for the input context to be done and
+       // closes the connection if this happens before the function returns.
+       if ctx.Done() != nil {
+               done := make(chan struct{})
+               interruptRes := make(chan error, 1)
+               defer func() {
+                       close(done)
+                       if ctxErr := <-interruptRes; ctxErr != nil {
+                               // Return context error to user.
+                               ret = ctxErr
+                       }
+               }()
+               go func() {
+                       select {
+                       case <-handshakeCtx.Done():
+                               // Close the connection, discarding the error
+                               _ = c.conn.Close()
+                               interruptRes <- handshakeCtx.Err()
+                       case <-done:
+                               interruptRes <- nil
+                       }
+               }()
+       }
+
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
 
@@ -1390,7 +1444,7 @@ func (c *Conn) Handshake() error {
        c.in.Lock()
        defer c.in.Unlock()
 
-       c.handshakeErr = c.handshakeFn()
+       c.handshakeErr = c.handshakeFn(handshakeCtx)
        if c.handshakeErr == nil {
                c.handshakes++
        } else {
index 46b0a770d5309436660038f77b3f6d08e32eaa60..d09a8c8ccfd90a7227df8ce7f5b65cb02f864c0a 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -23,6 +24,7 @@ import (
 
 type clientHandshakeState struct {
        c            *Conn
+       ctx          context.Context
        serverHello  *serverHelloMsg
        hello        *clientHelloMsg
        suite        *cipherSuite
@@ -133,7 +135,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
        return hello, params, nil
 }
 
-func (c *Conn) clientHandshake() (err error) {
+func (c *Conn) clientHandshake(ctx context.Context) (err error) {
        if c.config == nil {
                c.config = defaultConfig()
        }
@@ -197,6 +199,7 @@ func (c *Conn) clientHandshake() (err error) {
        if c.vers == VersionTLS13 {
                hs := &clientHandshakeStateTLS13{
                        c:           c,
+                       ctx:         ctx,
                        serverHello: serverHello,
                        hello:       hello,
                        ecdheParams: ecdheParams,
@@ -211,6 +214,7 @@ func (c *Conn) clientHandshake() (err error) {
 
        hs := &clientHandshakeState{
                c:           c,
+               ctx:         ctx,
                serverHello: serverHello,
                hello:       hello,
                session:     session,
@@ -539,7 +543,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                certRequested = true
                hs.finishedHash.Write(certReq.marshal())
 
-               cri := certificateRequestInfoFromMsg(c.vers, certReq)
+               cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
                if chainToSend, err = c.getClientCertificate(cri); err != nil {
                        c.sendAlert(alertInternalError)
                        return err
@@ -879,10 +883,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
 
 // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
 // <= 1.2 CertificateRequest, making an effort to fill in missing information.
-func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
+func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
        cri := &CertificateRequestInfo{
                AcceptableCAs: certReq.certificateAuthorities,
                Version:       vers,
+               ctx:           ctx,
        }
 
        var rsaAvail, ecAvail bool
index 12b0254123e938a71dd6b8c0592db9cf235bf255..8889e2c8c33d44304962df40e53e6330c0ad0460 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto/rsa"
        "crypto/x509"
        "encoding/base64"
@@ -20,6 +21,7 @@ import (
        "os/exec"
        "path/filepath"
        "reflect"
+       "runtime"
        "strconv"
        "strings"
        "testing"
@@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
                        serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
        }
 }
+
+func TestClientHandshakeContextCancellation(t *testing.T) {
+       c, s := localPipe(t)
+       serverConfig := testConfig.Clone()
+       serverErr := make(chan error, 1)
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       go func() {
+               defer close(serverErr)
+               defer s.Close()
+               conn := Server(s, serverConfig)
+               _, err := conn.readClientHello(ctx)
+               cancel()
+               serverErr <- err
+       }()
+       cli := Client(c, testConfig)
+       err := cli.HandshakeContext(ctx)
+       if err == nil {
+               t.Fatal("Client handshake did not error when the context was canceled")
+       }
+       if err != context.Canceled {
+               t.Errorf("Unexpected client handshake error: %v", err)
+       }
+       if err := <-serverErr; err != nil {
+               t.Errorf("Unexpected server error: %v", err)
+       }
+       if runtime.GOARCH == "wasm" {
+               t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+       }
+       err = cli.Close()
+       if err == nil {
+               t.Error("Client connection was not closed when the context was canceled")
+       }
+}
index 9c61105cf73d02b18b3c110c15777146d515677b..0e4b38003527cb29a3778f2854cfca97f5355ded 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/hmac"
        "crypto/rsa"
@@ -17,6 +18,7 @@ import (
 
 type clientHandshakeStateTLS13 struct {
        c           *Conn
+       ctx         context.Context
        serverHello *serverHelloMsg
        hello       *clientHelloMsg
        ecdheParams ecdheParameters
@@ -549,6 +551,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
                AcceptableCAs:    hs.certReq.certificateAuthorities,
                SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
                Version:          c.vers,
+               ctx:              hs.ctx,
        })
        if err != nil {
                return err
index 16d3e643f0b28ed45936142e6dacdb56385ee480..1fe026ae0e09682169d169aa99d949242d182a4b 100644 (file)
@@ -5,6 +5,7 @@
 package tls
 
 import (
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -22,6 +23,7 @@ import (
 // It's discarded once the handshake has completed.
 type serverHandshakeState struct {
        c            *Conn
+       ctx          context.Context
        clientHello  *clientHelloMsg
        hello        *serverHelloMsg
        suite        *cipherSuite
@@ -36,8 +38,8 @@ type serverHandshakeState struct {
 }
 
 // serverHandshake performs a TLS handshake as a server.
-func (c *Conn) serverHandshake() error {
-       clientHello, err := c.readClientHello()
+func (c *Conn) serverHandshake(ctx context.Context) error {
+       clientHello, err := c.readClientHello(ctx)
        if err != nil {
                return err
        }
@@ -45,6 +47,7 @@ func (c *Conn) serverHandshake() error {
        if c.vers == VersionTLS13 {
                hs := serverHandshakeStateTLS13{
                        c:           c,
+                       ctx:         ctx,
                        clientHello: clientHello,
                }
                return hs.handshake()
@@ -52,6 +55,7 @@ func (c *Conn) serverHandshake() error {
 
        hs := serverHandshakeState{
                c:           c,
+               ctx:         ctx,
                clientHello: clientHello,
        }
        return hs.handshake()
@@ -123,7 +127,7 @@ func (hs *serverHandshakeState) handshake() error {
 }
 
 // readClientHello reads a ClientHello message and selects the protocol version.
-func (c *Conn) readClientHello() (*clientHelloMsg, error) {
+func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
        msg, err := c.readHandshake()
        if err != nil {
                return nil, err
@@ -137,7 +141,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) {
        var configForClient *Config
        originalConfig := c.config
        if c.config.GetConfigForClient != nil {
-               chi := clientHelloInfo(c, clientHello)
+               chi := clientHelloInfo(ctx, c, clientHello)
                if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
                        c.sendAlert(alertInternalError)
                        return nil, err
@@ -219,7 +223,7 @@ func (hs *serverHandshakeState) processClientHello() error {
                }
        }
 
-       hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+       hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
        if err != nil {
                if err == errNoCertificates {
                        c.sendAlert(alertUnrecognizedName)
@@ -813,7 +817,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
        return nil
 }
 
-func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
+func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
        supportedVersions := clientHello.supportedVersions
        if len(clientHello.supportedVersions) == 0 {
                supportedVersions = supportedVersionsFromMax(clientHello.vers)
@@ -829,5 +833,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
                SupportedVersions: supportedVersions,
                Conn:              c.conn,
                config:            c.config,
+               ctx:               ctx,
        }
 }
index a7a532431296f32714a96239f3c8167eeadabc21..c4416c379a4baeecd090fa001db8bb388752b49b 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/elliptic"
        "crypto/x509"
@@ -17,6 +18,7 @@ import (
        "os"
        "os/exec"
        "path/filepath"
+       "runtime"
        "strings"
        "testing"
        "time"
@@ -36,10 +38,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
                cli.writeRecord(recordTypeHandshake, m.marshal())
                c.Close()
        }()
+       ctx := context.Background()
        conn := Server(s, serverConfig)
-       ch, err := conn.readClientHello()
+       ch, err := conn.readClientHello(ctx)
        hs := serverHandshakeState{
                c:           conn,
+               ctx:         ctx,
                clientHello: ch,
        }
        if err == nil {
@@ -1418,9 +1422,11 @@ func TestSNIGivenOnFailure(t *testing.T) {
                c.Close()
        }()
        conn := Server(s, serverConfig)
-       ch, err := conn.readClientHello()
+       ctx := context.Background()
+       ch, err := conn.readClientHello(ctx)
        hs := serverHandshakeState{
                c:           conn,
+               ctx:         ctx,
                clientHello: ch,
        }
        if err == nil {
@@ -1673,3 +1679,43 @@ func TestMultipleCertificates(t *testing.T) {
                t.Errorf("expected RSA certificate, got %v", got)
        }
 }
+
+func TestServerHandshakeContextCancellation(t *testing.T) {
+       c, s := localPipe(t)
+       clientConfig := testConfig.Clone()
+       clientErr := make(chan error, 1)
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       go func() {
+               defer close(clientErr)
+               defer c.Close()
+               clientHello := &clientHelloMsg{
+                       vers:               VersionTLS10,
+                       random:             make([]byte, 32),
+                       cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
+                       compressionMethods: []uint8{compressionNone},
+               }
+               cli := Client(c, clientConfig)
+               _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+               cancel()
+               clientErr <- err
+       }()
+       conn := Server(s, testConfig)
+       err := conn.HandshakeContext(ctx)
+       if err == nil {
+               t.Fatal("Server handshake did not error when the context was canceled")
+       }
+       if err != context.Canceled {
+               t.Errorf("Unexpected server handshake error: %v", err)
+       }
+       if err := <-clientErr; err != nil {
+               t.Errorf("Unexpected client error: %v", err)
+       }
+       if runtime.GOARCH == "wasm" {
+               t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+       }
+       err = conn.Close()
+       if err == nil {
+               t.Error("Server connection was not closed when the context was canceled")
+       }
+}
index 92d55e0293a46e9d362fc7ec34db9fd990fe3932..25c37b92c54674cb3a293c561dcd3424f3e9db9b 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/hmac"
        "crypto/rsa"
@@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5
 
 type serverHandshakeStateTLS13 struct {
        c               *Conn
+       ctx             context.Context
        clientHello     *clientHelloMsg
        hello           *serverHelloMsg
        sentDummyCCS    bool
@@ -361,7 +363,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
                return c.sendAlert(alertMissingExtension)
        }
 
-       certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+       certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
        if err != nil {
                if err == errNoCertificates {
                        c.sendAlert(alertUnrecognizedName)
index 454aa0bbbc0b62746645ce246b9d7050a4a5332c..bf577cadeaad4eb2d4c683ee8e42352a384f19ca 100644 (file)
@@ -25,7 +25,6 @@ import (
        "io/ioutil"
        "net"
        "strings"
-       "time"
 )
 
 // Server returns a new TLS server side connection
@@ -116,28 +115,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
 }
 
 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
-       // We want the Timeout and Deadline values from dialer to cover the
-       // whole process: TCP connection and TLS handshake. This means that we
-       // also need to start our own timers now.
-       timeout := netDialer.Timeout
-
-       if !netDialer.Deadline.IsZero() {
-               deadlineTimeout := time.Until(netDialer.Deadline)
-               if timeout == 0 || deadlineTimeout < timeout {
-                       timeout = deadlineTimeout
-               }
+       if netDialer.Timeout != 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
+               defer cancel()
        }
 
-       // hsErrCh is non-nil if we might not wait for Handshake to complete.
-       var hsErrCh chan error
-       if timeout != 0 || ctx.Done() != nil {
-               hsErrCh = make(chan error, 2)
-       }
-       if timeout != 0 {
-               timer := time.AfterFunc(timeout, func() {
-                       hsErrCh <- timeoutError{}
-               })
-               defer timer.Stop()
+       if !netDialer.Deadline.IsZero() {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
+               defer cancel()
        }
 
        rawConn, err := netDialer.DialContext(ctx, network, addr)
@@ -164,34 +151,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
        }
 
        conn := Client(rawConn, config)
-
-       if hsErrCh == nil {
-               err = conn.Handshake()
-       } else {
-               go func() {
-                       hsErrCh <- conn.Handshake()
-               }()
-
-               select {
-               case <-ctx.Done():
-                       err = ctx.Err()
-               case err = <-hsErrCh:
-                       if err != nil {
-                               // If the error was due to the context
-                               // closing, prefer the context's error, rather
-                               // than some random network teardown error.
-                               if e := ctx.Err(); e != nil {
-                                       err = e
-                               }
-                       }
-               }
-       }
-
-       if err != nil {
+       if err := conn.HandshakeContext(ctx); err != nil {
                rawConn.Close()
                return nil, err
        }
-
        return conn, nil
 }
 
index 4776d960e575883e95e1956c134f4be2c4834b05..6c7d2817051fa94621ec10598323deb8a2785ddc 100644 (file)
@@ -1831,7 +1831,7 @@ func (c *conn) serve(ctx context.Context) {
                if d := c.server.WriteTimeout; d != 0 {
                        c.rwc.SetWriteDeadline(time.Now().Add(d))
                }
-               if err := tlsConn.Handshake(); err != nil {
+               if err := tlsConn.HandshakeContext(ctx); err != nil {
                        // If the handshake failed due to the client not speaking
                        // TLS, assume they're speaking plaintext HTTP and write a
                        // 400 response on the TLS conn's underlying net.Conn.
index 29d7434f2a88997fae16ea3f9fc291f595f54aef..65ba6644154835b4cf35748d846af0b49ce1f5ae 100644 (file)
@@ -1502,7 +1502,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) {
 // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
 // tunnel, this function establishes a nested TLS session inside the encrypted channel.
 // The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
-func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error {
+func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error {
        // Initiate TLS and check remote host name against certificate.
        cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
        if cfg.ServerName == "" {
@@ -1524,7 +1524,7 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
                if trace != nil && trace.TLSHandshakeStart != nil {
                        trace.TLSHandshakeStart()
                }
-               err := tlsConn.Handshake()
+               err := tlsConn.HandshakeContext(ctx)
                if timer != nil {
                        timer.Stop()
                }
@@ -1580,7 +1580,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                        if trace != nil && trace.TLSHandshakeStart != nil {
                                trace.TLSHandshakeStart()
                        }
-                       if err := tc.Handshake(); err != nil {
+                       if err := tc.HandshakeContext(ctx); err != nil {
                                go pconn.conn.Close()
                                if trace != nil && trace.TLSHandshakeDone != nil {
                                        trace.TLSHandshakeDone(tls.ConnectionState{}, err)
@@ -1604,7 +1604,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                        if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
                                return nil, wrapErr(err)
                        }
-                       if err = pconn.addTLS(firstTLSHost, trace); err != nil {
+                       if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil {
                                return nil, wrapErr(err)
                        }
                }
@@ -1718,7 +1718,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
        }
 
        if cm.proxyURL != nil && cm.targetScheme == "https" {
-               if err := pconn.addTLS(cm.tlsHost(), trace); err != nil {
+               if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil {
                        return nil, err
                }
        }
index e69133e7868f4ddcd578eaf115d9556aa8bd3ee8..9086507d5766477dab7e6d5522c51853b00b808c 100644 (file)
@@ -3735,7 +3735,7 @@ func TestTransportDialTLSContext(t *testing.T) {
                if err != nil {
                        return nil, err
                }
-               return c, c.Handshake()
+               return c, c.HandshakeContext(ctx)
        }
 
        req, err := NewRequest("GET", ts.URL, nil)