]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add Dialer
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 15 Jan 2020 19:27:32 +0000 (19:27 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 20 Apr 2020 20:33:36 +0000 (20:33 +0000)
Fixes #18482

Change-Id: I99d65dc5d824c00093ea61e7445fc121314af87f
Reviewed-on: https://go-review.googlesource.com/c/go/+/214977
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
doc/go1.15.html
src/crypto/tls/conn.go
src/crypto/tls/tls.go
src/crypto/tls/tls_test.go
src/go/build/deps_test.go

index 7da012f46c5aa97c22bbde9c699b2a526cfdfcdb..bb5628cb19a10dd1d2fbed4d6459f18a87c54e95 100644 (file)
@@ -133,6 +133,18 @@ TODO
 TODO
 </p>
 
+<dl id="crypto/tls"><dt><a href="/crypto/tls/">crypto/tls</a></dt>
+  <dd>
+    <p><!-- CL 214977 -->
+      The new
+      <a href="/pkg/crypto/tls/#Dialer"><code>Dialer</code></a>
+      type and its
+      <a href="/pkg/crypto/tls/#Dialer.DialContext"><code>DialContext</code></a>
+      method permits using a context to both connect and handshake with a TLS server.
+    </p>
+  </dd>
+</dl>
+
 <dl id="flag"><dt><a href="/pkg/flag/">flag</a></dt>
   <dd>
     <p><!-- CL 221427 -->
index 6bda73e0857f8df013212f5dfe380e415cc50c99..bf2111cb97437ec0e69df18a6feb806ebcb2728b 100644 (file)
@@ -1334,8 +1334,12 @@ func (c *Conn) closeNotify() error {
 
 // Handshake runs the client or server handshake
 // protocol if it has not yet been run.
-// Most uses of this package need not call Handshake
-// explicitly: the first Read or Write will call it automatically.
+//
+// Most uses of this package need not call Handshake explicitly: the
+// first Read or Write will call it automatically.
+//
+// For control over canceling or setting a timeout on a handshake, use
+// the Dialer's DialContext method.
 func (c *Conn) Handshake() error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
index d98abdaea15a959df5e1e63db7f238365b5cb572..36d98d39eb60b986a7873554a7180254933be7ab 100644 (file)
@@ -13,6 +13,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/ecdsa"
        "crypto/ed25519"
@@ -111,29 +112,35 @@ func (timeoutError) Temporary() bool { return true }
 // DialWithDialer interprets a nil configuration as equivalent to the zero
 // configuration; see the documentation of Config for the defaults.
 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
+       return dial(context.Background(), dialer, network, addr, config)
+}
+
+func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
        // We want the Timeout and Deadline values from dialer to cover the
        // whole process: TCP connection and TLS handshake. This means that we
        // also need to start our own timers now.
-       timeout := dialer.Timeout
+       timeout := netDialer.Timeout
 
-       if !dialer.Deadline.IsZero() {
-               deadlineTimeout := time.Until(dialer.Deadline)
+       if !netDialer.Deadline.IsZero() {
+               deadlineTimeout := time.Until(netDialer.Deadline)
                if timeout == 0 || deadlineTimeout < timeout {
                        timeout = deadlineTimeout
                }
        }
 
-       var errChannel chan error
-
+       // hsErrCh is non-nil if we might not wait for Handshake to complete.
+       var hsErrCh chan error
+       if timeout != 0 || ctx.Done() != nil {
+               hsErrCh = make(chan error, 2)
+       }
        if timeout != 0 {
-               errChannel = make(chan error, 2)
                timer := time.AfterFunc(timeout, func() {
-                       errChannel <- timeoutError{}
+                       hsErrCh <- timeoutError{}
                })
                defer timer.Stop()
        }
 
-       rawConn, err := dialer.Dial(network, addr)
+       rawConn, err := netDialer.DialContext(ctx, network, addr)
        if err != nil {
                return nil, err
        }
@@ -158,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
 
        conn := Client(rawConn, config)
 
-       if timeout == 0 {
+       if hsErrCh == nil {
                err = conn.Handshake()
        } else {
                go func() {
-                       errChannel <- conn.Handshake()
+                       hsErrCh <- conn.Handshake()
                }()
 
-               err = <-errChannel
+               select {
+               case <-ctx.Done():
+                       err = ctx.Err()
+               case err = <-hsErrCh:
+                       if err != nil {
+                               // If the error was due to the context
+                               // closing, prefer the context's error, rather
+                               // than some random network teardown error.
+                               if e := ctx.Err(); e != nil {
+                                       err = e
+                               }
+                       }
+               }
        }
 
        if err != nil {
@@ -186,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
        return DialWithDialer(new(net.Dialer), network, addr, config)
 }
 
+// Dialer dials TLS connections given a configuration and a Dialer for the
+// underlying connection.
+type Dialer struct {
+       // NetDialer is the optional dialer to use for the TLS connections'
+       // underlying TCP connections.
+       // A nil NetDialer is equivalent to the net.Dialer zero value.
+       NetDialer *net.Dialer
+
+       // Config is the TLS configuration to use for new connections.
+       // A nil configuration is equivalent to the zero
+       // configuration; see the documentation of Config for the
+       // defaults.
+       Config *Config
+}
+
+// Dial connects to the given network address and initiates a TLS
+// handshake, returning the resulting TLS connection.
+//
+// The returned Conn, if any, will always be of type *Conn.
+func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
+       return d.DialContext(context.Background(), network, addr)
+}
+
+func (d *Dialer) netDialer() *net.Dialer {
+       if d.NetDialer != nil {
+               return d.NetDialer
+       }
+       return new(net.Dialer)
+}
+
+// Dial connects to the given network address and initiates a TLS
+// handshake, returning the resulting TLS connection.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The returned Conn, if any, will always be of type *Conn.
+func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+       c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
+       if err != nil {
+               // Don't return c (a typed nil) in an interface.
+               return nil, err
+       }
+       return c, nil
+}
+
 // LoadX509KeyPair reads and parses a public/private key pair from a pair
 // of files. The files must contain PEM encoded data. The certificate file
 // may contain intermediate certificates following the leaf certificate to
index 42fd5e1b8c6ad783172c7d8f5a6e2d649ee2620d..85005d4950ae17e2687c228572df1d3ae1783abf 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bytes"
+       "context"
        "crypto"
        "crypto/x509"
        "encoding/json"
@@ -272,6 +273,47 @@ func TestDeadlineOnWrite(t *testing.T) {
        }
 }
 
