]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add DialWithDialer.
authorAdam Langley <agl@golang.org>
Fri, 28 Feb 2014 14:40:12 +0000 (09:40 -0500)
committerAdam Langley <agl@golang.org>
Fri, 28 Feb 2014 14:40:12 +0000 (09:40 -0500)
While reviewing uses of the lower-level Client API in code, I found
that in many cases, code was using Client only because it needed a
timeout on the connection. DialWithDialer allows a timeout (and
 other values) to be specified without resorting to the low-level API.

LGTM=r
R=golang-codereviews, r, bradfitz
CC=golang-codereviews
https://golang.org/cl/68920045

src/pkg/crypto/tls/tls.go
src/pkg/crypto/tls/tls_test.go

index 40156a0013bc9fb982d2cd143d57c2759232f6e7..0b856c4e16befceeadb2e0c17293e5555f387fcf 100644 (file)
@@ -15,6 +15,7 @@ import (
        "io/ioutil"
        "net"
        "strings"
+       "time"
 )
 
 // Server returns a new TLS server side connection
@@ -76,24 +77,51 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
        return NewListener(l, config), nil
 }
 
-// Dial connects to the given network address using net.Dial
-// and then initiates a TLS handshake, returning the resulting
-// TLS connection.
-// Dial interprets a nil configuration as equivalent to
-// the zero configuration; see the documentation of Config
-// for the defaults.
-func Dial(network, addr string, config *Config) (*Conn, error) {
-       raddr := addr
-       c, err := net.Dial(network, raddr)
+type timeoutError struct{}
+
+func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
+func (timeoutError) Timeout() bool   { return true }
+func (timeoutError) Temporary() bool { return true }
+
+// DialWithDialer connects to the given network address using dialer.Dial and
+// then initiates a TLS handshake, returning the resulting TLS connection. Any
+// timeout or deadline given in the dialer apply to connection and TLS
+// handshake as a whole.
+//
+// DialWithDialer interprets a nil configuration as equivalent to the zero
+// configuration; see the documentation of Config for the defaults.
+func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
+       // We want the Timeout and Deadline values from dialer to cover the
+       // whole process: TCP connection and TLS handshake. This means that we
+       // also need to start our own timers now.
+       timeout := dialer.Timeout
+
+       if !dialer.Deadline.IsZero() {
+               deadlineTimeout := dialer.Deadline.Sub(time.Now())
+               if timeout == 0 || deadlineTimeout < timeout {
+                       timeout = deadlineTimeout
+               }
+       }
+
+       var errChannel chan error
+
+       if timeout != 0 {
+               errChannel = make(chan error, 2)
+               time.AfterFunc(timeout, func() {
+                       errChannel <- timeoutError{}
+               })
+       }
+
+       rawConn, err := dialer.Dial(network, addr)
        if err != nil {
                return nil, err
        }
 
-       colonPos := strings.LastIndex(raddr, ":")
+       colonPos := strings.LastIndex(addr, ":")
        if colonPos == -1 {
-               colonPos = len(raddr)
+               colonPos = len(addr)
        }
-       hostname := raddr[:colonPos]
+       hostname := addr[:colonPos]
 
        if config == nil {
                config = defaultConfig()
@@ -106,14 +134,37 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
                c.ServerName = hostname
                config = &c
        }
-       conn := Client(c, config)
-       if err = conn.Handshake(); err != nil {
-               c.Close()
+
+       conn := Client(rawConn, config)
+
+       if timeout == 0 {
+               err = conn.Handshake()
+       } else {
+               go func() {
+                       errChannel <- conn.Handshake()
+               }()
+
+               err = <-errChannel
+       }
+
+       if err != nil {
+               rawConn.Close()
                return nil, err
        }
+
        return conn, nil
 }
 
+// Dial connects to the given network address using net.Dial
+// and then initiates a TLS handshake, returning the resulting
+// TLS connection.
+// Dial interprets a nil configuration as equivalent to
+// the zero configuration; see the documentation of Config
+// for the defaults.
+func Dial(network, addr string, config *Config) (*Conn, error) {
+       return DialWithDialer(new(net.Dialer), network, addr, config)
+}
+
 // LoadX509KeyPair reads and parses a public/private key pair from a pair of
 // files. The files must contain PEM encoded data.
 func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
index 38229014cd662bba804577ed638acd139274999b..5b12610d0aa23dc140dc574c255f4e5344b75772 100644 (file)
@@ -5,7 +5,10 @@
 package tls
 
 import (
+       "net"
+       "strings"
        "testing"
+       "time"
 )
 
 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
@@ -105,3 +108,45 @@ func TestX509MixedKeyPair(t *testing.T) {
                t.Error("Load of ECDSA certificate succeeded with RSA private key")
        }
 }
+
+func TestDialTimeout(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+
+       listener, err := net.Listen("tcp", "127.0.0.1:0")
+       if err != nil {
+               listener, err = net.Listen("tcp6", "[::1]:0")
+       }
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       addr := listener.Addr().String()
+       defer listener.Close()
+
+       complete := make(chan bool)
+       defer close(complete)
+
+       go func() {
+               conn, err := listener.Accept()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               <-complete
+               conn.Close()
+       }()
+
+       dialer := &net.Dialer{
+               Timeout: 10 * time.Millisecond,
+       }
+
+       if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
+               t.Fatal("DialWithTimeout completed successfully")
+       }
+
+       if !strings.Contains(err.Error(), "timed out") {
+               t.Errorf("resulting error not a timeout: %s", err)
+       }
+}