]> Cypherpunks repositories - gostls13.git/commitdiff
net: add sequential and RFC 6555-compliant TCP dialing.
authorPaul Marks <pmarks@google.com>
Fri, 10 Apr 2015 21:15:54 +0000 (14:15 -0700)
committerPaul Marks <pmarks@google.com>
Tue, 16 Jun 2015 02:38:21 +0000 (02:38 +0000)
dialSerial connects to a list of addresses in sequence.  If a
timeout is specified, then each address gets an equal fraction of the
remaining time, with a magic constant (2 seconds) to prevent
"dial a million addresses" from allotting zero time to each.

Normally, net.Dial passes the DNS stub resolver's output to dialSerial.
If an error occurs (like destination/port unreachable), it quickly skips
to the next address, but a blackhole in the network will cause the
connection to hang until the timeout elapses.  This is how UNIXy clients
traditionally behave, and is usually sufficient for non-broken networks.

The DualStack flag enables dialParallel, which implements Happy Eyeballs
by racing two dialSerial goroutines, giving the preferred family a
head start (300ms by default).  This allows clients to avoid long
timeouts when the network blackholes IPv4 xor IPv6.

Fixes #8453
Fixes #8455
Fixes #8847

Change-Id: Ie415809c9226a1f7342b0217dcdd8f224ae19058
Reviewed-on: https://go-review.googlesource.com/8768
Reviewed-by: Mikio Hara <mikioh.mikioh@gmail.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/dial.go
src/net/dial_test.go
src/net/hook.go
src/net/net.go

