]> Cypherpunks repositories - gostls13.git/commitdiff
net: add Dialer.Cancel to cancel pending dials
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 14 Dec 2015 22:21:48 +0000 (22:21 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 15 Dec 2015 21:15:15 +0000 (21:15 +0000)
Dialer.Cancel is a new optional <-chan struct{} channel whose closure
indicates that the dial should be canceled. It is compatible with the
x/net/context and http.Request.Cancel types.

Tested by hand with:

package main

    import (
            "log"
            "net"
            "time"
    )

    func main() {
            log.Printf("start.")
            var d net.Dialer
            cancel := make(chan struct{})
            time.AfterFunc(2*time.Second, func() {
                    log.Printf("timeout firing")
                    close(cancel)
            })
            d.Cancel = cancel
            c, err := d.Dial("tcp", "192.168.0.1:22")
            if err != nil {
                    log.Print(err)
                    return
            }
            log.Fatalf("unexpected connect: %v", c)
    }

Which says:

    2015/12/14 22:24:58 start.
    2015/12/14 22:25:00 timeout firing
    2015/12/14 22:25:00 dial tcp 192.168.0.1:22: operation was canceled

Fixes #11225

Change-Id: I2ef39e3a540e29fe6bfec03ab7a629a6b187fcb3
Reviewed-on: https://go-review.googlesource.com/17821
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

12 files changed:
src/net/dial.go
src/net/dial_test.go
src/net/fd_unix.go
src/net/fd_windows.go
src/net/iprawsock_posix.go
src/net/ipsock_posix.go
src/net/net.go
src/net/sock_posix.go
src/net/tcpsock_plan9.go
src/net/tcpsock_posix.go
src/net/udpsock_posix.go
src/net/unixsock_posix.go

index cb4ec216d53f3c3497f9a63f59ed4fdc19c450b7..55863016fe565cea88495e11fd40ad9bc1651fdc 100644 (file)
@@ -57,6 +57,11 @@ type Dialer struct {
        // If zero, keep-alives are not enabled. Network protocols
        // that do not support keep-alives ignore this field.
        KeepAlive time.Duration
+
+       // Cancel is an optional channel whose closure indicates that
+       // the dial should be canceled. Not all types of dials support
+       // cancelation.
+       Cancel <-chan struct{}
 }
 
 // Return either now+Timeout or Deadline, whichever comes first.
@@ -361,7 +366,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro
        switch ra := ra.(type) {
        case *TCPAddr:
                la, _ := la.(*TCPAddr)
-               c, err = testHookDialTCP(ctx.network, la, ra, deadline)
+               c, err = testHookDialTCP(ctx.network, la, ra, deadline, ctx.Cancel)
        case *UDPAddr:
                la, _ := la.(*UDPAddr)
                c, err = dialUDP(ctx.network, la, ra, deadline)
index bd3b2dd9b1b99dc6f32ee524db5b83427296de6b..dbaca9efcee0f9f37a0a6512551b3b4b90ac70f2 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "internal/testenv"
        "io"
        "net/internal/socktest"
        "runtime"
@@ -236,8 +237,8 @@ const (
 // In some environments, the slow IPs may be explicitly unreachable, and fail
 // more quickly than expected. This test hook prevents dialTCP from returning
 // before the deadline.
-func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
-       c, err := dialTCP(net, laddr, raddr, deadline)
+func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
+       c, err := dialTCP(net, laddr, raddr, deadline, cancel)
        if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
                time.Sleep(deadline.Sub(time.Now()))
        }
@@ -716,3 +717,64 @@ func TestDialerKeepAlive(t *testing.T) {
                }
        }
 }
