]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix plan9 after context change, propagate contexts more
authorBrad Fitzpatrick <bradfitz@golang.org>
Sat, 16 Apr 2016 21:17:40 +0000 (14:17 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 18 Apr 2016 16:30:03 +0000 (16:30 +0000)
My previous https://golang.org/cl/22101 to add context throughout the
net package broke Plan 9, which isn't currently tested (#15251).

It also broke some old unsupported version of Windows (Windows 2000?)
which doesn't have the ConnectEx function, but that was only found
visually, since our minimum supported Windows version has ConnectEx.
This change simplifies the Windows and deletes the non-ConnectEx code
path.  Windows 2000 will work even less now, if it even worked
before. Windows XP remains our minimum supported version.

Specifically, the previous CL stopped using the "dial" function, which
0intro noted:
https://github.com/golang/go/issues/15333#issuecomment-210842761

This CL removes the dial function instead and makes plan9's net
implementation respect contexts, which likely fixes a number of
t.Skipped tests. I'm leaving that to 0intro to investigate.

In the process of propagating and respecting contexts for plan9, I had
to change some signatures to add contexts to more places and ended up
pushing contexts down into the Go-based DNS resolution as well,
replacing the pure-Go DNS implementation's use of "timeout
time.Duration" with a context instead.

Updates #11932
Updates #15328

Fixes #15333

Change-Id: I6ad1e62f38271cdd86b3f40921f2d0f23374936a
Reviewed-on: https://go-review.googlesource.com/22144
Reviewed-by: David du Colombier <0intro@gmail.com>
Reviewed-by: Mikio Hara <mikioh.mikioh@gmail.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

16 files changed:
src/net/dial.go
src/net/dial_gen.go [deleted file]
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/fd_plan9.go
src/net/fd_windows.go
src/net/iprawsock.go
src/net/iprawsock_posix.go
src/net/ipsock_plan9.go
src/net/lookup.go
src/net/lookup_plan9.go
src/net/lookup_stub.go
src/net/lookup_unix.go
src/net/lookup_windows.go
src/net/tcpsock_plan9.go
src/net/udpsock_plan9.go

index 1f31e8f2cc73a4100608ae833997c6b9a427273b..59e41f536b2a13e7e57edb9b3fb69e3a026a1e22 100644 (file)
@@ -124,7 +124,7 @@ func (d *Dialer) fallbackDelay() time.Duration {
        }
 }
 
-func parseNetwork(net string) (afnet string, proto int, err error) {
+func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) {
        i := last(net, ':')
        if i < 0 { // no colon
                switch net {
@@ -143,7 +143,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
                protostr := net[i+1:]
                proto, i, ok := dtoi(protostr, 0)
                if !ok || i != len(protostr) {
-                       proto, err = lookupProtocol(protostr)
+                       proto, err = lookupProtocol(ctx, protostr)
                        if err != nil {
                                return "", 0, err
                        }
@@ -157,7 +157,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
 // addresses. The result contains at least one address when error is
 // nil.
 func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
-       afnet, _, err := parseNetwork(network)
+       afnet, _, err := parseNetwork(ctx, network)
        if err != nil {
                return nil, err
        }
@@ -472,8 +472,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
 }
 
 // dialSingle attempts to establish and returns a single connection to
-// the destination address. This must be called through the OS-specific
-// dial function, because some OSes don't implement the deadline feature.
+// the destination address.
 func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
        la := dp.LocalAddr
        switch ra := ra.(type) {
diff --git a/src/net/dial_gen.go b/src/net/dial_gen.go
deleted file mode 100644 (file)
index a2cd8cb..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build windows plan9
-
-package net
-
-import "time"
-
-// dialChannel is the simple pure-Go implementation of dial, still
-// used on operating systems where the deadline hasn't been pushed
-// down into the pollserver. (Plan 9 and some old versions of Windows)
-func dialChannel(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) {
-       if deadline.IsZero() {
-               return dialer(noDeadline)
-       }
-       timeout := deadline.Sub(time.Now())
-       if timeout <= 0 {
-               return nil, &OpError{Op: "dial", Net: net, Source: nil, Addr: ra, Err: errTimeout}
-       }
-       t := time.NewTimer(timeout)
-       defer t.Stop()
-       type racer struct {
-               Conn
-               error
-       }
-       ch := make(chan racer, 1)
-       go func() {
-               testHookDialChannel()
-               c, err := dialer(noDeadline)
-               ch <- racer{c, err}
-       }()
-       select {
-       case <-t.C:
-               return nil, &OpError{Op: "dial", Net: net, Source: nil, Addr: ra, Err: errTimeout}
-       case racer := <-ch:
-               return racer.Conn, racer.error
-       }
-}
index 914dd767d33bb26854310a15e9a99a34afe4fb10..5ae21012e3cf150a3c4c081b95fec923d2dda3df 100644 (file)
@@ -27,10 +27,10 @@ import (
 
 // A dnsDialer provides dialing suitable for DNS queries.
 type dnsDialer interface {
-       dialDNS(string, string) (dnsConn, error)
+       dialDNS(ctx context.Context, network, addr string) (dnsConn, error)
 }
 
-var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} }
+var testHookDNSDialer = func() dnsDialer { return &Dialer{} }
 
 // A dnsConn represents a DNS transport endpoint.
 type dnsConn interface {
@@ -105,7 +105,7 @@ func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
        return nil
 }
 
-func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
+func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
        switch network {
        case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
        default:
@@ -116,9 +116,9 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
        // call back here to translate it. The DNS config parser has
        // already checked that all the cfg.servers[i] are IP
        // addresses, which Dial will use without a DNS lookup.
-       c, err := d.Dial(network, server)
+       c, err := d.DialContext(ctx, network, server)
        if err != nil {
-               return nil, err
+               return nil, mapErr(err)
        }
        switch network {
        case "tcp", "tcp4", "tcp6":
@@ -130,8 +130,8 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
 }
 
 // exchange sends a query on the connection and hopes for a response.
-func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
-       d := testHookDNSDialer(timeout)
+func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, error) {
+       d := testHookDNSDialer()
        out := dnsMsg{
                dnsMsgHdr: dnsMsgHdr{
                        recursion_desired: true,
@@ -141,21 +141,21 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg
                },
        }
        for _, network := range []string{"udp", "tcp"} {
-               c, err := d.dialDNS(network, server)
+               c, err := d.dialDNS(ctx, network, server)
                if err != nil {
                        return nil, err
                }
                defer c.Close()
-               if timeout > 0 {
-                       c.SetDeadline(time.Now().Add(timeout))
+               if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+                       c.SetDeadline(d)
                }
                out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
                if err := c.writeDNSQuery(&out); err != nil {
-                       return nil, err
+                       return nil, mapErr(err)
                }
                in, err := c.readDNSResponse()
                if err != nil {
-                       return nil, err
+                       return nil, mapErr(err)
                }
                if in.id != out.id {
                        return nil, errors.New("DNS message ID mismatch")
@@ -170,16 +170,24 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg
 
 // Do a lookup for a single name, which must be rooted
 // (otherwise answer will not find the answers).
-func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
+func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
        if len(cfg.servers) == 0 {
                return "", nil, &DNSError{Err: "no DNS servers", Name: name}
        }
+
        timeout := time.Duration(cfg.timeout) * time.Second
+       deadline := time.Now().Add(timeout)
+       if old, ok := ctx.Deadline(); !ok || deadline.Before(old) {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithDeadline(ctx, deadline)
+               defer cancel()
+       }
+
        var lastErr error
        for i := 0; i < cfg.attempts; i++ {
                for _, server := range cfg.servers {
                        server = JoinHostPort(server, "53")
-                       msg, err := exchange(server, name, qtype, timeout)
+                       msg, err := exchange(ctx, server, name, qtype)
                        if err != nil {
                                lastErr = &DNSError{
                                        Err:    err.Error(),
@@ -297,7 +305,7 @@ func (conf *resolverConfig) releaseSema() {
        <-conf.ch
 }
 
-func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
+func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
        if !isDomainName(name) {
                return "", nil, &DNSError{Err: "invalid domain name", Name: name}
        }
@@ -306,7 +314,7 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
        conf := resolvConf.dnsConfig
        resolvConf.mu.RUnlock()
        for _, fqdn := range conf.nameList(name) {
-               cname, rrs, err = tryOneName(conf, fqdn, qtype)
+               cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype)
                if err == nil {
                        break
                }
@@ -467,7 +475,7 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a
        for _, fqdn := range conf.nameList(name) {
                for _, qtype := range qtypes {
                        go func(qtype uint16) {
-                               _, rrs, err := tryOneName(conf, fqdn, qtype)
+                               _, rrs, err := tryOneName(ctx, conf, fqdn, qtype)
                                lane <- racer{fqdn, rrs, err}
                        }(qtype)
                }
@@ -510,8 +518,8 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a
 // Normally we let cgo use the C library resolver instead of
 // depending on our lookup code, so that Go and C get the same
 // answers.
-func goLookupCNAME(name string) (cname string, err error) {
-       _, rrs, err := lookup(name, dnsTypeCNAME)
+func goLookupCNAME(ctx context.Context, name string) (cname string, err error) {
+       _, rrs, err := lookup(ctx, name, dnsTypeCNAME)
        if err != nil {
                return
        }
@@ -524,7 +532,7 @@ func goLookupCNAME(name string) (cname string, err error) {
 // only if cgoLookupPTR is the stub in cgo_stub.go).
 // Normally we let cgo use the C library resolver instead of depending
 // on our lookup code, so that Go and C get the same answers.
-func goLookupPTR(addr string) ([]string, error) {
+func goLookupPTR(ctx context.Context, addr string) ([]string, error) {
        names := lookupStaticAddr(addr)
        if len(names) > 0 {
                return names, nil
@@ -533,7 +541,7 @@ func goLookupPTR(addr string) ([]string, error) {
        if err != nil {
                return nil, err
        }
-       _, rrs, err := lookup(arpa, dnsTypePTR)
+       _, rrs, err := lookup(ctx, arpa, dnsTypePTR)
        if err != nil {
                return nil, err
        }
index 145a3b6a33b1b4abe7c9c4d1da7abd48d2c50f6c..761fb23f142778c90221ede64f3f7be0f73059a3 100644 (file)
@@ -37,8 +37,9 @@ func TestDNSTransportFallback(t *testing.T) {
        testenv.MustHaveExternalNetwork(t)
 
        for _, tt := range dnsTransportFallbackTests {
-               timeout := time.Duration(tt.timeout) * time.Second
-               msg, err := exchange(tt.server, tt.name, tt.qtype, timeout)
+               ctx, cancel := context.WithTimeout(context.Background(), time.Duration(tt.timeout)*time.Second)
+               defer cancel()
+               msg, err := exchange(ctx, tt.server, tt.name, tt.qtype)
                if err != nil {
                        t.Error(err)
                        continue
@@ -78,7 +79,9 @@ func TestSpecialDomainName(t *testing.T) {
 
        server := "8.8.8.8:53"
        for _, tt := range specialDomainNameTests {
-               msg, err := exchange(server, tt.name, tt.qtype, 3*time.Second)
+               ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+               defer cancel()
+               msg, err := exchange(ctx, server, tt.name, tt.qtype)
                if err != nil {
                        t.Error(err)
                        continue
@@ -492,7 +495,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
        }
 
        d := &fakeDNSConn{}
-       testHookDNSDialer = func(time.Duration) dnsDialer { return d }
+       testHookDNSDialer = func() dnsDialer { return d }
 
        d.rh = func(q *dnsMsg) (*dnsMsg, error) {
                r := &dnsMsg{
@@ -571,7 +574,7 @@ type fakeDNSConn struct {
        rh func(*dnsMsg) (*dnsMsg, error)
 }
 
-func (f *fakeDNSConn) dialDNS(n, s string) (dnsConn, error) {
+func (f *fakeDNSConn) dialDNS(_ context.Context, n, s string) (dnsConn, error) {
        return f, nil
 }
 
index d0e9c53fca67b341447e0e232e4ffcfc859b6890..35d162431782381f17767c08368c6acd517ec6db 100644 (file)
@@ -32,12 +32,6 @@ func sysInit() {
        netdir = "/net"
 }
 
-func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) {
-       // On plan9, use the relatively inefficient
-       // goroutine-racing implementation.
-       return dialChannel(net, ra, dialer, deadline)
-}
-
 func newFD(net, name string, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) {
        return &netFD{net: net, n: name, dir: netdir + "/" + net + "/" + name, ctl: ctl, data: data, laddr: laddr, raddr: raddr}, nil
 }
index d1d91a6a5c532e07361cc8cd54c797084ebea31b..ca46bf9361019886b998ce54d0df2b13fa45bf79 100644 (file)
@@ -11,7 +11,6 @@ import (
        "runtime"
        "sync"
        "syscall"
-       "time"
        "unsafe"
 )
 
@@ -70,22 +69,15 @@ func sysInit() {
        }
 }
 
+// canUseConnectEx reports whether we can use the ConnectEx Windows API call
+// for the given network type.
 func canUseConnectEx(net string) bool {
        switch net {
-       case "udp", "udp4", "udp6", "ip", "ip4", "ip6":
-               // ConnectEx windows API does not support connectionless sockets.
-               return false
+       case "tcp", "tcp4", "tcp6":
+               return true
        }
-       return syscall.LoadConnectEx() == nil
-}
-
-func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) {
-       if !canUseConnectEx(net) {
-               // Use the relatively inefficient goroutine-racing
-               // implementation of DialTimeout.
-               return dialChannel(net, ra, dialer, deadline)
-       }
-       return dialer(deadline)
+       // ConnectEx windows API does not support connectionless sockets.
+       return false
 }
 
 // operation contains superset of data necessary to perform all async IO.
@@ -328,12 +320,13 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
        if err := fd.init(); err != nil {
                return err
        }
-       if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
+       if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
                fd.setWriteDeadline(deadline)
                defer fd.setWriteDeadline(noDeadline)
        }
        if !canUseConnectEx(fd.net) {
-               return os.NewSyscallError("connect", connectFunc(fd.sysfd, ra))
+               err := connectFunc(fd.sysfd, ra)
+               return os.NewSyscallError("connect", err)
        }
        // ConnectEx windows API requires an unconnected, previously bound socket.
        if la == nil {
index f4a4de82fcdfa034dc49304381925495f9fa4ebc..173b3cb4114296ef5e1d4e6606456e6631c5ffdc 100644 (file)
@@ -50,7 +50,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
        if net == "" { // a hint wildcard for Go 1.0 undocumented behavior
                net = "ip"
        }
-       afnet, _, err := parseNetwork(net)
+       afnet, _, err := parseNetwork(context.Background(), net)
        if err != nil {
                return nil, err
        }
index 68dc307b60685c92e06ceba11cfb2a3300b93e01..3e0b060a8a4a177dca2d0982ad245057111de575 100644 (file)
@@ -121,7 +121,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
 }
 
 func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
-       network, proto, err := parseNetwork(netProto)
+       network, proto, err := parseNetwork(ctx, netProto)
        if err != nil {
                return nil, err
        }
@@ -141,7 +141,7 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn
 }
 
 func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
-       network, proto, err := parseNetwork(netProto)
+       network, proto, err := parseNetwork(ctx, netProto)
        if err != nil {
                return nil, err
        }
index f7c2b4468833388fb1108faadb06e07725d1aa02..2b84683eeb5bfd6e3f5ab1a55195f6f93fe050bc 100644 (file)
@@ -7,6 +7,7 @@
 package net
 
 import (
+       "context"
        "os"
        "syscall"
 )
@@ -99,7 +100,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) {
        return addr, nil
 }
 
-func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) {
+func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) {
        var (
                ip   IP
                port int
@@ -118,7 +119,7 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string,
                return
        }
 
-       clone, dest, err := queryCS1(proto, ip, port)
+       clone, dest, err := queryCS1(ctx, proto, ip, port)
        if err != nil {
                return
        }
@@ -135,8 +136,8 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string,
        return f, dest, proto, string(buf[:n]), nil
 }
 
-func netErr(e error) {
-       oe, ok := e.(*OpError)
+func fixErr(err error) {
+       oe, ok := err.(*OpError)
        if !ok {
                return
        }
@@ -165,9 +166,34 @@ func netErr(e error) {
        }
 }
 
-func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) {
-       defer func() { netErr(err) }()
-       f, dest, proto, name, err := startPlan9(net, raddr)
+func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
+       defer func() { fixErr(err) }()
+       type res struct {
+               fd  *netFD
+               err error
+       }
+       resc := make(chan res)
+       go func() {
+               testHookDialChannel()
+               fd, err := dialPlan9Blocking(ctx, net, laddr, raddr)
+               select {
+               case resc <- res{fd, err}:
+               case <-ctx.Done():
+                       if fd != nil {
+                               fd.Close()
+                       }
+               }
+       }()
+       select {
+       case res := <-resc:
+               return res.fd, res.err
+       case <-ctx.Done():
+               return nil, mapErr(ctx.Err())
+       }
+}
+
+func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
+       f, dest, proto, name, err := startPlan9(ctx, net, raddr)
        if err != nil {
                return nil, err
        }
@@ -190,9 +216,9 @@ func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) {
        return newFD(proto, name, f, data, laddr, raddr)
 }
 
-func listenPlan9(net string, laddr Addr) (fd *netFD, err error) {
-       defer func() { netErr(err) }()
-       f, dest, proto, name, err := startPlan9(net, laddr)
+func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) {
+       defer func() { fixErr(err) }()
+       f, dest, proto, name, err := startPlan9(ctx, net, laddr)
        if err != nil {
                return nil, err
        }
@@ -214,7 +240,7 @@ func (fd *netFD) netFD() (*netFD, error) {
 }
 
 func (fd *netFD) acceptPlan9() (nfd *netFD, err error) {
-       defer func() { netErr(err) }()
+       defer func() { fixErr(err) }()
        if err := fd.readLock(); err != nil {
                return nil, err
        }
index 8f02787422be9063252cbf6a4962d092cc6f7f39..5e60011165dda2f828da35fe3eb3117b45caac63 100644 (file)
@@ -114,7 +114,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
 func LookupPort(network, service string) (port int, err error) {
        port, needsLookup := parsePort(service)
        if needsLookup {
-               port, err = lookupPort(network, service)
+               port, err = lookupPort(context.Background(), network, service)
                if err != nil {
                        return 0, err
                }
@@ -130,7 +130,7 @@ func LookupPort(network, service string) (port int, err error) {
 // LookupHost or LookupIP directly; both take care of resolving
 // the canonical name as part of the lookup.
 func LookupCNAME(name string) (cname string, err error) {
-       return lookupCNAME(name)
+       return lookupCNAME(context.Background(), name)
 }
 
 // LookupSRV tries to resolve an SRV query of the given service,
@@ -143,26 +143,26 @@ func LookupCNAME(name string) (cname string, err error) {
 // publishing SRV records under non-standard names, if both service
 // and proto are empty strings, LookupSRV looks up name directly.
 func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
-       return lookupSRV(service, proto, name)
+       return lookupSRV(context.Background(), service, proto, name)
 }
 
 // LookupMX returns the DNS MX records for the given domain name sorted by preference.
 func LookupMX(name string) (mxs []*MX, err error) {
-       return lookupMX(name)
+       return lookupMX(context.Background(), name)
 }
 
 // LookupNS returns the DNS NS records for the given domain name.
 func LookupNS(name string) (nss []*NS, err error) {
-       return lookupNS(name)
+       return lookupNS(context.Background(), name)
 }
 
 // LookupTXT returns the DNS TXT records for the given domain name.
 func LookupTXT(name string) (txts []string, err error) {
-       return lookupTXT(name)
+       return lookupTXT(context.Background(), name)
 }
 
 // LookupAddr performs a reverse lookup for the given address, returning a list
 // of names mapping to that address.
 func LookupAddr(addr string) (names []string, err error) {
-       return lookupAddr(addr)
+       return lookupAddr(context.Background(), addr)
 }
index 4224263602bcacec827d9745b799a2aeea4c3ad1..73147a2d3f7c40d69f6ddda33ff53647cf7ef482 100644 (file)
@@ -10,7 +10,7 @@ import (
        "os"
 )
 
-func query(filename, query string, bufSize int) (res []string, err error) {
+func query(ctx context.Context, filename, query string, bufSize int) (res []string, err error) {
        file, err := os.OpenFile(filename, os.O_RDWR, 0)
        if err != nil {
                return
@@ -40,7 +40,7 @@ func query(filename, query string, bufSize int) (res []string, err error) {
        return
 }
 
-func queryCS(net, host, service string) (res []string, err error) {
+func queryCS(ctx context.Context, net, host, service string) (res []string, err error) {
        switch net {
        case "tcp4", "tcp6":
                net = "tcp"
@@ -50,15 +50,15 @@ func queryCS(net, host, service string) (res []string, err error) {
        if host == "" {
                host = "*"
        }
-       return query(netdir+"/cs", net+"!"+host+"!"+service, 128)
+       return query(ctx, netdir+"/cs", net+"!"+host+"!"+service, 128)
 }
 
-func queryCS1(net string, ip IP, port int) (clone, dest string, err error) {
+func queryCS1(ctx context.Context, net string, ip IP, port int) (clone, dest string, err error) {
        ips := "*"
        if len(ip) != 0 && !ip.IsUnspecified() {
                ips = ip.String()
        }
-       lines, err := queryCS(net, ips, itoa(port))
+       lines, err := queryCS(ctx, net, ips, itoa(port))
        if err != nil {
                return
        }
@@ -70,8 +70,8 @@ func queryCS1(net string, ip IP, port int) (clone, dest string, err error) {
        return
 }
 
-func queryDNS(addr string, typ string) (res []string, err error) {
-       return query(netdir+"/dns", addr+" "+typ, 1024)
+func queryDNS(ctx context.Context, addr string, typ string) (res []string, err error) {
+       return query(ctx, netdir+"/dns", addr+" "+typ, 1024)
 }
 
 // toLower returns a lower-case version of in. Restricting us to
@@ -97,8 +97,8 @@ func toLower(in string) string {
 
 // lookupProtocol looks up IP protocol name and returns
 // the corresponding protocol number.
-func lookupProtocol(name string) (proto int, err error) {
-       lines, err := query(netdir+"/cs", "!protocol="+toLower(name), 128)
+func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
+       lines, err := query(ctx, netdir+"/cs", "!protocol="+toLower(name), 128)
        if err != nil {
                return 0, err
        }
@@ -119,7 +119,7 @@ func lookupProtocol(name string) (proto int, err error) {
 func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        // Use netdir/cs instead of netdir/dns because cs knows about
        // host names in local network (e.g. from /lib/ndb/local)
-       lines, err := queryCS("net", host, "1")
+       lines, err := queryCS(ctx, "net", host, "1")
        if err != nil {
                return
        }
@@ -148,8 +148,7 @@ loop:
 }
 
 func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
-       // TODO(bradfitz): push down ctx
-       lits, err := LookupHost(host)
+       lits, err := lookupHost(ctx, host)
        if err != nil {
                return
        }
@@ -163,14 +162,14 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        return
 }
 
-func lookupPort(network, service string) (port int, err error) {
+func lookupPort(ctx context.Context, network, service string) (port int, err error) {
        switch network {
        case "tcp4", "tcp6":
                network = "tcp"
        case "udp4", "udp6":
                network = "udp"
        }
-       lines, err := queryCS(network, "127.0.0.1", service)
+       lines, err := queryCS(ctx, network, "127.0.0.1", service)
        if err != nil {
                return
        }
@@ -192,8 +191,8 @@ func lookupPort(network, service string) (port int, err error) {
        return 0, unknownPortError
 }
 
-func lookupCNAME(name string) (cname string, err error) {
-       lines, err := queryDNS(name, "cname")
+func lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+       lines, err := queryDNS(ctx, name, "cname")
        if err != nil {
                return
        }
@@ -205,14 +204,14 @@ func lookupCNAME(name string) (cname string, err error) {
        return "", errors.New("bad response from ndb/dns")
 }
 
-func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
+func lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
        var target string
        if service == "" && proto == "" {
                target = name
        } else {
                target = "_" + service + "._" + proto + "." + name
        }
-       lines, err := queryDNS(target, "srv")
+       lines, err := queryDNS(ctx, target, "srv")
        if err != nil {
                return
        }
@@ -234,8 +233,8 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
        return
 }
 
-func lookupMX(name string) (mx []*MX, err error) {
-       lines, err := queryDNS(name, "mx")
+func lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
+       lines, err := queryDNS(ctx, name, "mx")
        if err != nil {
                return
        }
@@ -252,8 +251,8 @@ func lookupMX(name string) (mx []*MX, err error) {
        return
 }
 
-func lookupNS(name string) (ns []*NS, err error) {
-       lines, err := queryDNS(name, "ns")
+func lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
+       lines, err := queryDNS(ctx, name, "ns")
        if err != nil {
                return
        }
@@ -267,8 +266,8 @@ func lookupNS(name string) (ns []*NS, err error) {
        return
 }
 
-func lookupTXT(name string) (txt []string, err error) {
-       lines, err := queryDNS(name, "txt")
+func lookupTXT(ctx context.Context, name string) (txt []string, err error) {
+       lines, err := queryDNS(ctx, name, "txt")
        if err != nil {
                return
        }
@@ -280,12 +279,12 @@ func lookupTXT(name string) (txt []string, err error) {
        return
 }
 
-func lookupAddr(addr string) (name []string, err error) {
+func lookupAddr(ctx context.Context, addr string) (name []string, err error) {
        arpa, err := reverseaddr(addr)
        if err != nil {
                return
        }
-       lines, err := queryDNS(arpa, "ptr")
+       lines, err := queryDNS(ctx, arpa, "ptr")
        if err != nil {
                return
        }
index 38a4f0bae480e8d97feb4dab855a2523f509640a..bd096b39652a55e38a7d1191d350c17c982ee647 100644 (file)
@@ -11,7 +11,7 @@ import (
        "syscall"
 )
 
-func lookupProtocol(name string) (proto int, err error) {
+func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
        return 0, syscall.ENOPROTOOPT
 }
 
@@ -23,30 +23,30 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
-func lookupPort(network, service string) (port int, err error) {
+func lookupPort(ctx context.Context, network, service string) (port int, err error) {
        return 0, syscall.ENOPROTOOPT
 }
 
-func lookupCNAME(name string) (cname string, err error) {
+func lookupCNAME(ctx context.Context, name string) (cname string, err error) {
        return "", syscall.ENOPROTOOPT
 }
 
-func lookupSRV(service, proto, name string) (cname string, srvs []*SRV, err error) {
+func lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) {
        return "", nil, syscall.ENOPROTOOPT
 }
 
-func lookupMX(name string) (mxs []*MX, err error) {
+func lookupMX(ctx context.Context, name string) (mxs []*MX, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
-func lookupNS(name string) (nss []*NS, err error) {
+func lookupNS(ctx context.Context, name string) (nss []*NS, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
-func lookupTXT(name string) (txts []string, err error) {
+func lookupTXT(ctx context.Context, name string) (txts []string, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
-func lookupAddr(addr string) (ptrs []string, err error) {
+func lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) {
        return nil, syscall.ENOPROTOOPT
 }
index 8d3fa4778284e1406f19ce1cb6a756a2cf34297e..5461fe8a411e8d60e6eb8a71c43382356285fe90 100644 (file)
@@ -43,7 +43,7 @@ func readProtocols() {
 
 // lookupProtocol looks up IP protocol name in /etc/protocols and
 // returns correspondent protocol number.
-func lookupProtocol(name string) (int, error) {
+func lookupProtocol(_ context.Context, name string) (int, error) {
        onceReadProtocols.Do(readProtocols)
        proto, found := protocols[name]
        if !found {
@@ -77,7 +77,12 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        return goLookupIPOrder(ctx, host, order)
 }
 
-func lookupPort(network, service string) (int, error) {
+func lookupPort(ctx context.Context, network, service string) (int, error) {
+       // TODO: use the context if there ever becomes a need. Related
+       // is issue 15321. But port lookup generally just involves
+       // local files, and the os package has no context support. The
+       // files might be on a remote filesystem, though. This should
+       // probably race goroutines if ctx != context.Background().
        if systemConf().canUseCgo() {
                if port, err, ok := cgoLookupPort(network, service); ok {
                        return port, err
@@ -86,23 +91,24 @@ func lookupPort(network, service string) (int, error) {
        return goLookupPort(network, service)
 }
 
-func lookupCNAME(name string) (string, error) {
+func lookupCNAME(ctx context.Context, name string) (string, error) {
        if systemConf().canUseCgo() {
+               // TODO: use ctx. issue 15321. Or race goroutines.
                if cname, err, ok := cgoLookupCNAME(name); ok {
                        return cname, err
                }
        }
-       return goLookupCNAME(name)
+       return goLookupCNAME(ctx, name)
 }
 
-func lookupSRV(service, proto, name string) (string, []*SRV, error) {
+func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
        var target string
        if service == "" && proto == "" {
                target = name
        } else {
                target = "_" + service + "._" + proto + "." + name
        }
-       cname, rrs, err := lookup(target, dnsTypeSRV)
+       cname, rrs, err := lookup(ctx, target, dnsTypeSRV)
        if err != nil {
                return "", nil, err
        }
@@ -115,8 +121,8 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) {
        return cname, srvs, nil
 }
 
-func lookupMX(name string) ([]*MX, error) {
-       _, rrs, err := lookup(name, dnsTypeMX)
+func lookupMX(ctx context.Context, name string) ([]*MX, error) {
+       _, rrs, err := lookup(ctx, name, dnsTypeMX)
        if err != nil {
                return nil, err
        }
@@ -129,8 +135,8 @@ func lookupMX(name string) ([]*MX, error) {
        return mxs, nil
 }
 
-func lookupNS(name string) ([]*NS, error) {
-       _, rrs, err := lookup(name, dnsTypeNS)
+func lookupNS(ctx context.Context, name string) ([]*NS, error) {
+       _, rrs, err := lookup(ctx, name, dnsTypeNS)
        if err != nil {
                return nil, err
        }
@@ -141,8 +147,8 @@ func lookupNS(name string) ([]*NS, error) {
        return nss, nil
 }
 
-func lookupTXT(name string) ([]string, error) {
-       _, rrs, err := lookup(name, dnsTypeTXT)
+func lookupTXT(ctx context.Context, name string) ([]string, error) {
+       _, rrs, err := lookup(ctx, name, dnsTypeTXT)
        if err != nil {
                return nil, err
        }
@@ -153,11 +159,11 @@ func lookupTXT(name string) ([]string, error) {
        return txts, nil
 }
 
-func lookupAddr(addr string) ([]string, error) {
+func lookupAddr(ctx context.Context, addr string) ([]string, error) {
        if systemConf().canUseCgo() {
                if ptrs, err, ok := cgoLookupPTR(addr); ok {
                        return ptrs, err
                }
        }
-       return goLookupPTR(addr)
+       return goLookupPTR(ctx, addr)
 }
index ce012ba873fc70aec9c1982bc94c048570bb183e..7a04cc89984c096064684cb946c6f724a2ce9d87 100644 (file)
@@ -26,30 +26,37 @@ func getprotobyname(name string) (proto int, err error) {
 }
 
 // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
-func lookupProtocol(name string) (int, error) {
+func lookupProtocol(ctx context.Context, name string) (int, error) {
        // GetProtoByName return value is stored in thread local storage.
        // Start new os thread before the call to prevent races.
        type result struct {
                proto int
                err   error
        }
-       ch := make(chan result)
+       ch := make(chan result) // unbuffered
        go func() {
                acquireThread()
                defer releaseThread()
                runtime.LockOSThread()
                defer runtime.UnlockOSThread()
                proto, err := getprotobyname(name)
-               ch <- result{proto: proto, err: err}
+               select {
+               case ch <- result{proto: proto, err: err}:
+               case <-ctx.Done():
+               }
        }()
-       r := <-ch
-       if r.err != nil {
-               if proto, ok := protocols[name]; ok {
-                       return proto, nil
+       select {
+       case r := <-ch:
+               if r.err != nil {
+                       if proto, ok := protocols[name]; ok {
+                               return proto, nil
+                       }
+                       r.err = &DNSError{Err: r.err.Error(), Name: name}
                }
-               r.err = &DNSError{Err: r.err.Error(), Name: name}
+               return r.proto, r.err
+       case <-ctx.Done():
+               return 0, mapErr(ctx.Err())
        }
-       return r.proto, r.err
 }
 
 func lookupHost(ctx context.Context, name string) ([]string, error) {
@@ -193,30 +200,38 @@ func getservbyname(network, service string) (int, error) {
        return int(syscall.Ntohs(s.Port)), nil
 }
 
-func oldLookupPort(network, service string) (int, error) {
+func oldLookupPort(ctx context.Context, network, service string) (int, error) {
        // GetServByName return value is stored in thread local storage.
        // Start new os thread before the call to prevent races.
        type result struct {
                port int
                err  error
        }
-       ch := make(chan result)
+       ch := make(chan result) // unbuffered
        go func() {
                acquireThread()
                defer releaseThread()
                runtime.LockOSThread()
                defer runtime.UnlockOSThread()
                port, err := getservbyname(network, service)
-               ch <- result{port: port, err: err}
+               select {
+               case ch <- result{port: port, err: err}:
+               case <-ctx.Done():
+               }
        }()
-       r := <-ch
-       if r.err != nil {
-               r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service}
+       select {
+       case r := <-ch:
+               if r.err != nil {
+                       r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service}
+               }
+               return r.port, r.err
+       case <-ctx.Done():
+               return 0, mapErr(ctx.Err())
        }
-       return r.port, r.err
 }
 
-func newLookupPort(network, service string) (int, error) {
+func newLookupPort(ctx context.Context, network, service string) (int, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        var stype int32
@@ -252,7 +267,8 @@ func newLookupPort(network, service string) (int, error) {
        return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
 }
 
-func lookupCNAME(name string) (string, error) {
+func lookupCNAME(ctx context.Context, name string) (string, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        var r *syscall.DNSRecord
@@ -272,7 +288,8 @@ func lookupCNAME(name string) (string, error) {
        return absDomainName([]byte(cname)), nil
 }
 
-func lookupSRV(service, proto, name string) (string, []*SRV, error) {
+func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        var target string
@@ -297,7 +314,8 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) {
        return absDomainName([]byte(target)), srvs, nil
 }
 
-func lookupMX(name string) ([]*MX, error) {
+func lookupMX(ctx context.Context, name string) ([]*MX, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        var r *syscall.DNSRecord
@@ -316,7 +334,8 @@ func lookupMX(name string) ([]*MX, error) {
        return mxs, nil
 }
 
-func lookupNS(name string) ([]*NS, error) {
+func lookupNS(ctx context.Context, name string) ([]*NS, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        var r *syscall.DNSRecord
@@ -334,7 +353,8 @@ func lookupNS(name string) ([]*NS, error) {
        return nss, nil
 }
 
-func lookupTXT(name string) ([]string, error) {
+func lookupTXT(ctx context.Context, name string) ([]string, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        var r *syscall.DNSRecord
@@ -355,7 +375,8 @@ func lookupTXT(name string) ([]string, error) {
        return txts, nil
 }
 
-func lookupAddr(addr string) ([]string, error) {
+func lookupAddr(ctx context.Context, addr string) ([]string, error) {
+       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
        arpa, err := reverseaddr(addr)
index 08ad9be8f41a4d8757061ab9b5027a73f4a9b284..d2860607f8b9be253388dd412555c208d3036b7c 100644 (file)
@@ -22,10 +22,6 @@ func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn,
 }
 
 func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
-       if d, _ := ctx.Deadline(); !d.IsZero() {
-               // TODO: deadline not implemented on Plan 9 (see golang.og/issue/11932)
-       }
-       // TODO(bradfitz,0intro): also use the cancel channel.
        switch net {
        case "tcp", "tcp4", "tcp6":
        default:
@@ -34,7 +30,7 @@ func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn
        if raddr == nil {
                return nil, errMissingAddress
        }
-       fd, err := dialPlan9(net, laddr, raddr)
+       fd, err := dialPlan9(ctx, net, laddr, raddr)
        if err != nil {
                return nil, err
        }
@@ -71,7 +67,7 @@ func (ln *TCPListener) file() (*os.File, error) {
 }
 
 func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
-       fd, err := listenPlan9(network, laddr)
+       fd, err := listenPlan9(ctx, network, laddr)
        if err != nil {
                return nil, err
        }
index 3b3d8d7615d3bf88bc4043a59716baba5ca8415e..666f20622f611fe15688dd458c6c8244a9f6b188 100644 (file)
@@ -56,10 +56,7 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
 }
 
 func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
-       if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
-               // TODO: deadline not implemented on Plan 9 (see golang.og/issue/11932)
-       }
-       fd, err := dialPlan9(net, laddr, raddr)
+       fd, err := dialPlan9(ctx, net, laddr, raddr)
        if err != nil {
                return nil, err
        }
@@ -95,7 +92,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
 }
 
 func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
-       l, err := listenPlan9(network, laddr)
+       l, err := listenPlan9(ctx, network, laddr)
        if err != nil {
                return nil, err
        }