Add context aware dial functions for TCP, UDP, IP and Unix networks.
Fixes #49097
Updates #59897
Change-Id: I7523452e8e463a587a852e0555cec822d8dcb3dd
Reviewed-on: https://go-review.googlesource.com/c/go/+/490975
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
--- /dev/null
+pkg net, method (*Dialer) DialIP(context.Context, string, netip.Addr, netip.Addr) (*IPConn, error) #49097
+pkg net, method (*Dialer) DialTCP(context.Context, string, netip.AddrPort, netip.AddrPort) (*TCPConn, error) #49097
+pkg net, method (*Dialer) DialUDP(context.Context, string, netip.AddrPort, netip.AddrPort) (*UDPConn, error) #49097
+pkg net, method (*Dialer) DialUnix(context.Context, string, *UnixAddr, *UnixAddr) (*UnixConn, error) #49097
--- /dev/null
+Added context aware dial functions for TCP, UDP, IP and Unix networks.
"internal/bytealg"
"internal/godebug"
"internal/nettrace"
+ "net/netip"
"syscall"
"time"
)
// See func [Dial] for a description of the network and address
// parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
- if ctx == nil {
- panic("nil context")
- }
- deadline := d.deadline(ctx, time.Now())
- if !deadline.IsZero() {
- testHookStepTime()
- if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
- subCtx, cancel := context.WithDeadline(ctx, deadline)
- defer cancel()
- ctx = subCtx
- }
- }
- if oldCancel := d.Cancel; oldCancel != nil {
- subCtx, cancel := context.WithCancel(ctx)
- defer cancel()
- go func() {
- select {
- case <-oldCancel:
- cancel()
- case <-subCtx.Done():
- }
- }()
- ctx = subCtx
- }
+ ctx, cancel := d.dialCtx(ctx)
+ defer cancel()
// Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
resolveCtx := ctx
return sd.dialParallel(ctx, primaries, fallbacks)
}
+func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ deadline := d.deadline(ctx, time.Now())
+ var cancel1, cancel2 context.CancelFunc
+ if !deadline.IsZero() {
+ testHookStepTime()
+ if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
+ var subCtx context.Context
+ subCtx, cancel1 = context.WithDeadline(ctx, deadline)
+ ctx = subCtx
+ }
+ }
+ if oldCancel := d.Cancel; oldCancel != nil {
+ subCtx, cancel2 := context.WithCancel(ctx)
+ go func() {
+ select {
+ case <-oldCancel:
+ cancel2()
+ case <-subCtx.Done():
+ }
+ }()
+ ctx = subCtx
+ }
+ return ctx, func() {
+ if cancel1 != nil {
+ cancel1()
+ }
+ if cancel2 != nil {
+ cancel2()
+ }
+ }
+}
+
+// DialTCP acts like Dial for TCP networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be a TCP network name; see func Dial for details.
+func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
+ ctx, cancel := d.dialCtx(ctx)
+ defer cancel()
+ return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
+}
+
+// DialUDP acts like Dial for UDP networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be a UDP network name; see func Dial for details.
+func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
+ ctx, cancel := d.dialCtx(ctx)
+ defer cancel()
+ return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
+}
+
+// DialIP acts like Dial for IP networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be an IP network name; see func Dial for details.
+func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
+ ctx, cancel := d.dialCtx(ctx)
+ defer cancel()
+ return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
+}
+
+// DialUnix acts like Dial for Unix networks using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The network must be a Unix network name; see func Dial for details.
+func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
+ ctx, cancel := d.dialCtx(ctx)
+ defer cancel()
+ return dialUnix(ctx, d, network, laddr, raddr)
+}
+
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
"fmt"
"internal/testenv"
"io"
+ "net/netip"
"os"
"runtime"
"strings"
})
}
+func TestDialContext(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ case "js", "wasip1":
+ t.Skipf("skipping: fake net does not support Dialer.ControlContext")
+ }
+
+ t.Run("StreamDial", func(t *testing.T) {
+ var err error
+ for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln := newLocalListener(t, network)
+ defer ln.Close()
+ var id int
+ d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+ id = ctx.Value("id").(int)
+ return controlOnConnSetup(network, address, c)
+ }}
+ var c Conn
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ raddr, err := netip.ParseAddrPort(ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c, err = d.DialTCP(context.WithValue(context.Background(), "id", i+1), network, (*TCPAddr)(nil).AddrPort(), raddr)
+ case "unix", "unixpacket":
+ raddr, err := ResolveUnixAddr(network, ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
+ }
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if id != i+1 {
+ t.Errorf("%s: got id %d, want %d", network, id, i+1)
+ }
+ c.Close()
+ }
+ })
+ t.Run("PacketDial", func(t *testing.T) {
+ var err error
+ for i, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c1 := newLocalPacketListener(t, network)
+ if network == "unixgram" {
+ defer os.Remove(c1.LocalAddr().String())
+ }
+ defer c1.Close()
+ var id int
+ d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+ id = ctx.Value("id").(int)
+ return controlOnConnSetup(network, address, c)
+ }}
+ var c2 Conn
+ switch network {
+ case "udp", "udp4", "udp6":
+ raddr, err := netip.ParseAddrPort(c1.LocalAddr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c2, err = d.DialUDP(context.WithValue(context.Background(), "id", i+1), network, (*UDPAddr)(nil).AddrPort(), raddr)
+ case "unixgram":
+ raddr, err := ResolveUnixAddr(network, c1.LocalAddr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c2, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
+ }
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if id != i+1 {
+ t.Errorf("%s: got id %d, want %d", network, id, i+1)
+ }
+ c2.Close()
+ }
+ })
+}
+
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
// except on non-Linux, non-mobile builders it permits the test to
// run in -short mode.
import (
"context"
+ "net/netip"
"syscall"
)
// BUG(mikio): On JS and Plan 9, methods and functions related
// to IPConn are not implemented.
+func ipAddrFromAddr(addr netip.Addr) *IPAddr {
+ return &IPAddr{
+ IP: addr.AsSlice(),
+ Zone: addr.Zone(),
+ }
+}
+
// IPAddr represents the address of an IP end point.
type IPAddr struct {
IP IP
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
+ return dialIP(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialIP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *IPAddr) (*IPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
- c, err := sd.dialIP(context.Background(), laddr, raddr)
+ if dialer != nil {
+ sd.Dialer = *dialer
+ }
+ c, err := sd.dialIP(ctx, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ return dialTCP(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialTCP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
c *TCPConn
err error
)
+ if dialer != nil {
+ sd.Dialer = *dialer
+ }
if sd.MultipathTCP() {
- c, err = sd.dialMPTCP(context.Background(), laddr, raddr)
+ c, err = sd.dialMPTCP(ctx, laddr, raddr)
} else {
- c, err = sd.dialTCP(context.Background(), laddr, raddr)
+ c, err = sd.dialTCP(ctx, laddr, raddr)
}
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ return dialUDP(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialUDP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
switch network {
case "udp", "udp4", "udp6":
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
- c, err := sd.dialUDP(context.Background(), laddr, raddr)
+ if dialer != nil {
+ sd.Dialer = *dialer
+ }
+ c, err := sd.dialUDP(ctx, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
// If laddr is non-nil, it is used as the local address for the
// connection.
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
+ return dialUnix(context.Background(), nil, network, laddr, raddr)
+}
+
+func dialUnix(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
switch network {
case "unix", "unixgram", "unixpacket":
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
}
sd := &sysDialer{network: network, address: raddr.String()}
- c, err := sd.dialUnix(context.Background(), laddr, raddr)
+ if dialer != nil {
+ sd.Dialer = *dialer
+ }
+ c, err := sd.dialUnix(ctx, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}