--- /dev/null
+pkg net, type Dialer struct, ControlContext func(context.Context, string, string, syscall.RawConn) error #55301
\ No newline at end of file
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
+ //
+ // Control is ignored if ControlContext is not nil.
Control func(network, address string, c syscall.RawConn) error
+
+ // If ControlContext is not nil, it is called after creating the network
+ // connection but before actually dialing.
+ //
+ // Network and address parameters passed to Control method are not
+ // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+ // will cause the Control function to be called with "tcp4" or "tcp6".
+ //
+ // If ControlContext is not nil, Control is ignored.
+ ControlContext func(cxt context.Context, network, address string, c syscall.RawConn) error
}
func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
"runtime"
"strings"
"sync"
+ "syscall"
"testing"
"time"
)
})
}
+func TestDialerControlContext(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ }
+ t.Run("StreamDial", func(t *testing.T) {
+ for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln := newLocalListener(t, network)
+ defer ln.Close()
+ var id int
+ d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+ id = ctx.Value("id").(int)
+ return controlOnConnSetup(network, address, c)
+ }}
+ c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if id != i+1 {
+ t.Errorf("got id %d, want %d", id, i+1)
+ }
+ c.Close()
+ }
+ })
+}
+
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
// except that it won't skip testing on non-mobile builders.
func mustHaveExternalNetwork(t *testing.T) {
default:
return nil, UnknownNetworkError(sd.network)
}
- fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
default:
return nil, UnknownNetworkError(sl.network)
}
- fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
return syscall.AF_INET6, false
}
-func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
raddr = raddr.toLocal(net)
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
- return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
+ return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
}
func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
fd := &netFD{family: family, sotype: sotype, net: net}
if laddr != nil && raddr == nil { // listener
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
if laddr != nil && raddr == nil {
switch sotype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
- if err := fd.listenStream(laddr, listenerBacklog(), ctrlFn); err != nil {
+ if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
case syscall.SOCK_DGRAM:
- if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
+ if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
}
- if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
+ if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return func(syscall.Sockaddr) Addr { return nil }
}
-func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
- if ctrlFn != nil {
- c, err := newRawConn(fd)
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
+ var c *rawConn
+ var err error
+ if ctrlCtxFn != nil {
+ c, err = newRawConn(fd)
if err != nil {
return err
}
} else if laddr != nil {
ctrlAddr = laddr.String()
}
- if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
+ if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil {
return err
}
}
- var err error
+
var lsa syscall.Sockaddr
if laddr != nil {
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return nil
}
-func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
+func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
var err error
if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
return err
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
}
- if ctrlFn != nil {
+
+ if ctrlCtxFn != nil {
c, err := newRawConn(fd)
if err != nil {
return err
}
- if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
}
+
if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
return nil
}
-func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
+func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
switch addr := laddr.(type) {
case *UDPAddr:
// We provide a socket that listens to a wildcard
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
}
- if ctrlFn != nil {
+
+ if ctrlCtxFn != nil {
c, err := newRawConn(fd)
if err != nil {
return err
}
- if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
}
}
func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
- fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
if err == nil {
fd.Close()
}
- fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
+ fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
}
if err != nil {
}
func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
- fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
}
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
"syscall"
)
-func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
var sotype int
switch net {
case "unix":
return nil, errors.New("unknown mode: " + mode)
}
- fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
+ fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn)
if err != nil {
return nil, err
}
}
func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
- fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}