]> Cypherpunks repositories - gostls13.git/commitdiff
net: add Resolver type, Dialer.Resolver, and DefaultResolver
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 20 Sep 2016 19:45:37 +0000 (19:45 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 21 Sep 2016 18:35:40 +0000 (18:35 +0000)
The new Resolver type (a struct) has 9 Lookup methods, all taking a
context.Context.

There's now a new DefaultResolver global, like http's
DefaultTransport and DefaultClient.

net.Dialer now has an optional Resolver field to set the Resolver.

This also does finishes some resolver cleanup internally, deleting
lookupIPMerge and renaming lookupIPContext into Resolver.LookupIPAddr.

The Resolver currently doesn't let you tweak much, but it's a struct
specifically so we can add knobs in the future. Currently I just added
a bool to force the pure Go resolver. In the future we could let
people provide an interface to implement the methods, or add a Timeout
time.Duration, which would wrap all provided contexts in a
context.WithTimeout.

Fixes #16672

Change-Id: I7ba1f886704f06def7b6b5c4da9809db51bc1495
Reviewed-on: https://go-review.googlesource.com/29440
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/dial.go
src/net/dnsclient_unix.go
src/net/iprawsock.go
src/net/ipsock.go
src/net/lookup.go
src/net/lookup_nacl.go
src/net/lookup_plan9.go
src/net/lookup_test.go
src/net/lookup_windows.go
src/net/tcpsock.go
src/net/udpsock.go

index 48f3ad81c62bdb9b269877d8a975468a033736ed..dc982bdb8788d992fa72d88f0675de9e3406a984 100644 (file)
@@ -59,6 +59,9 @@ type Dialer struct {
        // that do not support keep-alives ignore this field.
        KeepAlive time.Duration
 
+       // Resolver optionally specifies an alternate resolver to use.
+       Resolver *Resolver
+
        // Cancel is an optional channel whose closure indicates that
        // the dial should be canceled. Not all types of dials support
        // cancelation.
@@ -92,6 +95,13 @@ func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Tim
        return minNonzeroTime(earliest, d.Deadline)
 }
 
+func (d *Dialer) resolver() *Resolver {
+       if d.Resolver != nil {
+               return d.Resolver
+       }
+       return DefaultResolver
+}
+
 // partialDeadline returns the deadline to use for a single address,
 // when multiple addresses are pending.
 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
@@ -156,7 +166,7 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err
 // resolverAddrList resolves addr using hint and returns a list of
 // addresses. The result contains at least one address when error is
 // nil.
-func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
+func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
        afnet, _, err := parseNetwork(ctx, network)
        if err != nil {
                return nil, err
@@ -166,7 +176,6 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
        }
        switch afnet {
        case "unix", "unixgram", "unixpacket":
-               // TODO(bradfitz): push down context
                addr, err := ResolveUnixAddr(afnet, addr)
                if err != nil {
                        return nil, err
@@ -176,7 +185,7 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
                }
                return addrList{addr}, nil
        }
-       addrs, err := internetAddrList(ctx, afnet, addr)
+       addrs, err := r.internetAddrList(ctx, afnet, addr)
        if err != nil || op != "dial" || hint == nil {
                return addrs, err
        }
@@ -326,7 +335,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
                resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
        }
 
-       addrs, err := resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
+       addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
        }
@@ -525,7 +534,7 @@ func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error)
 // instead of just the interface with the given host address.
 // See Dial for more details about address syntax.
 func Listen(net, laddr string) (Listener, error) {
-       addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
+       addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
        }
@@ -552,7 +561,7 @@ func Listen(net, laddr string) (Listener, error) {
 // instead of just the interface with the given host address.
 // See Dial for the syntax of laddr.
 func ListenPacket(net, laddr string) (PacketConn, error) {
-       addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
+       addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
        }
index 98be7a873dce39faac1c7d025e8c26eec508e950..130e4c958a06e68739e7529e7a5f8f2a973af7d2 100644 (file)
@@ -456,8 +456,9 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
 
 // goLookupIP is the native Go implementation of LookupIP.
 // The libc versions are in cgo_*.go.
-func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) {
-       return goLookupIPOrder(ctx, name, hostLookupFilesDNS)
+func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+       order := systemConf().hostLookupOrder(host)
+       return goLookupIPOrder(ctx, host, order)
 }
 
 func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) {
index 95761b3a9ce264b2390ecd8359604320e2291bbf..a7a4531fde6449e8521469a4b93a9bab9ebfd665 100644 (file)
@@ -65,7 +65,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       addrs, err := internetAddrList(context.Background(), afnet, addr)
+       addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, addr)
        if err != nil {
                return nil, err
        }
