From 90b40c0496440fbd57538eb4ba303164ed923d93 Mon Sep 17 00:00:00 2001 From: cuiweixie Date: Sun, 23 Oct 2022 08:41:38 +0800 Subject: [PATCH] net: add ControlContext to Dialer Fixes #55301 Change-Id: Ie8abcd383eee9af75038bde908ac638f43d33b7e Reviewed-on: https://go-review.googlesource.com/c/go/+/444955 Reviewed-by: Bryan Mills Reviewed-by: Ian Lance Taylor Run-TryBot: Ian Lance Taylor Auto-Submit: Ian Lance Taylor TryBot-Result: Gopher Robot Run-TryBot: xie cui <523516579@qq.com> --- api/next/55301.txt | 1 + src/net/dial.go | 12 ++++++++++++ src/net/dial_test.go | 31 +++++++++++++++++++++++++++++++ src/net/iprawsock_posix.go | 16 ++++++++++++++-- src/net/ipsock_posix.go | 4 ++-- src/net/net_fake.go | 2 +- src/net/sock_posix.go | 35 ++++++++++++++++++++--------------- src/net/tcpsock_posix.go | 18 +++++++++++++++--- src/net/udpsock_posix.go | 24 +++++++++++++++++++++--- src/net/unixsock_posix.go | 28 +++++++++++++++++++++++----- 10 files changed, 140 insertions(+), 31 deletions(-) create mode 100644 api/next/55301.txt diff --git a/api/next/55301.txt b/api/next/55301.txt new file mode 100644 index 0000000000..e86ecfb7a2 --- /dev/null +++ b/api/next/55301.txt @@ -0,0 +1 @@ +pkg net, type Dialer struct, ControlContext func(context.Context, string, string, syscall.RawConn) error #55301 \ No newline at end of file diff --git a/src/net/dial.go b/src/net/dial.go index c538342566..0461ab12ca 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -95,7 +95,19 @@ type Dialer struct { // Network and address parameters passed to Control method are not // necessarily the ones passed to Dial. For example, passing "tcp" to Dial // will cause the Control function to be called with "tcp4" or "tcp6". + // + // Control is ignored if ControlContext is not nil. Control func(network, address string, c syscall.RawConn) error + + // If ControlContext is not nil, it is called after creating the network + // connection but before actually dialing. + // + // Network and address parameters passed to Control method are not + // necessarily the ones passed to Dial. For example, passing "tcp" to Dial + // will cause the Control function to be called with "tcp4" or "tcp6". + // + // If ControlContext is not nil, Control is ignored. + ControlContext func(cxt context.Context, network, address string, c syscall.RawConn) error } func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 } diff --git a/src/net/dial_test.go b/src/net/dial_test.go index 23e4a7a10c..b04607e48f 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -17,6 +17,7 @@ import ( "runtime" "strings" "sync" + "syscall" "testing" "time" ) @@ -939,6 +940,36 @@ func TestDialerControl(t *testing.T) { }) } +func TestDialerControlContext(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("%s does not have full support of socktest", runtime.GOOS) + } + t.Run("StreamDial", func(t *testing.T) { + for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} { + if !testableNetwork(network) { + continue + } + ln := newLocalListener(t, network) + defer ln.Close() + var id int + d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error { + id = ctx.Value("id").(int) + return controlOnConnSetup(network, address, c) + }} + c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String()) + if err != nil { + t.Error(err) + continue + } + if id != i+1 { + t.Errorf("got id %d, want %d", id, i+1) + } + c.Close() + } + }) +} + // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork // except that it won't skip testing on non-mobile builders. func mustHaveExternalNetwork(t *testing.T) { diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index 64112b08dd..7b4d23927f 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -122,7 +122,13 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, default: return nil, UnknownNetworkError(sd.network) } - fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control) + ctrlCtxFn := sd.Dialer.ControlContext + if ctrlCtxFn == nil && sd.Dialer.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sd.Dialer.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn) if err != nil { return nil, err } @@ -139,7 +145,13 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er default: return nil, UnknownNetworkError(sl.network) } - fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control) + var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error + if sl.ListenConfig.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sl.ListenConfig.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn) if err != nil { return nil, err } diff --git a/src/net/ipsock_posix.go b/src/net/ipsock_posix.go index 7bb66f2d6c..7fd676bd2c 100644 --- a/src/net/ipsock_posix.go +++ b/src/net/ipsock_posix.go @@ -134,12 +134,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam return syscall.AF_INET6, false } -func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) { +func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) { if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() { raddr = raddr.toLocal(net) } family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) - return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn) + return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn) } func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) { diff --git a/src/net/net_fake.go b/src/net/net_fake.go index 6d07d6297a..7e3a35fa67 100644 --- a/src/net/net_fake.go +++ b/src/net/net_fake.go @@ -57,7 +57,7 @@ type netFD struct { // socket returns a network file descriptor that is ready for // asynchronous I/O using the network poller. -func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) { +func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) { fd := &netFD{family: family, sotype: sotype, net: net} if laddr != nil && raddr == nil { // listener diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index 4431c3a6b3..b3e1806ba9 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -15,7 +15,7 @@ import ( // socket returns a network file descriptor that is ready for // asynchronous I/O using the network poller. -func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) { +func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) { s, err := sysSocket(family, sotype, proto) if err != nil { return nil, err @@ -54,20 +54,20 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only if laddr != nil && raddr == nil { switch sotype { case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET: - if err := fd.listenStream(laddr, listenerBacklog(), ctrlFn); err != nil { + if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil { fd.Close() return nil, err } return fd, nil case syscall.SOCK_DGRAM: - if err := fd.listenDatagram(laddr, ctrlFn); err != nil { + if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil { fd.Close() return nil, err } return fd, nil } } - if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil { + if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil { fd.Close() return nil, err } @@ -113,9 +113,11 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { return func(syscall.Sockaddr) Addr { return nil } } -func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error { - if ctrlFn != nil { - c, err := newRawConn(fd) +func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error { + var c *rawConn + var err error + if ctrlCtxFn != nil { + c, err = newRawConn(fd) if err != nil { return err } @@ -125,11 +127,11 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st } else if laddr != nil { ctrlAddr = laddr.String() } - if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil { + if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil { return err } } - var err error + var lsa syscall.Sockaddr if laddr != nil { if lsa, err = laddr.sockaddr(fd.family); err != nil { @@ -172,7 +174,7 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st return nil } -func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error { +func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error { var err error if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil { return err @@ -181,15 +183,17 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s if lsa, err = laddr.sockaddr(fd.family); err != nil { return err } - if ctrlFn != nil { + + if ctrlCtxFn != nil { c, err := newRawConn(fd) if err != nil { return err } - if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil { + if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil { return err } } + if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { return os.NewSyscallError("bind", err) } @@ -204,7 +208,7 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s return nil } -func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error { +func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error { switch addr := laddr.(type) { case *UDPAddr: // We provide a socket that listens to a wildcard @@ -233,12 +237,13 @@ func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, sysc if lsa, err = laddr.sockaddr(fd.family); err != nil { return err } - if ctrlFn != nil { + + if ctrlCtxFn != nil { c, err := newRawConn(fd) if err != nil { return err } - if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil { + if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil { return err } } diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index 1c91170c50..463b456173 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -65,7 +65,13 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo } func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { - fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control) + ctrlCtxFn := sd.Dialer.ControlContext + if ctrlCtxFn == nil && sd.Dialer.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sd.Dialer.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn) // TCP has a rarely used mechanism called a 'simultaneous connection' in // which Dial("tcp", addr1, addr2) run on the machine at addr1 can @@ -95,7 +101,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP if err == nil { fd.Close() } - fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control) + fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn) } if err != nil { @@ -168,7 +174,13 @@ func (ln *TCPListener) file() (*os.File, error) { } func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) { - fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control) + var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error + if sl.ListenConfig.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sl.ListenConfig.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", ctrlCtxFn) if err != nil { return nil, err } diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go index 5b021d24ae..ffeec81cff 100644 --- a/src/net/udpsock_posix.go +++ b/src/net/udpsock_posix.go @@ -203,7 +203,13 @@ func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn } func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control) + ctrlCtxFn := sd.Dialer.ControlContext + if ctrlCtxFn == nil && sd.Dialer.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sd.Dialer.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn) if err != nil { return nil, err } @@ -211,7 +217,13 @@ func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPCo } func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control) + var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error + if sl.ListenConfig.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sl.ListenConfig.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn) if err != nil { return nil, err } @@ -219,7 +231,13 @@ func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, } func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control) + var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error + if sl.ListenConfig.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sl.ListenConfig.Control(network, address, c) + } + } + fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn) if err != nil { return nil, err } diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go index b244dbdbbd..c16b483603 100644 --- a/src/net/unixsock_posix.go +++ b/src/net/unixsock_posix.go @@ -13,7 +13,7 @@ import ( "syscall" ) -func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) { +func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) { var sotype int switch net { case "unix": @@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str return nil, errors.New("unknown mode: " + mode) } - fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn) + fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn) if err != nil { return nil, err } @@ -155,7 +155,13 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err } func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) { - fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control) + ctrlCtxFn := sd.Dialer.ControlContext + if ctrlCtxFn == nil && sd.Dialer.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sd.Dialer.Control(network, address, c) + } + } + fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn) if err != nil { return nil, err } @@ -211,7 +217,13 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) { } func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) { - fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control) + var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error + if sl.ListenConfig.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sl.ListenConfig.Control(network, address, c) + } + } + fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn) if err != nil { return nil, err } @@ -219,7 +231,13 @@ func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixLi } func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) { - fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control) + var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error + if sl.ListenConfig.Control != nil { + ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error { + return sl.ListenConfig.Control(network, address, c) + } + } + fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn) if err != nil { return nil, err } -- 2.50.0