+
+func TestDialCancel(t *testing.T) {
+       if runtime.GOOS == "plan9" || runtime.GOOS == "nacl" {
+               // plan9 is not implemented and nacl doesn't have
+               // external network access.
+               t.Skip("skipping on %s", runtime.GOOS)
+       }
+       onGoBuildFarm := testenv.Builder() != ""
+       if testing.Short() && !onGoBuildFarm {
+               t.Skip("skipping in short mode")
+       }
+
+       blackholeIPPort := JoinHostPort(slowDst4, "1234")
+       if !supportsIPv4 {
+               blackholeIPPort = JoinHostPort(slowDst6, "1234")
+       }
+
+       ticker := time.NewTicker(10 * time.Millisecond)
+       defer ticker.Stop()
+
+       const cancelTick = 5 // the timer tick we cancel the dial at
+       const timeoutTick = 100
+
+       var d Dialer
+       cancel := make(chan struct{})
+       d.Cancel = cancel
+       errc := make(chan error, 1)
+       connc := make(chan Conn, 1)
+       go func() {
+               if c, err := d.Dial("tcp", blackholeIPPort); err != nil {
+                       errc <- err
+               } else {
+                       connc <- c
+               }
+       }()
+       ticks := 0
+       for {
+               select {
+               case <-ticker.C:
+                       ticks++
+                       if ticks == cancelTick {
+                               close(cancel)
+                       }
+                       if ticks == timeoutTick {
+                               t.Fatal("timeout waiting for dial to fail")
+                       }
+               case c := <-connc:
+                       c.Close()
+                       t.Fatal("unexpected successful connection")
+               case err := <-errc:
+                       if ticks < cancelTick {
+                               t.Fatalf("dial error after %d ticks (%d before cancel sent): %v",
+                                       ticks, cancelTick-ticks, err)
+                       }
+                       if oe, ok := err.(*OpError); !ok || oe.Err != errCanceled {
+                               t.Fatalf("dial error = %v (%T); want OpError with Err == errCanceled", err, err)
+                       }
+                       return // success.
+               }
+       }
+}
index 6463b0df435376afebbb4db3209e9451793940cf..2639eab1c4f8c4061818bcddd95c90eb22040881 100644 (file)
@@ -68,7 +68,7 @@ func (fd *netFD) name() string {
        return fd.net + ":" + ls + "->" + rs
 }
 
-func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
+func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
        // Do not need to call fd.writeLock here,
        // because fd is not yet accessible to user,
        // so no concurrent operations are possible.
@@ -102,6 +102,19 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
                fd.setWriteDeadline(deadline)
                defer fd.setWriteDeadline(noDeadline)
        }
+       if cancel != nil {
+               done := make(chan bool)
+               defer close(done)
+               go func() {
+                       select {
+                       case <-cancel:
+                               // Force the runtime's poller to immediately give
+                               // up waiting for writability.
+                               fd.setWriteDeadline(aLongTimeAgo)
+                       case <-done:
+                       }
+               }()
+       }
        for {
                // Performing multiple connect system calls on a
                // non-blocking socket under Unix variants does not
@@ -112,6 +125,11 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
                // succeeded or failed. See issue 7474 for further
                // details.
                if err := fd.pd.WaitWrite(); err != nil {
+                       select {
+                       case <-cancel:
+                               return errCanceled
+                       default:
+                       }
                        return err
                }
                nerr, err := getsockoptIntFunc(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
index aa2b13c5def7efa4f0b20727f70aadfc0edb278a..de6a9cbf58cc84e2585d10937bab7b52cdcad266 100644 (file)
@@ -320,7 +320,7 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
        runtime.SetFinalizer(fd, (*netFD).Close)
 }
 
-func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
+func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
        // Do not need to call fd.writeLock here,
        // because fd is not yet accessible to user,
        // so no concurrent operations are possible.
@@ -351,14 +351,38 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
        // Call ConnectEx API.
        o := &fd.wop
        o.sa = ra
+       if cancel != nil {
+               done := make(chan struct{})
+               defer close(done)
+               go func() {
+                       select {
+                       case <-cancel:
+                               // TODO(bradfitz,brainman): cancel the dial operation
+                               // somehow. Brad doesn't know Windows but is going to
+                               // try this:
+                               if canCancelIO {
+                                       syscall.CancelIoEx(o.fd.sysfd, &o.o)
+                               } else {
+                                       wsrv.req <- ioSrvReq{o, nil}
+                                       <-o.errc
+                               }
+                       case <-done:
+                       }
+               }()
+       }
        _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
                return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
        })
        if err != nil {
-               if _, ok := err.(syscall.Errno); ok {
-                       err = os.NewSyscallError("connectex", err)
+               select {
+               case <-cancel:
+                       return errCanceled
+               default:
+                       if _, ok := err.(syscall.Errno); ok {
+                               err = os.NewSyscallError("connectex", err)
+                       }
+                       return err
                }
-               return err
        }
        // Refresh socket properties.
        return os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))))
index 9417606ce9487464ee91a217b62b01b8db2596bc..93fee3e232e114d3236196e3ef570b86497474c2 100644 (file)
@@ -220,7 +220,7 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn,
        if raddr == nil {
                return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
        }
-       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial")
+       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
@@ -241,7 +241,7 @@ func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
        default:
                return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(netProto)}
        }
-       fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen")
+       fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err}
        }
index 4419aaf8a06d70483c69a33aa744bed58ec0e5d6..2bddd46a156e5bbc10f2f9c45a029e4d7b106e9d 100644 (file)
@@ -156,9 +156,9 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
 
 // Internet sockets (TCP, UDP, IP)
 
