]> Cypherpunks repositories - gostls13.git/commitdiff
net: add Dialer.KeepAlive option
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 24 Feb 2014 21:14:48 +0000 (13:14 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 24 Feb 2014 21:14:48 +0000 (13:14 -0800)
LGTM=rsc
R=rsc
CC=golang-codereviews
https://golang.org/cl/68380043

src/pkg/net/dial.go
src/pkg/net/dial_test.go

index 70b66e70d1564f4ee2adb70c3df9a2c709a0d922..93569c253cdafefbd8f567a7c2ec5e58c6077e4a 100644 (file)
@@ -44,6 +44,12 @@ type Dialer struct {
        // destination is a host name that has multiple address family
        // DNS records.
        DualStack bool
+
+       // KeepAlive specifies the keep-alive period for an active
+       // network connection.
+       // If zero, keep-alives are not enabled. Network protocols
+       // that do not support keep-alives ignore this field.
+       KeepAlive time.Duration
 }
 
 // Return either now+Timeout or Deadline, whichever comes first.
@@ -162,9 +168,19 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
                        return dialMulti(network, address, d.LocalAddr, ras, deadline)
                }
        }
-       return dial(network, ra.toAddr(), dialer, d.deadline())
+       c, err := dial(network, ra.toAddr(), dialer, d.deadline())
+       if d.KeepAlive > 0 && err == nil {
+               if tc, ok := c.(*TCPConn); ok {
+                       tc.SetKeepAlive(true)
+                       tc.SetKeepAlivePeriod(d.KeepAlive)
+                       testHookSetKeepAlive()
+               }
+       }
+       return c, err
 }
 
+var testHookSetKeepAlive = func() {} // changed by dial_test.go
+
 // dialMulti attempts to establish connections to each destination of
 // the list of addresses. It will return the first established
 // connection and close the other connections. Otherwise it returns
index bd89780e8ae93af2d501b4f74cf37b0955ac1c29..15ab10dfd4599a1a05ea8be115e9184facf7486b 100644 (file)
@@ -555,3 +555,36 @@ func TestDialDualStackLocalhost(t *testing.T) {
                }
        }
 }
+
+func TestDialerKeepAlive(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+       defer func() {
+               testHookSetKeepAlive = func() {}
+       }()
+       go func() {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               return
+                       }
+                       c.Close()
+               }
+       }()
+       for _, keepAlive := range []bool{false, true} {
+               got := false
+               testHookSetKeepAlive = func() { got = true }
+               var d Dialer
+               if keepAlive {
+                       d.KeepAlive = 30 * time.Second
+               }
+               c, err := d.Dial("tcp", ln.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               c.Close()
+               if got != keepAlive {
+                       t.Errorf("Dialer.KeepAlive = %v: SetKeepAlive called = %v, want %v", d.KeepAlive, got, !got)
+               }
+       }
+}