index 4f0c6cb0eac4f6327361aed23940b0876255b884..4b430de4bde711a7fda85e30f554606c0ec6ac50 100644 (file)
@@ -21,6 +21,9 @@ type Dialer struct {
        //
        // The default is no timeout.
        //
+       // When dialing a name with multiple IP addresses, the timeout
+       // may be divided between them.
+       //
        // With or without a timeout, the operating system may impose
        // its own earlier timeout. For instance, TCP timeouts are
        // often around 3 minutes.
@@ -38,13 +41,17 @@ type Dialer struct {
        // If nil, a local address is automatically chosen.
        LocalAddr Addr
 
-       // DualStack allows a single dial to attempt to establish
-       // multiple IPv4 and IPv6 connections and to return the first
-       // established connection when the network is "tcp" and the
-       // destination is a host name that has multiple address family
-       // DNS records.
+       // DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing
+       // when the network is "tcp" and the destination is a host name
+       // with both IPv4 and IPv6 addresses. This allows a client to
+       // tolerate networks where one address family is silently broken.
        DualStack bool
 
+       // FallbackDelay specifies the length of time to wait before
+       // spawning a fallback connection, when DualStack is enabled.
+       // If zero, a default delay of 300ms is used.
+       FallbackDelay time.Duration
+
        // KeepAlive specifies the keep-alive period for an active
        // network connection.
        // If zero, keep-alives are not enabled. Network protocols
@@ -54,11 +61,11 @@ type Dialer struct {
 
 // Return either now+Timeout or Deadline, whichever comes first.
 // Or zero, if neither is set.
-func (d *Dialer) deadline() time.Time {
+func (d *Dialer) deadline(now time.Time) time.Time {
        if d.Timeout == 0 {
                return d.Deadline
        }
-       timeoutDeadline := time.Now().Add(d.Timeout)
+       timeoutDeadline := now.Add(d.Timeout)
        if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) {
                return timeoutDeadline
        } else {
@@ -66,6 +73,39 @@ func (d *Dialer) deadline() time.Time {
        }
 }
 
+// partialDeadline returns the deadline to use for a single address,
+// when multiple addresses are pending.
+func (d *Dialer) partialDeadline(now time.Time, addrsRemaining int) (time.Time, error) {
+       deadline := d.deadline(now)
+       if deadline.IsZero() {
+               return deadline, nil
+       }
+       timeRemaining := deadline.Sub(now)
+       if timeRemaining <= 0 {
+               return time.Time{}, errTimeout
+       }
+       // Tentatively allocate equal time to each remaining address.
+       timeout := timeRemaining / time.Duration(addrsRemaining)
+       // If the time per address is too short, steal from the end of the list.
+       const saneMinimum = 2 * time.Second
+       if timeout < saneMinimum {
+               if timeRemaining < saneMinimum {
+                       timeout = timeRemaining
+               } else {
+                       timeout = saneMinimum
+               }
+       }
+       return now.Add(timeout), nil
+}
+
+func (d *Dialer) fallbackDelay() time.Duration {
+       if d.FallbackDelay > 0 {
+               return d.FallbackDelay
+       } else {
+               return 300 * time.Millisecond
+       }
+}
+
 func parseNetwork(net string) (afnet string, proto int, err error) {
        i := last(net, ':')
        if i < 0 { // no colon
@@ -154,30 +194,44 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
        return d.Dial(network, address)
 }
 
+// dialContext holds common state for all dial operations.
+type dialContext struct {
+       Dialer
+       network, address string
+}
+
 // Dial connects to the address on the named network.
 //
 // See func Dial for a description of the network and address
 // parameters.
 func (d *Dialer) Dial(network, address string) (Conn, error) {
-       addrs, err := resolveAddrList("dial", network, address, d.deadline())
+       addrs, err := resolveAddrList("dial", network, address, d.deadline(time.Now()))
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
        }
-       var dialer func(deadline time.Time) (Conn, error)
+
+       ctx := &dialContext{
+               Dialer:  *d,
+               network: network,
+               address: address,
+       }
+
+       var primaries, fallbacks addrList
        if d.DualStack && network == "tcp" {
-               primaries, fallbacks := addrs.partition(isIPv4)
-               if len(fallbacks) > 0 {
-                       dialer = func(deadline time.Time) (Conn, error) {
-                               return dialMulti(network, address, d.LocalAddr, addrList{primaries[0], fallbacks[0]}, deadline)
-                       }
-               }
+               primaries, fallbacks = addrs.partition(isIPv4)
+       } else {
+               primaries = addrs
        }
-       if dialer == nil {
-               dialer = func(deadline time.Time) (Conn, error) {
-                       return dialSingle(network, address, d.LocalAddr, addrs.first(isIPv4), deadline)
-               }
+
+       var c Conn
+       if len(fallbacks) == 0 {
+               // dialParallel can accept an empty fallbacks list,
+               // but this shortcut avoids the goroutine/channel overhead.
+               c, err = dialSerial(ctx, primaries, nil)
+       } else {
+               c, err = dialParallel(ctx, primaries, fallbacks)
        }
-       c, err := dial(network, addrs.first(isIPv4), dialer, d.deadline())
+
        if d.KeepAlive > 0 && err == nil {
                if tc, ok := c.(*TCPConn); ok {
                        setKeepAlive(tc.fd, true)
@@ -188,70 +242,135 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
        return c, err
 }
 
-// dialMulti attempts to establish connections to each destination of
-// the list of addresses. It will return the first established
-// connection and close the other connections. Otherwise it returns
-// error on the last attempt.
-func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Conn, error) {
-       type racer struct {
-               Conn
-               error
+// dialParallel races two copies of dialSerial, giving the first a
+// head start. It returns the first established connection and
+// closes the others. Otherwise it returns an error from the first
+// primary address.
+func dialParallel(ctx *dialContext, primaries, fallbacks addrList) (Conn, error) {
+       results := make(chan dialResult) // unbuffered, so dialSerialAsync can detect race loss & cleanup
+       cancel := make(chan struct{})
+       defer close(cancel)
+
+       // Spawn the primary racer.
+       go dialSerialAsync(ctx, primaries, nil, cancel, results)
+
+       // Spawn the fallback racer.
+       fallbackTimer := time.NewTimer(ctx.fallbackDelay())
+       go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results)
+
+       var primaryErr error
+       for nracers := 2; nracers > 0; nracers-- {
+               res := <-results
+               // If we're still waiting for a connection, then hasten the delay.
+               // Otherwise, disable the Timer and let cancel take over.
+               if fallbackTimer.Stop() && res.error != nil {
+                       fallbackTimer.Reset(0)
+               }
+               if res.error == nil {
+                       return res.Conn, nil
+               }
+               if res.primary {
+                       primaryErr = res.error
+               }
+       }
+       return nil, primaryErr
+}
+
+type dialResult struct {
+       Conn
+       error
+       primary bool
+}
+
+// dialSerialAsync runs dialSerial after some delay, and returns the
+// resulting connection through a channel. When racing two connections,
+// the primary goroutine uses a nil timer to omit the delay.
+func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) {
+       if timer != nil {
+               // We're in the fallback goroutine; sleep before connecting.
+               select {
+               case <-timer.C:
+               case <-cancel:
+                       return
+               }
        }
-       // Sig controls the flow of dial results on lane. It passes a
-       // token to the next racer and also indicates the end of flow
-       // by using closed channel.
-       sig := make(chan bool, 1)
-       lane := make(chan racer, 1)
-       for _, ra := range ras {
-               go func(ra Addr) {
-                       c, err := dialSingle(net, addr, la, ra, deadline)
-                       if _, ok := <-sig; ok {
-                               lane <- racer{c, err}
-                       } else if err == nil {
-                               // We have to return the resources
-                               // that belong to the other
-                               // connections here for avoiding
-                               // unnecessary resource starvation.
-                               c.Close()
-                       }
-               }(ra)
+       c, err := dialSerial(ctx, ras, cancel)
+       select {
+       case results <- dialResult{c, err, timer == nil}:
+               // We won the race.
+       case <-cancel:
+               // The other goroutine won the race.
+               if c != nil {
+                       c.Close()
+               }
        }
-       defer close(sig)
-       lastErr := errTimeout
-       nracers := len(ras)
-       for nracers > 0 {
-               sig <- true
-               racer := <-lane
-               if racer.error == nil {
-                       return racer.Conn, nil
+}
+
+// dialSerial connects to a list of addresses in sequence, returning
+// either the first successful connection, or the first error.
+func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) {
+       var firstErr error // The error from the first address is most relevant.
+
+       for i, ra := range ras {
+               select {
+               case <-cancel:
+                       return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled}
+               default:
+               }
+
+               partialDeadline, err := ctx.partialDeadline(time.Now(), len(ras)-i)
+               if err != nil {
+                       // Ran out of time.
+                       if firstErr == nil {
+                               firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err}
+                       }
+                       break
                }
-               lastErr = racer.error
-               nracers--
+
+               // dialTCP does not support cancelation (see golang.org/issue/11225),
+               // so if cancel fires, we'll continue trying to connect until the next
+               // timeout, or return a spurious connection for the caller to close.
+               dialer := func(d time.Time) (Conn, error) {
+                       return dialSingle(ctx, ra, d)
+               }
+               c, err := dial(ctx.network, ra, dialer, partialDeadline)
+               if err == nil {
+                       return c, nil
+               }
+               if firstErr == nil {
+                       firstErr = err
+               }
+       }
+
+       if firstErr == nil {
+               firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress}
        }
-       return nil, lastErr
+       return nil, firstErr
 }
 
 // dialSingle attempts to establish and returns a single connection to
-// the destination address.
-func dialSingle(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) {
+// the destination address. This must be called through the OS-specific
+// dial function, because some OSes don't implement the deadline feature.
+func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err error) {
+       la := ctx.LocalAddr
        if la != nil && la.Network() != ra.Network() {
-               return nil, &OpError{Op: "dial", Net: net, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())}
+               return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())}
        }
        switch ra := ra.(type) {
        case *TCPAddr:
                la, _ := la.(*TCPAddr)
-               c, err = dialTCP(net, la, ra, deadline)
+               c, err = testHookDialTCP(ctx.network, la, ra, deadline)
        case *UDPAddr:
                la, _ := la.(*UDPAddr)
-               c, err = dialUDP(net, la, ra, deadline)
+               c, err = dialUDP(ctx.network, la, ra, deadline)
        case *IPAddr:
                la, _ := la.(*IPAddr)
-               c, err = dialIP(net, la, ra, deadline)
+               c, err = dialIP(ctx.network, la, ra, deadline)
        case *UnixAddr:
                la, _ := la.(*UnixAddr)
-               c, err = dialUnix(net, la, ra, deadline)
+               c, err = dialUnix(ctx.network, la, ra, deadline)
        default:
-               return nil, &OpError{Op: "dial", Net: net, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: addr}}
+               return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}}
        }
        if err != nil {
                return nil, err // c is non-nil interface containing nil pointer
index f5141bcd5eec22dfed3e89c936b1a6d994da6d77..9848f30e3c6912a9e13cf7651b5892af63fe499d 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "io"
        "net/internal/socktest"
        "runtime"
        "sync"
@@ -207,6 +208,360 @@ func TestDialerDualStackFDLeak(t *testing.T) {
        }
 }
 
