]> Cypherpunks repositories - gostls13.git/commitdiff
net: context aware Dialer.Dial functions
authorMichael Fraenkel <michael.fraenkel@gmail.com>
Sun, 30 Apr 2023 15:12:27 +0000 (09:12 -0600)
committerSean Liao <sean@liao.dev>
Mon, 11 Aug 2025 21:26:10 +0000 (14:26 -0700)
Add context aware dial functions for TCP, UDP, IP and Unix networks.

Fixes #49097
Updates #59897

Change-Id: I7523452e8e463a587a852e0555cec822d8dcb3dd
Reviewed-on: https://go-review.googlesource.com/c/go/+/490975
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
api/next/49097.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/49097.md [new file with mode: 0644]
src/net/dial.go
src/net/dial_test.go
src/net/iprawsock.go
src/net/tcpsock.go
src/net/udpsock.go
src/net/unixsock.go

diff --git a/api/next/49097.txt b/api/next/49097.txt
new file mode 100644 (file)
index 0000000..f724095
--- /dev/null
@@ -0,0 +1,4 @@
+pkg net, method (*Dialer) DialIP(context.Context, string, netip.Addr, netip.Addr) (*IPConn, error) #49097
+pkg net, method (*Dialer) DialTCP(context.Context, string, netip.AddrPort, netip.AddrPort) (*TCPConn, error) #49097
+pkg net, method (*Dialer) DialUDP(context.Context, string, netip.AddrPort, netip.AddrPort) (*UDPConn, error) #49097
+pkg net, method (*Dialer) DialUnix(context.Context, string, *UnixAddr, *UnixAddr) (*UnixConn, error) #49097
diff --git a/doc/next/6-stdlib/99-minor/net/49097.md b/doc/next/6-stdlib/99-minor/net/49097.md
new file mode 100644 (file)
index 0000000..bb7947b
--- /dev/null
@@ -0,0 +1 @@
+Added context aware dial functions for TCP, UDP, IP and Unix networks.
index 6264984ceca182bc394a6851233c4806586975b7..a87c57603a813ce70e64584da345507d1144db7d 100644 (file)
@@ -9,6 +9,7 @@ import (
        "internal/bytealg"
        "internal/godebug"
        "internal/nettrace"
+       "net/netip"
        "syscall"
        "time"
 )
@@ -523,30 +524,8 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
 // See func [Dial] for a description of the network and address
 // parameters.
 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
-       if ctx == nil {
-               panic("nil context")
-       }
-       deadline := d.deadline(ctx, time.Now())
-       if !deadline.IsZero() {
-               testHookStepTime()
-               if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
-                       subCtx, cancel := context.WithDeadline(ctx, deadline)
-                       defer cancel()
-                       ctx = subCtx
-               }
-       }
-       if oldCancel := d.Cancel; oldCancel != nil {
-               subCtx, cancel := context.WithCancel(ctx)
-               defer cancel()
-               go func() {
-                       select {
-                       case <-oldCancel:
-                               cancel()
-                       case <-subCtx.Done():
-                       }
-               }()
-               ctx = subCtx
-       }
+       ctx, cancel := d.dialCtx(ctx)
+       defer cancel()
 
        // Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
        resolveCtx := ctx
@@ -578,6 +557,97 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
        return sd.dialParallel(ctx, primaries, fallbacks)
 }
 
