]> Cypherpunks repositories - gostls13.git/commitdiff
net: add ListenConfig, Dialer.Control to permit socket opts before listen/dial
authorAudrius Butkevicius <audrius.butkevicius@gmail.com>
Mon, 28 May 2018 01:47:21 +0000 (02:47 +0100)
committerIan Lance Taylor <iant@golang.org>
Wed, 30 May 2018 22:54:22 +0000 (22:54 +0000)
Existing implementation does not provide a way to set options such as
SO_REUSEPORT, that has to be set prior the socket being bound.

New exposed API:
pkg net, method (*ListenConfig) Listen(context.Context, string, string) (Listener, error)
pkg net, method (*ListenConfig) ListenPacket(context.Context, string, string) (PacketConn, error)
pkg net, type ListenConfig struct
pkg net, type ListenConfig struct, Control func(string, string, syscall.RawConn) error
pkg net, type Dialer struct, Control func(string, string, syscall.RawConn) error

Fixes #9661

Change-Id: If4d275711f823df72d3ac5cc3858651a6a57cccb
Reviewed-on: https://go-review.googlesource.com/72810
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
12 files changed:
src/net/dial.go
src/net/dial_test.go
src/net/iprawsock_posix.go
src/net/ipsock_posix.go
src/net/listen_test.go
src/net/rawconn_stub_test.go
src/net/rawconn_unix_test.go
src/net/rawconn_windows_test.go
src/net/sock_posix.go
src/net/tcpsock_posix.go
src/net/udpsock_posix.go
src/net/unixsock_posix.go

index 3ea049ca460e79a05e04b820d922a4dbad594a98..b1a5ca7cd53f58d3c6f62b780bbdb16e715b50f7 100644 (file)
@@ -8,6 +8,7 @@ import (
        "context"
        "internal/nettrace"
        "internal/poll"
+       "syscall"
        "time"
 )
 
@@ -70,6 +71,14 @@ type Dialer struct {
        //
        // Deprecated: Use DialContext instead.
        Cancel <-chan struct{}
+
+       // If Control is not nil, it is called after creating the network
+       // connection but before actually dialing.
+       //
+       // Network and address parameters passed to Control method are not
+       // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+       // will cause the Control function to be called with "tcp4" or "tcp6".
+       Control func(network, address string, c syscall.RawConn) error
 }
 
 func minNonzeroTime(a, b time.Time) time.Time {
@@ -563,8 +572,82 @@ func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error
        return c, nil
 }
 
+// ListenConfig contains options for listening to an address.
+type ListenConfig struct {
+       // If Control is not nil, it is called after creating the network
+       // connection but before binding it to the operating system.
+       //
+       // Network and address parameters passed to Control method are not
+       // necessarily the ones passed to Listen. For example, passing "tcp" to
+       // Listen will cause the Control function to be called with "tcp4" or "tcp6".
+       Control func(network, address string, c syscall.RawConn) error
+}
+
+// Listen announces on the local network address.
+//
+// See func Listen for a description of the network and address
+// parameters.
+func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
+       addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+       if err != nil {
+               return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+       }
+       sl := &sysListener{
+               ListenConfig: *lc,
+               network:      network,
+               address:      address,
+       }
+       var l Listener
+       la := addrs.first(isIPv4)
+       switch la := la.(type) {
+       case *TCPAddr:
+               l, err = sl.listenTCP(ctx, la)
+       case *UnixAddr:
+               l, err = sl.listenUnix(ctx, la)
+       default:
+               return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+       }
+       if err != nil {
+               return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer
+       }
+       return l, nil
+}
+
+// ListenPacket announces on the local network address.
+//
+// See func ListenPacket for a description of the network and address
+// parameters.
+func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
+       addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+       if err != nil {
+               return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+       }
+       sl := &sysListener{
+               ListenConfig: *lc,
+               network:      network,
+               address:      address,
+       }
+       var c PacketConn
+       la := addrs.first(isIPv4)
+       switch la := la.(type) {
+       case *UDPAddr:
+               c, err = sl.listenUDP(ctx, la)
+       case *IPAddr:
+               c, err = sl.listenIP(ctx, la)
+       case *UnixAddr:
+               c, err = sl.listenUnixgram(ctx, la)
+       default:
+               return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+       }
+       if err != nil {
+               return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer
+       }
+       return c, nil
+}
+
 // sysListener contains a Listen's parameters and configuration.
 type sysListener struct {
+       ListenConfig
        network, address string
 }
 
