From: Brad Fitzpatrick Date: Mon, 24 Feb 2014 21:14:48 +0000 (-0800) Subject: net: add Dialer.KeepAlive option X-Git-Tag: go1.3beta1~595 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=fdfbb406d12a36b09425ecdc76f32345c1af8ffd;p=gostls13.git net: add Dialer.KeepAlive option LGTM=rsc R=rsc CC=golang-codereviews https://golang.org/cl/68380043 --- diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 70b66e70d1..93569c253c 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -44,6 +44,12 @@ type Dialer struct { // destination is a host name that has multiple address family // DNS records. DualStack bool + + // KeepAlive specifies the keep-alive period for an active + // network connection. + // If zero, keep-alives are not enabled. Network protocols + // that do not support keep-alives ignore this field. + KeepAlive time.Duration } // Return either now+Timeout or Deadline, whichever comes first. @@ -162,9 +168,19 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { return dialMulti(network, address, d.LocalAddr, ras, deadline) } } - return dial(network, ra.toAddr(), dialer, d.deadline()) + c, err := dial(network, ra.toAddr(), dialer, d.deadline()) + if d.KeepAlive > 0 && err == nil { + if tc, ok := c.(*TCPConn); ok { + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(d.KeepAlive) + testHookSetKeepAlive() + } + } + return c, err } +var testHookSetKeepAlive = func() {} // changed by dial_test.go + // dialMulti attempts to establish connections to each destination of // the list of addresses. It will return the first established // connection and close the other connections. Otherwise it returns diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go index bd89780e8a..15ab10dfd4 100644 --- a/src/pkg/net/dial_test.go +++ b/src/pkg/net/dial_test.go @@ -555,3 +555,36 @@ func TestDialDualStackLocalhost(t *testing.T) { } } } + +func TestDialerKeepAlive(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + defer func() { + testHookSetKeepAlive = func() {} + }() + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Close() + } + }() + for _, keepAlive := range []bool{false, true} { + got := false + testHookSetKeepAlive = func() { got = true } + var d Dialer + if keepAlive { + d.KeepAlive = 30 * time.Second + } + c, err := d.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + if got != keepAlive { + t.Errorf("Dialer.KeepAlive = %v: SetKeepAlive called = %v, want %v", d.KeepAlive, got, !got) + } + } +}