]> Cypherpunks repositories - gostls13.git/commitdiff
net: unify TCP keepalive behavior
authordatabase64128 <free122448@hotmail.com>
Thu, 10 Nov 2022 08:20:29 +0000 (08:20 +0000)
committerGopher Robot <gobot@golang.org>
Thu, 10 Nov 2022 18:46:00 +0000 (18:46 +0000)
CL 107196 introduced a default TCP keepalive interval for Dialer and TCPListener (used by both ListenConfig and ListenTCP). Leaving DialTCP out was likely an oversight.

DialTCP's documentation says it "acts like Dial". Therefore it's natural to also expect DialTCP to enable TCP keepalive by default.

This commit addresses this disparity by moving the enablement logic down to the newTCPConn function, which is used by both dialer and listener.

Fixes #49345

Change-Id: I99c08b161c468ed0b993d1dbd2bd0d7e803f3826
GitHub-Last-Rev: 5c2f1cb0fbc5e83aa6cdbdf3ed4e23419d9bca65
GitHub-Pull-Request: golang/go#56565
Reviewed-on: https://go-review.googlesource.com/c/go/+/447917
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>

src/net/dial.go
src/net/file_plan9.go
src/net/file_unix.go
src/net/tcpsock.go
src/net/tcpsock_plan9.go
src/net/tcpsock_posix.go
src/net/tcpsock_test.go

index 0461ab12cae34a6aceba5cc514c7cd4abd9d1d5c..e243f45ba3569105bb8a8b2d0fed090edacdd953 100644 (file)
@@ -437,21 +437,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
                primaries = addrs
        }
 
-       c, err := sd.dialParallel(ctx, primaries, fallbacks)
-       if err != nil {
-               return nil, err
-       }
-
-       if tc, ok := c.(*TCPConn); ok && d.KeepAlive >= 0 {
-               setKeepAlive(tc.fd, true)
-               ka := d.KeepAlive
-               if d.KeepAlive == 0 {
-                       ka = defaultTCPKeepAlive
-               }
-               setKeepAlivePeriod(tc.fd, ka)
-               testHookSetKeepAlive(ka)
-       }
-       return c, nil
+       return sd.dialParallel(ctx, primaries, fallbacks)
 }
 
 // dialParallel races two copies of dialSerial, giving the first a
index dfb23d2e8424d0fe5213a4fb65e9cbbd6cefb52f..64aabf93ee54adc0abbc76a02872ac30f1aa1d04 100644 (file)
@@ -100,7 +100,7 @@ func fileConn(f *os.File) (Conn, error) {
 
        switch fd.laddr.(type) {
        case *TCPAddr:
-               return newTCPConn(fd), nil
+               return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
        case *UDPAddr:
                return newUDPConn(fd), nil
        }
index 0df67db501211e5ba824a63cb0227b1bbb88939e..8b9fc38916f71be44d57ee01c4c084956aaf98df 100644 (file)
@@ -74,7 +74,7 @@ func fileConn(f *os.File) (Conn, error) {
        }
        switch fd.laddr.(type) {
        case *TCPAddr:
-               return newTCPConn(fd), nil
+               return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
        case *UDPAddr:
                return newUDPConn(fd), nil
        case *IPAddr:
index 6bad0e8f8bb93b70ffb3d6bfff53817e7f83f834..672170e6816328e1ea2084dba3fe6b3ac210bf36 100644 (file)
@@ -217,10 +217,19 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
        return nil
 }
 
-func newTCPConn(fd *netFD) *TCPConn {
-       c := &TCPConn{conn{fd}}
-       setNoDelay(c.fd, true)
-       return c
+func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn {
+       setNoDelay(fd, true)
+       if keepAlive == 0 {
+               keepAlive = defaultTCPKeepAlive
+       }
+       if keepAlive > 0 {
+               setKeepAlive(fd, true)
+               setKeepAlivePeriod(fd, keepAlive)
+               if keepAliveHook != nil {
+                       keepAliveHook(keepAlive)
+               }
+       }
+       return &TCPConn{conn{fd}}
 }
 
 // DialTCP acts like Dial for TCP networks.
index 435335e92e8edb651618e8da043434d765142f1e..d55948f69e4fe9f67c187a4cd8b481b9a8a2740c 100644 (file)
@@ -42,7 +42,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
        if err != nil {
                return nil, err
        }
-       return newTCPConn(fd), nil
+       return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
 }
 
 func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil && ln.fd.ctl != nil }
@@ -52,16 +52,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
        if err != nil {
                return nil, err
        }
-       tc := newTCPConn(fd)
-       if ln.lc.KeepAlive >= 0 {
-               setKeepAlive(fd, true)
-               ka := ln.lc.KeepAlive
-               if ln.lc.KeepAlive == 0 {
-                       ka = defaultTCPKeepAlive
-               }
-               setKeepAlivePeriod(fd, ka)
-       }
-       return tc, nil
+       return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
 }
 
 func (ln *TCPListener) close() error {
index 463b456173cc370b9f37309d3338bebd70240f8f..0b3fa1ae0c3ae6424096dadbc26bffb518602d7c 100644 (file)
@@ -107,7 +107,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
        if err != nil {
                return nil, err
        }
-       return newTCPConn(fd), nil
+       return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
 }
 
 func selfConnect(fd *netFD, err error) bool {
@@ -149,16 +149,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
        if err != nil {
                return nil, err
        }
-       tc := newTCPConn(fd)
-       if ln.lc.KeepAlive >= 0 {
-               setKeepAlive(fd, true)
-               ka := ln.lc.KeepAlive
-               if ln.lc.KeepAlive == 0 {
-                       ka = defaultTCPKeepAlive
-               }
-               setKeepAlivePeriod(fd, ka)
-       }
-       return tc, nil
+       return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
 }
 
 func (ln *TCPListener) close() error {
index ae65788a73dc5ea73ee11af2f29740d84b45ebda..990d34706fdf83c224206572eede65f45f9e8da7 100644 (file)
@@ -808,3 +808,22 @@ func BenchmarkSetReadDeadline(b *testing.B) {
                deadline = deadline.Add(1)
        }
 }
+
+func TestDialTCPDefaultKeepAlive(t *testing.T) {
+       ln := newLocalListener(t, "tcp")
+       defer ln.Close()
+
+       got := time.Duration(-1)
+       testHookSetKeepAlive = func(d time.Duration) { got = d }
+       defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
+
+       c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer c.Close()
+
+       if got != defaultTCPKeepAlive {
+               t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive)
+       }
+}