@@ -587,23 +670,8 @@ type sysListener struct {
 // See func Dial for a description of the network and address
 // parameters.
 func Listen(network, address string) (Listener, error) {
-       addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
-       if err != nil {
-               return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
-       }
-       var l Listener
-       switch la := addrs.first(isIPv4).(type) {
-       case *TCPAddr:
-               l, err = ListenTCP(network, la)
-       case *UnixAddr:
-               l, err = ListenUnix(network, la)
-       default:
-               return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
-       }
-       if err != nil {
-               return nil, err // l is non-nil interface containing nil pointer
-       }
-       return l, nil
+       var lc ListenConfig
+       return lc.Listen(context.Background(), network, address)
 }
 
 // ListenPacket announces on the local network address.
@@ -629,23 +697,6 @@ func Listen(network, address string) (Listener, error) {
 // See func Dial for a description of the network and address
 // parameters.
 func ListenPacket(network, address string) (PacketConn, error) {
-       addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
-       if err != nil {
-               return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
-       }
-       var l PacketConn
-       switch la := addrs.first(isIPv4).(type) {
-       case *UDPAddr:
-               l, err = ListenUDP(network, la)
-       case *IPAddr:
-               l, err = ListenIP(network, la)
-       case *UnixAddr:
-               l, err = ListenUnixgram(network, la)
-       default:
-               return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
-       }
-       if err != nil {
-               return nil, err // l is non-nil interface containing nil pointer
-       }
-       return l, nil
+       var lc ListenConfig
+       return lc.ListenPacket(context.Background(), network, address)
 }
index 96d8921ec85250f798714d4216c695d77f62193a..3934ad864836056dfc10ce7d4bce9aca87bd9ffd 100644 (file)
@@ -912,6 +912,57 @@ func TestDialListenerAddr(t *testing.T) {
        c.Close()
 }
 
+func TestDialerControl(t *testing.T) {
+       switch runtime.GOOS {
+       case "nacl", "plan9":
+               t.Skipf("not supported on %s", runtime.GOOS)
+       }
+
+       t.Run("StreamDial", func(t *testing.T) {
+               for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       ln, err := newLocalListener(network)
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       defer ln.Close()
+                       d := Dialer{Control: controlOnConnSetup}
+                       c, err := d.Dial(network, ln.Addr().String())
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       c.Close()
+               }
+       })
+       t.Run("PacketDial", func(t *testing.T) {
+               for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       c1, err := newLocalPacketListener(network)
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       if network == "unixgram" {
+                               defer os.Remove(c1.LocalAddr().String())
+                       }
+                       defer c1.Close()
+                       d := Dialer{Control: controlOnConnSetup}
+                       c2, err := d.Dial(network, c1.LocalAddr().String())
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       c2.Close()
+               }
+       })
+}
+
 // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
 // except that it won't skip testing on non-iOS builders.
 func mustHaveExternalNetwork(t *testing.T) {
index 7dafd20bf68619b96b9470c7fc807f07adc92fd7..b2f57916433eb807a21efe5c5f00c0a4221600ad 100644 (file)
@@ -122,7 +122,7 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn,
        default:
                return nil, UnknownNetworkError(sd.network)
        }
-       fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
+       fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
        if err != nil {
                return nil, err
        }
@@ -139,7 +139,7 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er
        default:
                return nil, UnknownNetworkError(sl.network)
        }
-       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
+       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
        if err != nil {
                return nil, err
        }
index 8372aaa7423cf4e59136845f057cbbbe443d83ee..eddd4118fa231ad5eee87759c87a5ffebfb5c4db 100644 (file)
@@ -133,12 +133,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
        return syscall.AF_INET6, false
 }
 
-func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
        if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() {
                raddr = raddr.toLocal(net)
        }
        family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
-       return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
+       return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
 }
 
 func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
index 96624f98ce53fde50ea55f6327d614548ad97030..ffd38d79506d65935eaf73aedf5ae0963fe23330 100644 (file)
@@ -7,6 +7,7 @@
 package net
 
 import (
+       "context"
        "fmt"
        "internal/testenv"
        "os"
@@ -729,3 +730,56 @@ func TestClosingListener(t *testing.T) {
        }
        ln2.Close()
 }
+
+func TestListenConfigControl(t *testing.T) {
+       switch runtime.GOOS {
+       case "nacl", "plan9":
+               t.Skipf("not supported on %s", runtime.GOOS)
+       }
+
+       t.Run("StreamListen", func(t *testing.T) {
+               for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       ln, err := newLocalListener(network)
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       address := ln.Addr().String()
+                       ln.Close()
+                       lc := ListenConfig{Control: controlOnConnSetup}
+                       ln, err = lc.Listen(context.Background(), network, address)
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       ln.Close()
+               }
+       })
+       t.Run("PacketListen", func(t *testing.T) {
+               for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       c, err := newLocalPacketListener(network)
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       address := c.LocalAddr().String()
+                       c.Close()
+                       if network == "unixgram" {
+                               os.Remove(address)
+                       }
+                       lc := ListenConfig{Control: controlOnConnSetup}
+                       c, err = lc.ListenPacket(context.Background(), network, address)
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       c.Close()
+               }
+       })
+}
index 391b4d188e2e25b4189b38f96bc275e8365c6762..3e3b6bf5b2a1aa6260791d042621791031a4fb82 100644 (file)
@@ -22,3 +22,7 @@ func writeRawConn(c syscall.RawConn, b []byte) error {
 func controlRawConn(c syscall.RawConn, addr Addr) error {
        return errors.New("not supported")
 }
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+       return nil
+}
index 2fe4d2c6bace69d78bce014ee42e7c71f3a489f4..a720a8a4a3e274450ac2f468edef8c89b6234b49 100644 (file)
@@ -6,7 +6,10 @@
 
 package net
 
