From 1f8b2a69ec871c1e4c33b6df4b2127bbafd67495 Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Fri, 28 Feb 2014 09:40:12 -0500 Subject: [PATCH] crypto/tls: add DialWithDialer. While reviewing uses of the lower-level Client API in code, I found that in many cases, code was using Client only because it needed a timeout on the connection. DialWithDialer allows a timeout (and other values) to be specified without resorting to the low-level API. LGTM=r R=golang-codereviews, r, bradfitz CC=golang-codereviews https://golang.org/cl/68920045 --- src/pkg/crypto/tls/tls.go | 81 +++++++++++++++++++++++++++------- src/pkg/crypto/tls/tls_test.go | 45 +++++++++++++++++++ 2 files changed, 111 insertions(+), 15 deletions(-) diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index 40156a0013..0b856c4e16 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -15,6 +15,7 @@ import ( "io/ioutil" "net" "strings" + "time" ) // Server returns a new TLS server side connection @@ -76,24 +77,51 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) { return NewListener(l, config), nil } -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config) (*Conn, error) { - raddr := addr - c, err := net.Dial(network, raddr) +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) if err != nil { return nil, err } - colonPos := strings.LastIndex(raddr, ":") + colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { - colonPos = len(raddr) + colonPos = len(addr) } - hostname := raddr[:colonPos] + hostname := addr[:colonPos] if config == nil { config = defaultConfig() @@ -106,14 +134,37 @@ func Dial(network, addr string, config *Config) (*Conn, error) { c.ServerName = hostname config = &c } - conn := Client(c, config) - if err = conn.Handshake(); err != nil { - c.Close() + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() return nil, err } + return conn, nil } +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} + // LoadX509KeyPair reads and parses a public/private key pair from a pair of // files. The files must contain PEM encoded data. func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) { diff --git a/src/pkg/crypto/tls/tls_test.go b/src/pkg/crypto/tls/tls_test.go index 38229014cd..5b12610d0a 100644 --- a/src/pkg/crypto/tls/tls_test.go +++ b/src/pkg/crypto/tls/tls_test.go @@ -5,7 +5,10 @@ package tls import ( + "net" + "strings" "testing" + "time" ) var rsaCertPEM = `-----BEGIN CERTIFICATE----- @@ -105,3 +108,45 @@ func TestX509MixedKeyPair(t *testing.T) { t.Error("Load of ECDSA certificate succeeded with RSA private key") } } + +func TestDialTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + listener, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + + addr := listener.Addr().String() + defer listener.Close() + + complete := make(chan bool) + defer close(complete) + + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error(err) + return + } + <-complete + conn.Close() + }() + + dialer := &net.Dialer{ + Timeout: 10 * time.Millisecond, + } + + if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil { + t.Fatal("DialWithTimeout completed successfully") + } + + if !strings.Contains(err.Error(), "timed out") { + t.Errorf("resulting error not a timeout: %s", err) + } +} -- 2.48.1