]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.TLSHandshakeTimeout; set it by default
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 25 Feb 2014 16:08:15 +0000 (08:08 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 25 Feb 2014 16:08:15 +0000 (08:08 -0800)
Update #3362

LGTM=agl
R=agl
CC=golang-codereviews
https://golang.org/cl/68150045

src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index cdad339a03c58fef84a81601d9ac6816a363c2b6..1a7b459fe1bbc18050dba434f567d3a4c080ee43 100644 (file)
@@ -36,6 +36,7 @@ var DefaultTransport RoundTripper = &Transport{
                Timeout:   30 * time.Second,
                KeepAlive: 30 * time.Second,
        }).Dial,
+       TLSHandshakeTimeout: 10 * time.Second,
 }
 
 // DefaultMaxIdleConnsPerHost is the default value of Transport's
@@ -69,6 +70,10 @@ type Transport struct {
        // tls.Client. If nil, the default configuration is used.
        TLSClientConfig *tls.Config
 
+       // TLSHandshakeTimeout specifies the maximum amount of time waiting to
+       // wait for a TLS handshake. Zero means no timeout.
+       TLSHandshakeTimeout time.Duration
+
        // DisableKeepAlives, if true, prevents re-use of TCP connections
        // between different HTTP requests.
        DisableKeepAlives bool
@@ -542,16 +547,33 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
                                cfg = &clone
                        }
                }
-               conn = tls.Client(conn, cfg)
-               if err = conn.(*tls.Conn).Handshake(); err != nil {
+               plainConn := conn
+               tlsConn := tls.Client(plainConn, cfg)
+               errc := make(chan error, 2)
+               var timer *time.Timer // for canceling TLS handshake
+               if d := t.TLSHandshakeTimeout; d != 0 {
+                       timer = time.AfterFunc(d, func() {
+                               errc <- tlsHandshakeTimeoutError{}
+                       })
+               }
+               go func() {
+                       err := tlsConn.Handshake()
+                       if timer != nil {
+                               timer.Stop()
+                       }
+                       errc <- err
+               }()
+               if err := <-errc; err != nil {
+                       plainConn.Close()
                        return nil, err
                }
                if !cfg.InsecureSkipVerify {
-                       if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil {
+                       if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+                               plainConn.Close()
                                return nil, err
                        }
                }
-               pconn.conn = conn
+               pconn.conn = tlsConn
        }
 
        pconn.br = bufio.NewReader(pconn.conn)
@@ -1084,3 +1106,9 @@ type readerAndCloser struct {
        io.Reader
        io.Closer
 }
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool   { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string   { return "net/http: TLS handshake timeout" }
index 2678d71b1dec19b0e9257ecadd69c61d3648c3ff..510679e53b24ce211cfd114f48650d97d0040e8d 100644 (file)
@@ -1722,6 +1722,73 @@ func TestTransportClosesRequestBody(t *testing.T) {
        }
 }
 
+func TestTransportTLSHandshakeTimeout(t *testing.T) {
+       defer afterTest(t)
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+       ln := newLocalListener(t)
+       defer ln.Close()
+       testdonec := make(chan struct{})
+       defer close(testdonec)
+
+       go func() {
+               c, err := ln.Accept()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               <-testdonec
+               c.Close()
+       }()
+
+       getdonec := make(chan struct{})
+       go func() {
+               defer close(getdonec)
+               tr := &Transport{
+                       Dial: func(_, _ string) (net.Conn, error) {
+                               return net.Dial("tcp", ln.Addr().String())
+                       },
+                       TLSHandshakeTimeout: 250 * time.Millisecond,
+               }
+               cl := &Client{Transport: tr}
+               _, err := cl.Get("https://dummy.tld/")
+               if err == nil {
+                       t.Fatal("expected error")
+               }
+               ue, ok := err.(*url.Error)
+               if !ok {
+                       t.Fatalf("expected url.Error; got %#v", err)
+               }
+               ne, ok := ue.Err.(net.Error)
+               if !ok {
+                       t.Fatalf("expected net.Error; got %#v", err)
+               }
+               if !ne.Timeout() {
+                       t.Error("expected timeout error; got %v", err)
+               }
+               if !strings.Contains(err.Error(), "handshake timeout") {
+                       t.Error("expected 'handshake timeout' in error; got %v", err)
+               }
+       }()
+       select {
+       case <-getdonec:
+       case <-time.After(5 * time.Second):
+               t.Error("test timeout; TLS handshake hung?")
+       }
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+       ln, err := net.Listen("tcp", "127.0.0.1:0")
+       if err != nil {
+               ln, err = net.Listen("tcp6", "[::1]:0")
+       }
+       if err != nil {
+               t.Fatal(err)
+       }
+       return ln
+}
+
 type countCloseReader struct {
        n *int
        io.Reader