-import "syscall"
+import (
+       "errors"
+       "syscall"
+)
 
 func readRawConn(c syscall.RawConn, b []byte) (int, error) {
        var operr error
@@ -89,3 +92,36 @@ func controlRawConn(c syscall.RawConn, addr Addr) error {
        }
        return nil
 }
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+       var operr error
+       var fn func(uintptr)
+       switch network {
+       case "tcp", "udp", "ip":
+               return errors.New("ambiguous network: " + network)
+       case "unix", "unixpacket", "unixgram":
+               fn = func(s uintptr) {
+                       _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_ERROR)
+               }
+       default:
+               switch network[len(network)-1] {
+               case '4':
+                       fn = func(s uintptr) {
+                               operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+                       }
+               case '6':
+                       fn = func(s uintptr) {
+                               operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+                       }
+               default:
+                       return errors.New("unknown network: " + network)
+               }
+       }
+       if err := c.Control(fn); err != nil {
+               return err
+       }
+       if operr != nil {
+               return operr
+       }
+       return nil
+}
index 6df101e9de4b2523dc232216db4a3effe6cb076d..2774c97e5c82e37937f8f23efab4ad9a2b76e780 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "errors"
        "syscall"
        "unsafe"
 )
@@ -96,3 +97,32 @@ func controlRawConn(c syscall.RawConn, addr Addr) error {
        }
        return nil
 }
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+       var operr error
+       var fn func(uintptr)
+       switch network {
+       case "tcp", "udp", "ip":
+               return errors.New("ambiguous network: " + network)
+       default:
+               switch network[len(network)-1] {
+               case '4':
+                       fn = func(s uintptr) {
+                               operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+                       }
+               case '6':
+                       fn = func(s uintptr) {
+                               operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+                       }
+               default:
+                       return errors.New("unknown network: " + network)
+               }
+       }
+       if err := c.Control(fn); err != nil {
+               return err
+       }
+       if operr != nil {
+               return operr
+       }
+       return nil
+}
index 8cfc42eb7e66fb777146fd7d69945110b67c1427..00ff3fd39394ce902a55031d38a04e35efd393b2 100644 (file)
@@ -38,7 +38,7 @@ type sockaddr interface {
 
 // socket returns a network file descriptor that is ready for
 // asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
        s, err := sysSocket(family, sotype, proto)
        if err != nil {
                return nil, err
@@ -77,26 +77,41 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
        if laddr != nil && raddr == nil {
                switch sotype {
                case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
-                       if err := fd.listenStream(laddr, listenerBacklog); err != nil {
+                       if err := fd.listenStream(laddr, listenerBacklog, ctrlFn); err != nil {
                                fd.Close()
                                return nil, err
                        }
                        return fd, nil
                case syscall.SOCK_DGRAM:
-                       if err := fd.listenDatagram(laddr); err != nil {
+                       if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
                                fd.Close()
                                return nil, err
                        }
                        return fd, nil
                }
        }
-       if err := fd.dial(ctx, laddr, raddr); err != nil {
+       if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
                fd.Close()
                return nil, err
        }
        return fd, nil
 }
 
+func (fd *netFD) ctrlNetwork() string {
+       switch fd.net {
+       case "unix", "unixgram", "unixpacket":
+               return fd.net
+       }
+       switch fd.net[len(fd.net)-1] {
+       case '4', '6':
+               return fd.net
+       }
+       if fd.family == syscall.AF_INET {
+               return fd.net + "4"
+       }
+       return fd.net + "6"
+}
+
 func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
        switch fd.family {
        case syscall.AF_INET, syscall.AF_INET6:
@@ -121,14 +136,29 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
        return func(syscall.Sockaddr) Addr { return nil }
 }
 
-func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
+       if ctrlFn != nil {
+               c, err := newRawConn(fd)
+               if err != nil {
+                       return err
+               }
+               var ctrlAddr string
+               if raddr != nil {
+                       ctrlAddr = raddr.String()
+               } else if laddr != nil {
+                       ctrlAddr = laddr.String()
+               }
+               if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
+                       return err
+               }
+       }
        var err error
        var lsa syscall.Sockaddr
        if laddr != nil {
                if lsa, err = laddr.sockaddr(fd.family); err != nil {
                        return err
                } else if lsa != nil {
-                       if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+                       if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
                                return os.NewSyscallError("bind", err)
                        }
                }
@@ -165,29 +195,39 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
        return nil
 }
 