-func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string) (fd *netFD, err error) {
+func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) {
        family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
-       return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline)
+       return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel)
 }
 
 func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
index 89212e6e26167054bff1fad35f83268695e627e0..d9d23fae8f6c3a77de201b48bea91b1591c8e058 100644 (file)
@@ -426,7 +426,16 @@ func (e *OpError) Error() string {
        return s
 }
 
-var noDeadline = time.Time{}
+var (
+       // aLongTimeAgo is a non-zero time, far in the past, used for
+       // immediate cancelation of dials.
+       aLongTimeAgo = time.Unix(233431200, 0)
+
+       // nonDeadline and noCancel are just zero values for
+       // readability with functions taking too many parameters.
+       noDeadline = time.Time{}
+       noCancel   = (chan struct{})(nil)
+)
 
 type timeout interface {
        Timeout() bool
index 4d2cfde3f1c8832bd7a30817f3e1bd4f4010f53c..46767215672c9c862c47af8ce44c0066f174275a 100644 (file)
@@ -34,7 +34,7 @@ type sockaddr interface {
 
 // socket returns a network file descriptor that is ready for
 // asynchronous I/O using the network poller.
-func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time) (fd *netFD, err error) {
+func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) {
        s, err := sysSocket(family, sotype, proto)
        if err != nil {
                return nil, err
@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
                        return fd, nil
                }
        }
-       if err := fd.dial(laddr, raddr, deadline); err != nil {
+       if err := fd.dial(laddr, raddr, deadline, cancel); err != nil {
                fd.Close()
                return nil, err
        }
@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
        return func(syscall.Sockaddr) Addr { return nil }
 }
 
-func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time) error {
+func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error {
        var err error
        var lsa syscall.Sockaddr
        if laddr != nil {
@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time) error {
                if rsa, err = raddr.sockaddr(fd.family); err != nil {
                        return err
                }
-               if err := fd.connect(lsa, rsa, deadline); err != nil {
+               if err := fd.connect(lsa, rsa, deadline, cancel); err != nil {
                        return err
                }
                fd.isConnected = true
index 9f23703abb4325ce4705fdd1c86c8c9dc460fb0e..afccbfe8a749bc85ec864b382529d8d0a95a65db 100644 (file)
@@ -107,13 +107,14 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
 // which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is
 // used as the local address for the connection.
 func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
-       return dialTCP(net, laddr, raddr, noDeadline)
+       return dialTCP(net, laddr, raddr, noDeadline, noCancel)
 }
 
-func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
+func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
        if !deadline.IsZero() {
                panic("net.dialTCP: deadline not implemented on Plan 9")
        }
+       // TODO(bradfitz,0intro): also use the cancel channel.
        switch net {
        case "tcp", "tcp4", "tcp6":
        default:
index 7e49b769e1c98c841e3bfc4798db639d6b5bb9cf..0e12d54300265d18ca7404928c22bac2d8bcbd43 100644 (file)
@@ -164,11 +164,11 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
        if raddr == nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
        }
-       return dialTCP(net, laddr, raddr, noDeadline)
+       return dialTCP(net, laddr, raddr, noDeadline, noCancel)
 }
 
-func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
-       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial")
+func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
+       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
 
        // TCP has a rarely used mechanism called a 'simultaneous connection' in
        // which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -198,7 +198,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, e
                if err == nil {
                        fd.Close()
                }
-               fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial")
+               fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
        }
 
        if err != nil {
@@ -326,7 +326,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
        if laddr == nil {
                laddr = &TCPAddr{}
        }
-       fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen")
+       fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr, Err: err}
        }
index 61868c4b0cfc2a39c1daa75ad7b1fcb9cb9370df..932c6ce713fe462790840752457281addba71327 100644 (file)
@@ -189,7 +189,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
 }
 
 func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
-       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial")
+       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
@@ -212,7 +212,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
        if laddr == nil {
                laddr = &UDPAddr{}
        }
-       fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen")
+       fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr, Err: err}
        }
@@ -239,7 +239,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
        if gaddr == nil || gaddr.IP == nil {
                return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
        }
-       fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen")
+       fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr, Err: err}
        }
index fc44c1a458e43b44dc03f3282662973cc3499046..fb2397e26f2b5b038355689bd6ee1e859cc80040 100644 (file)
@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti
                return nil, errors.New("unknown mode: " + mode)
        }
 
-       fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline)
+       fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel)
        if err != nil {
                return nil, err
        }