index 24daf173aceeaa430b8be57c5bb2bfbe3ca917d2..c04813fa32f12d3df9de9f3bbdafaba708bac14a 100644 (file)
@@ -190,7 +190,7 @@ func JoinHostPort(host, port string) string {
 // address or a DNS name, and returns a list of internet protocol
 // family addresses. The result contains at least one address when
 // error is nil.
-func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
+func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
        var (
                err        error
                host, port string
@@ -202,7 +202,7 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
                        if host, port, err = SplitHostPort(addr); err != nil {
                                return nil, err
                        }
-                       if portnum, err = LookupPort(net, port); err != nil {
+                       if portnum, err = r.LookupPort(ctx, net, port); err != nil {
                                return nil, err
                        }
                }
@@ -238,7 +238,7 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
                return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil
        }
        // Try as a DNS name.
-       ips, err := lookupIPContext(ctx, host)
+       ips, err := r.LookupIPAddr(ctx, host)
        if err != nil {
                return nil, err
        }
index 12ea3022ef8077675325073adf79edcb17a340b2..d1e2e0063d579dc41dc1200610ad1c3e0191c444 100644 (file)
@@ -84,86 +84,75 @@ func lookupPortMap(network, service string) (port int, error error) {
        return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
 }
 
+// DefaultResolver is the resolver used by the package-level Lookup
+// functions and by Dialers without a specified Resolver.
+var DefaultResolver = &Resolver{}
+
+// A Resolver looks up names and numbers.
+//
+// A nil *Resolver is equivalent to a zero Resolver.
+type Resolver struct {
+       // PreferGo controls whether Go's built-in DNS resolver is preferred
+       // on platforms where it's available. It is equivalent to setting
+       // GODEBUG=netdns=go, but scoped to just this resolver.
+       PreferGo bool
+
+       // TODO(bradfitz): optional interface impl override hook
+       // TODO(bradfitz): Timeout time.Duration?
+}
+
+func (r *Resolver) lookupIPFunc() func(context.Context, string) ([]IPAddr, error) {
+       if r != nil && r.PreferGo {
+               return goLookupIP
+       }
+       return lookupIP
+}
+
 // LookupHost looks up the given host using the local resolver.
-// It returns an array of that host's addresses.
+// It returns a slice of that host's addresses.
 func LookupHost(host string) (addrs []string, err error) {
-       // Make sure that no matter what we do later, host=="" is rejected.
-       // ParseIP, for example, does accept empty strings.
-       if host == "" {
-               return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
-       }
-       if ip := ParseIP(host); ip != nil {
-               return []string{host}, nil
-       }
-       return lookupHost(context.Background(), host)
+       return DefaultResolver.LookupHost(context.Background(), host)
 }
 
-// LookupIP looks up host using the local resolver.
-// It returns an array of that host's IPv4 and IPv6 addresses.
-func LookupIP(host string) (ips []IP, err error) {
+// LookupHost looks up the given host using the local resolver.
+// It returns a slice of that host's addresses.
+func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
        // Make sure that no matter what we do later, host=="" is rejected.
        // ParseIP, for example, does accept empty strings.
        if host == "" {
                return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
        }
        if ip := ParseIP(host); ip != nil {
-               return []IP{ip}, nil
-       }
-       addrs, err := lookupIPMerge(context.Background(), host)
-       if err != nil {
-               return
-       }
-       ips = make([]IP, len(addrs))
-       for i, addr := range addrs {
-               ips[i] = addr.IP
+               return []string{host}, nil
        }
-       return
+       return lookupHost(ctx, host)
 }
 
-var lookupGroup singleflight.Group
-
-// lookupIPMerge wraps lookupIP, but makes sure that for any given
-// host, only one lookup is in-flight at a time. The returned memory
-// is always owned by the caller.
-func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) {
-       addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) {
-               return testHookLookupIP(ctx, lookupIP, host)
-       })
-       return lookupIPReturn(addrsi, err, shared)
-}
-
-// lookupIPReturn turns the return values from singleflight.Do into
-// the return values from LookupIP.
-func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
+// LookupIP looks up host using the local resolver.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func LookupIP(host string) ([]IP, error) {
+       addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host)
        if err != nil {
                return nil, err
        }
