]> Cypherpunks repositories - gostls13.git/commitdiff
net: add ControlContext to Dialer
authorcuiweixie <cuiweixie@gmail.com>
Sun, 23 Oct 2022 00:41:38 +0000 (08:41 +0800)
committerGopher Robot <gobot@golang.org>
Fri, 4 Nov 2022 02:00:13 +0000 (02:00 +0000)
Fixes #55301

Change-Id: Ie8abcd383eee9af75038bde908ac638f43d33b7e
Reviewed-on: https://go-review.googlesource.com/c/go/+/444955
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: xie cui <523516579@qq.com>

api/next/55301.txt [new file with mode: 0644]
src/net/dial.go
src/net/dial_test.go
src/net/iprawsock_posix.go
src/net/ipsock_posix.go
src/net/net_fake.go
src/net/sock_posix.go
src/net/tcpsock_posix.go
src/net/udpsock_posix.go
src/net/unixsock_posix.go

diff --git a/api/next/55301.txt b/api/next/55301.txt
new file mode 100644 (file)
index 0000000..e86ecfb
--- /dev/null
@@ -0,0 +1 @@
+pkg net, type Dialer struct, ControlContext func(context.Context, string, string, syscall.RawConn) error #55301
\ No newline at end of file
index c5383425666eb51b7911d9bc3f41d25ed414abbf..0461ab12cae34a6aceba5cc514c7cd4abd9d1d5c 100644 (file)
@@ -95,7 +95,19 @@ type Dialer struct {
        // Network and address parameters passed to Control method are not
        // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
        // will cause the Control function to be called with "tcp4" or "tcp6".
+       //
+       // Control is ignored if ControlContext is not nil.
        Control func(network, address string, c syscall.RawConn) error
+
+       // If ControlContext is not nil, it is called after creating the network
+       // connection but before actually dialing.
+       //
+       // Network and address parameters passed to Control method are not
+       // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+       // will cause the Control function to be called with "tcp4" or "tcp6".
+       //
+       // If ControlContext is not nil, Control is ignored.
+       ControlContext func(cxt context.Context, network, address string, c syscall.RawConn) error
 }
 
 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
index 23e4a7a10ce1814a3702bf5b8e9bd59c2b8f240c..b04607e48f150996bc9cfdf5749ebc886bb5e1b2 100644 (file)
@@ -17,6 +17,7 @@ import (
        "runtime"
        "strings"
        "sync"
+       "syscall"
        "testing"
        "time"
 )
@@ -939,6 +940,36 @@ func TestDialerControl(t *testing.T) {
        })
 }
 
+func TestDialerControlContext(t *testing.T) {
+       switch runtime.GOOS {
+       case "plan9":
+               t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+       }
+       t.Run("StreamDial", func(t *testing.T) {
+               for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       ln := newLocalListener(t, network)
+                       defer ln.Close()
+                       var id int
+                       d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+                               id = ctx.Value("id").(int)
+                               return controlOnConnSetup(network, address, c)
+                       }}
+                       c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       if id != i+1 {
+                               t.Errorf("got id %d, want %d", id, i+1)
+                       }
+                       c.Close()
+               }
+       })
+}
+
 // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
 // except that it won't skip testing on non-mobile builders.
 func mustHaveExternalNetwork(t *testing.T) {
index 64112b08dd4028c7bc069739409fef8e04cc6042..7b4d23927ff3f9d95fdb166e014ffb862d327d1b 100644 (file)
@@ -122,7 +122,13 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn,
        default:
                return nil, UnknownNetworkError(sd.network)
        }
-       fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
+       ctrlCtxFn := sd.Dialer.ControlContext
+       if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sd.Dialer.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
@@ -139,7 +145,13 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er
        default:
                return nil, UnknownNetworkError(sl.network)
        }
-       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
+       var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+       if sl.ListenConfig.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sl.ListenConfig.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
index 7bb66f2d6cce3d814c9a840f813c978b3597733e..7fd676bd2c2561d1b4bc5d430117d142b13edafa 100644 (file)
@@ -134,12 +134,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
        return syscall.AF_INET6, false
 }
 
-func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
        if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
                raddr = raddr.toLocal(net)
        }
        family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
