]> Cypherpunks repositories - gostls13.git/commitdiff
net: add KeepAliveConfig and implement SetKeepAliveConfig
authorAndy Pan <panjf2000@gmail.com>
Tue, 14 Nov 2023 15:56:51 +0000 (23:56 +0800)
committerGopher Robot <gobot@golang.org>
Tue, 20 Feb 2024 06:04:31 +0000 (06:04 +0000)
Fixes #62254
Fixes #48622

Change-Id: Ida598e7fa914c8737fdbc1c813bcd68adb5119c3
Reviewed-on: https://go-review.googlesource.com/c/go/+/542275
Reviewed-by: Michael Knyszek <mknyszek@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Andy Pan <panjf2000@gmail.com>
Auto-Submit: Ian Lance Taylor <iant@golang.org>

32 files changed:
api/next/62254.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/62254.md [new file with mode: 0644]
doc/next/6-stdlib/99-minor/syscall (windows-386)/62254.md [new file with mode: 0644]
doc/next/6-stdlib/99-minor/syscall (windows-amd64)/62254.md [new file with mode: 0644]
src/net/dial.go
src/net/dial_test.go
src/net/file_plan9.go
src/net/file_unix.go
src/net/hook.go
src/net/mockserver_test.go
src/net/tcpconn_keepalive_conf_unix_test.go [new file with mode: 0644]
src/net/tcpconn_keepalive_darwin_test.go [new file with mode: 0644]
src/net/tcpconn_keepalive_dragonfly_test.go [new file with mode: 0644]
src/net/tcpconn_keepalive_solaris_test.go [new file with mode: 0644]
src/net/tcpconn_keepalive_test.go [new file with mode: 0644]
src/net/tcpconn_keepalive_unix_test.go [new file with mode: 0644]
src/net/tcpconn_keepalive_windows_test.go [new file with mode: 0644]
src/net/tcpsock.go
src/net/tcpsock_plan9.go
src/net/tcpsock_posix.go
src/net/tcpsock_test.go
src/net/tcpsock_unix.go [new file with mode: 0644]
src/net/tcpsock_windows.go [new file with mode: 0644]
src/net/tcpsockopt_darwin.go
src/net/tcpsockopt_dragonfly.go
src/net/tcpsockopt_openbsd.go
src/net/tcpsockopt_plan9.go
src/net/tcpsockopt_solaris.go
src/net/tcpsockopt_stub.go
src/net/tcpsockopt_unix.go
src/net/tcpsockopt_windows.go
src/syscall/types_windows.go

