From: Brad Fitzpatrick Date: Wed, 15 Jan 2020 19:27:32 +0000 (+0000) Subject: crypto/tls: add Dialer X-Git-Tag: go1.15beta1~486 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=40a144b94f14c3f3fbe06e097a236a5543ada57f;p=gostls13.git crypto/tls: add Dialer Fixes #18482 Change-Id: I99d65dc5d824c00093ea61e7445fc121314af87f Reviewed-on: https://go-review.googlesource.com/c/go/+/214977 Run-TryBot: Brad Fitzpatrick Reviewed-by: Ian Lance Taylor --- diff --git a/doc/go1.15.html b/doc/go1.15.html index 7da012f46c..bb5628cb19 100644 --- a/doc/go1.15.html +++ b/doc/go1.15.html @@ -133,6 +133,18 @@ TODO TODO

+
crypto/tls
+
+

+ The new + Dialer + type and its + DialContext + method permits using a context to both connect and handshake with a TLS server. +

+
+
+
flag

diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 6bda73e085..bf2111cb97 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -1334,8 +1334,12 @@ func (c *Conn) closeNotify() error { // Handshake runs the client or server handshake // protocol if it has not yet been run. -// Most uses of this package need not call Handshake -// explicitly: the first Read or Write will call it automatically. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// the Dialer's DialContext method. func (c *Conn) Handshake() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go index d98abdaea1..36d98d39eb 100644 --- a/src/crypto/tls/tls.go +++ b/src/crypto/tls/tls.go @@ -13,6 +13,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -111,29 +112,35 @@ func (timeoutError) Temporary() bool { return true } // DialWithDialer interprets a nil configuration as equivalent to the zero // configuration; see the documentation of Config for the defaults. func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + return dial(context.Background(), dialer, network, addr, config) +} + +func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { // We want the Timeout and Deadline values from dialer to cover the // whole process: TCP connection and TLS handshake. This means that we // also need to start our own timers now. - timeout := dialer.Timeout + timeout := netDialer.Timeout - if !dialer.Deadline.IsZero() { - deadlineTimeout := time.Until(dialer.Deadline) + if !netDialer.Deadline.IsZero() { + deadlineTimeout := time.Until(netDialer.Deadline) if timeout == 0 || deadlineTimeout < timeout { timeout = deadlineTimeout } } - var errChannel chan error - + // hsErrCh is non-nil if we might not wait for Handshake to complete. + var hsErrCh chan error + if timeout != 0 || ctx.Done() != nil { + hsErrCh = make(chan error, 2) + } if timeout != 0 { - errChannel = make(chan error, 2) timer := time.AfterFunc(timeout, func() { - errChannel <- timeoutError{} + hsErrCh <- timeoutError{} }) defer timer.Stop() } - rawConn, err := dialer.Dial(network, addr) + rawConn, err := netDialer.DialContext(ctx, network, addr) if err != nil { return nil, err } @@ -158,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (* conn := Client(rawConn, config) - if timeout == 0 { + if hsErrCh == nil { err = conn.Handshake() } else { go func() { - errChannel <- conn.Handshake() + hsErrCh <- conn.Handshake() }() - err = <-errChannel + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-hsErrCh: + if err != nil { + // If the error was due to the context + // closing, prefer the context's error, rather + // than some random network teardown error. + if e := ctx.Err(); e != nil { + err = e + } + } + } } if err != nil { @@ -186,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) { return DialWithDialer(new(net.Dialer), network, addr, config) } +// Dialer dials TLS connections given a configuration and a Dialer for the +// underlying connection. +type Dialer struct { + // NetDialer is the optional dialer to use for the TLS connections' + // underlying TCP connections. + // A nil NetDialer is equivalent to the net.Dialer zero value. + NetDialer *net.Dialer + + // Config is the TLS configuration to use for new connections. + // A nil configuration is equivalent to the zero + // configuration; see the documentation of Config for the + // defaults. + Config *Config +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func (d *Dialer) netDialer() *net.Dialer { + if d.NetDialer != nil { + return d.NetDialer + } + return new(net.Dialer) +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dial(ctx, d.netDialer(), network, addr, d.Config) + if err != nil { + // Don't return c (a typed nil) in an interface. + return nil, err + } + return c, nil +} + // LoadX509KeyPair reads and parses a public/private key pair from a pair // of files. The files must contain PEM encoded data. The certificate file // may contain intermediate certificates following the leaf certificate to diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 42fd5e1b8c..85005d4950 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/x509" "encoding/json" @@ -272,6 +273,47 @@ func TestDeadlineOnWrite(t *testing.T) { } } +type readerFunc func([]byte) (int, error) + +func (f readerFunc) Read(b []byte) (int, error) { return f(b) } + +// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake. +// (The other cases are all handled by the existing dial tests in this package, which +// all also flow through the same code shared code paths) +func TestDialer(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + unblockServer := make(chan struct{}) // close-only + defer close(unblockServer) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + <-unblockServer + }() + + ctx, cancel := context.WithCancel(context.Background()) + d := Dialer{Config: &Config{ + Rand: readerFunc(func(b []byte) (n int, err error) { + // By the time crypto/tls wants randomness, that means it has a TCP + // connection, so we're past the Dialer's dial and now blocked + // in a handshake. Cancel our context and see if we get unstuck. + // (Our TCP listener above never reads or writes, so the Handshake + // would otherwise be stuck forever) + cancel() + return len(b), nil + }), + ServerName: "foo", + }} + _, err := d.DialContext(ctx, "tcp", ln.Addr().String()) + if err != context.Canceled { + t.Errorf("err = %v; want context.Canceled", err) + } +} + func isTimeoutError(err error) bool { if ne, ok := err.(net.Error); ok { return ne.Timeout() diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index 24f79cabb3..fad165cf60 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -408,7 +408,7 @@ var pkgDeps = map[string][]string{ // SSL/TLS. "crypto/tls": { "L4", "CRYPTO-MATH", "OS", "golang.org/x/crypto/cryptobyte", "golang.org/x/crypto/hkdf", - "container/list", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519", + "container/list", "context", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519", }, "crypto/x509": { "L4", "CRYPTO-MATH", "OS", "CGO", "crypto/ed25519",