-       return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
+       return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
 }
 
 func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {
index 6d07d6297a4c56fcdae9de3ea04f8e13bba8f3fd..7e3a35fa67b07f69b1512f742c4631127a0b2363 100644 (file)
@@ -57,7 +57,7 @@ type netFD struct {
 
 // socket returns a network file descriptor that is ready for
 // asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
        fd := &netFD{family: family, sotype: sotype, net: net}
 
        if laddr != nil && raddr == nil { // listener
index 4431c3a6b3d422c7c38607bfd8eafc2200d36488..b3e1806ba9f48cb0734683e483269f4bd1c377cc 100644 (file)
@@ -15,7 +15,7 @@ import (
 
 // socket returns a network file descriptor that is ready for
 // asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
        s, err := sysSocket(family, sotype, proto)
        if err != nil {
                return nil, err
@@ -54,20 +54,20 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
        if laddr != nil && raddr == nil {
                switch sotype {
                case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
-                       if err := fd.listenStream(laddr, listenerBacklog(), ctrlFn); err != nil {
+                       if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil {
                                fd.Close()
                                return nil, err
                        }
                        return fd, nil
                case syscall.SOCK_DGRAM:
-                       if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
+                       if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil {
                                fd.Close()
                                return nil, err
                        }
                        return fd, nil
                }
        }
-       if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
+       if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil {
                fd.Close()
                return nil, err
        }
@@ -113,9 +113,11 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
        return func(syscall.Sockaddr) Addr { return nil }
 }
 
-func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
-       if ctrlFn != nil {
-               c, err := newRawConn(fd)
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
+       var c *rawConn
+       var err error
+       if ctrlCtxFn != nil {
+               c, err = newRawConn(fd)
                if err != nil {
                        return err
                }
@@ -125,11 +127,11 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st
                } else if laddr != nil {
                        ctrlAddr = laddr.String()
                }
-               if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
+               if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil {
                        return err
                }
        }
-       var err error
+
        var lsa syscall.Sockaddr
        if laddr != nil {
                if lsa, err = laddr.sockaddr(fd.family); err != nil {
@@ -172,7 +174,7 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st
        return nil
 }
 
-func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
+func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
        var err error
        if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
                return err
@@ -181,15 +183,17 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s
        if lsa, err = laddr.sockaddr(fd.family); err != nil {
                return err
        }
-       if ctrlFn != nil {
+
+       if ctrlCtxFn != nil {
                c, err := newRawConn(fd)
                if err != nil {
                        return err
                }
-               if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+               if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
                        return err
                }
        }
+
        if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
                return os.NewSyscallError("bind", err)
        }
@@ -204,7 +208,7 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s
        return nil
 }
 
-func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
+func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
        switch addr := laddr.(type) {
        case *UDPAddr:
                // We provide a socket that listens to a wildcard
@@ -233,12 +237,13 @@ func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, sysc
        if lsa, err = laddr.sockaddr(fd.family); err != nil {
                return err
        }
-       if ctrlFn != nil {
+
+       if ctrlCtxFn != nil {
                c, err := newRawConn(fd)
                if err != nil {
                        return err
                }
-               if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+               if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
                        return err
                }
        }
index 1c91170c50091569aa8d15545c75bf007a1b3995..463b456173cc370b9f37309d3338bebd70240f8f 100644 (file)
@@ -65,7 +65,13 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo
 }
 
 func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
-       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
+       ctrlCtxFn := sd.Dialer.ControlContext
+       if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sd.Dialer.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
 
        // TCP has a rarely used mechanism called a 'simultaneous connection' in
        // which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -95,7 +101,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
                if err == nil {
                        fd.Close()
                }
-               fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
+               fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
        }
 
        if err != nil {
@@ -168,7 +174,13 @@ func (ln *TCPListener) file() (*os.File, error) {
 }
 
 func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
-       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
+       var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+       if sl.ListenConfig.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sl.ListenConfig.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
index 5b021d24ae60208621b211fb24ebc790ed4fb4b3..ffeec81cffdf13657b43045457a29e5603db5c5c 100644 (file)
@@ -203,7 +203,13 @@ func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn
 }
 
 func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
+       ctrlCtxFn := sd.Dialer.ControlContext
+       if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sd.Dialer.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
@@ -211,7 +217,13 @@ func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPCo
 }
 
 func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
+       var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+       if sl.ListenConfig.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sl.ListenConfig.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
@@ -219,7 +231,13 @@ func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn,
 }
 
 func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
+       var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+       if sl.ListenConfig.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sl.ListenConfig.Control(network, address, c)
+               }
+       }
+       fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
index b244dbdbbd08f22c17768438ece5f5ac96d8c15b..c16b483603f6ccecd3d34a4c640679667252ca11 100644 (file)
@@ -13,7 +13,7 @@ import (
        "syscall"
 )
 
-func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
        var sotype int
        switch net {
        case "unix":
@@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str
                return nil, errors.New("unknown mode: " + mode)
        }
 
-       fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
+       fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn)
        if err != nil {
                return nil, err
        }
@@ -155,7 +155,13 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
 }
 
 func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
-       fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
+       ctrlCtxFn := sd.Dialer.ControlContext
+       if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sd.Dialer.Control(network, address, c)
+               }
+       }
+       fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
@@ -211,7 +217,13 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
 }
 
 func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
-       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
+       var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+       if sl.ListenConfig.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sl.ListenConfig.Control(network, address, c)
+               }
+       }
+       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
        if err != nil {
                return nil, err
        }
@@ -219,7 +231,13 @@ func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixLi
 }
 
 func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
-       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
+       var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+       if sl.ListenConfig.Control != nil {
+               ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+                       return sl.ListenConfig.Control(network, address, c)
+               }
+       }
+       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
        if err != nil {
                return nil, err
        }