+func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
+       if ctx == nil {
+               panic("nil context")
+       }
+       deadline := d.deadline(ctx, time.Now())
+       var cancel1, cancel2 context.CancelFunc
+       if !deadline.IsZero() {
+               testHookStepTime()
+               if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
+                       var subCtx context.Context
+                       subCtx, cancel1 = context.WithDeadline(ctx, deadline)
+                       ctx = subCtx
+               }
+       }
+       if oldCancel := d.Cancel; oldCancel != nil {
+               subCtx, cancel2 := context.WithCancel(ctx)
+               go func() {
+                       select {
+                       case <-oldCancel:
+                               cancel2()
+                       case <-subCtx.Done():
+                       }
+               }()
+               ctx = subCtx
+       }
+       return ctx, func() {
+               if cancel1 != nil {
+                       cancel1()
+               }
+               if cancel2 != nil {
+                       cancel2()
+               }
+       }
+}
+
+// DialTCP acts like Dial for TCP networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be a TCP network name; see func Dial for details.
+func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
+       ctx, cancel := d.dialCtx(ctx)
+       defer cancel()
+       return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
+}
+
+// DialUDP acts like Dial for UDP networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be a UDP network name; see func Dial for details.
+func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
+       ctx, cancel := d.dialCtx(ctx)
+       defer cancel()
+       return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
+}
+
+// DialIP acts like Dial for IP networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be an IP network name; see func Dial for details.
+func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
+       ctx, cancel := d.dialCtx(ctx)
+       defer cancel()
+       return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
+}
+
+// DialUnix acts like Dial for Unix networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be a Unix network name; see func Dial for details.
+func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
+       ctx, cancel := d.dialCtx(ctx)
+       defer cancel()
+       return dialUnix(ctx, d, network, laddr, raddr)
+}
+
 // dialParallel races two copies of dialSerial, giving the first a
 // head start. It returns the first established connection and
 // closes the others. Otherwise it returns an error from the first
index b3bedb2fa275c3ae62bc1fd3f79f74a975f4b676..829b80c33a198d4cf213d300ee0b1938db57a2b7 100644 (file)
@@ -11,6 +11,7 @@ import (
        "fmt"
        "internal/testenv"
        "io"
+       "net/netip"
        "os"
        "runtime"
        "strings"
@@ -1064,6 +1065,99 @@ func TestDialerControlContext(t *testing.T) {
        })
 }
 
+func TestDialContext(t *testing.T) {
+       switch runtime.GOOS {
+       case "plan9":
+               t.Skipf("not supported on %s", runtime.GOOS)
+       case "js", "wasip1":
+               t.Skipf("skipping: fake net does not support Dialer.ControlContext")
+       }
+
+       t.Run("StreamDial", func(t *testing.T) {
+               var err error
+               for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       ln := newLocalListener(t, network)
+                       defer ln.Close()
+                       var id int
+                       d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+                               id = ctx.Value("id").(int)
+                               return controlOnConnSetup(network, address, c)
+                       }}
+                       var c Conn
+                       switch network {
+                       case "tcp", "tcp4", "tcp6":
+                               raddr, err := netip.ParseAddrPort(ln.Addr().String())
+                               if err != nil {
+                                       t.Error(err)
+                                       continue
+                               }
+                               c, err = d.DialTCP(context.WithValue(context.Background(), "id", i+1), network, (*TCPAddr)(nil).AddrPort(), raddr)
+                       case "unix", "unixpacket":
+                               raddr, err := ResolveUnixAddr(network, ln.Addr().String())
+                               if err != nil {
+                                       t.Error(err)
+                                       continue
+                               }
+                               c, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
+                       }
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       if id != i+1 {
+                               t.Errorf("%s: got id %d, want %d", network, id, i+1)
+                       }
+                       c.Close()
+               }
+       })
+       t.Run("PacketDial", func(t *testing.T) {
+               var err error
+               for i, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+                       if !testableNetwork(network) {
+                               continue
+                       }
+                       c1 := newLocalPacketListener(t, network)
+                       if network == "unixgram" {
+                               defer os.Remove(c1.LocalAddr().String())
+                       }
+                       defer c1.Close()
+                       var id int
+                       d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+                               id = ctx.Value("id").(int)
+                               return controlOnConnSetup(network, address, c)
+                       }}
+                       var c2 Conn
+                       switch network {
+                       case "udp", "udp4", "udp6":
+                               raddr, err := netip.ParseAddrPort(c1.LocalAddr().String())
+                               if err != nil {
+                                       t.Error(err)
+                                       continue
+                               }
+                               c2, err = d.DialUDP(context.WithValue(context.Background(), "id", i+1), network, (*UDPAddr)(nil).AddrPort(), raddr)
+                       case "unixgram":
+                               raddr, err := ResolveUnixAddr(network, c1.LocalAddr().String())
+                               if err != nil {
+                                       t.Error(err)
+                                       continue
+                               }
+                               c2, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
+                       }
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       if id != i+1 {
+                               t.Errorf("%s: got id %d, want %d", network, id, i+1)
+                       }
+                       c2.Close()
+               }
+       })
+}
+
 // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
 // except on non-Linux, non-mobile builders it permits the test to
 // run in -short mode.
