]> Cypherpunks repositories - gostls13.git/commitdiff
net: context plumbing, add Dialer.DialContext
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2016 00:47:25 +0000 (17:47 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Apr 2016 22:48:12 +0000 (22:48 +0000)
For #12580 (http.Transport tracing/analytics)
Updates #13021

Change-Id: I126e494a7bd872e42c388ecb58499ecbf0f014cc
Reviewed-on: https://go-review.googlesource.com/22101
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Mikio Hara <mikioh.mikioh@gmail.com>
33 files changed:
src/go/build/deps_test.go
src/net/cgo_unix_test.go
src/net/dial.go
src/net/dial_test.go
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/error_test.go
src/net/fd_unix.go
src/net/fd_windows.go
src/net/hook.go
src/net/iprawsock.go
src/net/iprawsock_plan9.go
src/net/iprawsock_posix.go
src/net/ipsock.go
src/net/ipsock_posix.go
src/net/lookup.go
src/net/lookup_plan9.go
src/net/lookup_stub.go
src/net/lookup_test.go
src/net/lookup_unix.go
src/net/lookup_windows.go
src/net/net.go
src/net/netgo_unix_test.go
src/net/sock_posix.go
src/net/tcpsock.go
src/net/tcpsock_plan9.go
src/net/tcpsock_posix.go
src/net/udpsock.go
src/net/udpsock_plan9.go
src/net/udpsock_posix.go
src/net/unixsock.go
src/net/unixsock_plan9.go
src/net/unixsock_posix.go

index f1d19bb50cc901c3ca100af5e8bc237580881e9f..2db5ba67d1f769fe6d627d5bcb18c0b2e2098215 100644 (file)
@@ -280,7 +280,9 @@ var pkgDeps = map[string][]string{
        // Basic networking.
        // Because net must be used by any package that wants to
        // do networking portably, it must have a small dependency set: just L0+basic os.
-       "net": {"L0", "CGO", "math/rand", "os", "sort", "syscall", "time", "internal/syscall/windows", "internal/singleflight", "internal/race"},
+       "net": {"L0", "CGO",
+               "context", "math/rand", "os", "sort", "syscall", "time",
+               "internal/syscall/windows", "internal/singleflight", "internal/race"},
 
        // NET enables use of basic network-related packages.
        "NET": {
index 4d5ab23fd3647fa375102e91d7664b8903fc4b08..5dc7b1a62d466312a9249eca9ad748a6ee07e852 100644 (file)
@@ -7,7 +7,10 @@
 
 package net
 
-import "testing"
+import (
+       "context"
+       "testing"
+)
 
 func TestCgoLookupIP(t *testing.T) {
        host := "localhost"
@@ -18,7 +21,7 @@ func TestCgoLookupIP(t *testing.T) {
        if err != nil {
                t.Error(err)
        }
-       if _, err := goLookupIP(host); err != nil {
+       if _, err := goLookupIP(context.Background(), host); err != nil {
                t.Error(err)
        }
 }
index 22992d5b7a95a66f0f4f5af315fc776cb8b2a1f7..1f31e8f2cc73a4100608ae833997c6b9a427273b 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "runtime"
        "time"
 )
@@ -61,21 +62,34 @@ type Dialer struct {
        // Cancel is an optional channel whose closure indicates that
        // the dial should be canceled. Not all types of dials support
        // cancelation.
+       //
+       // Deprecated: Use DialContext instead.
        Cancel <-chan struct{}
 }
 
-// Return either now+Timeout or Deadline, whichever comes first.
-// Or zero, if neither is set.
-func (d *Dialer) deadline(now time.Time) time.Time {
-       if d.Timeout == 0 {
-               return d.Deadline
+func minNonzeroTime(a, b time.Time) time.Time {
+       if a.IsZero() {
+               return b
        }
-       timeoutDeadline := now.Add(d.Timeout)
-       if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) {
-               return timeoutDeadline
-       } else {
-               return d.Deadline
+       if b.IsZero() || a.Before(b) {
+               return a
        }
+       return b
+}
+
+// deadline returns the earliest of:
+//   - now+Timeout
+//   - d.Deadline
+//   - the context's deadline
+// Or zero, if none of Timeout, Deadline, or context's deadline is set.
+func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
+       if d.Timeout != 0 { // including negative, for historical reasons
+               earliest = now.Add(d.Timeout)
+       }
+       if d, ok := ctx.Deadline(); ok {
+               earliest = minNonzeroTime(earliest, d)
+       }
+       return minNonzeroTime(earliest, d.Deadline)
 }
 
 // partialDeadline returns the deadline to use for a single address,
@@ -142,7 +156,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
 // resolverAddrList resolves addr using hint and returns a list of
 // addresses. The result contains at least one address when error is
 // nil.
-func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) {
+func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
        afnet, _, err := parseNetwork(network)
        if err != nil {
                return nil, err
@@ -152,6 +166,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a
        }
        switch afnet {
        case "unix", "unixgram", "unixpacket":
+               // TODO(bradfitz): push down context
                addr, err := ResolveUnixAddr(afnet, addr)
                if err != nil {
                        return nil, err
@@ -161,7 +176,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a
                }
                return addrList{addr}, nil
        }
-       addrs, err := internetAddrList(afnet, addr, deadline)
+       addrs, err := internetAddrList(ctx, afnet, addr)
        if err != nil || op != "dial" || hint == nil {
                return addrs, err
        }
@@ -253,11 +268,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
        return d.Dial(network, address)
 }
 
-// dialContext holds common state for all dial operations.
-type dialContext struct {
+// dialParam contains a Dial's parameters and configuration.
+type dialParam struct {
        Dialer
        network, address string
-       finalDeadline    time.Time
 }
 
 // Dial connects to the address on the named network.
@@ -265,17 +279,50 @@ type dialContext struct {
 // See func Dial for a description of the network and address
 // parameters.
 func (d *Dialer) Dial(network, address string) (Conn, error) {
-       finalDeadline := d.deadline(time.Now())
-       addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline)
+       return d.DialContext(context.Background(), network, address)
+}
+
+// DialContext connects to the address on the named network using
+// the provided context.
+//
+// The provided Context must be non-nil.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
+       if ctx == nil {
+               panic("nil context")
+       }
+       deadline := d.deadline(ctx, time.Now())
+       if !deadline.IsZero() {
+               if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
+                       subCtx, cancel := context.WithDeadline(ctx, deadline)
+                       defer cancel()
+                       ctx = subCtx
+               }
+       }
+       if oldCancel := d.Cancel; oldCancel != nil {
+               subCtx, cancel := context.WithCancel(ctx)
+               defer cancel()
+               go func() {
+                       select {
+                       case <-oldCancel:
+                               cancel()
+                       case <-subCtx.Done():
+                       }
+               }()
+               ctx = subCtx
+       }
+
+       addrs, err := resolveAddrList(ctx, "dial", network, address, d.LocalAddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
        }
 
-       ctx := &dialContext{
-               Dialer:        *d,
-               network:       network,
-               address:       address,
-               finalDeadline: finalDeadline,
+       dp := &dialParam{
+               Dialer:  *d,
+               network: network,
+               address: address,
        }
 
        // DualStack mode requires that dialTCP support cancelation. This is
@@ -288,138 +335,128 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
        }
 
        var c Conn
-       if len(fallbacks) == 0 {
-               // dialParallel can accept an empty fallbacks list,
-               // but this shortcut avoids the goroutine/channel overhead.
-               c, err = dialSerial(ctx, primaries, ctx.Cancel)
+       if len(fallbacks) > 0 {
+               c, err = dialParallel(ctx, dp, primaries, fallbacks)
        } else {
-               c, err = dialParallel(ctx, primaries, fallbacks, ctx.Cancel)
+               c, err = dialSerial(ctx, dp, primaries)
+       }
+       if err != nil {
+               return nil, err
        }
 
-       if d.KeepAlive > 0 && err == nil {
-               if tc, ok := c.(*TCPConn); ok {
-                       setKeepAlive(tc.fd, true)
-                       setKeepAlivePeriod(tc.fd, d.KeepAlive)
-                       testHookSetKeepAlive()
-               }
+       if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 {
+               setKeepAlive(tc.fd, true)
+               setKeepAlivePeriod(tc.fd, d.KeepAlive)
+               testHookSetKeepAlive()
        }
-       return c, err
+       return c, nil
 }
 
 // dialParallel races two copies of dialSerial, giving the first a
 // head start. It returns the first established connection and
 // closes the others. Otherwise it returns an error from the first
 // primary address.
-func dialParallel(ctx *dialContext, primaries, fallbacks addrList, userCancel <-chan struct{}) (Conn, error) {
-       results := make(chan dialResult, 2)
-       cancel := make(chan struct{})
+func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
+       if len(fallbacks) == 0 {
+               return dialSerial(ctx, dp, primaries)
+       }
 
-       // Spawn the primary racer.
-       go dialSerialAsync(ctx, primaries, nil, cancel, results)
+       returned := make(chan struct{})
+       defer close(returned)
 
-       // Spawn the fallback racer.
-       fallbackTimer := time.NewTimer(ctx.fallbackDelay())
-       go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results)
+       type dialResult struct {
+               Conn
+               error
+               primary bool
+               done    bool
+       }
+       results := make(chan dialResult) // unbuffered
 
-       // Wait for both racers to succeed or fail.
-       var primaryResult, fallbackResult dialResult
-       for !primaryResult.done || !fallbackResult.done {
+       startRacer := func(ctx context.Context, primary bool) {
+               ras := primaries
+               if !primary {
+                       ras = fallbacks
+               }
+               c, err := dialSerial(ctx, dp, ras)
                select {
-               case <-userCancel:
-                       // Forward an external cancelation request.
-                       if cancel != nil {
-                               close(cancel)
-                               cancel = nil
+               case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
+               case <-returned:
+                       if c != nil {
+                               c.Close()
                        }
-                       userCancel = nil
+               }
+       }
+
+       var primary, fallback dialResult
+
+       // Start the main racer.
+       primaryCtx, primaryCancel := context.WithCancel(ctx)
+       defer primaryCancel()
+       go startRacer(primaryCtx, true)
+
+       // Start the timer for the fallback racer.
+       fallbackTimer := time.NewTimer(dp.fallbackDelay())
+       defer fallbackTimer.Stop()
+
+       for {
+               select {
+               case <-fallbackTimer.C:
+                       fallbackCtx, fallbackCancel := context.WithCancel(ctx)
+                       defer fallbackCancel()
+                       go startRacer(fallbackCtx, false)
+
                case res := <-results:
-                       // Drop the result into its assigned bucket.
+                       if res.error == nil {
+                               return res.Conn, nil
+                       }
                        if res.primary {
-                               primaryResult = res
+                               primary = res
                        } else {
-                               fallbackResult = res
+                               fallback = res
                        }
-                       // On success, cancel the other racer (if one exists.)
-                       if res.error == nil && cancel != nil {
-                               close(cancel)
-                               cancel = nil
+                       if primary.done && fallback.done {
+                               return nil, primary.error
                        }
-                       // If the fallbackTimer was pending, then either we've canceled the
-                       // fallback because we no longer want it, or we haven't canceled yet
-                       // and therefore want it to wake up immediately.
-                       if fallbackTimer.Stop() && cancel != nil {
+                       if res.primary && fallbackTimer.Stop() {
+                               // If we were able to stop the timer, that means it
+                               // was running (hadn't yet started the fallback), but
+                               // we just got an error on the primary path, so start
+                               // the fallback immediately (in 0 nanoseconds).
                                fallbackTimer.Reset(0)
                        }
                }
        }
-
-       // Return, in order of preference:
-       // 1. The primary connection (but close the other if we got both.)
-       // 2. The fallback connection.
-       // 3. The primary error.
-       if primaryResult.error == nil {
-               if fallbackResult.error == nil {
-                       fallbackResult.Conn.Close()
-               }
-               return primaryResult.Conn, nil
-       } else if fallbackResult.error == nil {
-               return fallbackResult.Conn, nil
-       } else {
-               return nil, primaryResult.error
-       }
-}
-
-type dialResult struct {
-       Conn
-       error
-       primary bool
-       done    bool
-}
-
-// dialSerialAsync runs dialSerial after some delay, and returns the
-// resulting connection through a channel. When racing two connections,
-// the primary goroutine uses a nil timer to omit the delay.
-func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) {
-       if timer != nil {
-               // We're in the fallback goroutine; sleep before connecting.
-               select {
-               case <-timer.C:
-               case <-cancel:
-                       // dialSerial will immediately return errCanceled in this case.
-               }
-       }
-       c, err := dialSerial(ctx, ras, cancel)
-       results <- dialResult{Conn: c, error: err, primary: timer == nil, done: true}
 }
 
 // dialSerial connects to a list of addresses in sequence, returning
 // either the first successful connection, or the first error.
-func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) {
+func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
        var firstErr error // The error from the first address is most relevant.
 
        for i, ra := range ras {
                select {
-               case <-cancel:
-                       return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled}
+               case <-ctx.Done():
+                       return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
                default:
                }
 
-               partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i)
+               deadline, _ := ctx.Deadline()
+               partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
                if err != nil {
                        // Ran out of time.
                        if firstErr == nil {
-                               firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err}
+                               firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
                        }
                        break
                }
-
-               // If this dial is canceled, the implementation is expected to complete
-               // quickly, but it's still possible that we could return a spurious Conn,
-               // which the caller must Close.
-               dialer := func(d time.Time) (Conn, error) {
-                       return dialSingle(ctx, ra, d, cancel)
+               dialCtx := ctx
+               if partialDeadline.Before(deadline) {
+                       var cancel context.CancelFunc
+                       dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+                       defer cancel()
                }
-               c, err := dial(ctx.network, ra, dialer, partialDeadline)
+
+               c, err := dialSingle(dialCtx, dp, ra)
                if err == nil {
                        return c, nil
                }
@@ -429,7 +466,7 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
        }
 
        if firstErr == nil {
-               firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress}
+               firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
        }
        return nil, firstErr
 }
@@ -437,26 +474,26 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
 // dialSingle attempts to establish and returns a single connection to
 // the destination address. This must be called through the OS-specific
 // dial function, because some OSes don't implement the deadline feature.
-func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) {
-       la := ctx.LocalAddr
+func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
+       la := dp.LocalAddr
        switch ra := ra.(type) {
        case *TCPAddr:
                la, _ := la.(*TCPAddr)
-               c, err = testHookDialTCP(ctx.network, la, ra, deadline, cancel)
+               c, err = dialTCP(ctx, dp.network, la, ra)
        case *UDPAddr:
                la, _ := la.(*UDPAddr)
-               c, err = dialUDP(ctx.network, la, ra, deadline)
+               c, err = dialUDP(ctx, dp.network, la, ra)
        case *IPAddr:
                la, _ := la.(*IPAddr)
-               c, err = dialIP(ctx.network, la, ra, deadline)
+               c, err = dialIP(ctx, dp.network, la, ra)
        case *UnixAddr:
                la, _ := la.(*UnixAddr)
-               c, err = dialUnix(ctx.network, la, ra, deadline)
+               c, err = dialUnix(ctx, dp.network, la, ra)
        default:
-               return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}}
+               return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
        }
        if err != nil {
-               return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
+               return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
        }
        return c, nil
 }