diff --git a/api/next/62254.txt b/api/next/62254.txt
new file mode 100644 (file)
index 0000000..49d3214
--- /dev/null
@@ -0,0 +1,12 @@
+pkg net, method (*TCPConn) SetKeepAliveConfig(KeepAliveConfig) error #62254
+pkg net, type Dialer struct, KeepAliveConfig KeepAliveConfig #62254
+pkg net, type KeepAliveConfig struct #62254
+pkg net, type KeepAliveConfig struct, Count int #62254
+pkg net, type KeepAliveConfig struct, Enable bool #62254
+pkg net, type KeepAliveConfig struct, Idle time.Duration #62254
+pkg net, type KeepAliveConfig struct, Interval time.Duration #62254
+pkg net, type ListenConfig struct, KeepAliveConfig KeepAliveConfig #62254
+pkg syscall (windows-386), const WSAENOPROTOOPT = 10042 #62254
+pkg syscall (windows-386), const WSAENOPROTOOPT Errno #62254
+pkg syscall (windows-amd64), const WSAENOPROTOOPT = 10042 #62254
+pkg syscall (windows-amd64), const WSAENOPROTOOPT Errno #62254
diff --git a/doc/next/6-stdlib/99-minor/net/62254.md b/doc/next/6-stdlib/99-minor/net/62254.md
new file mode 100644 (file)
index 0000000..1d32fd8
--- /dev/null
@@ -0,0 +1,4 @@
+The new type [`KeepAliveConfig`](/net#KeepAliveConfig) permits fine-tuning
+the keep-alive options for TCP connections, via a new
+[`TCPConn.SetKeepAliveConfig`](/net#TCPConn.SetKeepAliveConfig) method and
+new KeepAliveConfig fields for [`Dialer`](net#Dialer) and [`ListenConfig`](net#ListenConfig).
diff --git a/doc/next/6-stdlib/99-minor/syscall (windows-386)/62254.md b/doc/next/6-stdlib/99-minor/syscall (windows-386)/62254.md
new file mode 100644 (file)
index 0000000..fe9651a
--- /dev/null
@@ -0,0 +1 @@
+The syscall package now defines WSAENOPROTOOPT on Windows.
diff --git a/doc/next/6-stdlib/99-minor/syscall (windows-amd64)/62254.md b/doc/next/6-stdlib/99-minor/syscall (windows-amd64)/62254.md
new file mode 100644 (file)
index 0000000..e082778
--- /dev/null
@@ -0,0 +1 @@
+See `syscall (windows-386)/62254.md`.
index a6565c3ce5d13b8fcd81e73f194ecca72de379d2..28f346a372a78059ff4a1a489e9a38b7ff749479 100644 (file)
@@ -14,9 +14,16 @@ import (
 )
 
 const (
-       // defaultTCPKeepAlive is a default constant value for TCPKeepAlive times
-       // See go.dev/issue/31510
-       defaultTCPKeepAlive = 15 * time.Second
+       // defaultTCPKeepAliveIdle is a default constant value for TCP_KEEPIDLE.
+       // See go.dev/issue/31510 for details.
+       defaultTCPKeepAliveIdle = 15 * time.Second
+
+       // defaultTCPKeepAliveInterval is a default constant value for TCP_KEEPINTVL.
+       // It is the same as defaultTCPKeepAliveIdle, see go.dev/issue/31510 for details.
+       defaultTCPKeepAliveInterval = 15 * time.Second
+
+       // defaultTCPKeepAliveCount is a default constant value for TCP_KEEPCNT.
+       defaultTCPKeepAliveCount = 9
 
        // For the moment, MultiPath TCP is not used by default
        // See go.dev/issue/56539
@@ -116,13 +123,25 @@ type Dialer struct {
 
        // KeepAlive specifies the interval between keep-alive
        // probes for an active network connection.
+       //
+       // KeepAlive is ignored if KeepAliveConfig.Enable is true.
+       //
        // If zero, keep-alive probes are sent with a default value
        // (currently 15 seconds), if supported by the protocol and operating
        // system. Network protocols or operating systems that do
-       // not support keep-alives ignore this field.
+       // not support keep-alive ignore this field.
        // If negative, keep-alive probes are disabled.
        KeepAlive time.Duration
 
+       // KeepAliveConfig specifies the keep-alive probe configuration
+       // for an active network connection, when supported by the
+       // protocol and operating system.
+       //
+       // If KeepAliveConfig.Enable is true, keep-alive probes are enabled.
+       // If KeepAliveConfig.Enable is false and KeepAlive is negative,
+       // keep-alive probes are disabled.
+       KeepAliveConfig KeepAliveConfig
+
        // Resolver optionally specifies an alternate resolver to use.
        Resolver *Resolver
 
@@ -680,12 +699,24 @@ type ListenConfig struct {
 
        // KeepAlive specifies the keep-alive period for network
        // connections accepted by this listener.
-       // If zero, keep-alives are enabled if supported by the protocol
+       //
+       // KeepAlive is ignored if KeepAliveConfig.Enable is true.
+       //
+       // If zero, keep-alive are enabled if supported by the protocol
        // and operating system. Network protocols or operating systems
-       // that do not support keep-alives ignore this field.
-       // If negative, keep-alives are disabled.
+       // that do not support keep-alive ignore this field.
+       // If negative, keep-alive are disabled.
        KeepAlive time.Duration
 
+       // KeepAliveConfig specifies the keep-alive probe configuration
+       // for an active network connection, when supported by the
+       // protocol and operating system.
+       //
+       // If KeepAliveConfig.Enable is true, keep-alive probes are enabled.
+       // If KeepAliveConfig.Enable is false and KeepAlive is negative,
+       // keep-alive probes are disabled.
+       KeepAliveConfig KeepAliveConfig
+
        // If mptcpStatus is set to a value allowing Multipath TCP (MPTCP) to be
        // used, any call to Listen with "tcp(4|6)" as network will use MPTCP if
        // supported by the operating system.
index 1d0832e46ee56819fb32ba9b0ef541b95847254c..b3bedb2fa275c3ae62bc1fd3f79f74a975f4b676 100644 (file)
@@ -690,6 +690,10 @@ func TestDialerDualStack(t *testing.T) {
 }
 
 func TestDialerKeepAlive(t *testing.T) {
+       t.Cleanup(func() {
+               testHookSetKeepAlive = func(KeepAliveConfig) {}
+       })
+
        handler := func(ls *localServer, ln Listener) {
                for {
                        c, err := ln.Accept()
@@ -699,26 +703,30 @@ func TestDialerKeepAlive(t *testing.T) {
                        c.Close()
                }
        }
-       ls := newLocalServer(t, "tcp")
+       ln := newLocalListener(t, "tcp", &ListenConfig{
+               KeepAlive: -1, // prevent calling hook from accepting
+       })
+       ls := (&streamListener{Listener: ln}).newLocalServer()
        defer ls.teardown()
        if err := ls.buildup(handler); err != nil {
                t.Fatal(err)
        }
-       defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
 
        tests := []struct {
                ka       time.Duration
                expected time.Duration
        }{
                {-1, -1},
-               {0, 15 * time.Second},
+               {0, 0},
                {5 * time.Second, 5 * time.Second},
                {30 * time.Second, 30 * time.Second},
        }
 
+       var got time.Duration = -1
+       testHookSetKeepAlive = func(cfg KeepAliveConfig) { got = cfg.Idle }
+
        for _, test := range tests {
-               var got time.Duration = -1
-               testHookSetKeepAlive = func(d time.Duration) { got = d }
+               got = -1
                d := Dialer{KeepAlive: test.ka}
                c, err := d.Dial("tcp", ls.Listener.Addr().String())
                if err != nil {
index 64aabf93ee54adc0abbc76a02872ac30f1aa1d04..6c2151c4098a8af643db8a41e6a47e84ad33a691 100644 (file)
@@ -100,7 +100,7 @@ func fileConn(f *os.File) (Conn, error) {
 
        switch fd.laddr.(type) {
        case *TCPAddr:
-               return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
+               return newTCPConn(fd, defaultTCPKeepAliveIdle, KeepAliveConfig{}, testPreHookSetKeepAlive, testHookSetKeepAlive), nil
        case *UDPAddr:
                return newUDPConn(fd), nil
        }
index 8b9fc38916f71be44d57ee01c4c084956aaf98df..c0212cef65dba65f351ae317aa9fea61450f0e5c 100644 (file)
@@ -74,7 +74,7 @@ func fileConn(f *os.File) (Conn, error) {
        }
        switch fd.laddr.(type) {
        case *TCPAddr:
-               return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
+               return newTCPConn(fd, defaultTCPKeepAliveIdle, KeepAliveConfig{}, testPreHookSetKeepAlive, testHookSetKeepAlive), nil
        case *UDPAddr:
                return newUDPConn(fd), nil
        case *IPAddr:
index eded34d48abe4e8d0f91ed5e133097572423537d..08d1aa893481f2003f8ea61352ee632742c5288e 100644 (file)
@@ -6,7 +6,6 @@ package net
 
 import (
        "context"
-       "time"
 )
 
 var (
@@ -21,7 +20,8 @@ var (
        ) ([]IPAddr, error) {
                return fn(ctx, network, host)
        }
-       testHookSetKeepAlive = func(time.Duration) {}
+       testPreHookSetKeepAlive = func(*netFD) {}
+       testHookSetKeepAlive    = func(KeepAliveConfig) {}
 
        // testHookStepTime sleeps until time has moved forward by a nonzero amount.
        // This helps to avoid flakes in timeout tests by ensuring that an implausibly
index f5ac32faddea01f0abecc7c7d1c091be665d2298..4d5e79a592652960be481121b4651e0144d3ff9c 100644 (file)
@@ -60,12 +60,7 @@ func newLocalListener(t testing.TB, network string, lcOpt ...*ListenConfig) List
        switch network {
        case "tcp":
                if supportsIPv4() {
-                       if !supportsIPv6() {
-                               return listen("tcp4", "127.0.0.1:0")
-                       }
-                       if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil {
-                               return ln
-                       }
+                       return listen("tcp4", "127.0.0.1:0")
                }
                if supportsIPv6() {
                        return listen("tcp6", "[::1]:0")
diff --git a/src/net/tcpconn_keepalive_conf_unix_test.go b/src/net/tcpconn_keepalive_conf_unix_test.go
new file mode 100644 (file)
index 0000000..7c39708
--- /dev/null
@@ -0,0 +1,102 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || freebsd || linux || netbsd || darwin || dragonfly
+
+package net
+
+import "time"
+
+var testConfigs = []KeepAliveConfig{
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: 3 * time.Second,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     0,
+               Interval: 0,
+               Count:    0,
+       },
+       {
+               Enable:   true,
+               Idle:     -1,
+               Interval: -1,
+               Count:    -1,
+       },
+       {
+               Enable:   true,
+               Idle:     -1,
+               Interval: 3 * time.Second,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: -1,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: 3 * time.Second,
+               Count:    -1,
+       },
+       {
+               Enable:   true,
+               Idle:     -1,
+               Interval: -1,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     -1,
+               Interval: 3 * time.Second,
+               Count:    -1,
+       },
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: -1,
+               Count:    -1,
+       },
+       {
+               Enable:   true,
+               Idle:     0,
+               Interval: 3 * time.Second,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: 0,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: 3 * time.Second,
+               Count:    0,
+       },
+       {
+               Enable:   true,
+               Idle:     0,
+               Interval: 0,
+               Count:    10,
+       },
+       {
+               Enable:   true,
+               Idle:     0,
+               Interval: 3 * time.Second,
+               Count:    0,
+       },
+       {
+               Enable:   true,
+               Idle:     5 * time.Second,
+               Interval: 0,
+               Count:    0,
+       },
+}
diff --git a/src/net/tcpconn_keepalive_darwin_test.go b/src/net/tcpconn_keepalive_darwin_test.go
new file mode 100644 (file)
index 0000000..147e08c
--- /dev/null
@@ -0,0 +1,92 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin
+
+package net
+
+import (
+       "syscall"
+       "testing"
+       "time"
+)
+
+func getCurrentKeepAliveSettings(fd int) (cfg KeepAliveConfig, err error) {
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, sysTCP_KEEPINTVL)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, sysTCP_KEEPCNT)
+       if err != nil {
+               return
+       }
+       cfg = KeepAliveConfig{
+               Enable:   tcpKeepAlive != 0,
+               Idle:     time.Duration(tcpKeepAliveIdle) * time.Second,
+               Interval: time.Duration(tcpKeepAliveInterval) * time.Second,
+               Count:    tcpKeepAliveCount,
+       }
+       return
+}
+
+func verifyKeepAliveSettings(t *testing.T, fd int, oldCfg, cfg KeepAliveConfig) {
+       if cfg.Idle == 0 {
+               cfg.Idle = defaultTCPKeepAliveIdle
+       }
+       if cfg.Interval == 0 {
+               cfg.Interval = defaultTCPKeepAliveInterval
+       }
+       if cfg.Count == 0 {
+               cfg.Count = defaultTCPKeepAliveCount
+       }
+       if cfg.Idle == -1 {
+               cfg.Idle = oldCfg.Idle
+       }
+       if cfg.Interval == -1 {
+               cfg.Interval = oldCfg.Interval
+       }
+       if cfg.Count == -1 {
+               cfg.Count = oldCfg.Count
+       }
+
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if (tcpKeepAlive != 0) != cfg.Enable {
+               t.Fatalf("SO_KEEPALIVE: got %t; want %t", tcpKeepAlive != 0, cfg.Enable)
+       }
+
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveIdle)*time.Second != cfg.Idle {
+               t.Fatalf("TCP_KEEPIDLE: got %ds; want %v", tcpKeepAliveIdle, cfg.Idle)
+       }
+
+       tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, sysTCP_KEEPINTVL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveInterval)*time.Second != cfg.Interval {
+               t.Fatalf("TCP_KEEPINTVL: got %ds; want %v", tcpKeepAliveInterval, cfg.Interval)
+       }
+
+       tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, sysTCP_KEEPCNT)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if tcpKeepAliveCount != cfg.Count {
+               t.Fatalf("TCP_KEEPCNT: got %d; want %d", tcpKeepAliveCount, cfg.Count)
+       }
+}
diff --git a/src/net/tcpconn_keepalive_dragonfly_test.go b/src/net/tcpconn_keepalive_dragonfly_test.go
new file mode 100644 (file)
index 0000000..61b073b
--- /dev/null
@@ -0,0 +1,92 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build dragonfly
+
+package net
+
+import (
+       "syscall"
+       "testing"
+       "time"
+)
+
+func getCurrentKeepAliveSettings(fd int) (cfg KeepAliveConfig, err error) {
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT)
+       if err != nil {
+               return
+       }
+       cfg = KeepAliveConfig{
+               Enable:   tcpKeepAlive != 0,
+               Idle:     time.Duration(tcpKeepAliveIdle) * time.Millisecond,
+               Interval: time.Duration(tcpKeepAliveInterval) * time.Millisecond,
+               Count:    tcpKeepAliveCount,
+       }
+       return
+}
+
+func verifyKeepAliveSettings(t *testing.T, fd int, oldCfg, cfg KeepAliveConfig) {
+       if cfg.Idle == 0 {
+               cfg.Idle = defaultTCPKeepAliveIdle
+       }
+       if cfg.Interval == 0 {
+               cfg.Interval = defaultTCPKeepAliveInterval
+       }
+       if cfg.Count == 0 {
+               cfg.Count = defaultTCPKeepAliveCount
+       }
+       if cfg.Idle == -1 {
+               cfg.Idle = oldCfg.Idle
+       }
+       if cfg.Interval == -1 {
+               cfg.Interval = oldCfg.Interval
+       }
+       if cfg.Count == -1 {
+               cfg.Count = oldCfg.Count
+       }
+
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if (tcpKeepAlive != 0) != cfg.Enable {
+               t.Fatalf("SO_KEEPALIVE: got %t; want %t", tcpKeepAlive != 0, cfg.Enable)
+       }
+
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveIdle)*time.Millisecond != cfg.Idle {
+               t.Fatalf("TCP_KEEPIDLE: got %dms; want %v", tcpKeepAliveIdle, cfg.Idle)
+       }
+
+       tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveInterval)*time.Millisecond != cfg.Interval {
+               t.Fatalf("TCP_KEEPINTVL: got %dms; want %v", tcpKeepAliveInterval, cfg.Interval)
+       }
+
+       tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if tcpKeepAliveCount != cfg.Count {
+               t.Fatalf("TCP_KEEPCNT: got %d; want %d", tcpKeepAliveCount, cfg.Count)
+       }
+}
diff --git a/src/net/tcpconn_keepalive_solaris_test.go b/src/net/tcpconn_keepalive_solaris_test.go
new file mode 100644 (file)
index 0000000..c6456c4
--- /dev/null
@@ -0,0 +1,89 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build solaris
+
+package net
+
+import (
+       "syscall"
+       "testing"
+       "time"
+)
+
+var testConfigs = []KeepAliveConfig{
+       {
+               Enable:   true,
+               Idle:     2 * time.Second,
+               Interval: -1,
+               Count:    -1,
+       },
+       {
+               Enable:   true,
+               Idle:     0,
+               Interval: -1,
+               Count:    -1,
+       },
+       {
+               Enable:   true,
+               Idle:     -1,
+               Interval: -1,
+               Count:    -1,
+       },
+}
+
+func getCurrentKeepAliveSettings(fd int) (cfg KeepAliveConfig, err error) {
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE_THRESHOLD)
+       if err != nil {
+               return
+       }
+       cfg = KeepAliveConfig{
+               Enable:   tcpKeepAlive != 0,
+               Idle:     time.Duration(tcpKeepAliveIdle) * time.Millisecond,
+               Interval: -1,
+               Count:    -1,
+       }
+       return
+}
+
+func verifyKeepAliveSettings(t *testing.T, fd int, oldCfg, cfg KeepAliveConfig) {
+       if cfg.Idle == 0 {
+               cfg.Idle = defaultTCPKeepAliveIdle
+       }
+       if cfg.Interval == 0 {
+               cfg.Interval = defaultTCPKeepAliveInterval
+       }
+       if cfg.Count == 0 {
+               cfg.Count = defaultTCPKeepAliveCount
+       }
+       if cfg.Idle == -1 {
+               cfg.Idle = oldCfg.Idle
+       }
+       if cfg.Interval == -1 {
+               cfg.Interval = oldCfg.Interval
+       }
+       if cfg.Count == -1 {
+               cfg.Count = oldCfg.Count
+       }
+
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if (tcpKeepAlive != 0) != cfg.Enable {
+               t.Fatalf("SO_KEEPALIVE: got %t; want %t", tcpKeepAlive != 0, cfg.Enable)
+       }
+
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE_THRESHOLD)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveIdle)*time.Millisecond != cfg.Idle {
+               t.Fatalf("TCP_KEEPIDLE: got %dms; want %v", tcpKeepAliveIdle, cfg.Idle)
+       }
+}
diff --git a/src/net/tcpconn_keepalive_test.go b/src/net/tcpconn_keepalive_test.go
new file mode 100644 (file)
index 0000000..f858d99
--- /dev/null
@@ -0,0 +1,195 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || freebsd || linux || netbsd || dragonfly || darwin || solaris || windows
+
+package net
+
+import (
+       "runtime"
+       "testing"
+)
+
+func TestTCPConnDialerKeepAliveConfig(t *testing.T) {
+       // TODO(panjf2000): stop skipping this test on Solaris
+       //  when https://go.dev/issue/64251 is fixed.
+       if runtime.GOOS == "solaris" {
+               t.Skip("skipping on solaris for now")
+       }
+
+       t.Cleanup(func() {
+               testPreHookSetKeepAlive = func(*netFD) {}
+       })
+       var (
+               errHook error
+               oldCfg  KeepAliveConfig
+       )
+       testPreHookSetKeepAlive = func(nfd *netFD) {
+               oldCfg, errHook = getCurrentKeepAliveSettings(int(nfd.pfd.Sysfd))
+       }
+
+       handler := func(ls *localServer, ln Listener) {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               return
+                       }
+                       c.Close()
+               }
+       }
+       ln := newLocalListener(t, "tcp", &ListenConfig{
+               KeepAlive: -1, // prevent calling hook from accepting
+       })
+       ls := (&streamListener{Listener: ln}).newLocalServer()
+       defer ls.teardown()
+       if err := ls.buildup(handler); err != nil {
+               t.Fatal(err)
+       }
+
+       for _, cfg := range testConfigs {
+               d := Dialer{
+                       KeepAlive:       defaultTCPKeepAliveIdle, // should be ignored
+                       KeepAliveConfig: cfg}
+               c, err := d.Dial("tcp", ls.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer c.Close()
+
+               if errHook != nil {
+                       t.Fatal(errHook)
+               }
+
+               sc, err := c.(*TCPConn).SyscallConn()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := sc.Control(func(fd uintptr) {
+                       verifyKeepAliveSettings(t, int(fd), oldCfg, cfg)
+               }); err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
+
+func TestTCPConnListenerKeepAliveConfig(t *testing.T) {
+       // TODO(panjf2000): stop skipping this test on Solaris
+       //  when https://go.dev/issue/64251 is fixed.
+       if runtime.GOOS == "solaris" {
+               t.Skip("skipping on solaris for now")
+       }
+
+       t.Cleanup(func() {
+               testPreHookSetKeepAlive = func(*netFD) {}
+       })
+       var (
+               errHook error
+               oldCfg  KeepAliveConfig
+       )
+       testPreHookSetKeepAlive = func(nfd *netFD) {
+               oldCfg, errHook = getCurrentKeepAliveSettings(int(nfd.pfd.Sysfd))
+       }
+
+       ch := make(chan Conn, 1)
+       handler := func(ls *localServer, ln Listener) {
+               c, err := ln.Accept()
+               if err != nil {
+                       return
+               }
+               ch <- c
+       }
+       for _, cfg := range testConfigs {
+               ln := newLocalListener(t, "tcp", &ListenConfig{
+                       KeepAlive:       defaultTCPKeepAliveIdle, // should be ignored
+                       KeepAliveConfig: cfg})
+               ls := (&streamListener{Listener: ln}).newLocalServer()
+               defer ls.teardown()
+               if err := ls.buildup(handler); err != nil {
+                       t.Fatal(err)
+               }
+               d := Dialer{KeepAlive: -1} // prevent calling hook from dialing
+               c, err := d.Dial("tcp", ls.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer c.Close()
+
+               cc := <-ch
+               defer cc.Close()
+               if errHook != nil {
+                       t.Fatal(errHook)
+               }
+               sc, err := cc.(*TCPConn).SyscallConn()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := sc.Control(func(fd uintptr) {
+                       verifyKeepAliveSettings(t, int(fd), oldCfg, cfg)
+               }); err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
+
+func TestTCPConnSetKeepAliveConfig(t *testing.T) {
+       // TODO(panjf2000): stop skipping this test on Solaris
+       //  when https://go.dev/issue/64251 is fixed.
+       if runtime.GOOS == "solaris" {
+               t.Skip("skipping on solaris for now")
+       }
+
+       handler := func(ls *localServer, ln Listener) {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               return
+                       }
+                       c.Close()
+               }
+       }
+       ls := newLocalServer(t, "tcp")
+       defer ls.teardown()
+       if err := ls.buildup(handler); err != nil {
+               t.Fatal(err)
+       }
+       ra, err := ResolveTCPAddr("tcp", ls.Listener.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+       for _, cfg := range testConfigs {
+               c, err := DialTCP("tcp", nil, ra)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer c.Close()
+
+               sc, err := c.SyscallConn()
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               var (
+                       errHook error
+                       oldCfg  KeepAliveConfig
+               )
+               if err := sc.Control(func(fd uintptr) {
+                       oldCfg, errHook = getCurrentKeepAliveSettings(int(fd))
+               }); err != nil {
+                       t.Fatal(err)
+               }
+               if errHook != nil {
+                       t.Fatal(errHook)
+               }
+
+               if err := c.SetKeepAliveConfig(cfg); err != nil {
+                       t.Fatal(err)
+               }
+
+               if err := sc.Control(func(fd uintptr) {
+                       verifyKeepAliveSettings(t, int(fd), oldCfg, cfg)
+               }); err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
diff --git a/src/net/tcpconn_keepalive_unix_test.go b/src/net/tcpconn_keepalive_unix_test.go
new file mode 100644 (file)
index 0000000..8f74b6e
--- /dev/null
@@ -0,0 +1,92 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || freebsd || linux || netbsd
+
+package net
+
+import (
+       "syscall"
+       "testing"
+       "time"
+)
+
+func getCurrentKeepAliveSettings(fd int) (cfg KeepAliveConfig, err error) {
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL)
+       if err != nil {
+               return
+       }
+       tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT)
+       if err != nil {
+               return
+       }
+       cfg = KeepAliveConfig{
+               Enable:   tcpKeepAlive != 0,
+               Idle:     time.Duration(tcpKeepAliveIdle) * time.Second,
+               Interval: time.Duration(tcpKeepAliveInterval) * time.Second,
+               Count:    tcpKeepAliveCount,
+       }
+       return
+}
+
+func verifyKeepAliveSettings(t *testing.T, fd int, oldCfg, cfg KeepAliveConfig) {
+       if cfg.Idle == 0 {
+               cfg.Idle = defaultTCPKeepAliveIdle
+       }
+       if cfg.Interval == 0 {
+               cfg.Interval = defaultTCPKeepAliveInterval
+       }
+       if cfg.Count == 0 {
+               cfg.Count = defaultTCPKeepAliveCount
+       }
+       if cfg.Idle == -1 {
+               cfg.Idle = oldCfg.Idle
+       }
+       if cfg.Interval == -1 {
+               cfg.Interval = oldCfg.Interval
+       }
+       if cfg.Count == -1 {
+               cfg.Count = oldCfg.Count
+       }
+
+       tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if (tcpKeepAlive != 0) != cfg.Enable {
+               t.Fatalf("SO_KEEPALIVE: got %t; want %t", tcpKeepAlive != 0, cfg.Enable)
+       }
+
+       tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveIdle)*time.Second != cfg.Idle {
+               t.Fatalf("TCP_KEEPIDLE: got %ds; want %v", tcpKeepAliveIdle, cfg.Idle)
+       }
+
+       tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if time.Duration(tcpKeepAliveInterval)*time.Second != cfg.Interval {
+               t.Fatalf("TCP_KEEPINTVL: got %ds; want %v", tcpKeepAliveInterval, cfg.Interval)
+       }
+
+       tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if tcpKeepAliveCount != cfg.Count {
+               t.Fatalf("TCP_KEEPCNT: got %d; want %d", tcpKeepAliveCount, cfg.Count)
+       }
+}
diff --git a/src/net/tcpconn_keepalive_windows_test.go b/src/net/tcpconn_keepalive_windows_test.go
new file mode 100644 (file)
index 0000000..c3d6366
--- /dev/null
@@ -0,0 +1,33 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+package net
+
+import (
+       "testing"
+       "time"
+)
+
+var testConfigs = []KeepAliveConfig{
+       {
+               Enable:   true,
+               Idle:     2 * time.Second,
+               Interval: time.Second,
+               Count:    -1,
+       },
+}
+
+func getCurrentKeepAliveSettings(_ int) (cfg KeepAliveConfig, err error) {
+       // TODO(panjf2000): same as verifyKeepAliveSettings.
+       return
+}
+
+func verifyKeepAliveSettings(_ *testing.T, _ int, _, _ KeepAliveConfig) {
+       // TODO(panjf2000): Unlike Unix-like OS's, Windows doesn't provide
+       //      any ways to retrieve the current TCP keep-alive settings, therefore
+       //      we're not able to run the test suite similar to Unix-like OS's on Windows.
+       //  Try to find another proper approach to test the keep-alive settings on Windows.
+}
index 590516bff13034c8eb9f960cfb6e43d74664c5ee..5ffdbb035920837a6f8a7424e7cbc6d2d665b50c 100644 (file)
@@ -113,6 +113,36 @@ type TCPConn struct {
        conn
 }
 
+// KeepAliveConfig contains TCP keep-alive options.
+//
+// If the Idle, Interval, or Count fields are zero, a default value is chosen.
+// If a field is negative, the corresponding socket-level option will be left unchanged.
+//
+// Note that Windows doesn't support setting the KeepAliveIdle and KeepAliveInterval separately.
+// It's recommended to set both Idle and Interval to non-negative values on Windows if you
+// intend to customize the TCP keep-alive settings.
+// By contrast, if only one of Idle and Interval is set to a non-negative value, the other will
+// be set to the system default value, and ultimately, set both Idle and Interval to negative
+// values if you want to leave them unchanged.
+type KeepAliveConfig struct {
+       // If Enable is true, keep-alive probes are enabled.
+       Enable bool
+
+       // Idle is the time that the connection must be idle before
+       // the first keep-alive probe is sent.
+       // If zero, a default value of 15 seconds is used.
+       Idle time.Duration
+
+       // Interval is the time between keep-alive probes.
+       // If zero, a default value of 15 seconds is used.
+       Interval time.Duration
+
+       // Count is the maximum number of keep-alive probes that
+       // can go unanswered before dropping a connection.
+       // If zero, a default value of 9 is used.
+       Count int
+}
+
 // SyscallConn returns a raw network connection.
 // This implements the [syscall.Conn] interface.
 func (c *TCPConn) SyscallConn() (syscall.RawConn, error) {
@@ -206,12 +236,16 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) error {
        return nil
 }
 
-// SetKeepAlivePeriod sets period between keep-alives.
+// SetKeepAlivePeriod sets the idle duration the connection
+// needs to remain idle before TCP starts sending keepalive probes.
+//
+// Note that calling this method on Windows will reset the KeepAliveInterval
+// to the default system value, which is normally 1 second.
 func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error {
        if !c.ok() {
                return syscall.EINVAL
        }
-       if err := setKeepAlivePeriod(c.fd, d); err != nil {
+       if err := setKeepAliveIdle(c.fd, d); err != nil {
                return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
        }
        return nil
@@ -247,19 +281,25 @@ func (c *TCPConn) MultipathTCP() (bool, error) {
        return isUsingMultipathTCP(c.fd), nil
 }
 
-func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn {
+func newTCPConn(fd *netFD, keepAliveIdle time.Duration, keepAliveCfg KeepAliveConfig, preKeepAliveHook func(*netFD), keepAliveHook func(KeepAliveConfig)) *TCPConn {
        setNoDelay(fd, true)
-       if keepAlive == 0 {
-               keepAlive = defaultTCPKeepAlive
+       if !keepAliveCfg.Enable && keepAliveIdle >= 0 {
+               keepAliveCfg = KeepAliveConfig{
+                       Enable: true,
+                       Idle:   keepAliveIdle,
+               }
        }
-       if keepAlive > 0 {
-               setKeepAlive(fd, true)
-               setKeepAlivePeriod(fd, keepAlive)
+       c := &TCPConn{conn{fd}}
+       if keepAliveCfg.Enable {
+               if preKeepAliveHook != nil {
+                       preKeepAliveHook(fd)
+               }
+               c.SetKeepAliveConfig(keepAliveCfg)
                if keepAliveHook != nil {
-                       keepAliveHook(keepAlive)
+                       keepAliveHook(keepAliveCfg)
                }
        }
-       return &TCPConn{conn{fd}}
+       return c
 }
 
 // DialTCP acts like [Dial] for TCP networks.
index 463dedcf44cdedf424edbf8c88c6ff3fbbed21ba..430ed29ed42d0bf0021ca801e5e519c84523a30a 100644 (file)
@@ -46,7 +46,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
        if err != nil {
                return nil, err
        }
-       return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
+       return newTCPConn(fd, sd.Dialer.KeepAlive, sd.Dialer.KeepAliveConfig, testPreHookSetKeepAlive, testHookSetKeepAlive), nil
 }
 
 func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil && ln.fd.ctl != nil }
@@ -56,7 +56,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
        if err != nil {
                return nil, err
        }
-       return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
+       return newTCPConn(fd, ln.lc.KeepAlive, ln.lc.KeepAliveConfig, testPreHookSetKeepAlive, testHookSetKeepAlive), nil
 }
 
 func (ln *TCPListener) close() error {
index 01b5ec9ed0564243952ea36a444962ec250e0e9e..a25494d9c059e1fbd4f7aad52368e002e5095a69 100644 (file)
@@ -118,7 +118,7 @@ func (sd *sysDialer) doDialTCPProto(ctx context.Context, laddr, raddr *TCPAddr,
        if err != nil {
                return nil, err
        }
-       return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
+       return newTCPConn(fd, sd.Dialer.KeepAlive, sd.Dialer.KeepAliveConfig, testPreHookSetKeepAlive, testHookSetKeepAlive), nil
 }
 
 func selfConnect(fd *netFD, err error) bool {
@@ -160,7 +160,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
        if err != nil {
                return nil, err
        }
-       return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
+       return newTCPConn(fd, ln.lc.KeepAlive, ln.lc.KeepAliveConfig, testPreHookSetKeepAlive, testHookSetKeepAlive), nil
 }
 
 func (ln *TCPListener) close() error {
index b37e936ff82e103f000910afbb8243cb8773b383..9ed49a925b4b39c13f14033ff329f0d323180b20 100644 (file)
@@ -775,8 +775,8 @@ func TestDialTCPDefaultKeepAlive(t *testing.T) {
        defer ln.Close()
 
        got := time.Duration(-1)
-       testHookSetKeepAlive = func(d time.Duration) { got = d }
-       defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
+       testHookSetKeepAlive = func(cfg KeepAliveConfig) { got = cfg.Idle }
+       defer func() { testHookSetKeepAlive = func(KeepAliveConfig) {} }()
 
        c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
        if err != nil {
@@ -784,8 +784,8 @@ func TestDialTCPDefaultKeepAlive(t *testing.T) {
        }
        defer c.Close()
 
-       if got != defaultTCPKeepAlive {
-               t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive)
+       if got != 0 {
+               t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAliveIdle)
        }
 }
 
diff --git a/src/net/tcpsock_unix.go b/src/net/tcpsock_unix.go
new file mode 100644 (file)
index 0000000..b5c05f4
--- /dev/null
@@ -0,0 +1,31 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !windows
+
+package net
+
+import "syscall"
+
+// SetKeepAliveConfig configures keep-alive messages sent by the operating system.
+func (c *TCPConn) SetKeepAliveConfig(config KeepAliveConfig) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+
+       if err := setKeepAlive(c.fd, config.Enable); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       if err := setKeepAliveIdle(c.fd, config.Idle); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       if err := setKeepAliveInterval(c.fd, config.Interval); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       if err := setKeepAliveCount(c.fd, config.Count); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+
+       return nil
+}
diff --git a/src/net/tcpsock_windows.go b/src/net/tcpsock_windows.go
new file mode 100644 (file)
index 0000000..8ec71ab
--- /dev/null
@@ -0,0 +1,26 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "syscall"
+
+// SetKeepAliveConfig configures keep-alive messages sent by the operating system.
+func (c *TCPConn) SetKeepAliveConfig(config KeepAliveConfig) error {
+       if !c.ok() {
+               return syscall.EINVAL
+       }
+
+       if err := setKeepAlive(c.fd, config.Enable); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       if err := setKeepAliveIdleAndInterval(c.fd, config.Idle, config.Interval); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       if err := setKeepAliveCount(c.fd, config.Count); err != nil {
+               return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+
+       return nil
+}
index 53c6756e33e00b3f107d3012fc11109f88219228..efe7f63323f05fba2d63952b4e52a0d1c0a7468b 100644 (file)
@@ -10,16 +10,48 @@ import (
        "time"
 )
 
-// syscall.TCP_KEEPINTVL is missing on some darwin architectures.
-const sysTCP_KEEPINTVL = 0x101
+// syscall.TCP_KEEPINTVL and syscall.TCP_KEEPCNT might be missing on some darwin architectures.
+const (
+       sysTCP_KEEPINTVL = 0x101
+       sysTCP_KEEPCNT   = 0x102
+)
+
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveIdle
+       } else if d < 0 {
+               return nil
+       }
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
        // The kernel expects seconds so round to next highest second.
        secs := int(roundDurationUp(d, time.Second))
-       if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs); err != nil {
-               return wrapSyscallError("setsockopt", err)
-       }
        err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, secs)
        runtime.KeepAlive(fd)
        return wrapSyscallError("setsockopt", err)
 }
+
+func setKeepAliveInterval(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveInterval
+       } else if d < 0 {
+               return nil
+       }
+
+       // The kernel expects seconds so round to next highest second.
+       secs := int(roundDurationUp(d, time.Second))
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
+
+func setKeepAliveCount(fd *netFD, n int) error {
+       if n == 0 {
+               n = defaultTCPKeepAliveCount
+       } else if n < 0 {
+               return nil
+       }
+
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPCNT, n)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
index b473c02b6867db4a5afb4fcb4c0b2d947ad64aab..612baaea31b0cb4b4efa437862045e618fa726ac 100644 (file)
@@ -10,14 +10,44 @@ import (
        "time"
 )
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveIdle
+       } else if d < 0 {
+               return nil
+       }
+
        // The kernel expects milliseconds so round to next highest
        // millisecond.
        msecs := int(roundDurationUp(d, time.Millisecond))
-       if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, msecs); err != nil {
-               return wrapSyscallError("setsockopt", err)
-       }
        err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, msecs)
        runtime.KeepAlive(fd)
        return wrapSyscallError("setsockopt", err)
 }