-       addrs := addrsi.([]IPAddr)
-       if shared {
-               clone := make([]IPAddr, len(addrs))
-               copy(clone, addrs)
-               addrs = clone
+       ips := make([]IP, len(addrs))
+       for i, ia := range addrs {
+               ips[i] = ia.IP
        }
-       return addrs, nil
+       return ips, nil
 }
 
-// ipAddrsEface returns an empty interface slice of addrs.
-func ipAddrsEface(addrs []IPAddr) []interface{} {
-       s := make([]interface{}, len(addrs))
-       for i, v := range addrs {
-               s[i] = v
+// LookupIPAddr looks up host using the local resolver.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
+       // Make sure that no matter what we do later, host=="" is rejected.
+       // ParseIP, for example, does accept empty strings.
+       if host == "" {
+               return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
+       }
+       if ip := ParseIP(host); ip != nil {
+               return []IPAddr{{IP: ip}}, nil
        }
-       return s
-}
-
-// lookupIPContext looks up a hostname with a context.
-//
-// TODO(bradfitz): rename this function. All the other
-// build-tag-specific lookupIP funcs also take a context now, so this
-// name is no longer great. Maybe make this lookupIPMerge and ditch
-// the other one, making its callers call this instead with a
-// context.Background().
-func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
        trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
        if trace != nil && trace.DNSStart != nil {
                trace.DNSStart(host)
@@ -171,7 +160,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
        // The underlying resolver func is lookupIP by default but it
        // can be overridden by tests. This is needed by net/http, so it
        // uses a context key instead of unexported variables.
-       resolverFunc := lookupIP
+       resolverFunc := r.lookupIPFunc()
        if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil {
                resolverFunc = alt
        }
@@ -201,11 +190,46 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
        }
 }
 
+// lookupGroup merges LookupIPAddr calls together for lookups
+// for the same host. The lookupGroup key is is the LookupIPAddr.host
+// argument.
+// The return values are ([]IPAddr, error).
+var lookupGroup singleflight.Group
+
+// lookupIPReturn turns the return values from singleflight.Do into
+// the return values from LookupIP.
+func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
+       if err != nil {
+               return nil, err
+       }
+       addrs := addrsi.([]IPAddr)
+       if shared {
+               clone := make([]IPAddr, len(addrs))
+               copy(clone, addrs)
+               addrs = clone
+       }
+       return addrs, nil
+}
+
+// ipAddrsEface returns an empty interface slice of addrs.
+func ipAddrsEface(addrs []IPAddr) []interface{} {
+       s := make([]interface{}, len(addrs))
+       for i, v := range addrs {
+               s[i] = v
+       }
+       return s
+}
+
 // LookupPort looks up the port for the given network and service.
 func LookupPort(network, service string) (port int, err error) {
+       return DefaultResolver.LookupPort(context.Background(), network, service)
+}
+
+// LookupPort looks up the port for the given network and service.
+func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
        port, needsLookup := parsePort(service)
        if needsLookup {
-               port, err = lookupPort(context.Background(), network, service)
+               port, err = lookupPort(ctx, network, service)
                if err != nil {
                        return 0, err
                }
@@ -224,6 +248,14 @@ func LookupCNAME(name string) (cname string, err error) {
        return lookupCNAME(context.Background(), name)
 }
 
+// LookupCNAME returns the canonical DNS host for the given name.
+// Callers that do not care about the canonical name can call
+// LookupHost or LookupIP directly; both take care of resolving
+// the canonical name as part of the lookup.
+func (r *Resolver) LookupCNAME(ctx context.Context, name string) (cname string, err error) {
+       return lookupCNAME(ctx, name)
+}
+
 // LookupSRV tries to resolve an SRV query of the given service,
 // protocol, and domain name. The proto is "tcp" or "udp".
 // The returned records are sorted by priority and randomized
@@ -237,23 +269,57 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
        return lookupSRV(context.Background(), service, proto, name)
 }
 
+// LookupSRV tries to resolve an SRV query of the given service,
+// protocol, and domain name. The proto is "tcp" or "udp".
+// The returned records are sorted by priority and randomized
+// by weight within a priority.
+//
+// LookupSRV constructs the DNS name to look up following RFC 2782.
+// That is, it looks up _service._proto.name. To accommodate services
+// publishing SRV records under non-standard names, if both service
+// and proto are empty strings, LookupSRV looks up name directly.
+func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
+       return lookupSRV(ctx, service, proto, name)
+}
+
 // LookupMX returns the DNS MX records for the given domain name sorted by preference.
