"context"
"internal/nettrace"
"internal/poll"
+ "syscall"
"time"
)
//
// Deprecated: Use DialContext instead.
Cancel <-chan struct{}
+
+ // If Control is not nil, it is called after creating the network
+ // connection but before actually dialing.
+ //
+ // Network and address parameters passed to Control method are not
+ // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+ // will cause the Control function to be called with "tcp4" or "tcp6".
+ Control func(network, address string, c syscall.RawConn) error
}
func minNonzeroTime(a, b time.Time) time.Time {
return c, nil
}
+// ListenConfig contains options for listening to an address.
+type ListenConfig struct {
+ // If Control is not nil, it is called after creating the network
+ // connection but before binding it to the operating system.
+ //
+ // Network and address parameters passed to Control method are not
+ // necessarily the ones passed to Listen. For example, passing "tcp" to
+ // Listen will cause the Control function to be called with "tcp4" or "tcp6".
+ Control func(network, address string, c syscall.RawConn) error
+}
+
+// Listen announces on the local network address.
+//
+// See func Listen for a description of the network and address
+// parameters.
+func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
+ addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+ sl := &sysListener{
+ ListenConfig: *lc,
+ network: network,
+ address: address,
+ }
+ var l Listener
+ la := addrs.first(isIPv4)
+ switch la := la.(type) {
+ case *TCPAddr:
+ l, err = sl.listenTCP(ctx, la)
+ case *UnixAddr:
+ l, err = sl.listenUnix(ctx, la)
+ default:
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer
+ }
+ return l, nil
+}
+
+// ListenPacket announces on the local network address.
+//
+// See func ListenPacket for a description of the network and address
+// parameters.
+func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
+ addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+ sl := &sysListener{
+ ListenConfig: *lc,
+ network: network,
+ address: address,
+ }
+ var c PacketConn
+ la := addrs.first(isIPv4)
+ switch la := la.(type) {
+ case *UDPAddr:
+ c, err = sl.listenUDP(ctx, la)
+ case *IPAddr:
+ c, err = sl.listenIP(ctx, la)
+ case *UnixAddr:
+ c, err = sl.listenUnixgram(ctx, la)
+ default:
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer
+ }
+ return c, nil
+}
+
// sysListener contains a Listen's parameters and configuration.
type sysListener struct {
+ ListenConfig
network, address string
}
// See func Dial for a description of the network and address
// parameters.
func Listen(network, address string) (Listener, error) {
- addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
- if err != nil {
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
- }
- var l Listener
- switch la := addrs.first(isIPv4).(type) {
- case *TCPAddr:
- l, err = ListenTCP(network, la)
- case *UnixAddr:
- l, err = ListenUnix(network, la)
- default:
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
- }
- if err != nil {
- return nil, err // l is non-nil interface containing nil pointer
- }
- return l, nil
+ var lc ListenConfig
+ return lc.Listen(context.Background(), network, address)
}
// ListenPacket announces on the local network address.
// See func Dial for a description of the network and address
// parameters.
func ListenPacket(network, address string) (PacketConn, error) {
- addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
- if err != nil {
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
- }
- var l PacketConn
- switch la := addrs.first(isIPv4).(type) {
- case *UDPAddr:
- l, err = ListenUDP(network, la)
- case *IPAddr:
- l, err = ListenIP(network, la)
- case *UnixAddr:
- l, err = ListenUnixgram(network, la)
- default:
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
- }
- if err != nil {
- return nil, err // l is non-nil interface containing nil pointer
- }
- return l, nil
+ var lc ListenConfig
+ return lc.ListenPacket(context.Background(), network, address)
}
c.Close()
}
+func TestDialerControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("StreamDial", func(t *testing.T) {
+ for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln, err := newLocalListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ defer ln.Close()
+ d := Dialer{Control: controlOnConnSetup}
+ c, err := d.Dial(network, ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c.Close()
+ }
+ })
+ t.Run("PacketDial", func(t *testing.T) {
+ for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c1, err := newLocalPacketListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if network == "unixgram" {
+ defer os.Remove(c1.LocalAddr().String())
+ }
+ defer c1.Close()
+ d := Dialer{Control: controlOnConnSetup}
+ c2, err := d.Dial(network, c1.LocalAddr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c2.Close()
+ }
+ })
+}
+
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
// except that it won't skip testing on non-iOS builders.
func mustHaveExternalNetwork(t *testing.T) {
default:
return nil, UnknownNetworkError(sd.network)
}
- fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
+ fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
default:
return nil, UnknownNetworkError(sl.network)
}
- fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
+ fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
return syscall.AF_INET6, false
}
-func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() {
raddr = raddr.toLocal(net)
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
- return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
+ return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
}
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
package net
import (
+ "context"
"fmt"
"internal/testenv"
"os"
}
ln2.Close()
}
+
+func TestListenConfigControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("StreamListen", func(t *testing.T) {
+ for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln, err := newLocalListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ address := ln.Addr().String()
+ ln.Close()
+ lc := ListenConfig{Control: controlOnConnSetup}
+ ln, err = lc.Listen(context.Background(), network, address)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ ln.Close()
+ }
+ })
+ t.Run("PacketListen", func(t *testing.T) {
+ for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c, err := newLocalPacketListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ address := c.LocalAddr().String()
+ c.Close()
+ if network == "unixgram" {
+ os.Remove(address)
+ }
+ lc := ListenConfig{Control: controlOnConnSetup}
+ c, err = lc.ListenPacket(context.Background(), network, address)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c.Close()
+ }
+ })
+}
func controlRawConn(c syscall.RawConn, addr Addr) error {
return errors.New("not supported")
}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ return nil
+}
package net
-import "syscall"
+import (
+ "errors"
+ "syscall"
+)
func readRawConn(c syscall.RawConn, b []byte) (int, error) {
var operr error
}
return nil
}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ var operr error
+ var fn func(uintptr)
+ switch network {
+ case "tcp", "udp", "ip":
+ return errors.New("ambiguous network: " + network)
+ case "unix", "unixpacket", "unixgram":
+ fn = func(s uintptr) {
+ _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_ERROR)
+ }
+ default:
+ switch network[len(network)-1] {
+ case '4':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ case '6':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ }
+ default:
+ return errors.New("unknown network: " + network)
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ if operr != nil {
+ return operr
+ }
+ return nil
+}
package net
import (
+ "errors"
"syscall"
"unsafe"
)
}
return nil
}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ var operr error
+ var fn func(uintptr)
+ switch network {
+ case "tcp", "udp", "ip":
+ return errors.New("ambiguous network: " + network)
+ default:
+ switch network[len(network)-1] {
+ case '4':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ case '6':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ }
+ default:
+ return errors.New("unknown network: " + network)
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ if operr != nil {
+ return operr
+ }
+ return nil
+}
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
if laddr != nil && raddr == nil {
switch sotype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
- if err := fd.listenStream(laddr, listenerBacklog); err != nil {
+ if err := fd.listenStream(laddr, listenerBacklog, ctrlFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
case syscall.SOCK_DGRAM:
- if err := fd.listenDatagram(laddr); err != nil {
+ if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
}
- if err := fd.dial(ctx, laddr, raddr); err != nil {
+ if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
+func (fd *netFD) ctrlNetwork() string {
+ switch fd.net {
+ case "unix", "unixgram", "unixpacket":
+ return fd.net
+ }
+ switch fd.net[len(fd.net)-1] {
+ case '4', '6':
+ return fd.net
+ }
+ if fd.family == syscall.AF_INET {
+ return fd.net + "4"
+ }
+ return fd.net + "6"
+}
+
func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
switch fd.family {
case syscall.AF_INET, syscall.AF_INET6:
return func(syscall.Sockaddr) Addr { return nil }
}
-func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
+ if ctrlFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ var ctrlAddr string
+ if raddr != nil {
+ ctrlAddr = raddr.String()
+ } else if laddr != nil {
+ ctrlAddr = laddr.String()
+ }
+ if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
+ return err
+ }
+ }
var err error
var lsa syscall.Sockaddr
if laddr != nil {
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
} else if lsa != nil {
- if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
}
return nil
}
-func (fd *netFD) listenStream(laddr sockaddr, backlog int) error {
- if err := setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
+func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
+ var err error
+ if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
return err
}
- if lsa, err := laddr.sockaddr(fd.family); err != nil {
+ var lsa syscall.Sockaddr
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
- } else if lsa != nil {
- if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
- return os.NewSyscallError("bind", err)
+ }
+ if ctrlFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ return err
}
}
- if err := listenFunc(fd.pfd.Sysfd, backlog); err != nil {
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
+ }
+ if err = listenFunc(fd.pfd.Sysfd, backlog); err != nil {
return os.NewSyscallError("listen", err)
}
- if err := fd.init(); err != nil {
+ if err = fd.init(); err != nil {
return err
}
- lsa, _ := syscall.Getsockname(fd.pfd.Sysfd)
+ lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
fd.setAddr(fd.addrFunc()(lsa), nil)
return nil
}
-func (fd *netFD) listenDatagram(laddr sockaddr) error {
+func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
switch addr := laddr.(type) {
case *UDPAddr:
// We provide a socket that listens to a wildcard
laddr = &addr
}
}
- if lsa, err := laddr.sockaddr(fd.family); err != nil {
+ var err error
+ var lsa syscall.Sockaddr
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
- } else if lsa != nil {
- if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
- return os.NewSyscallError("bind", err)
+ }
+ if ctrlFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ return err
}
}
- if err := fd.init(); err != nil {
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
+ }
+ if err = fd.init(); err != nil {
return err
}
- lsa, _ := syscall.Getsockname(fd.pfd.Sysfd)
+ lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
fd.setAddr(fd.addrFunc()(lsa), nil)
return nil
}
}
func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
- fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
if err == nil {
fd.Close()
}
- fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
+ fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
}
if err != nil {
}
func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
- fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen")
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
}
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial")
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen")
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen")
+ fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
"syscall"
)
-func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) {
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
var sotype int
switch net {
case "unix":
return nil, errors.New("unknown mode: " + mode)
}
- fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr)
+ fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
if err != nil {
return nil, err
}
}
func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial")
+ fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
- fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen")
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
}
func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen")
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}