-func (fd *netFD) listenStream(laddr sockaddr, backlog int) error {
-       if err := setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
+func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
+       var err error
+       if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
                return err
        }
-       if lsa, err := laddr.sockaddr(fd.family); err != nil {
+       var lsa syscall.Sockaddr
+       if lsa, err = laddr.sockaddr(fd.family); err != nil {
                return err
-       } else if lsa != nil {
-               if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
-                       return os.NewSyscallError("bind", err)
+       }
+       if ctrlFn != nil {
+               c, err := newRawConn(fd)
+               if err != nil {
+                       return err
+               }
+               if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+                       return err
                }
        }
-       if err := listenFunc(fd.pfd.Sysfd, backlog); err != nil {
+       if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+               return os.NewSyscallError("bind", err)
+       }
+       if err = listenFunc(fd.pfd.Sysfd, backlog); err != nil {
                return os.NewSyscallError("listen", err)
        }
-       if err := fd.init(); err != nil {
+       if err = fd.init(); err != nil {
                return err
        }
-       lsa, _ := syscall.Getsockname(fd.pfd.Sysfd)
+       lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
        fd.setAddr(fd.addrFunc()(lsa), nil)
        return nil
 }
 
-func (fd *netFD) listenDatagram(laddr sockaddr) error {
+func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
        switch addr := laddr.(type) {
        case *UDPAddr:
                // We provide a socket that listens to a wildcard
@@ -211,17 +251,27 @@ func (fd *netFD) listenDatagram(laddr sockaddr) error {
                        laddr = &addr
                }
        }
-       if lsa, err := laddr.sockaddr(fd.family); err != nil {
+       var err error
+       var lsa syscall.Sockaddr
+       if lsa, err = laddr.sockaddr(fd.family); err != nil {
                return err
-       } else if lsa != nil {
-               if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
-                       return os.NewSyscallError("bind", err)
+       }
+       if ctrlFn != nil {
+               c, err := newRawConn(fd)
+               if err != nil {
+                       return err
+               }
+               if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+                       return err
                }
        }
-       if err := fd.init(); err != nil {
+       if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+               return os.NewSyscallError("bind", err)
+       }
+       if err = fd.init(); err != nil {
                return err
        }
-       lsa, _ := syscall.Getsockname(fd.pfd.Sysfd)
+       lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
        fd.setAddr(fd.addrFunc()(lsa), nil)
        return nil
 }
index 6061c16986c22a1a4f585f5130bb0f860a8ef9e8..bcf7592d35f791789de74ac79d8bd20d9fa30b22 100644 (file)
@@ -62,7 +62,7 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo
 }
 
 func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
-       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
+       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
 
        // TCP has a rarely used mechanism called a 'simultaneous connection' in
        // which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -92,7 +92,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
                if err == nil {
                        fd.Close()
                }
-               fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
+               fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
        }
 
        if err != nil {
@@ -156,7 +156,7 @@ func (ln *TCPListener) file() (*os.File, error) {
 }
 
 func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
-       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen")
+       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
        if err != nil {
                return nil, err
        }
index 4e96f4781df47bfbedd35d35ef5dbd38bdc98536..8f4b71c01ec8023e49fd98cb4938a34c09337ce6 100644 (file)
@@ -95,7 +95,7 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
 }
 
 func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial")
+       fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
        if err != nil {
                return nil, err
        }
@@ -103,7 +103,7 @@ func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPCo
 }
 
 func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen")
+       fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
        if err != nil {
                return nil, err
        }
@@ -111,7 +111,7 @@ func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn,
 }
 
 func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen")
+       fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
        if err != nil {
                return nil, err
        }
index f627567af5f0f38cb0de8ac6867d0a73c75fb100..2495da1d253fa675fe9315c4374083886ffd7289 100644 (file)
@@ -13,7 +13,7 @@ import (
        "syscall"
 )
 
-func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) {
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
        var sotype int
        switch net {
        case "unix":
@@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str
                return nil, errors.New("unknown mode: " + mode)
        }
 
-       fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr)
+       fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
        if err != nil {
                return nil, err
        }
@@ -151,7 +151,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
 }
 
 func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
-       fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial")
+       fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
        if err != nil {
                return nil, err
        }
@@ -207,7 +207,7 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
 }
 
 func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
-       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen")
+       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
        if err != nil {
                return nil, err
        }
@@ -215,7 +215,7 @@ func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixLi
 }
 
 func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
-       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen")
+       fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
        if err != nil {
                return nil, err
        }