+
+func setKeepAliveInterval(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveInterval
+       } else if d < 0 {
+               return nil
+       }
+
+       // The kernel expects milliseconds so round to next highest
+       // millisecond.
+       msecs := int(roundDurationUp(d, time.Millisecond))
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, msecs)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
+
+func setKeepAliveCount(fd *netFD, n int) error {
+       if n == 0 {
+               n = defaultTCPKeepAliveCount
+       } else if n < 0 {
+               return nil
+       }
+
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT, n)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
index 10e1bef3e5aaed12b5a05cdcd43f611f8e1a2254..d21b77c4068827986d1d6dbcf7f73c0f5321ba55 100644 (file)
@@ -9,7 +9,28 @@ import (
        "time"
 )
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+func setKeepAliveIdle(_ *netFD, d time.Duration) error {
+       if d < 0 {
+               return nil
+       }
+       // OpenBSD has no user-settable per-socket TCP keepalive
+       // options.
+       return syscall.ENOPROTOOPT
+}
+
+func setKeepAliveInterval(_ *netFD, d time.Duration) error {
+       if d < 0 {
+               return nil
+       }
+       // OpenBSD has no user-settable per-socket TCP keepalive
+       // options.
+       return syscall.ENOPROTOOPT
+}
+
+func setKeepAliveCount(_ *netFD, n int) error {
+       if n < 0 {
+               return nil
+       }
        // OpenBSD has no user-settable per-socket TCP keepalive
        // options.
        return syscall.ENOPROTOOPT
index 264359dcf3daf15bd14c8d82eadafeffcf566dbb..017e87518aeb9157ab83cdd1738612c7dea45157 100644 (file)
@@ -12,13 +12,31 @@ import (
        "time"
 )
 
-func setNoDelay(fd *netFD, noDelay bool) error {
+func setNoDelay(_ *netFD, _ bool) error {
        return syscall.EPLAN9
 }
 
 // Set keep alive period.
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       if d < 0 {
+               return nil
+       }
+
        cmd := "keepalive " + itoa.Itoa(int(d/time.Millisecond))
        _, e := fd.ctl.WriteAt([]byte(cmd), 0)
        return e
 }
+
+func setKeepAliveInterval(_ *netFD, d time.Duration) error {
+       if d < 0 {
+               return nil
+       }
+       return syscall.EPLAN9
+}
+
+func setKeepAliveCount(_ *netFD, n int) error {
+       if n < 0 {
+               return nil
+       }
+       return syscall.EPLAN9
+}
index f15e589dc058488850533a8ff510a9901cde149b..44eb9cd09e7558ecebe152175c5eea238a7c25cc 100644 (file)
@@ -10,11 +10,31 @@ import (
        "time"
 )
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveIdle
+       } else if d < 0 {
+               return nil
+       }
+
        // The kernel expects milliseconds so round to next highest
        // millisecond.
        msecs := int(roundDurationUp(d, time.Millisecond))
 