+// Define a pair of blackholed (IPv4, IPv6) addresses, for which dialTCP is
+// expected to hang until the timeout elapses. These addresses are reserved
+// for benchmarking by RFC 6890.
+const (
+       slowDst4    = "192.18.0.254"
+       slowDst6    = "2001:2::254"
+       slowTimeout = 1 * time.Second
+)
+
+// In some environments, the slow IPs may be explicitly unreachable, and fail
+// more quickly than expected. This test hook prevents dialTCP from returning
+// before the deadline.
+func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
+       c, err := dialTCP(net, laddr, raddr, deadline)
+       if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
+               time.Sleep(deadline.Sub(time.Now()))
+       }
+       return c, err
+}
+
+func dialClosedPort() time.Duration {
+       l, err := Listen("tcp", "127.0.0.1:0")
+       if err != nil {
+               return 999 * time.Hour
+       }
+       addr := l.Addr().String()
+       l.Close()
+       // On OpenBSD, interference from TestSelfConnect is mysteriously
+       // causing the first attempt to hang for a few seconds, so we throw
+       // away the first result and keep the second.
+       for i := 1; ; i++ {
+               startTime := time.Now()
+               c, err := Dial("tcp", addr)
+               if err == nil {
+                       c.Close()
+               }
+               elapsed := time.Now().Sub(startTime)
+               if i == 2 {
+                       return elapsed
+               }
+       }
+}
+
+func TestDialParallel(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("avoid external network")
+       }
+       if !supportsIPv4 || !supportsIPv6 {
+               t.Skip("both IPv4 and IPv6 are required")
+       }
+
+       // Determine the time required to dial a closed port.
+       // On Windows, this takes roughly 1 second, but other platforms
+       // are expected to be instantaneous.
+       closedPortDelay := dialClosedPort()
+       var expectClosedPortDelay time.Duration
+       if runtime.GOOS == "windows" {
+               expectClosedPortDelay = 1095 * time.Millisecond
+       } else {
+               expectClosedPortDelay = 95 * time.Millisecond
+       }
+       if closedPortDelay > expectClosedPortDelay {
+               t.Errorf("got %v; want <= %v", closedPortDelay, expectClosedPortDelay)
+       }
+
+       const instant time.Duration = 0
+       const fallbackDelay = 200 * time.Millisecond
+
+       // Some cases will run quickly when "connection refused" is fast,
+       // or trigger the fallbackDelay on Windows.  This value holds the
+       // lesser of the two delays.
+       var closedPortOrFallbackDelay time.Duration
+       if closedPortDelay < fallbackDelay {
+               closedPortOrFallbackDelay = closedPortDelay
+       } else {
+               closedPortOrFallbackDelay = fallbackDelay
+       }
+
+       origTestHookDialTCP := testHookDialTCP
+       defer func() { testHookDialTCP = origTestHookDialTCP }()
+       testHookDialTCP = slowDialTCP
+
+       nCopies := func(s string, n int) []string {
+               out := make([]string, n)
+               for i := 0; i < n; i++ {
+                       out[i] = s
+               }
+               return out
+       }
+
+       var testCases = []struct {
+               primaries       []string
+               fallbacks       []string
+               teardownNetwork string
+               expectOk        bool
+               expectElapsed   time.Duration
+       }{
+               // These should just work on the first try.
+               {[]string{"127.0.0.1"}, []string{}, "", true, instant},
+               {[]string{"::1"}, []string{}, "", true, instant},
+               {[]string{"127.0.0.1", "::1"}, []string{slowDst6}, "tcp6", true, instant},
+               {[]string{"::1", "127.0.0.1"}, []string{slowDst4}, "tcp4", true, instant},
+               // Primary is slow; fallback should kick in.
+               {[]string{slowDst4}, []string{"::1"}, "", true, fallbackDelay},
+               // Skip a "connection refused" in the primary thread.
+               {[]string{"127.0.0.1", "::1"}, []string{}, "tcp4", true, closedPortDelay},
+               {[]string{"::1", "127.0.0.1"}, []string{}, "tcp6", true, closedPortDelay},
+               // Skip a "connection refused" in the fallback thread.
+               {[]string{slowDst4, slowDst6}, []string{"::1", "127.0.0.1"}, "tcp6", true, fallbackDelay + closedPortDelay},
+               // Primary refused, fallback without delay.
+               {[]string{"127.0.0.1"}, []string{"::1"}, "tcp4", true, closedPortOrFallbackDelay},
+               {[]string{"::1"}, []string{"127.0.0.1"}, "tcp6", true, closedPortOrFallbackDelay},
+               // Everything is refused.
+               {[]string{"127.0.0.1"}, []string{}, "tcp4", false, closedPortDelay},
+               // Nothing to do; fail instantly.
+               {[]string{}, []string{}, "", false, instant},
+               // Connecting to tons of addresses should not trip the deadline.
+               {nCopies("::1", 1000), []string{}, "", true, instant},
+       }
+
+       handler := func(dss *dualStackServer, ln Listener) {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               return
+                       }
+                       c.Close()
+               }
+       }
+
+       // Convert a list of IP strings into TCPAddrs.
+       makeAddrs := func(ips []string, port string) addrList {
+               var out addrList
+               for _, ip := range ips {
+                       addr, err := ResolveTCPAddr("tcp", JoinHostPort(ip, port))
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       out = append(out, addr)
+               }
+               return out
+       }
+
+       for i, tt := range testCases {
+               dss, err := newDualStackServer([]streamListener{
+                       {network: "tcp4", address: "127.0.0.1"},
+                       {network: "tcp6", address: "::1"},
+               })
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer dss.teardown()
+               if err := dss.buildup(handler); err != nil {
+                       t.Fatal(err)
+               }
+               if tt.teardownNetwork != "" {
+                       // Destroy one of the listening sockets, creating an unreachable port.
+                       dss.teardownNetwork(tt.teardownNetwork)
+               }
+
+               primaries := makeAddrs(tt.primaries, dss.port)
+               fallbacks := makeAddrs(tt.fallbacks, dss.port)
+               ctx := &dialContext{
+                       Dialer: Dialer{
+                               FallbackDelay: fallbackDelay,
+                               Timeout:       slowTimeout,
+                       },
+                       network: "tcp",
+                       address: "?",
+               }
+               startTime := time.Now()
+               c, err := dialParallel(ctx, primaries, fallbacks)
+               elapsed := time.Now().Sub(startTime)
+
+               if c != nil {
+                       c.Close()
+               }
+
+               if tt.expectOk && err != nil {
+                       t.Errorf("#%d: got %v; want nil", i, err)
+               } else if !tt.expectOk && err == nil {
+                       t.Errorf("#%d: got nil; want non-nil", i)
+               }
+
+               expectElapsedMin := tt.expectElapsed - 95*time.Millisecond
+               expectElapsedMax := tt.expectElapsed + 95*time.Millisecond
+               if !(elapsed >= expectElapsedMin) {
+                       t.Errorf("#%d: got %v; want >= %v", i, elapsed, expectElapsedMin)
+               } else if !(elapsed <= expectElapsedMax) {
+                       t.Errorf("#%d: got %v; want <= %v", i, elapsed, expectElapsedMax)
+               }
+       }
+       // Wait for any slowDst4/slowDst6 connections to timeout.
+       time.Sleep(slowTimeout * 3 / 2)
+}
+
+func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       switch host {
+       case "slow6loopback4":
+               // Returns a slow IPv6 address, and a local IPv4 address.
+               return []IPAddr{
+                       {IP: ParseIP(slowDst6)},
+                       {IP: ParseIP("127.0.0.1")},
+               }, nil
+       default:
+               return fn(host)
+       }
+}
+
+func TestDialerFallbackDelay(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("avoid external network")
+       }
+       if !supportsIPv4 || !supportsIPv6 {
+               t.Skip("both IPv4 and IPv6 are required")
+       }
+
+       origTestHookLookupIP := testHookLookupIP
+       defer func() { testHookLookupIP = origTestHookLookupIP }()
+       testHookLookupIP = lookupSlowFast
+
+       origTestHookDialTCP := testHookDialTCP
+       defer func() { testHookDialTCP = origTestHookDialTCP }()
+       testHookDialTCP = slowDialTCP
+
+       var testCases = []struct {
+               dualstack     bool
+               delay         time.Duration
+               expectElapsed time.Duration
+       }{
+               // Use a very brief delay, which should fallback immediately.
+               {true, 1 * time.Nanosecond, 0},
+               // Use a 200ms explicit timeout.
+               {true, 200 * time.Millisecond, 200 * time.Millisecond},
+               // The default is 300ms.
+               {true, 0, 300 * time.Millisecond},
+               // This case is last, in order to wait for hanging slowDst6 connections.
+               {false, 0, slowTimeout},
+       }
+
+       handler := func(dss *dualStackServer, ln Listener) {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               return
+                       }
+                       c.Close()
+               }
+       }
+       dss, err := newDualStackServer([]streamListener{
+               {network: "tcp", address: "127.0.0.1"},
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer dss.teardown()
+       if err := dss.buildup(handler); err != nil {
+               t.Fatal(err)
+       }
+
+       for i, tt := range testCases {
+               d := &Dialer{DualStack: tt.dualstack, FallbackDelay: tt.delay, Timeout: slowTimeout}
+
+               startTime := time.Now()
+               c, err := d.Dial("tcp", JoinHostPort("slow6loopback4", dss.port))
+               elapsed := time.Now().Sub(startTime)
+               if err == nil {
+                       c.Close()
+               } else if tt.dualstack {
+                       t.Error(err)
+               }
+               expectMin := tt.expectElapsed - 1*time.Millisecond
+               expectMax := tt.expectElapsed + 95*time.Millisecond
+               if !(elapsed >= expectMin) {
+                       t.Errorf("#%d: got %v; want >= %v", i, elapsed, expectMin)
+               }
+               if !(elapsed <= expectMax) {
+                       t.Errorf("#%d: got %v; want <= %v", i, elapsed, expectMax)
+               }
+       }
+}
+
+func TestDialSerialAsyncSpuriousConnection(t *testing.T) {
+       ln, err := newLocalListener("tcp")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer ln.Close()
+
+       ctx := &dialContext{
+               network: "tcp",
+               address: "?",
+       }
+
+       results := make(chan dialResult)
+       cancel := make(chan struct{})
+
+       // Spawn a connection in the background.
+       go dialSerialAsync(ctx, addrList{ln.Addr()}, nil, cancel, results)
+
+       // Receive it at the server.
+       c, err := ln.Accept()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer c.Close()
+
+       // Tell dialSerialAsync that someone else won the race.
+       close(cancel)
+
+       // The connection should close itself, without sending data.
+       c.SetReadDeadline(time.Now().Add(1 * time.Second))
+       var b [1]byte
+       if _, err := c.Read(b[:]); err != io.EOF {
+               t.Errorf("got %v; want %v", err, io.EOF)
+       }
+}
+
+func TestDialerPartialDeadline(t *testing.T) {
+       now := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)
+       var testCases = []struct {
+               now            time.Time
+               deadline       time.Time
+               addrs          int
+               expectDeadline time.Time
+               expectErr      error
+       }{
+               // Regular division.
+               {now, now.Add(12 * time.Second), 1, now.Add(12 * time.Second), nil},
+               {now, now.Add(12 * time.Second), 2, now.Add(6 * time.Second), nil},
+               {now, now.Add(12 * time.Second), 3, now.Add(4 * time.Second), nil},
+               // Bump against the 2-second sane minimum.
+               {now, now.Add(12 * time.Second), 999, now.Add(2 * time.Second), nil},
+               // Total available is now below the sane minimum.
+               {now, now.Add(1900 * time.Millisecond), 999, now.Add(1900 * time.Millisecond), nil},
+               // Null deadline.
+               {now, noDeadline, 1, noDeadline, nil},
+               // Step the clock forward and cross the deadline.
+               {now.Add(-1 * time.Millisecond), now, 1, now, nil},
+               {now.Add(0 * time.Millisecond), now, 1, noDeadline, errTimeout},
+               {now.Add(1 * time.Millisecond), now, 1, noDeadline, errTimeout},
+       }
+       for i, tt := range testCases {
+               d := Dialer{Deadline: tt.deadline}
+               deadline, err := d.partialDeadline(tt.now, tt.addrs)
+               if err != tt.expectErr {
+                       t.Errorf("#%d: got %v; want %v", i, err, tt.expectErr)
+               }
+               if deadline != tt.expectDeadline {
+                       t.Errorf("#%d: got %v; want %v", i, deadline, tt.expectDeadline)
+               }
+       }
+}
+
 func TestDialerLocalAddr(t *testing.T) {
        ch := make(chan error, 1)
        handler := func(ls *localServer, ln Listener) {
@@ -262,33 +617,36 @@ func TestDialerDualStack(t *testing.T) {
                        c.Close()
                }
        }
-       dss, err := newDualStackServer([]streamListener{
-               {network: "tcp4", address: "127.0.0.1"},
-               {network: "tcp6", address: "::1"},
-       })
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer dss.teardown()
-       if err := dss.buildup(handler); err != nil {
-               t.Fatal(err)
-       }
 
        const T = 100 * time.Millisecond
-       d := &Dialer{DualStack: true, Timeout: T}
-       for range dss.lns {
-               c, err := d.Dial("tcp", JoinHostPort("localhost", dss.port))
+       for _, dualstack := range []bool{false, true} {
+               dss, err := newDualStackServer([]streamListener{
+                       {network: "tcp4", address: "127.0.0.1"},
+                       {network: "tcp6", address: "::1"},
+               })
                if err != nil {
-                       t.Error(err)
-                       continue
+                       t.Fatal(err)
                }
-               switch addr := c.LocalAddr().(*TCPAddr); {
-               case addr.IP.To4() != nil:
-                       dss.teardownNetwork("tcp4")
-               case addr.IP.To16() != nil && addr.IP.To4() == nil:
-                       dss.teardownNetwork("tcp6")
+               defer dss.teardown()
+               if err := dss.buildup(handler); err != nil {
+                       t.Fatal(err)
+               }
+
+               d := &Dialer{DualStack: dualstack, Timeout: T}
+               for range dss.lns {
+                       c, err := d.Dial("tcp", JoinHostPort("localhost", dss.port))
+                       if err != nil {
+                               t.Error(err)
+                               continue
+                       }
+                       switch addr := c.LocalAddr().(*TCPAddr); {
+                       case addr.IP.To4() != nil:
+                               dss.teardownNetwork("tcp4")
+                       case addr.IP.To16() != nil && addr.IP.To4() == nil:
+                               dss.teardownNetwork("tcp6")
+                       }
+                       c.Close()
                }
-               c.Close()
        }
        time.Sleep(2 * T) // wait for the dial racers to stop
 }
index f8de28b8bc07fa78a9a143ac05217807fb70ca24..9ab34c0e36faf6d0451e9c4869c1eca7bb2411fe 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 var (
+       testHookDialTCP      = dialTCP
        testHookHostsPath    = "/etc/hosts"
        testHookLookupIP     = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { return fn(host) }
        testHookSetKeepAlive = func() {}
index fbeac81d2707c75f71f6e6de56a4446aa821678d..cd1372fd02423e832ce6765157be953b84837908 100644 (file)
@@ -321,6 +321,7 @@ var (
 
        // For both read and write operations.
        errTimeout          error = &timeoutError{}
+       errCanceled               = errors.New("operation was canceled")
        errClosing                = errors.New("use of closed network connection")
        ErrWriteToConnected       = errors.New("use of WriteTo with pre-connected connection")
 )