]> Cypherpunks repositories - gostls13.git/commitdiff
net: DialTimeout
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Dec 2011 21:17:39 +0000 (13:17 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Dec 2011 21:17:39 +0000 (13:17 -0800)
Fixes #240

R=adg, dsymonds, rsc, r, mikioh.mikioh
CC=golang-dev
https://golang.org/cl/5491062

src/pkg/net/dial.go
src/pkg/net/dial_test.go [new file with mode: 0644]

index 43866dcb51d18ec6c7cbeb1502e3fe6fcde49c54..00acb8477d22dee3b555402b1aa7f50a4f675735 100644 (file)
@@ -4,6 +4,10 @@
 
 package net
 
+import (
+       "time"
+)
+
 func resolveNetAddr(op, net, addr string) (a Addr, err error) {
        if addr == "" {
                return nil, &OpError{op, net, nil, errMissingAddress}
@@ -42,11 +46,15 @@ func resolveNetAddr(op, net, addr string) (a Addr, err error) {
 //     Dial("tcp", "google.com:80")
 //     Dial("tcp", "[de:ad:be:ef::ca:fe]:80")
 //
-func Dial(net, addr string) (c Conn, err error) {
+func Dial(net, addr string) (Conn, error) {
        addri, err := resolveNetAddr("dial", net, addr)
        if err != nil {
                return nil, err
        }
+       return dialAddr(net, addr, addri)
+}
+
+func dialAddr(net, addr string, addri Addr) (c Conn, err error) {
        switch ra := addri.(type) {
        case *TCPAddr:
                c, err = DialTCP(net, nil, ra)
@@ -65,6 +73,62 @@ func Dial(net, addr string) (c Conn, err error) {
        return
 }
 
+// DialTimeout acts like Dial but takes a timeout.
+// The timeout includes name resolution, if required.
+func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+       // TODO(bradfitz): the timeout should be pushed down into the
+       // net package's event loop, so on timeout to dead hosts we
+       // don't have a goroutine sticking around for the default of
+       // ~3 minutes.
+       t := time.NewTimer(timeout)
+       defer t.Stop()
+       type pair struct {
+               Conn
+               error
+       }
+       ch := make(chan pair, 1)
+       resolvedAddr := make(chan Addr, 1)
+       go func() {
+               addri, err := resolveNetAddr("dial", net, addr)
+               if err != nil {
+                       ch <- pair{nil, err}
+                       return
+               }
+               resolvedAddr <- addri // in case we need it for OpError
+               c, err := dialAddr(net, addr, addri)
+               ch <- pair{c, err}
+       }()
+       select {
+       case <-t.C:
+               // Try to use the real Addr in our OpError, if we resolved it
+               // before the timeout. Otherwise we just use stringAddr.
+               var addri Addr
+               select {
+               case a := <-resolvedAddr:
+                       addri = a
+               default:
+                       addri = &stringAddr{net, addr}
+               }
+               err := &OpError{
+                       Op:   "dial",
+                       Net:  net,
+                       Addr: addri,
+                       Err:  &timeoutError{},
+               }
+               return nil, err
+       case p := <-ch:
+               return p.Conn, p.error
+       }
+       panic("unreachable")
+}
+
+type stringAddr struct {
+       net, addr string
+}
+
+func (a stringAddr) Network() string { return a.net }
+func (a stringAddr) String() string  { return a.addr }
+
 // Listen announces on the local network address laddr.
 // The network string net must be a stream-oriented
 // network: "tcp", "tcp4", "tcp6", or "unix", or "unixpacket".
diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go
new file mode 100644 (file)
index 0000000..16b7263
--- /dev/null
@@ -0,0 +1,88 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "runtime"
+       "testing"
+       "time"
+)
+
+func newLocalListener(t *testing.T) Listener {
+       ln, err := Listen("tcp", "127.0.0.1:0")
+       if err != nil {
+               ln, err = Listen("tcp6", "[::1]:0")
+       }
+       if err != nil {
+               t.Fatal(err)
+       }
+       return ln
+}
+
+func TestDialTimeout(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       errc := make(chan error)
+
+       const SOMAXCONN = 0x80 // copied from syscall, but not always available
+       const numConns = SOMAXCONN + 10
+
+       // TODO(bradfitz): It's hard to test this in a portable
+       // way. This is unforunate, but works for now.
+       switch runtime.GOOS {
+       case "linux":
+               // The kernel will start accepting TCP connections before userspace
+               // gets a chance to not accept them, so fire off a bunch to fill up
+               // the kernel's backlog.  Then we test we get a failure after that.
+               for i := 0; i < numConns; i++ {
+                       go func() {
+                               _, err := DialTimeout("tcp", ln.Addr().String(), 200*time.Millisecond)
+                               errc <- err
+                       }()
+               }
+       case "darwin":
+               // At least OS X 10.7 seems to accept any number of
+               // connections, ignoring listen's backlog, so resort
+               // to connecting to a hopefully-dead 127/8 address.
+               go func() {
+                       _, err := DialTimeout("tcp", "127.0.71.111:80", 200*time.Millisecond)
+                       errc <- err
+               }()
+       default:
+               // TODO(bradfitz): this probably doesn't work on
+               // Windows? SOMAXCONN is huge there.  I'm not sure how
+               // listen works there.
+               // OpenBSD may have a reject route to 10/8.
+               // FreeBSD likely works, but is untested.
+               t.Logf("skipping test on %q; untested.", runtime.GOOS)
+               return
+       }
+
+       connected := 0
+       for {
+               select {
+               case <-time.After(15 * time.Second):
+                       t.Fatal("too slow")
+               case err := <-errc:
+                       if err == nil {
+                               connected++
+                               if connected == numConns {
+                                       t.Fatal("all connections connected; expected some to time out")
+                               }
+                       } else {
+                               terr, ok := err.(timeout)
+                               if !ok {
+                                       t.Fatalf("got error %q; want error with timeout interface", err)
+                               }
+                               if !terr.Timeout() {
+                                       t.Fatalf("got error %q; not a timeout", err)
+                               }
+                               // Pass. We saw a timeout error.
+                               return
+                       }
+               }
+       }
+}