+       // TODO(panjf2000): the system call here always returns an error of invalid argument,
+       //       this was never discovered due to the lack of tests for TCP keep-alive on various
+       //       platforms in Go's test suite. Try to dive deep and figure out the reason later.
+       // Check out https://go.dev/issue/64251 for more details.
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE_THRESHOLD, msecs)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
+
+func setKeepAliveInterval(_ *netFD, d time.Duration) error {
+       if d < 0 {
+               return nil
+       }
+
        // Normally we'd do
        //      syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs)
        // here, but we can't because Solaris does not have TCP_KEEPINTVL.
@@ -25,8 +45,12 @@ func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
        // and do it anyway, like on Darwin, because Solaris might eventually
        // allocate a constant with a different meaning for the value of
        // TCP_KEEPINTVL on illumos.
+       return syscall.ENOPROTOOPT
+}
 
-       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE_THRESHOLD, msecs)
-       runtime.KeepAlive(fd)
-       return wrapSyscallError("setsockopt", err)
+func setKeepAliveCount(_ *netFD, n int) error {
+       if n < 0 {
+               return nil
+       }
+       return syscall.ENOPROTOOPT
 }
index cef07cd6484e51b7c761da575d0a5ad2a923693a..b789e0ae934e34d77e873ef13393150df9c273ac 100644 (file)
@@ -15,6 +15,14 @@ func setNoDelay(fd *netFD, noDelay bool) error {
        return syscall.ENOPROTOOPT
 }
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       return syscall.ENOPROTOOPT
+}
+
+func setKeepAliveInterval(fd *netFD, d time.Duration) error {
+       return syscall.ENOPROTOOPT
+}
+
+func setKeepAliveCount(fd *netFD, n int) error {
        return syscall.ENOPROTOOPT
 }