index 76dded9ca16e120e6a1a97a2f0087b8f62d15011..80a80fef7d3e4a132af66f3b8f5b88d6ecbc308e 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "context"
+       "net/netip"
        "syscall"
 )
 
@@ -24,6 +25,13 @@ import (
 // BUG(mikio): On JS and Plan 9, methods and functions related
 // to IPConn are not implemented.
 
+func ipAddrFromAddr(addr netip.Addr) *IPAddr {
+       return &IPAddr{
+               IP:   addr.AsSlice(),
+               Zone: addr.Zone(),
+       }
+}
+
 // IPAddr represents the address of an IP end point.
 type IPAddr struct {
        IP   IP
@@ -206,11 +214,18 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
 // If the IP field of raddr is nil or an unspecified IP address, the
 // local system is assumed.
 func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
+       return dialIP(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialIP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *IPAddr) (*IPConn, error) {
        if raddr == nil {
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
        }
        sd := &sysDialer{network: network, address: raddr.String()}
-       c, err := sd.dialIP(context.Background(), laddr, raddr)
+       if dialer != nil {
+               sd.Dialer = *dialer
+       }
+       c, err := sd.dialIP(ctx, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
index 9d215db1b2eec35e959adda4370dd659c3bdfe8a..376bf238c70d0710948f0b78c953bd6f3f521192 100644 (file)
@@ -315,6 +315,10 @@ func newTCPConn(fd *netFD, keepAliveIdle time.Duration, keepAliveCfg KeepAliveCo
 // If the IP field of raddr is nil or an unspecified IP address, the
 // local system is assumed.
 func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+       return dialTCP(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialTCP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
        switch network {
        case "tcp", "tcp4", "tcp6":
        default:
@@ -328,10 +332,13 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
                c   *TCPConn
                err error
        )
+       if dialer != nil {
+               sd.Dialer = *dialer
+       }
        if sd.MultipathTCP() {
-               c, err = sd.dialMPTCP(context.Background(), laddr, raddr)
+               c, err = sd.dialMPTCP(ctx, laddr, raddr)
        } else {
-               c, err = sd.dialTCP(context.Background(), laddr, raddr)
+               c, err = sd.dialTCP(ctx, laddr, raddr)
        }
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
index 35da018c307afbf2adaa161187ea42450fde91ec..f9a3bee867d340cd318722b97081e125a4fa5260 100644 (file)
@@ -285,6 +285,10 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
 // If the IP field of raddr is nil or an unspecified IP address, the
 // local system is assumed.
 func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+       return dialUDP(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialUDP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
        switch network {
        case "udp", "udp4", "udp6":
        default:
@@ -294,7 +298,10 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
        }
        sd := &sysDialer{network: network, address: raddr.String()}
-       c, err := sd.dialUDP(context.Background(), laddr, raddr)
+       if dialer != nil {
+               sd.Dialer = *dialer
+       }
+       c, err := sd.dialUDP(ctx, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
index c93ef91d5730e6a674b2c7a9a21f1066cee8055d..0ee79f35dec8a4acb312ab082739b0663a80ce81 100644 (file)
@@ -201,13 +201,20 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
 // If laddr is non-nil, it is used as the local address for the
 // connection.
 func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
+       return dialUnix(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialUnix(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
        switch network {
        case "unix", "unixgram", "unixpacket":
        default:
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
        }
        sd := &sysDialer{network: network, address: raddr.String()}
-       c, err := sd.dialUnix(context.Background(), laddr, raddr)
+       if dialer != nil {
+               sd.Dialer = *dialer
+       }
+       c, err := sd.dialUnix(ctx, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }