From b6b4004d5a5bf7099ac9ab76777797236da7fe63 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 14 Apr 2016 17:47:25 -0700 Subject: [PATCH] net: context plumbing, add Dialer.DialContext For #12580 (http.Transport tracing/analytics) Updates #13021 Change-Id: I126e494a7bd872e42c388ecb58499ecbf0f014cc Reviewed-on: https://go-review.googlesource.com/22101 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Ian Lance Taylor Reviewed-by: Mikio Hara --- src/go/build/deps_test.go | 4 +- src/net/cgo_unix_test.go | 7 +- src/net/dial.go | 283 +++++++++++++++++++-------------- src/net/dial_test.go | 59 ++++--- src/net/dnsclient_unix.go | 15 +- src/net/dnsclient_unix_test.go | 22 +-- src/net/error_test.go | 7 +- src/net/fd_unix.go | 54 +++---- src/net/fd_windows.go | 43 ++--- src/net/hook.go | 16 +- src/net/iprawsock.go | 11 +- src/net/iprawsock_plan9.go | 6 +- src/net/iprawsock_posix.go | 10 +- src/net/ipsock.go | 8 +- src/net/ipsock_posix.go | 7 +- src/net/lookup.go | 41 ++--- src/net/lookup_plan9.go | 6 +- src/net/lookup_stub.go | 9 +- src/net/lookup_test.go | 14 +- src/net/lookup_unix.go | 14 +- src/net/lookup_windows.go | 113 ++++++++----- src/net/net.go | 17 ++ src/net/netgo_unix_test.go | 7 +- src/net/sock_posix.go | 10 +- src/net/tcpsock.go | 7 +- src/net/tcpsock_plan9.go | 15 +- src/net/tcpsock_posix.go | 19 ++- src/net/udpsock.go | 13 +- src/net/udpsock_plan9.go | 10 +- src/net/udpsock_posix.go | 14 +- src/net/unixsock.go | 7 +- src/net/unixsock_plan9.go | 8 +- src/net/unixsock_posix.go | 18 +-- 33 files changed, 517 insertions(+), 377 deletions(-) diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index f1d19bb50c..2db5ba67d1 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -280,7 +280,9 @@ var pkgDeps = map[string][]string{ // Basic networking. // Because net must be used by any package that wants to // do networking portably, it must have a small dependency set: just L0+basic os. - "net": {"L0", "CGO", "math/rand", "os", "sort", "syscall", "time", "internal/syscall/windows", "internal/singleflight", "internal/race"}, + "net": {"L0", "CGO", + "context", "math/rand", "os", "sort", "syscall", "time", + "internal/syscall/windows", "internal/singleflight", "internal/race"}, // NET enables use of basic network-related packages. "NET": { diff --git a/src/net/cgo_unix_test.go b/src/net/cgo_unix_test.go index 4d5ab23fd3..5dc7b1a62d 100644 --- a/src/net/cgo_unix_test.go +++ b/src/net/cgo_unix_test.go @@ -7,7 +7,10 @@ package net -import "testing" +import ( + "context" + "testing" +) func TestCgoLookupIP(t *testing.T) { host := "localhost" @@ -18,7 +21,7 @@ func TestCgoLookupIP(t *testing.T) { if err != nil { t.Error(err) } - if _, err := goLookupIP(host); err != nil { + if _, err := goLookupIP(context.Background(), host); err != nil { t.Error(err) } } diff --git a/src/net/dial.go b/src/net/dial.go index 22992d5b7a..1f31e8f2cc 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -5,6 +5,7 @@ package net import ( + "context" "runtime" "time" ) @@ -61,21 +62,34 @@ type Dialer struct { // Cancel is an optional channel whose closure indicates that // the dial should be canceled. Not all types of dials support // cancelation. + // + // Deprecated: Use DialContext instead. Cancel <-chan struct{} } -// Return either now+Timeout or Deadline, whichever comes first. -// Or zero, if neither is set. -func (d *Dialer) deadline(now time.Time) time.Time { - if d.Timeout == 0 { - return d.Deadline +func minNonzeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b } - timeoutDeadline := now.Add(d.Timeout) - if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) { - return timeoutDeadline - } else { - return d.Deadline + if b.IsZero() || a.Before(b) { + return a } + return b +} + +// deadline returns the earliest of: +// - now+Timeout +// - d.Deadline +// - the context's deadline +// Or zero, if none of Timeout, Deadline, or context's deadline is set. +func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) { + if d.Timeout != 0 { // including negative, for historical reasons + earliest = now.Add(d.Timeout) + } + if d, ok := ctx.Deadline(); ok { + earliest = minNonzeroTime(earliest, d) + } + return minNonzeroTime(earliest, d.Deadline) } // partialDeadline returns the deadline to use for a single address, @@ -142,7 +156,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { // resolverAddrList resolves addr using hint and returns a list of // addresses. The result contains at least one address when error is // nil. -func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) { +func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { afnet, _, err := parseNetwork(network) if err != nil { return nil, err @@ -152,6 +166,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a } switch afnet { case "unix", "unixgram", "unixpacket": + // TODO(bradfitz): push down context addr, err := ResolveUnixAddr(afnet, addr) if err != nil { return nil, err @@ -161,7 +176,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a } return addrList{addr}, nil } - addrs, err := internetAddrList(afnet, addr, deadline) + addrs, err := internetAddrList(ctx, afnet, addr) if err != nil || op != "dial" || hint == nil { return addrs, err } @@ -253,11 +268,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { return d.Dial(network, address) } -// dialContext holds common state for all dial operations. -type dialContext struct { +// dialParam contains a Dial's parameters and configuration. +type dialParam struct { Dialer network, address string - finalDeadline time.Time } // Dial connects to the address on the named network. @@ -265,17 +279,50 @@ type dialContext struct { // See func Dial for a description of the network and address // parameters. func (d *Dialer) Dial(network, address string) (Conn, error) { - finalDeadline := d.deadline(time.Now()) - addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline) + return d.DialContext(context.Background(), network, address) +} + +// DialContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. +// +// See func Dial for a description of the network and address +// parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { + if ctx == nil { + panic("nil context") + } + deadline := d.deadline(ctx, time.Now()) + if !deadline.IsZero() { + if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { + subCtx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + ctx = subCtx + } + } + if oldCancel := d.Cancel; oldCancel != nil { + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-oldCancel: + cancel() + case <-subCtx.Done(): + } + }() + ctx = subCtx + } + + addrs, err := resolveAddrList(ctx, "dial", network, address, d.LocalAddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} } - ctx := &dialContext{ - Dialer: *d, - network: network, - address: address, - finalDeadline: finalDeadline, + dp := &dialParam{ + Dialer: *d, + network: network, + address: address, } // DualStack mode requires that dialTCP support cancelation. This is @@ -288,138 +335,128 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { } var c Conn - if len(fallbacks) == 0 { - // dialParallel can accept an empty fallbacks list, - // but this shortcut avoids the goroutine/channel overhead. - c, err = dialSerial(ctx, primaries, ctx.Cancel) + if len(fallbacks) > 0 { + c, err = dialParallel(ctx, dp, primaries, fallbacks) } else { - c, err = dialParallel(ctx, primaries, fallbacks, ctx.Cancel) + c, err = dialSerial(ctx, dp, primaries) + } + if err != nil { + return nil, err } - if d.KeepAlive > 0 && err == nil { - if tc, ok := c.(*TCPConn); ok { - setKeepAlive(tc.fd, true) - setKeepAlivePeriod(tc.fd, d.KeepAlive) - testHookSetKeepAlive() - } + if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 { + setKeepAlive(tc.fd, true) + setKeepAlivePeriod(tc.fd, d.KeepAlive) + testHookSetKeepAlive() } - return c, err + return c, nil } // dialParallel races two copies of dialSerial, giving the first a // head start. It returns the first established connection and // closes the others. Otherwise it returns an error from the first // primary address. -func dialParallel(ctx *dialContext, primaries, fallbacks addrList, userCancel <-chan struct{}) (Conn, error) { - results := make(chan dialResult, 2) - cancel := make(chan struct{}) +func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) { + if len(fallbacks) == 0 { + return dialSerial(ctx, dp, primaries) + } - // Spawn the primary racer. - go dialSerialAsync(ctx, primaries, nil, cancel, results) + returned := make(chan struct{}) + defer close(returned) - // Spawn the fallback racer. - fallbackTimer := time.NewTimer(ctx.fallbackDelay()) - go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results) + type dialResult struct { + Conn + error + primary bool + done bool + } + results := make(chan dialResult) // unbuffered - // Wait for both racers to succeed or fail. - var primaryResult, fallbackResult dialResult - for !primaryResult.done || !fallbackResult.done { + startRacer := func(ctx context.Context, primary bool) { + ras := primaries + if !primary { + ras = fallbacks + } + c, err := dialSerial(ctx, dp, ras) select { - case <-userCancel: - // Forward an external cancelation request. - if cancel != nil { - close(cancel) - cancel = nil + case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: + case <-returned: + if c != nil { + c.Close() } - userCancel = nil + } + } + + var primary, fallback dialResult + + // Start the main racer. + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + go startRacer(primaryCtx, true) + + // Start the timer for the fallback racer. + fallbackTimer := time.NewTimer(dp.fallbackDelay()) + defer fallbackTimer.Stop() + + for { + select { + case <-fallbackTimer.C: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + go startRacer(fallbackCtx, false) + case res := <-results: - // Drop the result into its assigned bucket. + if res.error == nil { + return res.Conn, nil + } if res.primary { - primaryResult = res + primary = res } else { - fallbackResult = res + fallback = res } - // On success, cancel the other racer (if one exists.) - if res.error == nil && cancel != nil { - close(cancel) - cancel = nil + if primary.done && fallback.done { + return nil, primary.error } - // If the fallbackTimer was pending, then either we've canceled the - // fallback because we no longer want it, or we haven't canceled yet - // and therefore want it to wake up immediately. - if fallbackTimer.Stop() && cancel != nil { + if res.primary && fallbackTimer.Stop() { + // If we were able to stop the timer, that means it + // was running (hadn't yet started the fallback), but + // we just got an error on the primary path, so start + // the fallback immediately (in 0 nanoseconds). fallbackTimer.Reset(0) } } } - - // Return, in order of preference: - // 1. The primary connection (but close the other if we got both.) - // 2. The fallback connection. - // 3. The primary error. - if primaryResult.error == nil { - if fallbackResult.error == nil { - fallbackResult.Conn.Close() - } - return primaryResult.Conn, nil - } else if fallbackResult.error == nil { - return fallbackResult.Conn, nil - } else { - return nil, primaryResult.error - } -} - -type dialResult struct { - Conn - error - primary bool - done bool -} - -// dialSerialAsync runs dialSerial after some delay, and returns the -// resulting connection through a channel. When racing two connections, -// the primary goroutine uses a nil timer to omit the delay. -func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) { - if timer != nil { - // We're in the fallback goroutine; sleep before connecting. - select { - case <-timer.C: - case <-cancel: - // dialSerial will immediately return errCanceled in this case. - } - } - c, err := dialSerial(ctx, ras, cancel) - results <- dialResult{Conn: c, error: err, primary: timer == nil, done: true} } // dialSerial connects to a list of addresses in sequence, returning // either the first successful connection, or the first error. -func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) { +func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) { var firstErr error // The error from the first address is most relevant. for i, ra := range ras { select { - case <-cancel: - return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled} + case <-ctx.Done(): + return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())} default: } - partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i) + deadline, _ := ctx.Deadline() + partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i) if err != nil { // Ran out of time. if firstErr == nil { - firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err} + firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err} } break } - - // If this dial is canceled, the implementation is expected to complete - // quickly, but it's still possible that we could return a spurious Conn, - // which the caller must Close. - dialer := func(d time.Time) (Conn, error) { - return dialSingle(ctx, ra, d, cancel) + dialCtx := ctx + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() } - c, err := dial(ctx.network, ra, dialer, partialDeadline) + + c, err := dialSingle(dialCtx, dp, ra) if err == nil { return c, nil } @@ -429,7 +466,7 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e } if firstErr == nil { - firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress} + firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress} } return nil, firstErr } @@ -437,26 +474,26 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e // dialSingle attempts to establish and returns a single connection to // the destination address. This must be called through the OS-specific // dial function, because some OSes don't implement the deadline feature. -func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) { - la := ctx.LocalAddr +func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) { + la := dp.LocalAddr switch ra := ra.(type) { case *TCPAddr: la, _ := la.(*TCPAddr) - c, err = testHookDialTCP(ctx.network, la, ra, deadline, cancel) + c, err = dialTCP(ctx, dp.network, la, ra) case *UDPAddr: la, _ := la.(*UDPAddr) - c, err = dialUDP(ctx.network, la, ra, deadline) + c, err = dialUDP(ctx, dp.network, la, ra) case *IPAddr: la, _ := la.(*IPAddr) - c, err = dialIP(ctx.network, la, ra, deadline) + c, err = dialIP(ctx, dp.network, la, ra) case *UnixAddr: la, _ := la.(*UnixAddr) - c, err = dialUnix(ctx.network, la, ra, deadline) + c, err = dialUnix(ctx, dp.network, la, ra) default: - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}} + return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}} } if err != nil { - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer + return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer } return c, nil } @@ -469,7 +506,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str // instead of just the interface with the given host address. // See Dial for more details about address syntax. func Listen(net, laddr string) (Listener, error) { - addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline) + addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } @@ -496,7 +533,7 @@ func Listen(net, laddr string) (Listener, error) { // instead of just the interface with the given host address. // See Dial for the syntax of laddr. func ListenPacket(net, laddr string) (PacketConn, error) { - addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline) + addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } diff --git a/src/net/dial_test.go b/src/net/dial_test.go index d4f04e0a4f..eb145f476c 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -6,6 +6,7 @@ package net import ( "bufio" + "context" "internal/testenv" "io" "net/internal/socktest" @@ -193,18 +194,11 @@ const ( // In some environments, the slow IPs may be explicitly unreachable, and fail // more quickly than expected. This test hook prevents dialTCP from returning // before the deadline. -func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { - c, err := dialTCP(net, laddr, raddr, deadline, cancel) +func slowDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + c, err := doDialTCP(ctx, net, laddr, raddr) if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) { // Wait for the deadline, or indefinitely if none exists. - var wait <-chan time.Time - if !deadline.IsZero() { - wait = time.After(deadline.Sub(time.Now())) - } - select { - case <-cancel: - case <-wait: - } + <-ctx.Done() } return c, err } @@ -356,15 +350,14 @@ func TestDialParallel(t *testing.T) { d := Dialer{ FallbackDelay: fallbackDelay, } - ctx := &dialContext{ - Dialer: d, - network: "tcp", - address: "?", - finalDeadline: d.deadline(time.Now()), - } startTime := time.Now() - c, err := dialParallel(ctx, primaries, fallbacks, nil) - elapsed := time.Now().Sub(startTime) + dp := &dialParam{ + Dialer: d, + network: "tcp", + address: "?", + } + c, err := dialParallel(context.Background(), dp, primaries, fallbacks) + elapsed := time.Since(startTime) if c != nil { c.Close() @@ -385,16 +378,16 @@ func TestDialParallel(t *testing.T) { } // Repeat each case, ensuring that it can be canceled quickly. - cancel := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { time.Sleep(5 * time.Millisecond) - close(cancel) + cancel() wg.Done() }() startTime = time.Now() - c, err = dialParallel(ctx, primaries, fallbacks, cancel) + c, err = dialParallel(ctx, dp, primaries, fallbacks) if c != nil { c.Close() } @@ -406,7 +399,7 @@ func TestDialParallel(t *testing.T) { } } -func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { +func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) { switch host { case "slow6loopback4": // Returns a slow IPv6 address, and a local IPv4 address. @@ -415,7 +408,7 @@ func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, e {IP: ParseIP("127.0.0.1")}, }, nil default: - return fn(host) + return fn(ctx, host) } } @@ -530,22 +523,24 @@ func TestDialParallelSpuriousConnection(t *testing.T) { origTestHookDialTCP := testHookDialTCP defer func() { testHookDialTCP = origTestHookDialTCP }() - testHookDialTCP = func(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { + testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { // Sleep long enough for Happy Eyeballs to kick in, and inhibit cancelation. // This forces dialParallel to juggle two successful connections. time.Sleep(fallbackDelay * 2) - cancel = nil - return dialTCP(net, laddr, raddr, deadline, cancel) + + // Now ignore the provided context (which will be canceled) and use a + // different one to make sure this completes with a valid connection, + // which we hope to be closed below: + return doDialTCP(context.Background(), net, laddr, raddr) } d := Dialer{ FallbackDelay: fallbackDelay, } - ctx := &dialContext{ - Dialer: d, - network: "tcp", - address: "?", - finalDeadline: d.deadline(time.Now()), + dp := &dialParam{ + Dialer: d, + network: "tcp", + address: "?", } makeAddr := func(ip string) addrList { @@ -557,7 +552,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) { } // dialParallel returns one connection (and closes the other.) - c, err := dialParallel(ctx, makeAddr("127.0.0.1"), makeAddr("::1"), nil) + c, err := dialParallel(context.Background(), dp, makeAddr("127.0.0.1"), makeAddr("::1")) if err != nil { t.Fatal(err) } diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index cf0b2680db..914dd767d3 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -16,6 +16,7 @@ package net import ( + "context" "errors" "io" "math/rand" @@ -399,11 +400,11 @@ func (o hostLookupOrder) String() string { // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupHost(name string) (addrs []string, err error) { - return goLookupHostOrder(name, hostLookupFilesDNS) +func goLookupHost(ctx context.Context, name string) (addrs []string, err error) { + return goLookupHostOrder(ctx, name, hostLookupFilesDNS) } -func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) { +func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) @@ -411,7 +412,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err return } } - ips, err := goLookupIPOrder(name, order) + ips, err := goLookupIPOrder(ctx, name, order) if err != nil { return } @@ -437,11 +438,11 @@ func goLookupIPFiles(name string) (addrs []IPAddr) { // goLookupIP is the native Go implementation of LookupIP. // The libc versions are in cgo_*.go. -func goLookupIP(name string) (addrs []IPAddr, err error) { - return goLookupIPOrder(name, hostLookupFilesDNS) +func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) { + return goLookupIPOrder(ctx, name, hostLookupFilesDNS) } -func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) { +func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { addrs = goLookupIPFiles(name) if len(addrs) > 0 || order == hostLookupFiles { diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index edf7c00f72..145a3b6a33 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -7,6 +7,7 @@ package net import ( + "context" "fmt" "internal/testenv" "io/ioutil" @@ -133,7 +134,7 @@ func TestAvoidDNSName(t *testing.T) { // Issue 13705: don't try to resolve onion addresses, etc func TestLookupTorOnion(t *testing.T) { - addrs, err := goLookupIP("foo.onion") + addrs, err := goLookupIP(context.Background(), "foo.onion") if len(addrs) > 0 { t.Errorf("unexpected addresses: %v", addrs) } @@ -249,7 +250,7 @@ func TestUpdateResolvConf(t *testing.T) { for j := 0; j < N; j++ { go func(name string) { defer wg.Done() - ips, err := goLookupIP(name) + ips, err := goLookupIP(context.Background(), name) if err != nil { t.Error(err) return @@ -397,7 +398,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { t.Error(err) continue } - addrs, err := goLookupIP(tt.name) + addrs, err := goLookupIP(context.Background(), tt.name) if err != nil { // This test uses external network connectivity. // We need to take care with errors on both @@ -447,14 +448,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { name := fmt.Sprintf("order %v", order) // First ensure that we get an error when contacting a non-existent host. - _, err := goLookupIPOrder("notarealhost", order) + _, err := goLookupIPOrder(context.Background(), "notarealhost", order) if err == nil { t.Errorf("%s: expected error while looking up name not in hosts file", name) continue } // Now check that we get an address when the name appears in the hosts file. - addrs, err := goLookupIPOrder("thor", order) // entry is in "testdata/hosts" + addrs, err := goLookupIPOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" if err != nil { t.Errorf("%s: expected to successfully lookup host entry", name) continue @@ -510,7 +511,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { return r, nil } - _, err = goLookupIP(fqdn) + _, err = goLookupIP(context.Background(), fqdn) if err == nil { t.Fatal("expected an error") } @@ -523,17 +524,19 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { func BenchmarkGoLookupIP(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) + ctx := context.Background() for i := 0; i < b.N; i++ { - goLookupIP("www.example.com") + goLookupIP(ctx, "www.example.com") } } func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) + ctx := context.Background() for i := 0; i < b.N; i++ { - goLookupIP("some.nonexistent") + goLookupIP(ctx, "some.nonexistent") } } @@ -553,9 +556,10 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { if err := conf.writeAndUpdate(lines); err != nil { b.Fatal(err) } + ctx := context.Background() for i := 0; i < b.N; i++ { - goLookupIP("www.example.com") + goLookupIP(ctx, "www.example.com") } } diff --git a/src/net/error_test.go b/src/net/error_test.go index 31cc0e5055..9f496d7d2d 100644 --- a/src/net/error_test.go +++ b/src/net/error_test.go @@ -5,6 +5,7 @@ package net import ( + "context" "fmt" "io" "io/ioutil" @@ -138,7 +139,7 @@ func TestDialError(t *testing.T) { origTestHookLookupIP := testHookLookupIP defer func() { testHookLookupIP = origTestHookLookupIP }() - testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { + testHookLookupIP = func(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) { return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true} } sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) { @@ -283,7 +284,7 @@ func TestListenError(t *testing.T) { origTestHookLookupIP := testHookLookupIP defer func() { testHookLookupIP = origTestHookLookupIP }() - testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { + testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) { return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true} } sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) { @@ -343,7 +344,7 @@ func TestListenPacketError(t *testing.T) { origTestHookLookupIP := testHookLookupIP defer func() { testHookLookupIP = origTestHookLookupIP }() - testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { + testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) { return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true} } diff --git a/src/net/fd_unix.go b/src/net/fd_unix.go index d47b4bef99..7ef10702ed 100644 --- a/src/net/fd_unix.go +++ b/src/net/fd_unix.go @@ -7,12 +7,12 @@ package net import ( + "context" "io" "os" "runtime" "sync/atomic" "syscall" - "time" ) // Network file descriptor. @@ -36,10 +36,6 @@ type netFD struct { func sysInit() { } -func dial(network string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - return dialer(deadline) -} - func newFD(sysfd, family, sotype int, net string) (*netFD, error) { return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil } @@ -68,15 +64,17 @@ func (fd *netFD) name() string { return fd.net + ":" + ls + "->" + rs } -func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error { +func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { // Do not need to call fd.writeLock here, // because fd is not yet accessible to user, // so no concurrent operations are possible. switch err := connectFunc(fd.sysfd, ra); err { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: case nil, syscall.EISCONN: - if !deadline.IsZero() && deadline.Before(time.Now()) { - return errTimeout + select { + case <-ctx.Done(): + return mapErr(ctx.Err()) + default: } if err := fd.init(); err != nil { return err @@ -98,27 +96,27 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c if err := fd.init(); err != nil { return err } - if !deadline.IsZero() { + if deadline, _ := ctx.Deadline(); !deadline.IsZero() { fd.setWriteDeadline(deadline) defer fd.setWriteDeadline(noDeadline) } - if cancel != nil { - done := make(chan bool) - defer func() { - // This is unbuffered; wait for the goroutine before returning. - done <- true - }() - go func() { - select { - case <-cancel: - // Force the runtime's poller to immediately give - // up waiting for writability. - fd.setWriteDeadline(aLongTimeAgo) - <-done - case <-done: - } - }() - } + + // Wait for the goroutine converting context.Done into a write timeout + // to exist, otherwise our caller might cancel the context and + // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial. + done := make(chan bool) // must be unbuffered + defer func() { done <- true }() + go func() { + select { + case <-ctx.Done(): + // Force the runtime's poller to immediately give + // up waiting for writability. + fd.setWriteDeadline(aLongTimeAgo) + <-done + case <-done: + } + }() + for { // Performing multiple connect system calls on a // non-blocking socket under Unix variants does not @@ -130,8 +128,8 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c // details. if err := fd.pd.waitWrite(); err != nil { select { - case <-cancel: - return errCanceled + case <-ctx.Done(): + return mapErr(ctx.Err()) default: } return err diff --git a/src/net/fd_windows.go b/src/net/fd_windows.go index 100994525e..d1d91a6a5c 100644 --- a/src/net/fd_windows.go +++ b/src/net/fd_windows.go @@ -5,6 +5,7 @@ package net import ( + "context" "internal/race" "os" "runtime" @@ -320,14 +321,14 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { runtime.SetFinalizer(fd, (*netFD).Close) } -func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error { +func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { // Do not need to call fd.writeLock here, // because fd is not yet accessible to user, // so no concurrent operations are possible. if err := fd.init(); err != nil { return err } - if !deadline.IsZero() { + if deadline, _ := ctx.Deadline(); !deadline.IsZero() { fd.setWriteDeadline(deadline) defer fd.setWriteDeadline(noDeadline) } @@ -351,30 +352,30 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c // Call ConnectEx API. o := &fd.wop o.sa = ra - if cancel != nil { - done := make(chan bool) - defer func() { - // This is unbuffered; wait for the goroutine before returning. - done <- true - }() - go func() { - select { - case <-cancel: - // Force the runtime's poller to immediately give - // up waiting for writability. - fd.setWriteDeadline(aLongTimeAgo) - <-done - case <-done: - } - }() - } + + // Wait for the goroutine converting context.Done into a write timeout + // to exist, otherwise our caller might cancel the context and + // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial. + done := make(chan bool) // must be unbuffered + defer func() { done <- true }() + go func() { + select { + case <-ctx.Done(): + // Force the runtime's poller to immediately give + // up waiting for writability. + fd.setWriteDeadline(aLongTimeAgo) + <-done + case <-done: + } + }() + _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error { return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o) }) if err != nil { select { - case <-cancel: - return errCanceled + case <-ctx.Done(): + return mapErr(ctx.Err()) default: if _, ok := err.(syscall.Errno); ok { err = os.NewSyscallError("connectex", err) diff --git a/src/net/hook.go b/src/net/hook.go index 9ab34c0e36..d7316ea438 100644 --- a/src/net/hook.go +++ b/src/net/hook.go @@ -4,9 +4,19 @@ package net +import "context" + var ( - testHookDialTCP = dialTCP - testHookHostsPath = "/etc/hosts" - testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { return fn(host) } + // if non-nil, overrides dialTCP. + testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) + + testHookHostsPath = "/etc/hosts" + testHookLookupIP = func( + ctx context.Context, + fn func(context.Context, string) ([]IPAddr, error), + host string, + ) ([]IPAddr, error) { + return fn(ctx, host) + } testHookSetKeepAlive = func() {} ) diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index 41cfb2311a..f4a4de82fc 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -4,7 +4,10 @@ package net -import "syscall" +import ( + "context" + "syscall" +) // IPAddr represents the address of an IP end point. type IPAddr struct { @@ -56,7 +59,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { default: return nil, UnknownNetworkError(net) } - addrs, err := internetAddrList(afnet, addr, noDeadline) + addrs, err := internetAddrList(context.Background(), afnet, addr) if err != nil { return nil, err } @@ -171,7 +174,7 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} } // netProto, which must be "ip", "ip4", or "ip6" followed by a colon // and a protocol number or name. func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - c, err := dialIP(netProto, laddr, raddr, noDeadline) + c, err := dialIP(context.Background(), netProto, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } @@ -183,7 +186,7 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { // methods can be used to receive and send IP packets with per-packet // addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { - c, err := listenIP(netProto, laddr) + c, err := listenIP(context.Background(), netProto, laddr) if err != nil { return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err} } diff --git a/src/net/iprawsock_plan9.go b/src/net/iprawsock_plan9.go index e08f271e9b..6aebea169c 100644 --- a/src/net/iprawsock_plan9.go +++ b/src/net/iprawsock_plan9.go @@ -5,8 +5,8 @@ package net import ( + "context" "syscall" - "time" ) func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) { @@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) return 0, 0, syscall.EPLAN9 } -func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { +func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) { return nil, syscall.EPLAN9 } -func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) { +func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) { return nil, syscall.EPLAN9 } diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index b959afe6b4..68dc307b60 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -7,8 +7,8 @@ package net import ( + "context" "syscall" - "time" ) // BUG(mikio): On every POSIX platform, reads from the "ip4" network @@ -120,7 +120,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) return c.fd.writeMsg(b, oob, sa) } -func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { +func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) { network, proto, err := parseNetwork(netProto) if err != nil { return nil, err @@ -133,14 +133,14 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, if raddr == nil { return nil, errMissingAddress } - fd, err := internetSocket(network, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel) + fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial") if err != nil { return nil, err } return newIPConn(fd), nil } -func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) { +func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) { network, proto, err := parseNetwork(netProto) if err != nil { return nil, err @@ -150,7 +150,7 @@ func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) { default: return nil, UnknownNetworkError(netProto) } - fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel) + fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen") if err != nil { return nil, err } diff --git a/src/net/ipsock.go b/src/net/ipsock.go index dc13c17439..24daf173ac 100644 --- a/src/net/ipsock.go +++ b/src/net/ipsock.go @@ -6,7 +6,9 @@ package net -import "time" +import ( + "context" +) var ( // supportsIPv4 reports whether the platform supports IPv4 @@ -188,7 +190,7 @@ func JoinHostPort(host, port string) string { // address or a DNS name, and returns a list of internet protocol // family addresses. The result contains at least one address when // error is nil. -func internetAddrList(net, addr string, deadline time.Time) (addrList, error) { +func internetAddrList(ctx context.Context, net, addr string) (addrList, error) { var ( err error host, port string @@ -236,7 +238,7 @@ func internetAddrList(net, addr string, deadline time.Time) (addrList, error) { return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil } // Try as a DNS name. - ips, err := lookupIPDeadline(host, deadline) + ips, err := lookupIPContext(ctx, host) if err != nil { return nil, err } diff --git a/src/net/ipsock_posix.go b/src/net/ipsock_posix.go index 644964e78d..abe90ac0e6 100644 --- a/src/net/ipsock_posix.go +++ b/src/net/ipsock_posix.go @@ -7,9 +7,9 @@ package net import ( + "context" "runtime" "syscall" - "time" ) // BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the @@ -152,9 +152,10 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family return syscall.AF_INET6, false } -func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) { +// Internet sockets (TCP, UDP, IP) +func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) { family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) - return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel) + return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr) } func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) { diff --git a/src/net/lookup.go b/src/net/lookup.go index ab6886ddff..0d3ef79bab 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -5,8 +5,8 @@ package net import ( + "context" "internal/singleflight" - "time" ) // protocols contains minimal mappings between internet protocol @@ -33,7 +33,7 @@ func LookupHost(host string) (addrs []string, err error) { if ip := ParseIP(host); ip != nil { return []string{host}, nil } - return lookupHost(host) + return lookupHost(context.Background(), host) } // LookupIP looks up host using the local resolver. @@ -47,7 +47,7 @@ func LookupIP(host string) (ips []IP, err error) { if ip := ParseIP(host); ip != nil { return []IP{ip}, nil } - addrs, err := lookupIPMerge(host) + addrs, err := lookupIPMerge(context.Background(), host) if err != nil { return } @@ -63,9 +63,9 @@ var lookupGroup singleflight.Group // lookupIPMerge wraps lookupIP, but makes sure that for any given // host, only one lookup is in-flight at a time. The returned memory // is always owned by the caller. -func lookupIPMerge(host string) (addrs []IPAddr, err error) { +func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) { addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) { - return testHookLookupIP(lookupIP, host) + return testHookLookupIP(ctx, lookupIP, host) }) return lookupIPReturn(addrsi, err, shared) } @@ -85,37 +85,26 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error return addrs, nil } -// lookupIPDeadline looks up a hostname with a deadline. -func lookupIPDeadline(host string, deadline time.Time) (addrs []IPAddr, err error) { - if deadline.IsZero() { - return lookupIPMerge(host) - } - - // We could push the deadline down into the name resolution - // functions. However, the most commonly used implementation - // calls getaddrinfo, which has no timeout. - - timeout := deadline.Sub(time.Now()) - if timeout <= 0 { - return nil, errTimeout - } - t := time.NewTimer(timeout) - defer t.Stop() +// lookupIPContext looks up a hostname with a context. +func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) { + // TODO(bradfitz): when adding trace hooks later here, make + // sure the tracing is done outside of the singleflight + // merging. Both callers should see the DNS lookup delay, even + // if it's only being done once. The r.Shared bit can be + // included in the trace for callers who need it. ch := lookupGroup.DoChan(host, func() (interface{}, error) { - return testHookLookupIP(lookupIP, host) + return testHookLookupIP(ctx, lookupIP, host) }) select { - case <-t.C: + case <-ctx.Done(): // The DNS lookup timed out for some reason. Force // future requests to start the DNS lookup again // rather than waiting for the current lookup to // complete. See issue 8602. lookupGroup.Forget(host) - - return nil, errTimeout - + return nil, mapErr(ctx.Err()) case r := <-ch: return lookupIPReturn(r.Val, r.Err, r.Shared) } diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go index 34ee5354a5..4224263602 100644 --- a/src/net/lookup_plan9.go +++ b/src/net/lookup_plan9.go @@ -5,6 +5,7 @@ package net import ( + "context" "errors" "os" ) @@ -115,7 +116,7 @@ func lookupProtocol(name string) (proto int, err error) { return 0, UnknownNetworkError(name) } -func lookupHost(host string) (addrs []string, err error) { +func lookupHost(ctx context.Context, host string) (addrs []string, err error) { // Use netdir/cs instead of netdir/dns because cs knows about // host names in local network (e.g. from /lib/ndb/local) lines, err := queryCS("net", host, "1") @@ -146,7 +147,8 @@ loop: return } -func lookupIP(host string) (addrs []IPAddr, err error) { +func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + // TODO(bradfitz): push down ctx lits, err := LookupHost(host) if err != nil { return diff --git a/src/net/lookup_stub.go b/src/net/lookup_stub.go index a8625905e4..38a4f0bae4 100644 --- a/src/net/lookup_stub.go +++ b/src/net/lookup_stub.go @@ -6,17 +6,20 @@ package net -import "syscall" +import ( + "context" + "syscall" +) func lookupProtocol(name string) (proto int, err error) { return 0, syscall.ENOPROTOOPT } -func lookupHost(host string) (addrs []string, err error) { +func lookupHost(ctx context.Context, host string) (addrs []string, err error) { return nil, syscall.ENOPROTOOPT } -func lookupIP(host string) (addrs []IPAddr, err error) { +func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return nil, syscall.ENOPROTOOPT } diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go index 1345751cfd..85bcfef6e9 100644 --- a/src/net/lookup_test.go +++ b/src/net/lookup_test.go @@ -6,6 +6,7 @@ package net import ( "bytes" + "context" "fmt" "internal/testenv" "runtime" @@ -14,7 +15,7 @@ import ( "time" ) -func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { +func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) { switch host { case "localhost": return []IPAddr{ @@ -22,7 +23,7 @@ func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, {IP: IPv6loopback}, }, nil default: - return fn(host) + return fn(ctx, host) } } @@ -375,15 +376,20 @@ func TestLookupIPDeadline(t *testing.T) { const N = 5000 const timeout = 3 * time.Second + ctxHalfTimeout, cancel := context.WithTimeout(context.Background(), timeout/2) + defer cancel() + ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + c := make(chan error, 2*N) for i := 0; i < N; i++ { name := fmt.Sprintf("%d.net-test.golang.org", i) go func() { - _, err := lookupIPDeadline(name, time.Now().Add(timeout/2)) + _, err := lookupIPContext(ctxHalfTimeout, name) c <- err }() go func() { - _, err := lookupIPDeadline(name, time.Now().Add(timeout)) + _, err := lookupIPContext(ctxTimeout, name) c <- err }() } diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index cd4ddbdb24..8d3fa47782 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -6,7 +6,10 @@ package net -import "sync" +import ( + "context" + "sync" +) var onceReadProtocols sync.Once @@ -49,7 +52,7 @@ func lookupProtocol(name string) (int, error) { return proto, nil } -func lookupHost(host string) (addrs []string, err error) { +func lookupHost(ctx context.Context, host string) (addrs []string, err error) { order := systemConf().hostLookupOrder(host) if order == hostLookupCgo { if addrs, err, ok := cgoLookupHost(host); ok { @@ -58,19 +61,20 @@ func lookupHost(host string) (addrs []string, err error) { // cgo not available (or netgo); fall back to Go's DNS resolver order = hostLookupFilesDNS } - return goLookupHostOrder(host, order) + return goLookupHostOrder(ctx, host, order) } -func lookupIP(host string) (addrs []IPAddr, err error) { +func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { order := systemConf().hostLookupOrder(host) if order == hostLookupCgo { + // TODO(bradfitz): push down ctx, or at least its deadline to start if addrs, err, ok := cgoLookupIP(host); ok { return addrs, err } // cgo not available (or netgo); fall back to Go's DNS resolver order = hostLookupFilesDNS } - return goLookupIPOrder(host, order) + return goLookupIPOrder(ctx, host, order) } func lookupPort(network, service string) (int, error) { diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index 13edc264e8..ce012ba873 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -5,6 +5,7 @@ package net import ( + "context" "os" "runtime" "syscall" @@ -51,8 +52,8 @@ func lookupProtocol(name string) (int, error) { return r.proto, r.err } -func lookupHost(name string) ([]string, error) { - ips, err := LookupIP(name) +func lookupHost(ctx context.Context, name string) ([]string, error) { + ips, err := lookupIP(ctx, name) if err != nil { return nil, err } @@ -83,59 +84,97 @@ func gethostbyname(name string) (addrs []IPAddr, err error) { return addrs, nil } -func oldLookupIP(name string) ([]IPAddr, error) { +func oldLookupIP(ctx context.Context, name string) ([]IPAddr, error) { // GetHostByName return value is stored in thread local storage. // Start new os thread before the call to prevent races. - type result struct { + type ret struct { addrs []IPAddr err error } - ch := make(chan result) + ch := make(chan ret, 1) go func() { acquireThread() defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() addrs, err := gethostbyname(name) - ch <- result{addrs: addrs, err: err} + ch <- ret{addrs: addrs, err: err} }() - r := <-ch - if r.err != nil { - r.err = &DNSError{Err: r.err.Error(), Name: name} + select { + case r := <-ch: + if r.err != nil { + r.err = &DNSError{Err: r.err.Error(), Name: name} + } + return r.addrs, r.err + case <-ctx.Done(): + // TODO(bradfitz,brainman): cancel the ongoing + // gethostbyname? For now we just let it finish and + // write to the buffered channel. + return nil, &DNSError{ + Name: name, + Err: ctx.Err().Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } } - return r.addrs, r.err } -func newLookupIP(name string) ([]IPAddr, error) { - acquireThread() - defer releaseThread() - hints := syscall.AddrinfoW{ - Family: syscall.AF_UNSPEC, - Socktype: syscall.SOCK_STREAM, - Protocol: syscall.IPPROTO_IP, - } - var result *syscall.AddrinfoW - e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result) - if e != nil { - return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name} +func newLookupIP(ctx context.Context, name string) ([]IPAddr, error) { + // TODO(bradfitz,brainman): use ctx? + + type ret struct { + addrs []IPAddr + err error } - defer syscall.FreeAddrInfoW(result) - addrs := make([]IPAddr, 0, 5) - for ; result != nil; result = result.Next { - addr := unsafe.Pointer(result.Addr) - switch result.Family { - case syscall.AF_INET: - a := (*syscall.RawSockaddrInet4)(addr).Addr - addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])}) - case syscall.AF_INET6: - a := (*syscall.RawSockaddrInet6)(addr).Addr - zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id)) - addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone}) - default: - return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name} + ch := make(chan ret, 1) + go func() { + acquireThread() + defer releaseThread() + hints := syscall.AddrinfoW{ + Family: syscall.AF_UNSPEC, + Socktype: syscall.SOCK_STREAM, + Protocol: syscall.IPPROTO_IP, + } + var result *syscall.AddrinfoW + e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result) + if e != nil { + ch <- ret{err: &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}} + } + defer syscall.FreeAddrInfoW(result) + addrs := make([]IPAddr, 0, 5) + for ; result != nil; result = result.Next { + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case syscall.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr).Addr + addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])}) + case syscall.AF_INET6: + a := (*syscall.RawSockaddrInet6)(addr).Addr + zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id)) + addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone}) + default: + ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}} + } + } + ch <- ret{addrs: addrs} + }() + select { + case r := <-ch: + return r.addrs, r.err + case <-ctx.Done(): + // TODO(bradfitz,brainman): cancel the ongoing + // GetAddrInfoW? It would require conditionally using + // GetAddrInfoEx with lpOverlapped, which requires + // Windows 8 or newer. I guess we'll need oldLookupIP, + // newLookupIP, and newerLookUP. + // + // For now we just let it finish and write to the + // buffered channel. + return nil, &DNSError{ + Name: name, + Err: ctx.Err().Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, } } - return addrs, nil } func getservbyname(network, service string) (int, error) { diff --git a/src/net/net.go b/src/net/net.go index 3b37b336d1..27e9ca367d 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -79,6 +79,7 @@ On Windows, the resolver always uses C library functions, such as GetAddrInfo an package net import ( + "context" "errors" "io" "os" @@ -377,6 +378,22 @@ var ( ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection") ) +// mapErr maps from the context errors to the historical internal net +// error values. +// +// TODO(bradfitz): get rid of this after adjusting tests and making +// context.DeadlineExceeded implement net.Error? +func mapErr(err error) error { + switch err { + case context.Canceled: + return errCanceled + case context.DeadlineExceeded: + return errTimeout + default: + return err + } +} + // OpError is the error type usually returned by functions in the net // package. It describes the operation, network type, and address of // an error. diff --git a/src/net/netgo_unix_test.go b/src/net/netgo_unix_test.go index 1d950d67d7..0a118874c2 100644 --- a/src/net/netgo_unix_test.go +++ b/src/net/netgo_unix_test.go @@ -7,7 +7,10 @@ package net -import "testing" +import ( + "context" + "testing" +) func TestGoLookupIP(t *testing.T) { host := "localhost" @@ -18,7 +21,7 @@ func TestGoLookupIP(t *testing.T) { if err != nil { t.Error(err) } - if _, err := goLookupIP(host); err != nil { + if _, err := goLookupIP(context.Background(), host); err != nil { t.Error(err) } } diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index 3dddfef4c5..c3af27b596 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -7,9 +7,9 @@ package net import ( + "context" "os" "syscall" - "time" ) // A sockaddr represents a TCP, UDP, IP or Unix network endpoint @@ -34,7 +34,7 @@ type sockaddr interface { // socket returns a network file descriptor that is ready for // asynchronous I/O using the network poller. -func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) { +func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) { s, err := sysSocket(family, sotype, proto) if err != nil { return nil, err @@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s return fd, nil } } - if err := fd.dial(laddr, raddr, deadline, cancel); err != nil { + if err := fd.dial(ctx, laddr, raddr); err != nil { fd.Close() return nil, err } @@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { return func(syscall.Sockaddr) Addr { return nil } } -func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error { +func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error { var err error var lsa syscall.Sockaddr if laddr != nil { @@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan s if rsa, err = raddr.sockaddr(fd.family); err != nil { return err } - if err := fd.connect(lsa, rsa, deadline, cancel); err != nil { + if err := fd.connect(ctx, lsa, rsa); err != nil { return err } fd.isConnected = true diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index a5c3515c19..7cffcc58cb 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -5,6 +5,7 @@ package net import ( + "context" "io" "os" "syscall" @@ -60,7 +61,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { default: return nil, UnknownNetworkError(net) } - addrs, err := internetAddrList(net, addr, noDeadline) + addrs, err := internetAddrList(context.Background(), net, addr) if err != nil { return nil, err } @@ -186,7 +187,7 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { if raddr == nil { return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} } - c, err := dialTCP(net, laddr, raddr, noDeadline, noCancel) + c, err := dialTCP(context.Background(), net, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } @@ -285,7 +286,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { if laddr == nil { laddr = &TCPAddr{} } - ln, err := listenTCP(net, laddr) + ln, err := listenTCP(context.Background(), net, laddr) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} } diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index 698b834295..dd36c70d50 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -5,17 +5,24 @@ package net import ( + "context" "io" "os" - "time" ) func (c *TCPConn) readFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } -func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { - if !deadline.IsZero() { +func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + if testHookDialTCP != nil { + return testHookDialTCP(ctx, net, laddr, raddr) + } + return doDialTCP(ctx, net, laddr, raddr) +} + +func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + if d, _ := ctx.Deadline(); !d.IsZero() { panic("net.dialTCP: deadline not implemented on Plan 9") } // TODO(bradfitz,0intro): also use the cancel channel. @@ -63,7 +70,7 @@ func (ln *TCPListener) file() (*os.File, error) { return f, nil } -func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) { +func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) { fd, err := listenPlan9(network, laddr) if err != nil { return nil, err diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index 3902565f73..c9a8b6808e 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -7,10 +7,10 @@ package net import ( + "context" "io" "os" "syscall" - "time" ) func sockaddrToTCP(sa syscall.Sockaddr) Addr { @@ -47,8 +47,15 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } -func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { - fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel) +func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + if testHookDialTCP != nil { + return testHookDialTCP(ctx, net, laddr, raddr) + } + return doDialTCP(ctx, net, laddr, raddr) +} + +func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") // TCP has a rarely used mechanism called a 'simultaneous connection' in // which Dial("tcp", addr1, addr2) run on the machine at addr1 can @@ -78,7 +85,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-cha if err == nil { fd.Close() } - fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel) + fd, err = internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") } if err != nil { @@ -141,8 +148,8 @@ func (ln *TCPListener) file() (*os.File, error) { return f, nil } -func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) { - fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel) +func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) { + fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_STREAM, 0, "listen") if err != nil { return nil, err } diff --git a/src/net/udpsock.go b/src/net/udpsock.go index e7e9796668..980f67c81f 100644 --- a/src/net/udpsock.go +++ b/src/net/udpsock.go @@ -4,7 +4,10 @@ package net -import "syscall" +import ( + "context" + "syscall" +) // UDPAddr represents the address of a UDP end point. type UDPAddr struct { @@ -55,7 +58,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { default: return nil, UnknownNetworkError(net) } - addrs, err := internetAddrList(net, addr, noDeadline) + addrs, err := internetAddrList(context.Background(), net, addr) if err != nil { return nil, err } @@ -181,7 +184,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { if raddr == nil { return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} } - c, err := dialUDP(net, laddr, raddr, noDeadline) + c, err := dialUDP(context.Background(), net, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } @@ -204,7 +207,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { if laddr == nil { laddr = &UDPAddr{} } - c, err := listenUDP(net, laddr) + c, err := listenUDP(context.Background(), net, laddr) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} } @@ -231,7 +234,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon if gaddr == nil || gaddr.IP == nil { return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress} } - c, err := listenMulticastUDP(network, ifi, gaddr) + c, err := listenMulticastUDP(context.Background(), network, ifi, gaddr) if err != nil { return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err} } diff --git a/src/net/udpsock_plan9.go b/src/net/udpsock_plan9.go index 5f15427064..81edaf59fe 100644 --- a/src/net/udpsock_plan9.go +++ b/src/net/udpsock_plan9.go @@ -5,10 +5,10 @@ package net import ( + "context" "errors" "os" "syscall" - "time" ) func (c *UDPConn) readFrom(b []byte) (n int, addr *UDPAddr, err error) { @@ -55,8 +55,8 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error return 0, 0, syscall.EPLAN9 } -func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { - if !deadline.IsZero() { +func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) { + if deadline, _ := ctx.Deadline(); !deadline.IsZero() { panic("net.dialUDP: deadline not implemented on Plan 9") } fd, err := dialPlan9(net, laddr, raddr) @@ -94,7 +94,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { return h, b } -func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { +func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) { l, err := listenPlan9(network, laddr) if err != nil { return nil, err @@ -111,6 +111,6 @@ func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { return newUDPConn(fd), err } -func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { +func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { return nil, syscall.EPLAN9 } diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go index 4d3255c996..4924801ebb 100644 --- a/src/net/udpsock_posix.go +++ b/src/net/udpsock_posix.go @@ -7,8 +7,8 @@ package net import ( + "context" "syscall" - "time" ) func sockaddrToUDP(sa syscall.Sockaddr) Addr { @@ -90,24 +90,24 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error return c.fd.writeMsg(b, oob, sa) } -func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { - fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel) +func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) { + fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial") if err != nil { return nil, err } return newUDPConn(fd), nil } -func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel) +func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) { + fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen") if err != nil { return nil, err } return newUDPConn(fd), nil } -func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel) +func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { + fd, err := internetSocket(ctx, network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen") if err != nil { return nil, err } diff --git a/src/net/unixsock.go b/src/net/unixsock.go index d1eb0b62ee..bacdaa41d9 100644 --- a/src/net/unixsock.go +++ b/src/net/unixsock.go @@ -5,6 +5,7 @@ package net import ( + "context" "os" "syscall" "time" @@ -188,7 +189,7 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { default: return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)} } - c, err := dialUnix(net, laddr, raddr, noDeadline) + c, err := dialUnix(context.Background(), net, laddr, raddr) if err != nil { return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } @@ -290,7 +291,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { if laddr == nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress} } - ln, err := listenUnix(net, laddr) + ln, err := listenUnix(context.Background(), net, laddr) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} } @@ -310,7 +311,7 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { if laddr == nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: errMissingAddress} } - c, err := listenUnixgram(net, laddr) + c, err := listenUnixgram(context.Background(), net, laddr) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} } diff --git a/src/net/unixsock_plan9.go b/src/net/unixsock_plan9.go index 5d5b18f467..e70eb211bb 100644 --- a/src/net/unixsock_plan9.go +++ b/src/net/unixsock_plan9.go @@ -5,9 +5,9 @@ package net import ( + "context" "os" "syscall" - "time" ) func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) { @@ -26,7 +26,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err return 0, 0, syscall.EPLAN9 } -func dialUnix(network string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { +func dialUnix(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) { return nil, syscall.EPLAN9 } @@ -42,10 +42,10 @@ func (ln *UnixListener) file() (*os.File, error) { return nil, syscall.EPLAN9 } -func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { +func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) { return nil, syscall.EPLAN9 } -func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) { +func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) { return nil, syscall.EPLAN9 } diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go index 9275e1034c..5f0999c4c2 100644 --- a/src/net/unixsock_posix.go +++ b/src/net/unixsock_posix.go @@ -7,13 +7,13 @@ package net import ( + "context" "errors" "os" "syscall" - "time" ) -func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Time) (*netFD, error) { +func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) { var sotype int switch net { case "unix": @@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti return nil, errors.New("unknown mode: " + mode) } - fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel) + fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr) if err != nil { return nil, err } @@ -146,8 +146,8 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err return c.fd.writeMsg(b, oob, sa) } -func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { - fd, err := unixSocket(net, laddr, raddr, "dial", deadline) +func dialUnix(ctx context.Context, net string, laddr, raddr *UnixAddr) (*UnixConn, error) { + fd, err := unixSocket(ctx, net, laddr, raddr, "dial") if err != nil { return nil, err } @@ -187,16 +187,16 @@ func (ln *UnixListener) file() (*os.File, error) { return f, nil } -func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { - fd, err := unixSocket(network, laddr, nil, "listen", noDeadline) +func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) { + fd, err := unixSocket(ctx, network, laddr, nil, "listen") if err != nil { return nil, err } return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil } -func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) { - fd, err := unixSocket(network, laddr, nil, "listen", noDeadline) +func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) { + fd, err := unixSocket(ctx, network, laddr, nil, "listen") if err != nil { return nil, err } -- 2.48.1