index bdcdc4023983944dbe39b8483786df7555418f8c..eb01663c528be6f81b57c7d1602a99ec661684a1 100644 (file)
@@ -12,13 +12,42 @@ import (
        "time"
 )
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveIdle
+       } else if d < 0 {
+               return nil
+       }
+
        // The kernel expects seconds so round to next highest second.
        secs := int(roundDurationUp(d, time.Second))
-       if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs); err != nil {
-               return wrapSyscallError("setsockopt", err)
-       }
        err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs)
        runtime.KeepAlive(fd)
        return wrapSyscallError("setsockopt", err)
 }
+
+func setKeepAliveInterval(fd *netFD, d time.Duration) error {
+       if d == 0 {
+               d = defaultTCPKeepAliveInterval
+       } else if d < 0 {
+               return nil
+       }
+
+       // The kernel expects seconds so round to next highest second.
+       secs := int(roundDurationUp(d, time.Second))
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
+
+func setKeepAliveCount(fd *netFD, n int) error {
+       if n == 0 {
+               n = defaultTCPKeepAliveCount
+       } else if n < 0 {
+               return nil
+       }
+
+       err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPCNT, n)
+       runtime.KeepAlive(fd)
+       return wrapSyscallError("setsockopt", err)
+}
index 4a0b09465eea839d19a7baa91bcf0842bd311ca2..274fc4d9c487f10d3d6bb1ab016c30f2ef11fa87 100644 (file)
@@ -12,14 +12,72 @@ import (
        "unsafe"
 )
 
-func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+// Default values of KeepAliveTime and KeepAliveInterval on Windows,
+// check out https://learn.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals#remarks for details.
+const (
+       defaultKeepAliveIdle     = 2 * time.Hour
+       defaultKeepAliveInterval = time.Second
+)
+
+func setKeepAliveIdle(fd *netFD, d time.Duration) error {
+       return setKeepAliveIdleAndInterval(fd, d, -1)
+}
+
+func setKeepAliveInterval(fd *netFD, d time.Duration) error {
+       return setKeepAliveIdleAndInterval(fd, -1, d)
+}
+
+func setKeepAliveCount(_ *netFD, n int) error {
+       if n < 0 {
+               return nil
+       }
+
+       // This value is not capable to be changed on Windows.
+       return syscall.WSAENOPROTOOPT
+}
+
+func setKeepAliveIdleAndInterval(fd *netFD, idle, interval time.Duration) error {
+       // WSAIoctl with SIO_KEEPALIVE_VALS control code requires all fields in
+       // `tcp_keepalive` struct to be provided.
+       // Otherwise, if any of the fields were not provided, just leaving them
+       // zero will knock off any existing values of keep-alive.
+       // Unfortunately, Windows doesn't support retrieving current keep-alive
+       // settings in any form programmatically, which disable us to first retrieve
+       // the current keep-alive settings, then set it without unwanted corruption.
+       switch {
+       case idle < 0 && interval >= 0:
+               // Given that we can't set KeepAliveInterval alone, and this code path
+               // is new, it doesn't exist before, so we just return an error.
+               return syscall.WSAENOPROTOOPT
+       case idle >= 0 && interval < 0:
+               // Although we can't set KeepAliveTime alone either, this existing code
+               // path had been backing up [SetKeepAlivePeriod] which used to be set both
+               // KeepAliveTime and KeepAliveInterval to 15 seconds.
+               // Now we will use the default of KeepAliveInterval on Windows if user doesn't
+               // provide one.
+               interval = defaultKeepAliveInterval
+       case idle < 0 && interval < 0:
+               // Nothing to do, just bail out.
+               return nil
+       case idle >= 0 && interval >= 0:
+               // Go ahead.
+       }
+
+       if idle == 0 {
+               idle = defaultTCPKeepAliveIdle
+       }
+       if interval == 0 {
+               interval = defaultTCPKeepAliveInterval
+       }
+
        // The kernel expects milliseconds so round to next highest
        // millisecond.
-       msecs := uint32(roundDurationUp(d, time.Millisecond))
+       tcpKeepAliveIdle := uint32(roundDurationUp(idle, time.Millisecond))
+       tcpKeepAliveInterval := uint32(roundDurationUp(interval, time.Millisecond))
        ka := syscall.TCPKeepalive{
                OnOff:    1,
-               Time:     msecs,
-               Interval: msecs,
+               Time:     tcpKeepAliveIdle,
+               Interval: tcpKeepAliveInterval,
        }
        ret := uint32(0)
        size := uint32(unsafe.Sizeof(ka))
index b338ec47001f850c9ecca012f466e7a4276a48fc..6743675b95909bdcc4ff033909d2663202ee15f3 100644 (file)
@@ -27,6 +27,7 @@ const (
        ERROR_NOT_FOUND           Errno = 1168
        ERROR_PRIVILEGE_NOT_HELD  Errno = 1314
        WSAEACCES                 Errno = 10013
+       WSAENOPROTOOPT            Errno = 10042
        WSAECONNABORTED           Errno = 10053
        WSAECONNRESET             Errno = 10054
 )