@@ -469,7 +506,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str
 // instead of just the interface with the given host address.
 // See Dial for more details about address syntax.
 func Listen(net, laddr string) (Listener, error) {
-       addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
+       addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
        }
@@ -496,7 +533,7 @@ func Listen(net, laddr string) (Listener, error) {
 // instead of just the interface with the given host address.
 // See Dial for the syntax of laddr.
 func ListenPacket(net, laddr string) (PacketConn, error) {
-       addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
+       addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
        }
index d4f04e0a4fbad375427e98e826e295770efa8fc0..eb145f476cd6ee2d90b87a554a0c7d303d7477f8 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "bufio"
+       "context"
        "internal/testenv"
        "io"
        "net/internal/socktest"
@@ -193,18 +194,11 @@ const (
 // In some environments, the slow IPs may be explicitly unreachable, and fail
 // more quickly than expected. This test hook prevents dialTCP from returning
 // before the deadline.
-func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
-       c, err := dialTCP(net, laddr, raddr, deadline, cancel)
+func slowDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+       c, err := doDialTCP(ctx, net, laddr, raddr)
        if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
                // Wait for the deadline, or indefinitely if none exists.
-               var wait <-chan time.Time
-               if !deadline.IsZero() {
-                       wait = time.After(deadline.Sub(time.Now()))
-               }
-               select {
-               case <-cancel:
-               case <-wait:
-               }
+               <-ctx.Done()
        }
        return c, err
 }
