From: Brad Fitzpatrick Date: Tue, 25 Feb 2014 16:08:15 +0000 (-0800) Subject: net/http: add Transport.TLSHandshakeTimeout; set it by default X-Git-Tag: go1.3beta1~580 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=fd4b4b56c4a1fd3426fc9ab4c36ec1b270089d29;p=gostls13.git net/http: add Transport.TLSHandshakeTimeout; set it by default Update #3362 LGTM=agl R=agl CC=golang-codereviews https://golang.org/cl/68150045 --- diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index cdad339a03..1a7b459fe1 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -36,6 +36,7 @@ var DefaultTransport RoundTripper = &Transport{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).Dial, + TLSHandshakeTimeout: 10 * time.Second, } // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -69,6 +70,10 @@ type Transport struct { // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config + // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + TLSHandshakeTimeout time.Duration + // DisableKeepAlives, if true, prevents re-use of TCP connections // between different HTTP requests. DisableKeepAlives bool @@ -542,16 +547,33 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { cfg = &clone } } - conn = tls.Client(conn, cfg) - if err = conn.(*tls.Conn).Handshake(); err != nil { + plainConn := conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() return nil, err } if !cfg.InsecureSkipVerify { - if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + plainConn.Close() return nil, err } } - pconn.conn = conn + pconn.conn = tlsConn } pconn.br = bufio.NewReader(pconn.conn) @@ -1084,3 +1106,9 @@ type readerAndCloser struct { io.Reader io.Closer } + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 2678d71b1d..510679e53b 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -1722,6 +1722,73 @@ func TestTransportClosesRequestBody(t *testing.T) { } } +func TestTransportTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping in short mode") + } + ln := newLocalListener(t) + defer ln.Close() + testdonec := make(chan struct{}) + defer close(testdonec) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + <-testdonec + c.Close() + }() + + getdonec := make(chan struct{}) + go func() { + defer close(getdonec) + tr := &Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Fatal("expected error") + } + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("expected url.Error; got %#v", err) + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Fatalf("expected net.Error; got %#v", err) + } + if !ne.Timeout() { + t.Error("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Error("expected 'handshake timeout' in error; got %v", err) + } + }() + select { + case <-getdonec: + case <-time.After(5 * time.Second): + t.Error("test timeout; TLS handshake hung?") + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + type countCloseReader struct { n *int io.Reader