Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
+ TLSHandshakeTimeout: 10 * time.Second,
}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
+ // TLSHandshakeTimeout specifies the maximum amount of time waiting to
+ // wait for a TLS handshake. Zero means no timeout.
+ TLSHandshakeTimeout time.Duration
+
// DisableKeepAlives, if true, prevents re-use of TCP connections
// between different HTTP requests.
DisableKeepAlives bool
cfg = &clone
}
}
- conn = tls.Client(conn, cfg)
- if err = conn.(*tls.Conn).Handshake(); err != nil {
+ plainConn := conn
+ tlsConn := tls.Client(plainConn, cfg)
+ errc := make(chan error, 2)
+ var timer *time.Timer // for canceling TLS handshake
+ if d := t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ err := tlsConn.Handshake()
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
return nil, err
}
if !cfg.InsecureSkipVerify {
- if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil {
+ if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+ plainConn.Close()
return nil, err
}
}
- pconn.conn = conn
+ pconn.conn = tlsConn
}
pconn.br = bufio.NewReader(pconn.conn)
io.Reader
io.Closer
}
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
}
}
+func TestTransportTLSHandshakeTimeout(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+ testdonec := make(chan struct{})
+ defer close(testdonec)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-testdonec
+ c.Close()
+ }()
+
+ getdonec := make(chan struct{})
+ go func() {
+ defer close(getdonec)
+ tr := &Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ln.Addr().String())
+ },
+ TLSHandshakeTimeout: 250 * time.Millisecond,
+ }
+ cl := &Client{Transport: tr}
+ _, err := cl.Get("https://dummy.tld/")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("expected url.Error; got %#v", err)
+ }
+ ne, ok := ue.Err.(net.Error)
+ if !ok {
+ t.Fatalf("expected net.Error; got %#v", err)
+ }
+ if !ne.Timeout() {
+ t.Error("expected timeout error; got %v", err)
+ }
+ if !strings.Contains(err.Error(), "handshake timeout") {
+ t.Error("expected 'handshake timeout' in error; got %v", err)
+ }
+ }()
+ select {
+ case <-getdonec:
+ case <-time.After(5 * time.Second):
+ t.Error("test timeout; TLS handshake hung?")
+ }
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
type countCloseReader struct {
n *int
io.Reader