From: Michael Fraenkel Date: Sun, 30 Apr 2023 15:12:27 +0000 (-0600) Subject: net: context aware Dialer.Dial functions X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=2b804abf0712d45801671232585e0011902a5c48;p=gostls13.git net: context aware Dialer.Dial functions Add context aware dial functions for TCP, UDP, IP and Unix networks. Fixes #49097 Updates #59897 Change-Id: I7523452e8e463a587a852e0555cec822d8dcb3dd Reviewed-on: https://go-review.googlesource.com/c/go/+/490975 LUCI-TryBot-Result: Go LUCI Reviewed-by: Dmitri Shuralyov Reviewed-by: David Chase Reviewed-by: Sean Liao --- diff --git a/api/next/49097.txt b/api/next/49097.txt new file mode 100644 index 0000000000..f7240954c6 --- /dev/null +++ b/api/next/49097.txt @@ -0,0 +1,4 @@ +pkg net, method (*Dialer) DialIP(context.Context, string, netip.Addr, netip.Addr) (*IPConn, error) #49097 +pkg net, method (*Dialer) DialTCP(context.Context, string, netip.AddrPort, netip.AddrPort) (*TCPConn, error) #49097 +pkg net, method (*Dialer) DialUDP(context.Context, string, netip.AddrPort, netip.AddrPort) (*UDPConn, error) #49097 +pkg net, method (*Dialer) DialUnix(context.Context, string, *UnixAddr, *UnixAddr) (*UnixConn, error) #49097 diff --git a/doc/next/6-stdlib/99-minor/net/49097.md b/doc/next/6-stdlib/99-minor/net/49097.md new file mode 100644 index 0000000000..bb7947b0a1 --- /dev/null +++ b/doc/next/6-stdlib/99-minor/net/49097.md @@ -0,0 +1 @@ +Added context aware dial functions for TCP, UDP, IP and Unix networks. diff --git a/src/net/dial.go b/src/net/dial.go index 6264984cec..a87c57603a 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -9,6 +9,7 @@ import ( "internal/bytealg" "internal/godebug" "internal/nettrace" + "net/netip" "syscall" "time" ) @@ -523,30 +524,8 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { // See func [Dial] for a description of the network and address // parameters. func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { - if ctx == nil { - panic("nil context") - } - deadline := d.deadline(ctx, time.Now()) - if !deadline.IsZero() { - testHookStepTime() - if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { - subCtx, cancel := context.WithDeadline(ctx, deadline) - defer cancel() - ctx = subCtx - } - } - if oldCancel := d.Cancel; oldCancel != nil { - subCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-oldCancel: - cancel() - case <-subCtx.Done(): - } - }() - ctx = subCtx - } + ctx, cancel := d.dialCtx(ctx) + defer cancel() // Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups. resolveCtx := ctx @@ -578,6 +557,97 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn return sd.dialParallel(ctx, primaries, fallbacks) } +func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + panic("nil context") + } + deadline := d.deadline(ctx, time.Now()) + var cancel1, cancel2 context.CancelFunc + if !deadline.IsZero() { + testHookStepTime() + if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { + var subCtx context.Context + subCtx, cancel1 = context.WithDeadline(ctx, deadline) + ctx = subCtx + } + } + if oldCancel := d.Cancel; oldCancel != nil { + subCtx, cancel2 := context.WithCancel(ctx) + go func() { + select { + case <-oldCancel: + cancel2() + case <-subCtx.Done(): + } + }() + ctx = subCtx + } + return ctx, func() { + if cancel1 != nil { + cancel1() + } + if cancel2 != nil { + cancel2() + } + } +} + +// DialTCP acts like Dial for TCP networks using the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The network must be a TCP network name; see func Dial for details. +func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) { + ctx, cancel := d.dialCtx(ctx) + defer cancel() + return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr)) +} + +// DialUDP acts like Dial for UDP networks using the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The network must be a UDP network name; see func Dial for details. +func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) { + ctx, cancel := d.dialCtx(ctx) + defer cancel() + return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr)) +} + +// DialIP acts like Dial for IP networks using the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The network must be an IP network name; see func Dial for details. +func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) { + ctx, cancel := d.dialCtx(ctx) + defer cancel() + return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr)) +} + +// DialUnix acts like Dial for Unix networks using the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The network must be a Unix network name; see func Dial for details. +func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) { + ctx, cancel := d.dialCtx(ctx) + defer cancel() + return dialUnix(ctx, d, network, laddr, raddr) +} + // dialParallel races two copies of dialSerial, giving the first a // head start. It returns the first established connection and // closes the others. Otherwise it returns an error from the first diff --git a/src/net/dial_test.go b/src/net/dial_test.go index b3bedb2fa2..829b80c33a 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -11,6 +11,7 @@ import ( "fmt" "internal/testenv" "io" + "net/netip" "os" "runtime" "strings" @@ -1064,6 +1065,99 @@ func TestDialerControlContext(t *testing.T) { }) } +func TestDialContext(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("not supported on %s", runtime.GOOS) + case "js", "wasip1": + t.Skipf("skipping: fake net does not support Dialer.ControlContext") + } + + t.Run("StreamDial", func(t *testing.T) { + var err error + for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} { + if !testableNetwork(network) { + continue + } + ln := newLocalListener(t, network) + defer ln.Close() + var id int + d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error { + id = ctx.Value("id").(int) + return controlOnConnSetup(network, address, c) + }} + var c Conn + switch network { + case "tcp", "tcp4", "tcp6": + raddr, err := netip.ParseAddrPort(ln.Addr().String()) + if err != nil { + t.Error(err) + continue + } + c, err = d.DialTCP(context.WithValue(context.Background(), "id", i+1), network, (*TCPAddr)(nil).AddrPort(), raddr) + case "unix", "unixpacket": + raddr, err := ResolveUnixAddr(network, ln.Addr().String()) + if err != nil { + t.Error(err) + continue + } + c, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr) + } + if err != nil { + t.Error(err) + continue + } + if id != i+1 { + t.Errorf("%s: got id %d, want %d", network, id, i+1) + } + c.Close() + } + }) + t.Run("PacketDial", func(t *testing.T) { + var err error + for i, network := range []string{"udp", "udp4", "udp6", "unixgram"} { + if !testableNetwork(network) { + continue + } + c1 := newLocalPacketListener(t, network) + if network == "unixgram" { + defer os.Remove(c1.LocalAddr().String()) + } + defer c1.Close() + var id int + d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error { + id = ctx.Value("id").(int) + return controlOnConnSetup(network, address, c) + }} + var c2 Conn + switch network { + case "udp", "udp4", "udp6": + raddr, err := netip.ParseAddrPort(c1.LocalAddr().String()) + if err != nil { + t.Error(err) + continue + } + c2, err = d.DialUDP(context.WithValue(context.Background(), "id", i+1), network, (*UDPAddr)(nil).AddrPort(), raddr) + case "unixgram": + raddr, err := ResolveUnixAddr(network, c1.LocalAddr().String()) + if err != nil { + t.Error(err) + continue + } + c2, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr) + } + if err != nil { + t.Error(err) + continue + } + if id != i+1 { + t.Errorf("%s: got id %d, want %d", network, id, i+1) + } + c2.Close() + } + }) +} + // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork // except on non-Linux, non-mobile builders it permits the test to // run in -short mode. diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index 76dded9ca1..80a80fef7d 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -6,6 +6,7 @@ package net import ( "context" + "net/netip" "syscall" ) @@ -24,6 +25,13 @@ import ( // BUG(mikio): On JS and Plan 9, methods and functions related // to IPConn are not implemented. +func ipAddrFromAddr(addr netip.Addr) *IPAddr { + return &IPAddr{ + IP: addr.AsSlice(), + Zone: addr.Zone(), + } +} + // IPAddr represents the address of an IP end point. type IPAddr struct { IP IP @@ -206,11 +214,18 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} } // If the IP field of raddr is nil or an unspecified IP address, the // local system is assumed. func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) { + return dialIP(context.Background(), nil, network, laddr, raddr) +} + +func dialIP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *IPAddr) (*IPConn, error) { if raddr == nil { return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialIP(context.Background(), laddr, raddr) + if dialer != nil { + sd.Dialer = *dialer + } + c, err := sd.dialIP(ctx, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index 9d215db1b2..376bf238c7 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -315,6 +315,10 @@ func newTCPConn(fd *netFD, keepAliveIdle time.Duration, keepAliveCfg KeepAliveCo // If the IP field of raddr is nil or an unspecified IP address, the // local system is assumed. func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) { + return dialTCP(context.Background(), nil, network, laddr, raddr) +} + +func dialTCP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *TCPAddr) (*TCPConn, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -328,10 +332,13 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) { c *TCPConn err error ) + if dialer != nil { + sd.Dialer = *dialer + } if sd.MultipathTCP() { - c, err = sd.dialMPTCP(context.Background(), laddr, raddr) + c, err = sd.dialMPTCP(ctx, laddr, raddr) } else { - c, err = sd.dialTCP(context.Background(), laddr, raddr) + c, err = sd.dialTCP(ctx, laddr, raddr) } if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} diff --git a/src/net/udpsock.go b/src/net/udpsock.go index 35da018c30..f9a3bee867 100644 --- a/src/net/udpsock.go +++ b/src/net/udpsock.go @@ -285,6 +285,10 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // If the IP field of raddr is nil or an unspecified IP address, the // local system is assumed. func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) { + return dialUDP(context.Background(), nil, network, laddr, raddr) +} + +func dialUDP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UDPAddr) (*UDPConn, error) { switch network { case "udp", "udp4", "udp6": default: @@ -294,7 +298,10 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) { return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialUDP(context.Background(), laddr, raddr) + if dialer != nil { + sd.Dialer = *dialer + } + c, err := sd.dialUDP(ctx, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } diff --git a/src/net/unixsock.go b/src/net/unixsock.go index c93ef91d57..0ee79f35de 100644 --- a/src/net/unixsock.go +++ b/src/net/unixsock.go @@ -201,13 +201,20 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} } // If laddr is non-nil, it is used as the local address for the // connection. func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) { + return dialUnix(context.Background(), nil, network, laddr, raddr) +} + +func dialUnix(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UnixAddr) (*UnixConn, error) { switch network { case "unix", "unixgram", "unixpacket": default: return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialUnix(context.Background(), laddr, raddr) + if dialer != nil { + sd.Dialer = *dialer + } + c, err := sd.dialUnix(ctx, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} }