+type readerFunc func([]byte) (int, error)
+
+func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
+
+// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
+// (The other cases are all handled by the existing dial tests in this package, which
+// all also flow through the same code shared code paths)
+func TestDialer(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       unblockServer := make(chan struct{}) // close-only
+       defer close(unblockServer)
+       go func() {
+               conn, err := ln.Accept()
+               if err != nil {
+                       return
+               }
+               defer conn.Close()
+               <-unblockServer
+       }()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       d := Dialer{Config: &Config{
+               Rand: readerFunc(func(b []byte) (n int, err error) {
+                       // By the time crypto/tls wants randomness, that means it has a TCP
+                       // connection, so we're past the Dialer's dial and now blocked
+                       // in a handshake. Cancel our context and see if we get unstuck.
+                       // (Our TCP listener above never reads or writes, so the Handshake
+                       // would otherwise be stuck forever)
+                       cancel()
+                       return len(b), nil
+               }),
+               ServerName: "foo",
+       }}
+       _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+       if err != context.Canceled {
+               t.Errorf("err = %v; want context.Canceled", err)
+       }
+}
+
 func isTimeoutError(err error) bool {
        if ne, ok := err.(net.Error); ok {
                return ne.Timeout()
index 24f79cabb303d1eef47c594d2e405c500cae1596..fad165cf603f5186a0b02bfc6328dad17d8b2903 100644 (file)
@@ -408,7 +408,7 @@ var pkgDeps = map[string][]string{
        // SSL/TLS.
        "crypto/tls": {
                "L4", "CRYPTO-MATH", "OS", "golang.org/x/crypto/cryptobyte", "golang.org/x/crypto/hkdf",
-               "container/list", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
+               "container/list", "context", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
        },
        "crypto/x509": {
                "L4", "CRYPTO-MATH", "OS", "CGO", "crypto/ed25519",