@@ -356,15 +350,14 @@ func TestDialParallel(t *testing.T) {
                d := Dialer{
                        FallbackDelay: fallbackDelay,
                }
-               ctx := &dialContext{
-                       Dialer:        d,
-                       network:       "tcp",
-                       address:       "?",
-                       finalDeadline: d.deadline(time.Now()),
-               }
                startTime := time.Now()
-               c, err := dialParallel(ctx, primaries, fallbacks, nil)
-               elapsed := time.Now().Sub(startTime)
+               dp := &dialParam{
+                       Dialer:  d,
+                       network: "tcp",
+                       address: "?",
+               }
+               c, err := dialParallel(context.Background(), dp, primaries, fallbacks)
+               elapsed := time.Since(startTime)
 
                if c != nil {
                        c.Close()
@@ -385,16 +378,16 @@ func TestDialParallel(t *testing.T) {
                }
 
                // Repeat each case, ensuring that it can be canceled quickly.
-               cancel := make(chan struct{})
+               ctx, cancel := context.WithCancel(context.Background())
                var wg sync.WaitGroup
                wg.Add(1)
                go func() {
                        time.Sleep(5 * time.Millisecond)
-                       close(cancel)
+                       cancel()
                        wg.Done()
                }()
                startTime = time.Now()
-               c, err = dialParallel(ctx, primaries, fallbacks, cancel)
+               c, err = dialParallel(ctx, dp, primaries, fallbacks)
                if c != nil {
                        c.Close()
                }
@@ -406,7 +399,7 @@ func TestDialParallel(t *testing.T) {
        }
 }
 
-func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
        switch host {
        case "slow6loopback4":
                // Returns a slow IPv6 address, and a local IPv4 address.
@@ -415,7 +408,7 @@ func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, e
                        {IP: ParseIP("127.0.0.1")},
                }, nil
        default:
-               return fn(host)
+               return fn(ctx, host)
        }
 }
 
@@ -530,22 +523,24 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
 
        origTestHookDialTCP := testHookDialTCP
        defer func() { testHookDialTCP = origTestHookDialTCP }()
-       testHookDialTCP = func(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
+       testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
                // Sleep long enough for Happy Eyeballs to kick in, and inhibit cancelation.
                // This forces dialParallel to juggle two successful connections.
                time.Sleep(fallbackDelay * 2)
-               cancel = nil
-               return dialTCP(net, laddr, raddr, deadline, cancel)
+
+               // Now ignore the provided context (which will be canceled) and use a
+               // different one to make sure this completes with a valid connection,
+               // which we hope to be closed below:
+               return doDialTCP(context.Background(), net, laddr, raddr)
        }
 
        d := Dialer{
                FallbackDelay: fallbackDelay,
        }