-func LookupMX(name string) (mxs []*MX, err error) {
+func LookupMX(name string) ([]*MX, error) {
        return lookupMX(context.Background(), name)
 }
 
+// LookupMX returns the DNS MX records for the given domain name sorted by preference.
+func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
+       return lookupMX(ctx, name)
+}
+
 // LookupNS returns the DNS NS records for the given domain name.
-func LookupNS(name string) (nss []*NS, err error) {
+func LookupNS(name string) ([]*NS, error) {
        return lookupNS(context.Background(), name)
 }
 
+// LookupNS returns the DNS NS records for the given domain name.
+func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
+       return lookupNS(ctx, name)
+}
+
 // LookupTXT returns the DNS TXT records for the given domain name.
-func LookupTXT(name string) (txts []string, err error) {
+func LookupTXT(name string) ([]string, error) {
        return lookupTXT(context.Background(), name)
 }
 
+// LookupTXT returns the DNS TXT records for the given domain name.
+func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
+       return lookupTXT(ctx, name)
+}
+
 // LookupAddr performs a reverse lookup for the given address, returning a list
 // of names mapping to that address.
 func LookupAddr(addr string) (names []string, err error) {
        return lookupAddr(context.Background(), addr)
 }
+
+// LookupAddr performs a reverse lookup for the given address, returning a list
+// of names mapping to that address.
+func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) {
+       return lookupAddr(ctx, addr)
+}
index 48c0d1938ecf1020f61cf4eae8b6fad04cb29f20..83ecdb50f5730462ef26db455d2fc058d7129870 100644 (file)
@@ -19,6 +19,10 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
+func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+       return nil, syscall.ENOPROTOOPT
+}
+
 func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        return nil, syscall.ENOPROTOOPT
 }
index 2d974146cd83b13e7cdcaf5609871c7eee44b8d9..3abaf090bacfa761f2c74da621111cf1a8af066d 100644 (file)
@@ -148,6 +148,8 @@ loop:
        return
 }
 
+var goLookupIP = lookupIP
+
 func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        lits, err := lookupHost(ctx, host)
        if err != nil {
index 5de9f39b08a2c060a6a4e01e67d21d187094a721..acf7ffdf7906ffc7a5c35ce89195a8544a054779 100644 (file)
@@ -398,11 +398,11 @@ func TestDNSFlood(t *testing.T) {
        for i := 0; i < N; i++ {
                name := fmt.Sprintf("%d.net-test.golang.org", i)
                go func() {
-                       _, err := lookupIPContext(ctxHalfTimeout, name)
+                       _, err := DefaultResolver.LookupIPAddr(ctxHalfTimeout, name)
                        c <- err
                }()
                go func() {
-                       _, err := lookupIPContext(ctxTimeout, name)
+                       _, err := DefaultResolver.LookupIPAddr(ctxTimeout, name)
                        c <- err
                }()
        }
index 5f65c2d00d4de5c40415047310ba4fe9635fe8a0..9435fef8393f11dfeded7bce8e6402218879e8b7 100644 (file)
@@ -66,8 +66,14 @@ func lookupHost(ctx context.Context, name string) ([]string, error) {
        return addrs, nil
 }
 
+// goLookupIP isn't a Pure Go implementation on Windows.
+// TODO(bradfitz): should it be? Not sure it can be. It's always used syscall.GetAddrInfoW.
+func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+       return lookupIP(ctx, host)
+}
+
 func lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
-       // TODO(bradfitz,brainman): use ctx?
+       // TODO(bradfitz,brainman): use ctx more. See TODO below.
 
        type ret struct {
                addrs []IPAddr
index e02e6c9c7d325a1e70f4e84acf224ef530122f6d..1f7f59a3b6e3536ba4fb5bf657e732f901c992ff 100644 (file)
@@ -64,7 +64,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       addrs, err := internetAddrList(context.Background(), net, addr)
+       addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr)
        if err != nil {
                return nil, err
        }
index a859f4d4c0e1a8202877d8df8d4367c26aecfd74..e54eee837aab578c58ea77815645802455039ba7 100644 (file)
@@ -67,7 +67,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       addrs, err := internetAddrList(context.Background(), net, addr)
+       addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr)
        if err != nil {
                return nil, err
        }