-       ctx := &dialContext{
-               Dialer:        d,
-               network:       "tcp",
-               address:       "?",
-               finalDeadline: d.deadline(time.Now()),
+       dp := &dialParam{
+               Dialer:  d,
+               network: "tcp",
+               address: "?",
        }
 
        makeAddr := func(ip string) addrList {
@@ -557,7 +552,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
        }
 
        // dialParallel returns one connection (and closes the other.)
-       c, err := dialParallel(ctx, makeAddr("127.0.0.1"), makeAddr("::1"), nil)
+       c, err := dialParallel(context.Background(), dp, makeAddr("127.0.0.1"), makeAddr("::1"))
        if err != nil {
                t.Fatal(err)
        }
index cf0b2680dbcdec5d6f8f99e5afb0c62e458600e6..914dd767d33bb26854310a15e9a99a34afe4fb10 100644 (file)
@@ -16,6 +16,7 @@
 package net
 
 import (
+       "context"
        "errors"
        "io"
        "math/rand"
@@ -399,11 +400,11 @@ func (o hostLookupOrder) String() string {
 // Normally we let cgo use the C library resolver instead of
 // depending on our lookup code, so that Go and C get the same
 // answers.
-func goLookupHost(name string) (addrs []string, err error) {
-       return goLookupHostOrder(name, hostLookupFilesDNS)
+func goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
+       return goLookupHostOrder(ctx, name, hostLookupFilesDNS)
 }
 
-func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) {
+func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
        if order == hostLookupFilesDNS || order == hostLookupFiles {
                // Use entries from /etc/hosts if they match.
                addrs = lookupStaticHost(name)
@@ -411,7 +412,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err
                        return
                }
        }
-       ips, err := goLookupIPOrder(name, order)
+       ips, err := goLookupIPOrder(ctx, name, order)
        if err != nil {
                return
        }
@@ -437,11 +438,11 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
 
 // goLookupIP is the native Go implementation of LookupIP.
 // The libc versions are in cgo_*.go.
-func goLookupIP(name string) (addrs []IPAddr, err error) {
-       return goLookupIPOrder(name, hostLookupFilesDNS)
+func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) {
+       return goLookupIPOrder(ctx, name, hostLookupFilesDNS)
 }
 
-func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) {
+func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) {
        if order == hostLookupFilesDNS || order == hostLookupFiles {
                addrs = goLookupIPFiles(name)
                if len(addrs) > 0 || order == hostLookupFiles {
index edf7c00f72dd73b0a0e56986bdf7d0b52b893a45..145a3b6a33b1b4abe7c9c4d1da7abd48d2c50f6c 100644 (file)
@@ -7,6 +7,7 @@
 package net
 
 import (
+       "context"
        "fmt"
        "internal/testenv"
        "io/ioutil"
@@ -133,7 +134,7 @@ func TestAvoidDNSName(t *testing.T) {
 
 // Issue 13705: don't try to resolve onion addresses, etc
 func TestLookupTorOnion(t *testing.T) {
-       addrs, err := goLookupIP("foo.onion")
+       addrs, err := goLookupIP(context.Background(), "foo.onion")
        if len(addrs) > 0 {
                t.Errorf("unexpected addresses: %v", addrs)
        }
@@ -249,7 +250,7 @@ func TestUpdateResolvConf(t *testing.T) {
                        for j := 0; j < N; j++ {
                                go func(name string) {
                                        defer wg.Done()
-                                       ips, err := goLookupIP(name)
+                                       ips, err := goLookupIP(context.Background(), name)
                                        if err != nil {
                                                t.Error(err)
                                                return
@@ -397,7 +398,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
                        t.Error(err)
                        continue
                }
-               addrs, err := goLookupIP(tt.name)
+               addrs, err := goLookupIP(context.Background(), tt.name)
                if err != nil {
                        // This test uses external network connectivity.
                        // We need to take care with errors on both
@@ -447,14 +448,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
                name := fmt.Sprintf("order %v", order)
 
                // First ensure that we get an error when contacting a non-existent host.
-               _, err := goLookupIPOrder("notarealhost", order)
+               _, err := goLookupIPOrder(context.Background(), "notarealhost", order)
                if err == nil {
                        t.Errorf("%s: expected error while looking up name not in hosts file", name)
                        continue
                }
 
                // Now check that we get an address when the name appears in the hosts file.
-               addrs, err := goLookupIPOrder("thor", order) // entry is in "testdata/hosts"
+               addrs, err := goLookupIPOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
                if err != nil {
                        t.Errorf("%s: expected to successfully lookup host entry", name)
                        continue
@@ -510,7 +511,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
                return r, nil
        }
 
-       _, err = goLookupIP(fqdn)
+       _, err = goLookupIP(context.Background(), fqdn)
        if err == nil {
                t.Fatal("expected an error")
        }
@@ -523,17 +524,19 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
 
 func BenchmarkGoLookupIP(b *testing.B) {
        testHookUninstaller.Do(uninstallTestHooks)
+       ctx := context.Background()
 
        for i := 0; i < b.N; i++ {
-               goLookupIP("www.example.com")
+               goLookupIP(ctx, "www.example.com")
        }
 }
 
 func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
        testHookUninstaller.Do(uninstallTestHooks)
+       ctx := context.Background()
 
        for i := 0; i < b.N; i++ {
-               goLookupIP("some.nonexistent")
+               goLookupIP(ctx, "some.nonexistent")
        }
 }
 
@@ -553,9 +556,10 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
        if err := conf.writeAndUpdate(lines); err != nil {
                b.Fatal(err)
        }
+       ctx := context.Background()
 
        for i := 0; i < b.N; i++ {
-               goLookupIP("www.example.com")
+               goLookupIP(ctx, "www.example.com")
        }
 }
 
index 31cc0e505563ca1fc8c6d10133a2ac9e3358bceb..9f496d7d2db5e8c41a3fcbd7bd237afe3e81b8aa 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "fmt"
        "io"
        "io/ioutil"
@@ -138,7 +139,7 @@ func TestDialError(t *testing.T) {
 
        origTestHookLookupIP := testHookLookupIP
        defer func() { testHookLookupIP = origTestHookLookupIP }()
-       testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       testHookLookupIP = func(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
                return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true}
        }
        sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
@@ -283,7 +284,7 @@ func TestListenError(t *testing.T) {
 
        origTestHookLookupIP := testHookLookupIP
        defer func() { testHookLookupIP = origTestHookLookupIP }()
-       testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
                return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
        }
        sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) {
@@ -343,7 +344,7 @@ func TestListenPacketError(t *testing.T) {
 
        origTestHookLookupIP := testHookLookupIP
        defer func() { testHookLookupIP = origTestHookLookupIP }()
-       testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
                return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
        }
 
index d47b4bef996a66b8a9c62f43ab5ca6e328381c58..7ef10702ed8c08112fc8382bd1958ca91f9e9176 100644 (file)
@@ -7,12 +7,12 @@
 package net
 
 import (
+       "context"
        "io"
        "os"
        "runtime"
        "sync/atomic"
        "syscall"
-       "time"
 )
 
 // Network file descriptor.
@@ -36,10 +36,6 @@ type netFD struct {
 func sysInit() {
 }
 
-func dial(network string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) {
-       return dialer(deadline)
-}
-
 func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
        return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
 }
@@ -68,15 +64,17 @@ func (fd *netFD) name() string {
        return fd.net + ":" + ls + "->" + rs
 }
 
-func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
        // Do not need to call fd.writeLock here,
        // because fd is not yet accessible to user,
        // so no concurrent operations are possible.
        switch err := connectFunc(fd.sysfd, ra); err {
        case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
        case nil, syscall.EISCONN:
-               if !deadline.IsZero() && deadline.Before(time.Now()) {
-                       return errTimeout
+               select {
+               case <-ctx.Done():
+                       return mapErr(ctx.Err())
+               default:
                }
                if err := fd.init(); err != nil {
                        return err
@@ -98,27 +96,27 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
        if err := fd.init(); err != nil {
                return err
        }
-       if !deadline.IsZero() {
+       if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
                fd.setWriteDeadline(deadline)
                defer fd.setWriteDeadline(noDeadline)
        }
-       if cancel != nil {
-               done := make(chan bool)
-               defer func() {
-                       // This is unbuffered; wait for the goroutine before returning.
-                       done <- true
-               }()
-               go func() {
-                       select {
-                       case <-cancel:
-                               // Force the runtime's poller to immediately give
-                               // up waiting for writability.
-                               fd.setWriteDeadline(aLongTimeAgo)
-                               <-done
-                       case <-done:
-                       }
-               }()
-       }
+
+       // Wait for the goroutine converting context.Done into a write timeout
+       // to exist, otherwise our caller might cancel the context and
+       // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
+       done := make(chan bool) // must be unbuffered
+       defer func() { done <- true }()
+       go func() {
+               select {
+               case <-ctx.Done():
+                       // Force the runtime's poller to immediately give
+                       // up waiting for writability.
+                       fd.setWriteDeadline(aLongTimeAgo)
+                       <-done
+               case <-done:
+               }
+       }()
+
        for {
                // Performing multiple connect system calls on a
                // non-blocking socket under Unix variants does not
@@ -130,8 +128,8 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
                // details.
                if err := fd.pd.waitWrite(); err != nil {
                        select {
-                       case <-cancel:
-                               return errCanceled
+                       case <-ctx.Done():
+                               return mapErr(ctx.Err())
                        default:
                        }
                        return err
index 100994525eb33cbd7d93ffe2724cb7d6496d36e7..d1d91a6a5c532e07361cc8cd54c797084ebea31b 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "internal/race"
        "os"
        "runtime"
@@ -320,14 +321,14 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
        runtime.SetFinalizer(fd, (*netFD).Close)
 }
 
-func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
        // Do not need to call fd.writeLock here,
        // because fd is not yet accessible to user,
        // so no concurrent operations are possible.
        if err := fd.init(); err != nil {
                return err
        }
-       if !deadline.IsZero() {
+       if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
                fd.setWriteDeadline(deadline)
                defer fd.setWriteDeadline(noDeadline)
        }
@@ -351,30 +352,30 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
        // Call ConnectEx API.
        o := &fd.wop
        o.sa = ra
-       if cancel != nil {
-               done := make(chan bool)
-               defer func() {
-                       // This is unbuffered; wait for the goroutine before returning.
-                       done <- true
-               }()
-               go func() {
-                       select {
-                       case <-cancel:
-                               // Force the runtime's poller to immediately give
-                               // up waiting for writability.
-                               fd.setWriteDeadline(aLongTimeAgo)
-                               <-done
-                       case <-done:
-                       }
-               }()
-       }
+
+       // Wait for the goroutine converting context.Done into a write timeout
+       // to exist, otherwise our caller might cancel the context and
+       // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
+       done := make(chan bool) // must be unbuffered
+       defer func() { done <- true }()
+       go func() {
+               select {
+               case <-ctx.Done():
+                       // Force the runtime's poller to immediately give
+                       // up waiting for writability.
+                       fd.setWriteDeadline(aLongTimeAgo)
+                       <-done
+               case <-done:
+               }
+       }()
+
        _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
                return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
        })
        if err != nil {
                select {
-               case <-cancel:
-                       return errCanceled
+               case <-ctx.Done():
+                       return mapErr(ctx.Err())
                default:
                        if _, ok := err.(syscall.Errno); ok {
                                err = os.NewSyscallError("connectex", err)
index 9ab34c0e36faf6d0451e9c4869c1eca7bb2411fe..d7316ea4383f5c359ee4e05cf25158719dce8a17 100644 (file)
@@ -4,9 +4,19 @@
 
 package net
 
+import "context"
+
 var (
-       testHookDialTCP      = dialTCP
-       testHookHostsPath    = "/etc/hosts"
-       testHookLookupIP     = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { return fn(host) }
+       // if non-nil, overrides dialTCP.
+       testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
+
+       testHookHostsPath = "/etc/hosts"
+       testHookLookupIP  = func(
+               ctx context.Context,
+               fn func(context.Context, string) ([]IPAddr, error),
+               host string,
+       ) ([]IPAddr, error) {
+               return fn(ctx, host)
+       }
        testHookSetKeepAlive = func() {}
 )
index 41cfb2311add3adcc1a059351719693ed0277cf2..f4a4de82fcdfa034dc49304381925495f9fa4ebc 100644 (file)
@@ -4,7 +4,10 @@
 
 package net
 
-import "syscall"
+import (
+       "context"
+       "syscall"
+)
 
 // IPAddr represents the address of an IP end point.
 type IPAddr struct {
@@ -56,7 +59,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       addrs, err := internetAddrList(afnet, addr, noDeadline)
+       addrs, err := internetAddrList(context.Background(), afnet, addr)
        if err != nil {
                return nil, err
        }
@@ -171,7 +174,7 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
 // netProto, which must be "ip", "ip4", or "ip6" followed by a colon
 // and a protocol number or name.
 func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
-       c, err := dialIP(netProto, laddr, raddr, noDeadline)
+       c, err := dialIP(context.Background(), netProto, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
@@ -183,7 +186,7 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
 // methods can be used to receive and send IP packets with per-packet
 // addressing.
 func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
-       c, err := listenIP(netProto, laddr)
+       c, err := listenIP(context.Background(), netProto, laddr)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err}
        }
index e08f271e9b4c905f3f766889dfc2e88d4f39d053..6aebea169ca5121172143179d2046a3f165fec0b 100644 (file)
@@ -5,8 +5,8 @@
 package net
 
 import (
+       "context"
        "syscall"
-       "time"
 )
 
 func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
@@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
        return 0, 0, syscall.EPLAN9
 }
 
-func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) {
+func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
        return nil, syscall.EPLAN9
 }
 
-func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
+func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
        return nil, syscall.EPLAN9
 }
index b959afe6b4af11dd23f8772f94e5a6134e0c1fce..68dc307b60685c92e06ceba11cfb2a3300b93e01 100644 (file)
@@ -7,8 +7,8 @@
 package net
 
 import (
+       "context"
        "syscall"
-       "time"
 )
 
 // BUG(mikio): On every POSIX platform, reads from the "ip4" network
@@ -120,7 +120,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
        return c.fd.writeMsg(b, oob, sa)
 }
 
-func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) {
+func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
        network, proto, err := parseNetwork(netProto)
        if err != nil {
                return nil, err
@@ -133,14 +133,14 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn,
        if raddr == nil {
                return nil, errMissingAddress
        }
-       fd, err := internetSocket(network, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel)
+       fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
        if err != nil {
                return nil, err
        }
        return newIPConn(fd), nil
 }
 
-func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
+func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
        network, proto, err := parseNetwork(netProto)
        if err != nil {
                return nil, err
@@ -150,7 +150,7 @@ func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
        default:
                return nil, UnknownNetworkError(netProto)
        }
-       fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel)
+       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
        if err != nil {
                return nil, err
        }
index dc13c1743960d8d4bbcc20d11baf2f2257b57d54..24daf173aceeaa430b8be57c5bb2bfbe3ca917d2 100644 (file)
@@ -6,7 +6,9 @@
 
 package net
 
-import "time"
+import (
+       "context"
+)
 
 var (
        // supportsIPv4 reports whether the platform supports IPv4
@@ -188,7 +190,7 @@ func JoinHostPort(host, port string) string {
 // address or a DNS name, and returns a list of internet protocol
 // family addresses. The result contains at least one address when
 // error is nil.
-func internetAddrList(net, addr string, deadline time.Time) (addrList, error) {
+func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
        var (
                err        error
                host, port string
@@ -236,7 +238,7 @@ func internetAddrList(net, addr string, deadline time.Time) (addrList, error) {
                return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil
        }
        // Try as a DNS name.
-       ips, err := lookupIPDeadline(host, deadline)
+       ips, err := lookupIPContext(ctx, host)
        if err != nil {
                return nil, err
        }
index 644964e78d8b07725438b3108c2471fe6f7655b4..abe90ac0e61792d9e20f29dbfb0ba318d53512e6 100644 (file)
@@ -7,9 +7,9 @@
 package net
 
 import (
+       "context"
        "runtime"
        "syscall"
-       "time"
 )
 
 // BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
@@ -152,9 +152,10 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
        return syscall.AF_INET6, false
 }
 
-func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) {
+// Internet sockets (TCP, UDP, IP)
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
        family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
-       return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel)
+       return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
 }
 
 func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
index ab6886ddff644a4cd51b3cdf939c2bfd2b922f3f..0d3ef79bab342e0ffd88779463a40be6a34f1f50 100644 (file)
@@ -5,8 +5,8 @@
 package net
 
 import (
+       "context"
        "internal/singleflight"
-       "time"
 )
 
 // protocols contains minimal mappings between internet protocol
@@ -33,7 +33,7 @@ func LookupHost(host string) (addrs []string, err error) {
        if ip := ParseIP(host); ip != nil {
                return []string{host}, nil
        }
-       return lookupHost(host)
+       return lookupHost(context.Background(), host)
 }
 
 // LookupIP looks up host using the local resolver.
@@ -47,7 +47,7 @@ func LookupIP(host string) (ips []IP, err error) {
        if ip := ParseIP(host); ip != nil {
                return []IP{ip}, nil
        }
-       addrs, err := lookupIPMerge(host)
+       addrs, err := lookupIPMerge(context.Background(), host)
        if err != nil {
                return
        }
@@ -63,9 +63,9 @@ var lookupGroup singleflight.Group
 // lookupIPMerge wraps lookupIP, but makes sure that for any given
 // host, only one lookup is in-flight at a time. The returned memory
 // is always owned by the caller.
-func lookupIPMerge(host string) (addrs []IPAddr, err error) {
+func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) {
        addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) {
-               return testHookLookupIP(lookupIP, host)
+               return testHookLookupIP(ctx, lookupIP, host)
        })
        return lookupIPReturn(addrsi, err, shared)
 }
@@ -85,37 +85,26 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error
        return addrs, nil
 }
 
-// lookupIPDeadline looks up a hostname with a deadline.
-func lookupIPDeadline(host string, deadline time.Time) (addrs []IPAddr, err error) {
-       if deadline.IsZero() {
-               return lookupIPMerge(host)
-       }
-
-       // We could push the deadline down into the name resolution
-       // functions. However, the most commonly used implementation
-       // calls getaddrinfo, which has no timeout.
-
-       timeout := deadline.Sub(time.Now())
-       if timeout <= 0 {
-               return nil, errTimeout
-       }
-       t := time.NewTimer(timeout)
-       defer t.Stop()
+// lookupIPContext looks up a hostname with a context.
+func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
+       // TODO(bradfitz): when adding trace hooks later here, make
+       // sure the tracing is done outside of the singleflight
+       // merging. Both callers should see the DNS lookup delay, even
+       // if it's only being done once. The r.Shared bit can be
+       // included in the trace for callers who need it.
 
        ch := lookupGroup.DoChan(host, func() (interface{}, error) {
-               return testHookLookupIP(lookupIP, host)
+               return testHookLookupIP(ctx, lookupIP, host)
        })
 
        select {
-       case <-t.C:
+       case <-ctx.Done():
                // The DNS lookup timed out for some reason. Force
                // future requests to start the DNS lookup again
                // rather than waiting for the current lookup to
                // complete. See issue 8602.
                lookupGroup.Forget(host)
-
-               return nil, errTimeout
-
+               return nil, mapErr(ctx.Err())
        case r := <-ch:
                return lookupIPReturn(r.Val, r.Err, r.Shared)
        }
index 34ee5354a542d8cf837506d868b458453f67c020..4224263602bcacec827d9745b799a2aeea4c3ad1 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "errors"
        "os"
 )
@@ -115,7 +116,7 @@ func lookupProtocol(name string) (proto int, err error) {
        return 0, UnknownNetworkError(name)
 }
 
-func lookupHost(host string) (addrs []string, err error) {
+func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        // Use netdir/cs instead of netdir/dns because cs knows about
        // host names in local network (e.g. from /lib/ndb/local)
        lines, err := queryCS("net", host, "1")
@@ -146,7 +147,8 @@ loop:
        return
 }
 
-func lookupIP(host string) (addrs []IPAddr, err error) {
+func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+       // TODO(bradfitz): push down ctx
        lits, err := LookupHost(host)
        if err != nil {
                return
index a8625905e47c06accdc5cdea4fcc110ac7569bbc..38a4f0bae480e8d97feb4dab855a2523f509640a 100644 (file)
@@ -6,17 +6,20 @@
 
 package net
 
-import "syscall"
+import (
+       "context"
+       "syscall"
+)
 
 func lookupProtocol(name string) (proto int, err error) {
        return 0, syscall.ENOPROTOOPT
 }
 
-func lookupHost(host string) (addrs []string, err error) {
+func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
-func lookupIP(host string) (addrs []IPAddr, err error) {
+func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
index 1345751cfda9ff48718004af7127de7ea0da1b36..85bcfef6e9ad4899161d039e33ce09752fdd07f9 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "bytes"
+       "context"
        "fmt"
        "internal/testenv"
        "runtime"
@@ -14,7 +15,7 @@ import (
        "time"
 )
 
-func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
        switch host {
        case "localhost":
                return []IPAddr{
@@ -22,7 +23,7 @@ func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr,
                        {IP: IPv6loopback},
                }, nil
        default:
-               return fn(host)
+               return fn(ctx, host)
        }
 }
 
@@ -375,15 +376,20 @@ func TestLookupIPDeadline(t *testing.T) {
 
        const N = 5000
        const timeout = 3 * time.Second
+       ctxHalfTimeout, cancel := context.WithTimeout(context.Background(), timeout/2)
+       defer cancel()
+       ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
+       defer cancel()
+
        c := make(chan error, 2*N)
        for i := 0; i < N; i++ {
                name := fmt.Sprintf("%d.net-test.golang.org", i)
                go func() {
-                       _, err := lookupIPDeadline(name, time.Now().Add(timeout/2))
+                       _, err := lookupIPContext(ctxHalfTimeout, name)
                        c <- err
                }()
                go func() {
-                       _, err := lookupIPDeadline(name, time.Now().Add(timeout))
+                       _, err := lookupIPContext(ctxTimeout, name)
                        c <- err
                }()
        }
index cd4ddbdb24ef3d9c58e3ff0e8dcdc16efe490e48..8d3fa4778284e1406f19ce1cb6a756a2cf34297e 100644 (file)
@@ -6,7 +6,10 @@
 
 package net
 
-import "sync"
+import (
+       "context"
+       "sync"
+)
 
 var onceReadProtocols sync.Once
 
@@ -49,7 +52,7 @@ func lookupProtocol(name string) (int, error) {
        return proto, nil
 }
 
-func lookupHost(host string) (addrs []string, err error) {
+func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        order := systemConf().hostLookupOrder(host)
        if order == hostLookupCgo {
                if addrs, err, ok := cgoLookupHost(host); ok {
@@ -58,19 +61,20 @@ func lookupHost(host string) (addrs []string, err error) {
                // cgo not available (or netgo); fall back to Go's DNS resolver
                order = hostLookupFilesDNS
        }
-       return goLookupHostOrder(host, order)
+       return goLookupHostOrder(ctx, host, order)
 }
 
-func lookupIP(host string) (addrs []IPAddr, err error) {
+func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        order := systemConf().hostLookupOrder(host)
        if order == hostLookupCgo {
+               // TODO(bradfitz): push down ctx, or at least its deadline to start
                if addrs, err, ok := cgoLookupIP(host); ok {
                        return addrs, err
                }
                // cgo not available (or netgo); fall back to Go's DNS resolver
                order = hostLookupFilesDNS
        }
-       return goLookupIPOrder(host, order)
+       return goLookupIPOrder(ctx, host, order)
 }
 
 func lookupPort(network, service string) (int, error) {
index 13edc264e82a59006112b6edc5fd6a8cb4ef16fd..ce012ba873fc70aec9c1982bc94c048570bb183e 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "os"
        "runtime"
        "syscall"
@@ -51,8 +52,8 @@ func lookupProtocol(name string) (int, error) {
        return r.proto, r.err
 }
 
-func lookupHost(name string) ([]string, error) {
-       ips, err := LookupIP(name)
+func lookupHost(ctx context.Context, name string) ([]string, error) {
+       ips, err := lookupIP(ctx, name)
        if err != nil {
                return nil, err
        }
@@ -83,59 +84,97 @@ func gethostbyname(name string) (addrs []IPAddr, err error) {
        return addrs, nil
 }
 
-func oldLookupIP(name string) ([]IPAddr, error) {
+func oldLookupIP(ctx context.Context, name string) ([]IPAddr, error) {
        // GetHostByName return value is stored in thread local storage.
        // Start new os thread before the call to prevent races.
-       type result struct {
+       type ret struct {
                addrs []IPAddr
                err   error
        }
-       ch := make(chan result)
+       ch := make(chan ret, 1)
        go func() {
                acquireThread()
                defer releaseThread()
                runtime.LockOSThread()
                defer runtime.UnlockOSThread()
                addrs, err := gethostbyname(name)
-               ch <- result{addrs: addrs, err: err}
+               ch <- ret{addrs: addrs, err: err}
        }()
-       r := <-ch
-       if r.err != nil {
-               r.err = &DNSError{Err: r.err.Error(), Name: name}
+       select {
+       case r := <-ch:
+               if r.err != nil {
+                       r.err = &DNSError{Err: r.err.Error(), Name: name}
+               }
+               return r.addrs, r.err
+       case <-ctx.Done():
+               // TODO(bradfitz,brainman): cancel the ongoing
+               // gethostbyname?  For now we just let it finish and
+               // write to the buffered channel.
+               return nil, &DNSError{
+                       Name:      name,
+                       Err:       ctx.Err().Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
        }
-       return r.addrs, r.err
 }
 
-func newLookupIP(name string) ([]IPAddr, error) {
-       acquireThread()
-       defer releaseThread()
-       hints := syscall.AddrinfoW{
-               Family:   syscall.AF_UNSPEC,
-               Socktype: syscall.SOCK_STREAM,
-               Protocol: syscall.IPPROTO_IP,
-       }
-       var result *syscall.AddrinfoW
-       e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
-       if e != nil {
-               return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}
+func newLookupIP(ctx context.Context, name string) ([]IPAddr, error) {
+       // TODO(bradfitz,brainman): use ctx?
+
+       type ret struct {
+               addrs []IPAddr
+               err   error
        }
-       defer syscall.FreeAddrInfoW(result)
-       addrs := make([]IPAddr, 0, 5)
-       for ; result != nil; result = result.Next {
-               addr := unsafe.Pointer(result.Addr)
-               switch result.Family {
-               case syscall.AF_INET:
-                       a := (*syscall.RawSockaddrInet4)(addr).Addr
-                       addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
-               case syscall.AF_INET6:
-                       a := (*syscall.RawSockaddrInet6)(addr).Addr
-                       zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
-                       addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
-               default:
-                       return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
+       ch := make(chan ret, 1)
+       go func() {
+               acquireThread()
+               defer releaseThread()
+               hints := syscall.AddrinfoW{
+                       Family:   syscall.AF_UNSPEC,
+                       Socktype: syscall.SOCK_STREAM,
+                       Protocol: syscall.IPPROTO_IP,
+               }
+               var result *syscall.AddrinfoW
+               e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
+               if e != nil {
+                       ch <- ret{err: &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}}
+               }
+               defer syscall.FreeAddrInfoW(result)
+               addrs := make([]IPAddr, 0, 5)
+               for ; result != nil; result = result.Next {
+                       addr := unsafe.Pointer(result.Addr)
+                       switch result.Family {
+                       case syscall.AF_INET:
+                               a := (*syscall.RawSockaddrInet4)(addr).Addr
+                               addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
+                       case syscall.AF_INET6:
+                               a := (*syscall.RawSockaddrInet6)(addr).Addr
+                               zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
+                               addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
+                       default:
+                               ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}}
+                       }
+               }
+               ch <- ret{addrs: addrs}
+       }()
+       select {
+       case r := <-ch:
+               return r.addrs, r.err
+       case <-ctx.Done():
+               // TODO(bradfitz,brainman): cancel the ongoing
+               // GetAddrInfoW? It would require conditionally using
+               // GetAddrInfoEx with lpOverlapped, which requires
+               // Windows 8 or newer. I guess we'll need oldLookupIP,
+               // newLookupIP, and newerLookUP.
+               //
+               // For now we just let it finish and write to the
+               // buffered channel.
+               return nil, &DNSError{
+                       Name:      name,
+                       Err:       ctx.Err().Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
                }
        }
-       return addrs, nil
 }
 
 func getservbyname(network, service string) (int, error) {
index 3b37b336d1bdcffe64574537b325ff4326a1373b..27e9ca367d4054adaaa39ac6168d583ae6faf6b3 100644 (file)
@@ -79,6 +79,7 @@ On Windows, the resolver always uses C library functions, such as GetAddrInfo an
 package net
 
 import (
+       "context"
        "errors"
        "io"
        "os"
@@ -377,6 +378,22 @@ var (
        ErrWriteToConnected       = errors.New("use of WriteTo with pre-connected connection")
 )
 
+// mapErr maps from the context errors to the historical internal net
+// error values.
+//
+// TODO(bradfitz): get rid of this after adjusting tests and making
+// context.DeadlineExceeded implement net.Error?
+func mapErr(err error) error {
+       switch err {
+       case context.Canceled:
+               return errCanceled
+       case context.DeadlineExceeded:
+               return errTimeout
+       default:
+               return err
+       }
+}
+
 // OpError is the error type usually returned by functions in the net
 // package. It describes the operation, network type, and address of
 // an error.
index 1d950d67d7c232fa49c8444fdd9be9d669befdc5..0a118874c25aef7912c82054b1b6f45149bf013e 100644 (file)
@@ -7,7 +7,10 @@
 
 package net
 
-import "testing"
+import (
+       "context"
+       "testing"
+)
 
 func TestGoLookupIP(t *testing.T) {
        host := "localhost"
@@ -18,7 +21,7 @@ func TestGoLookupIP(t *testing.T) {
        if err != nil {
                t.Error(err)
        }
-       if _, err := goLookupIP(host); err != nil {
+       if _, err := goLookupIP(context.Background(), host); err != nil {
                t.Error(err)
        }
 }
index 3dddfef4c5e3e65bb915ffda91a861a25afe480d..c3af27b596fd7f8b51049671f12afadcd753bbc4 100644 (file)
@@ -7,9 +7,9 @@
 package net
 
 import (
+       "context"
        "os"
        "syscall"
-       "time"
 )
 
 // A sockaddr represents a TCP, UDP, IP or Unix network endpoint
@@ -34,7 +34,7 @@ type sockaddr interface {
 
 // socket returns a network file descriptor that is ready for
 // asynchronous I/O using the network poller.
-func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) {
        s, err := sysSocket(family, sotype, proto)
        if err != nil {
                return nil, err
@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
                        return fd, nil
                }
        }
-       if err := fd.dial(laddr, raddr, deadline, cancel); err != nil {
+       if err := fd.dial(ctx, laddr, raddr); err != nil {
                fd.Close()
                return nil, err
        }
@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
        return func(syscall.Sockaddr) Addr { return nil }
 }
 
-func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error {
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
        var err error
        var lsa syscall.Sockaddr
        if laddr != nil {
@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan s
                if rsa, err = raddr.sockaddr(fd.family); err != nil {
                        return err
                }
-               if err := fd.connect(lsa, rsa, deadline, cancel); err != nil {
+               if err := fd.connect(ctx, lsa, rsa); err != nil {
                        return err
                }
                fd.isConnected = true
index a5c3515c19458ec8768eab515b4e6e1b4a6e85e9..7cffcc58cbbe83b51e5e5edf525ff448058e4ab2 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "io"
        "os"
        "syscall"
@@ -60,7 +61,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       addrs, err := internetAddrList(net, addr, noDeadline)
+       addrs, err := internetAddrList(context.Background(), net, addr)
        if err != nil {
                return nil, err
        }
@@ -186,7 +187,7 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
        if raddr == nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
        }
-       c, err := dialTCP(net, laddr, raddr, noDeadline, noCancel)
+       c, err := dialTCP(context.Background(), net, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
@@ -285,7 +286,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
        if laddr == nil {
                laddr = &TCPAddr{}
        }
-       ln, err := listenTCP(net, laddr)
+       ln, err := listenTCP(context.Background(), net, laddr)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
        }
index 698b834295f9bece1413dfe6b97c4f70e10238ab..dd36c70d506a7bf040b4c96beb674d816017f9a8 100644 (file)
@@ -5,17 +5,24 @@
 package net
 
 import (
+       "context"
        "io"
        "os"
-       "time"
 )
 
 func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
        return genericReadFrom(c, r)
 }
 
-func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
-       if !deadline.IsZero() {
+func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+       if testHookDialTCP != nil {
+               return testHookDialTCP(ctx, net, laddr, raddr)
+       }
+       return doDialTCP(ctx, net, laddr, raddr)
+}
+
+func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+       if d, _ := ctx.Deadline(); !d.IsZero() {
                panic("net.dialTCP: deadline not implemented on Plan 9")
        }
        // TODO(bradfitz,0intro): also use the cancel channel.
@@ -63,7 +70,7 @@ func (ln *TCPListener) file() (*os.File, error) {
        return f, nil
 }
 
-func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
+func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
        fd, err := listenPlan9(network, laddr)
        if err != nil {
                return nil, err
index 3902565f7353cace6befb728cbd3e5426811911b..c9a8b6808ea8da13612dc221952074510741a3f3 100644 (file)
@@ -7,10 +7,10 @@
 package net
 
 import (
+       "context"
        "io"
        "os"
        "syscall"
-       "time"
 )
 
 func sockaddrToTCP(sa syscall.Sockaddr) Addr {
@@ -47,8 +47,15 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
        return genericReadFrom(c, r)
 }
 
-func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
-       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
+func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+       if testHookDialTCP != nil {
+               return testHookDialTCP(ctx, net, laddr, raddr)
+       }
+       return doDialTCP(ctx, net, laddr, raddr)
+}
+
+func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+       fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
 
        // TCP has a rarely used mechanism called a 'simultaneous connection' in
        // which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -78,7 +85,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-cha
                if err == nil {
                        fd.Close()
                }
-               fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
+               fd, err = internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
        }
 
        if err != nil {
@@ -141,8 +148,8 @@ func (ln *TCPListener) file() (*os.File, error) {
        return f, nil
 }
 
-func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
-       fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel)
+func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
+       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_STREAM, 0, "listen")
        if err != nil {
                return nil, err
        }
index e7e9796668333f01a454fd812510f547d4954b53..980f67c81f9fa67a61a3ef755aa8b587043ae835 100644 (file)
@@ -4,7 +4,10 @@
 
 package net
 
-import "syscall"
+import (
+       "context"
+       "syscall"
+)
 
 // UDPAddr represents the address of a UDP end point.
 type UDPAddr struct {
@@ -55,7 +58,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       addrs, err := internetAddrList(net, addr, noDeadline)
+       addrs, err := internetAddrList(context.Background(), net, addr)
        if err != nil {
                return nil, err
        }
@@ -181,7 +184,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
        if raddr == nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
        }
-       c, err := dialUDP(net, laddr, raddr, noDeadline)
+       c, err := dialUDP(context.Background(), net, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
@@ -204,7 +207,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
        if laddr == nil {
                laddr = &UDPAddr{}
        }
-       c, err := listenUDP(net, laddr)
+       c, err := listenUDP(context.Background(), net, laddr)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
        }
@@ -231,7 +234,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
        if gaddr == nil || gaddr.IP == nil {
                return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
        }
-       c, err := listenMulticastUDP(network, ifi, gaddr)
+       c, err := listenMulticastUDP(context.Background(), network, ifi, gaddr)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err}
        }
index 5f15427064093d3f4d9aebb06252a3b3286e1c38..81edaf59fe369a11f786e7412d5db85522673915 100644 (file)
@@ -5,10 +5,10 @@
 package net
 
 import (
+       "context"
        "errors"
        "os"
        "syscall"
-       "time"
 )
 
 func (c *UDPConn) readFrom(b []byte) (n int, addr *UDPAddr, err error) {
@@ -55,8 +55,8 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
        return 0, 0, syscall.EPLAN9
 }
 
-func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
-       if !deadline.IsZero() {
+func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+       if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
                panic("net.dialUDP: deadline not implemented on Plan 9")
        }
        fd, err := dialPlan9(net, laddr, raddr)
@@ -94,7 +94,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
        return h, b
 }
 
-func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
+func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
        l, err := listenPlan9(network, laddr)
        if err != nil {
                return nil, err
@@ -111,6 +111,6 @@ func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
        return newUDPConn(fd), err
 }
 
-func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
        return nil, syscall.EPLAN9
 }
index 4d3255c996423f7ce9c154f0a7e3fc26b7245aab..4924801ebb258dd20454e326b1e5a4014b52f43f 100644 (file)
@@ -7,8 +7,8 @@
 package net
 
 import (
+       "context"
        "syscall"
-       "time"
 )
 
 func sockaddrToUDP(sa syscall.Sockaddr) Addr {
@@ -90,24 +90,24 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
        return c.fd.writeMsg(b, oob, sa)
 }
 
-func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
-       fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel)
+func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+       fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial")
        if err != nil {
                return nil, err
        }
        return newUDPConn(fd), nil
 }
 
-func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
+func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
+       fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen")
        if err != nil {
                return nil, err
        }
        return newUDPConn(fd), nil
 }
 
-func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
-       fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
+func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+       fd, err := internetSocket(ctx, network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen")
        if err != nil {
                return nil, err
        }
index d1eb0b62eeeee96c375f69513169c3a1ad16f8b2..bacdaa41d9079bc94cc04cd973932dee56819ed3 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "context"
        "os"
        "syscall"
        "time"
@@ -188,7 +189,7 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
        default:
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)}
        }
-       c, err := dialUnix(net, laddr, raddr, noDeadline)
+       c, err := dialUnix(context.Background(), net, laddr, raddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
        }
@@ -290,7 +291,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
        if laddr == nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress}
        }
-       ln, err := listenUnix(net, laddr)
+       ln, err := listenUnix(context.Background(), net, laddr)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
        }
@@ -310,7 +311,7 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) {
        if laddr == nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: errMissingAddress}
        }
-       c, err := listenUnixgram(net, laddr)
+       c, err := listenUnixgram(context.Background(), net, laddr)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
        }
index 5d5b18f467c4db4c7868b376f37363117480ee1b..e70eb211bbf0babd15656b31546d59c3591e2ac9 100644 (file)
@@ -5,9 +5,9 @@
 package net
 
 import (
+       "context"
        "os"
        "syscall"
-       "time"
 )
 
 func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
@@ -26,7 +26,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
        return 0, 0, syscall.EPLAN9
 }
 
-func dialUnix(network string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) {
+func dialUnix(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
        return nil, syscall.EPLAN9
 }
 
@@ -42,10 +42,10 @@ func (ln *UnixListener) file() (*os.File, error) {
        return nil, syscall.EPLAN9
 }
 
-func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
+func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
        return nil, syscall.EPLAN9
 }
 
-func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
+func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
        return nil, syscall.EPLAN9
 }
index 9275e1034cc4e0fbf0d269c9fdddf4dcc1e3422c..5f0999c4c270e1773bdc33f571ef3f62f2e5e3d9 100644 (file)
@@ -7,13 +7,13 @@
 package net
 
 import (
+       "context"
        "errors"
        "os"
        "syscall"
-       "time"
 )
 
-func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Time) (*netFD, error) {
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) {
        var sotype int
        switch net {
        case "unix":
@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti
                return nil, errors.New("unknown mode: " + mode)
        }
 
-       fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel)
+       fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr)
        if err != nil {
                return nil, err
        }
@@ -146,8 +146,8 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
        return c.fd.writeMsg(b, oob, sa)
 }
 
-func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) {
-       fd, err := unixSocket(net, laddr, raddr, "dial", deadline)
+func dialUnix(ctx context.Context, net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
+       fd, err := unixSocket(ctx, net, laddr, raddr, "dial")
        if err != nil {
                return nil, err
        }
@@ -187,16 +187,16 @@ func (ln *UnixListener) file() (*os.File, error) {
        return f, nil
 }
 
-func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
-       fd, err := unixSocket(network, laddr, nil, "listen", noDeadline)
+func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
+       fd, err := unixSocket(ctx, network, laddr, nil, "listen")
        if err != nil {
                return nil, err
        }
        return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
 }
 
-func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
-       fd, err := unixSocket(network, laddr, nil, "listen", noDeadline)
+func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
+       fd, err := unixSocket(ctx, network, laddr, nil, "listen")
        if err != nil {
                return nil, err
        }