]> Cypherpunks repositories - gostls13.git/commitdiff
net: permit use of Resolver.PreferGo, netgo on Windows and Plan 9
authorBrad Fitzpatrick <bradfitz@golang.org>
Sat, 28 May 2022 21:06:43 +0000 (14:06 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 30 May 2022 21:23:29 +0000 (21:23 +0000)
This reverts commit CL 401754 (440c9312c8) which reverted CL 400654,
thus reapplying CL 400654, re-adding the func init() { netGo = true }
to cgo_stub.go CL 400654 had originally removed (mistakenly during
development?) that had broken the darwin nocgo builder.

Fixes #33097

Change-Id: I90f59746d2ceb6b5d2bd832c9fc90068f8ff7417
Reviewed-on: https://go-review.googlesource.com/c/go/+/409234
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Keith Randall <khr@google.com>
15 files changed:
src/net/addrselect.go
src/net/conf.go
src/net/dnsclient_unix.go
src/net/dnsconfig.go [new file with mode: 0644]
src/net/dnsconfig_unix.go
src/net/dnsconfig_windows.go [new file with mode: 0644]
src/net/lookup.go
src/net/lookup_plan9.go
src/net/lookup_unix.go
src/net/lookup_windows.go
src/net/net.go
src/net/net_fake.go
src/net/netgo.go [new file with mode: 0644]
src/net/nss.go
src/net/resolverdialfunc_test.go [new file with mode: 0644]

index 8accdb89e14f458c98d352f7a95110e3afde586c..59380b94868faab05e75e6787786e2d1714b1bf2 100644 (file)
@@ -2,8 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build unix
-
 // Minimal RFC 6724 address selection.
 
 package net
index 9d4752173e1afe3cd98f2583f90580c858ccd66d..b08bbc7d7a1c818f4eca4238ea24f424aa91add0 100644 (file)
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build unix
+//go:build !js
 
 package net
 
@@ -21,7 +21,7 @@ type conf struct {
        forceCgoLookupHost bool
 
        netGo  bool // go DNS resolution forced
-       netCgo bool // cgo DNS resolution forced
+       netCgo bool // non-go DNS resolution forced (cgo, or win32)
 
        // machine has an /etc/mdns.allow file
        hasMDNSAllow bool
@@ -49,9 +49,23 @@ func initConfVal() {
        confVal.dnsDebugLevel = debugLevel
        confVal.netGo = netGo || dnsMode == "go"
        confVal.netCgo = netCgo || dnsMode == "cgo"
+       if !confVal.netGo && !confVal.netCgo && (runtime.GOOS == "windows" || runtime.GOOS == "plan9") {
+               // Neither of these platforms actually use cgo.
+               //
+               // The meaning of "cgo" mode in the net package is
+               // really "the native OS way", which for libc meant
+               // cgo on the original platforms that motivated
+               // PreferGo support before Windows and Plan9 got support,
+               // at which time the GODEBUG=netdns=go and GODEBUG=netdns=cgo
+               // names were already kinda locked in.
+               confVal.netCgo = true
+       }
 
        if confVal.dnsDebugLevel > 0 {
                defer func() {
+                       if confVal.dnsDebugLevel > 1 {
+                               println("go package net: confVal.netCgo =", confVal.netCgo, " netGo =", confVal.netGo)
+                       }
                        switch {
                        case confVal.netGo:
                                if netGo {
@@ -75,6 +89,10 @@ func initConfVal() {
                return
        }
 
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
+               return
+       }
+
        // If any environment-specified resolver options are specified,
        // force cgo. Note that LOCALDOMAIN can change behavior merely
        // by being specified with the empty string.
@@ -129,7 +147,19 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde
        }
        fallbackOrder := hostLookupCgo
        if c.netGo || r.preferGo() {
-               fallbackOrder = hostLookupFilesDNS
+               switch c.goos {
+               case "windows":
+                       // TODO(bradfitz): implement files-based
+                       // lookup on Windows too? I guess /etc/hosts
+                       // kinda exists on Windows. But for now, only
+                       // do DNS.
+                       fallbackOrder = hostLookupDNS
+               default:
+                       fallbackOrder = hostLookupFilesDNS
+               }
+       }
+       if c.goos == "windows" || c.goos == "plan9" {
+               return fallbackOrder
        }
        if c.forceCgoLookupHost || c.resolv.unknownOpt || c.goos == "android" {
                return fallbackOrder
index 17435365907ad5d81b275985993d557b5c712668..088c81adee92777d43845bfc1a601354de5d5648 100644 (file)
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build unix
+//go:build !js
 
 // DNS client: see RFC 1035.
 // Has to be linked into package net for Dial.
@@ -20,6 +20,7 @@ import (
        "internal/itoa"
        "io"
        "os"
+       "runtime"
        "sync"
        "time"
 
@@ -381,12 +382,21 @@ func (conf *resolverConfig) tryUpdate(name string) {
        }
        conf.lastChecked = now
 
-       var mtime time.Time
-       if fi, err := os.Stat(name); err == nil {
-               mtime = fi.ModTime()
-       }
-       if mtime.Equal(conf.dnsConfig.mtime) {
-               return
+       switch runtime.GOOS {
+       case "windows":
+               // There's no file on disk, so don't bother checking
+               // and failing.
+               //
+               // The Windows implementation of dnsReadConfig (called
+               // below) ignores the name.
+       default:
+               var mtime time.Time
+               if fi, err := os.Stat(name); err == nil {
+                       mtime = fi.ModTime()
+               }
+               if mtime.Equal(conf.dnsConfig.mtime) {
+                       return
+               }
        }
 
        dnsConf := dnsReadConfig(name)
diff --git a/src/net/dnsconfig.go b/src/net/dnsconfig.go
new file mode 100644 (file)
index 0000000..091b548
--- /dev/null
@@ -0,0 +1,43 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "os"
+       "sync/atomic"
+       "time"
+)
+
+var (
+       defaultNS   = []string{"127.0.0.1:53", "[::1]:53"}
+       getHostname = os.Hostname // variable for testing
+)
+
+type dnsConfig struct {
+       servers       []string      // server addresses (in host:port form) to use
+       search        []string      // rooted suffixes to append to local name
+       ndots         int           // number of dots in name to trigger absolute lookup
+       timeout       time.Duration // wait before giving up on a query, including retries
+       attempts      int           // lost packets before giving up on server
+       rotate        bool          // round robin among servers
+       unknownOpt    bool          // anything unknown was encountered
+       lookup        []string      // OpenBSD top-level database "lookup" order
+       err           error         // any error that occurs during open of resolv.conf
+       mtime         time.Time     // time of resolv.conf modification
+       soffset       uint32        // used by serverOffset
+       singleRequest bool          // use sequential A and AAAA queries instead of parallel queries
+       useTCP        bool          // force usage of TCP for DNS resolutions
+}
+
+// serverOffset returns an offset that can be used to determine
+// indices of servers in c.servers when making queries.
+// When the rotate option is enabled, this offset increases.
+// Otherwise it is always 0.
+func (c *dnsConfig) serverOffset() uint32 {
+       if c.rotate {
+               return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
+       }
+       return 0
+}
index 7552bc51e653a72895d82a1cd634959e65a271cf..94cd09ec71066f0f4fd329cb5a0fb7436d15df04 100644 (file)
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build unix
+//go:build !js && !windows
 
 // Read system DNS config from /etc/resolv.conf
 
@@ -10,32 +10,9 @@ package net
 
 import (
        "internal/bytealg"
-       "os"
-       "sync/atomic"
        "time"
 )
 
-var (
-       defaultNS   = []string{"127.0.0.1:53", "[::1]:53"}
-       getHostname = os.Hostname // variable for testing
-)
-
-type dnsConfig struct {
-       servers       []string      // server addresses (in host:port form) to use
-       search        []string      // rooted suffixes to append to local name
-       ndots         int           // number of dots in name to trigger absolute lookup
-       timeout       time.Duration // wait before giving up on a query, including retries
-       attempts      int           // lost packets before giving up on server
-       rotate        bool          // round robin among servers
-       unknownOpt    bool          // anything unknown was encountered
-       lookup        []string      // OpenBSD top-level database "lookup" order
-       err           error         // any error that occurs during open of resolv.conf
-       mtime         time.Time     // time of resolv.conf modification
-       soffset       uint32        // used by serverOffset
-       singleRequest bool          // use sequential A and AAAA queries instead of parallel queries
-       useTCP        bool          // force usage of TCP for DNS resolutions
-}
-
 // See resolv.conf(5) on a Linux machine.
 func dnsReadConfig(filename string) *dnsConfig {
        conf := &dnsConfig{
@@ -156,17 +133,6 @@ func dnsReadConfig(filename string) *dnsConfig {
        return conf
 }
 
-// serverOffset returns an offset that can be used to determine
-// indices of servers in c.servers when making queries.
-// When the rotate option is enabled, this offset increases.
-// Otherwise it is always 0.
-func (c *dnsConfig) serverOffset() uint32 {
-       if c.rotate {
-               return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
-       }
-       return 0
-}
-
 func dnsDefaultSearch() []string {
        hn, err := getHostname()
        if err != nil {
diff --git a/src/net/dnsconfig_windows.go b/src/net/dnsconfig_windows.go
new file mode 100644 (file)
index 0000000..5d640da
--- /dev/null
@@ -0,0 +1,58 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "syscall"
+       "time"
+)
+
+func dnsReadConfig(ignoredFilename string) (conf *dnsConfig) {
+       conf = &dnsConfig{
+               ndots:    1,
+               timeout:  5 * time.Second,
+               attempts: 2,
+       }
+       defer func() {
+               if len(conf.servers) == 0 {
+                       conf.servers = defaultNS
+               }
+       }()
+       aas, err := adapterAddresses()
+       if err != nil {
+               return
+       }
+       // TODO(bradfitz): this just collects all the DNS servers on all
+       // the interfaces in some random order. It should order it by
+       // default route, or only use the default route(s) instead.
+       // In practice, however, it mostly works.
+       for _, aa := range aas {
+               for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
+                       sa, err := dns.Address.Sockaddr.Sockaddr()
+                       if err != nil {
+                               continue
+                       }
+                       var ip IP
+                       switch sa := sa.(type) {
+                       case *syscall.SockaddrInet4:
+                               ip = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
+                       case *syscall.SockaddrInet6:
+                               ip = make(IP, IPv6len)
+                               copy(ip, sa.Addr[:])
+                               if ip[0] == 0xfe && ip[1] == 0xc0 {
+                                       // Ignore these fec0/10 ones. Windows seems to
+                                       // populate them as defaults on its misc rando
+                                       // interfaces.
+                                       continue
+                               }
+                       default:
+                               // Unexpected type.
+                               continue
+                       }
+                       conf.servers = append(conf.servers, JoinHostPort(ip.String(), "53"))
+               }
+       }
+       return conf
+}
index 6fa90f354d46601b12721c30b63fd1fb1ce60897..7f3d20126c902c79e863f06d0f2e4fd49d8e38ee 100644 (file)
@@ -10,6 +10,8 @@ import (
        "internal/singleflight"
        "net/netip"
        "sync"
+
+       "golang.org/x/net/dns/dnsmessage"
 )
 
 // protocols contains minimal mappings between internet protocol
@@ -665,3 +667,227 @@ func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error
 // method receives DNS records which contain invalid DNS names. This may be returned alongside
 // results which have had the malformed records filtered out.
 var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"
+
+// dial makes a new connection to the provided server (which must be
+// an IP address) with the provided network type, using either r.Dial
+// (if both r and r.Dial are non-nil) or else Dialer.DialContext.
+func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
+       // Calling Dial here is scary -- we have to be sure not to
+       // dial a name that will require a DNS lookup, or Dial will
+       // call back here to translate it. The DNS config parser has
+       // already checked that all the cfg.servers are IP
+       // addresses, which Dial will use without a DNS lookup.
+       var c Conn
+       var err error
+       if r != nil && r.Dial != nil {
+               c, err = r.Dial(ctx, network, server)
+       } else {
+               var d Dialer
+               c, err = d.DialContext(ctx, network, server)
+       }
+       if err != nil {
+               return nil, mapErr(err)
+       }
+       return c, nil
+}
+
+// goLookupSRV returns the SRV records for a target name, built either
+// from its component service ("sip"), protocol ("tcp"), and name
+// ("example.com."), or from name directly (if service and proto are
+// both empty).
+//
+// In either case, the returned target name ("_sip._tcp.example.com.")
+// is also returned on success.
+//
+// The records are sorted by weight.
+func (r *Resolver) goLookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*SRV, err error) {
+       if service == "" && proto == "" {
+               target = name
+       } else {
+               target = "_" + service + "._" + proto + "." + name
+       }
+       p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
+       if err != nil {
+               return "", nil, err
+       }
+       var cname dnsmessage.Name
+       for {
+               h, err := p.AnswerHeader()
+               if err == dnsmessage.ErrSectionDone {
+                       break
+               }
+               if err != nil {
+                       return "", nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               if h.Type != dnsmessage.TypeSRV {
+                       if err := p.SkipAnswer(); err != nil {
+                               return "", nil, &DNSError{
+                                       Err:    "cannot unmarshal DNS message",
+                                       Name:   name,
+                                       Server: server,
+                               }
+                       }
+                       continue
+               }
+               if cname.Length == 0 && h.Name.Length != 0 {
+                       cname = h.Name
+               }
+               srv, err := p.SRVResource()
+               if err != nil {
+                       return "", nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
+       }
+       byPriorityWeight(srvs).sort()
+       return cname.String(), srvs, nil
+}
+
+// goLookupMX returns the MX records for name.
+func (r *Resolver) goLookupMX(ctx context.Context, name string) ([]*MX, error) {
+       p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
+       if err != nil {
+               return nil, err
+       }
+       var mxs []*MX
+       for {
+               h, err := p.AnswerHeader()
+               if err == dnsmessage.ErrSectionDone {
+                       break
+               }
+               if err != nil {
+                       return nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               if h.Type != dnsmessage.TypeMX {
+                       if err := p.SkipAnswer(); err != nil {
+                               return nil, &DNSError{
+                                       Err:    "cannot unmarshal DNS message",
+                                       Name:   name,
+                                       Server: server,
+                               }
+                       }
+                       continue
+               }
+               mx, err := p.MXResource()
+               if err != nil {
+                       return nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
+
+       }
+       byPref(mxs).sort()
+       return mxs, nil
+}
+
+// goLookupNS returns the NS records for name.
+func (r *Resolver) goLookupNS(ctx context.Context, name string) ([]*NS, error) {
+       p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
+       if err != nil {
+               return nil, err
+       }
+       var nss []*NS
+       for {
+               h, err := p.AnswerHeader()
+               if err == dnsmessage.ErrSectionDone {
+                       break
+               }
+               if err != nil {
+                       return nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               if h.Type != dnsmessage.TypeNS {
+                       if err := p.SkipAnswer(); err != nil {
+                               return nil, &DNSError{
+                                       Err:    "cannot unmarshal DNS message",
+                                       Name:   name,
+                                       Server: server,
+                               }
+                       }
+                       continue
+               }
+               ns, err := p.NSResource()
+               if err != nil {
+                       return nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               nss = append(nss, &NS{Host: ns.NS.String()})
+       }
+       return nss, nil
+}
+
+// goLookupTXT returns the TXT records from name.
+func (r *Resolver) goLookupTXT(ctx context.Context, name string) ([]string, error) {
+       p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
+       if err != nil {
+               return nil, err
+       }
+       var txts []string
+       for {
+               h, err := p.AnswerHeader()
+               if err == dnsmessage.ErrSectionDone {
+                       break
+               }
+               if err != nil {
+                       return nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               if h.Type != dnsmessage.TypeTXT {
+                       if err := p.SkipAnswer(); err != nil {
+                               return nil, &DNSError{
+                                       Err:    "cannot unmarshal DNS message",
+                                       Name:   name,
+                                       Server: server,
+                               }
+                       }
+                       continue
+               }
+               txt, err := p.TXTResource()
+               if err != nil {
+                       return nil, &DNSError{
+                               Err:    "cannot unmarshal DNS message",
+                               Name:   name,
+                               Server: server,
+                       }
+               }
+               // Multiple strings in one TXT record need to be
+               // concatenated without separator to be consistent
+               // with previous Go resolver.
+               n := 0
+               for _, s := range txt.TXT {
+                       n += len(s)
+               }
+               txtJoin := make([]byte, 0, n)
+               for _, s := range txt.TXT {
+                       txtJoin = append(txtJoin, s...)
+               }
+               if len(txts) == 0 {
+                       txts = make([]string, 0, 1)
+               }
+               txts = append(txts, string(txtJoin))
+       }
+       return txts, nil
+}
index d43a03b778d477a3327dfd00f96667a6f56b9507..445b1294e352dd683b91d25d9469c271c7845bec 100644 (file)
@@ -179,7 +179,27 @@ loop:
        return
 }
 
-func (r *Resolver) lookupIP(ctx context.Context, _, host string) (addrs []IPAddr, err error) {
+// preferGoOverPlan9 reports whether the resolver should use the
+// "PreferGo" implementation rather than asking plan9 services
+// for the answers.
+func (r *Resolver) preferGoOverPlan9() bool {
+       conf := systemConf()
+       order := conf.hostLookupOrder(r, "") // name is unused
+
+       // TODO(bradfitz): for now we only permit use of the PreferGo
+       // implementation when there's a non-nil Resolver with a
+       // non-nil Dialer. This is a sign that they the code is trying
+       // to use their DNS-speaking net.Conn (such as an in-memory
+       // DNS cache) and they don't want to actually hit the network.
+       // Once we add support for looking the default DNS servers
+       // from plan9, though, then we can relax this.
+       return order != hostLookupCgo && r != nil && r.Dial != nil
+}
+
+func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupIP(ctx, network, host)
+       }
        lits, err := r.lookupHost(ctx, host)
        if err != nil {
                return
@@ -223,7 +243,10 @@ func (*Resolver) lookupPort(ctx context.Context, network, service string) (port
        return 0, unknownPortError
 }
 
-func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupCNAME(ctx, name)
+       }
        lines, err := queryDNS(ctx, name, "cname")
        if err != nil {
                if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") {
@@ -240,7 +263,10 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, er
        return "", errors.New("bad response from ndb/dns")
 }
 
-func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
+func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupSRV(ctx, service, proto, name)
+       }
        var target string
        if service == "" && proto == "" {
                target = name
@@ -269,7 +295,10 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cn
        return
 }
 
-func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
+func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupMX(ctx, name)
+       }
        lines, err := queryDNS(ctx, name, "mx")
        if err != nil {
                return
@@ -287,7 +316,10 @@ func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error
        return
 }
 
-func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
+func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupNS(ctx, name)
+       }
        lines, err := queryDNS(ctx, name, "ns")
        if err != nil {
                return
@@ -302,7 +334,10 @@ func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error
        return
 }
 
-func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
+func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupTXT(ctx, name)
+       }
        lines, err := queryDNS(ctx, name, "txt")
        if err != nil {
                return
@@ -315,7 +350,10 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err
        return
 }
 
-func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
+func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
+       if r.preferGoOverPlan9() {
+               return r.goLookupPTR(ctx, addr)
+       }
        arpa, err := reverseaddr(addr)
        if err != nil {
                return
index ad4164d86517adbb659ac386474d6b598bde8ea9..4b885e938a06fe539db6eee7bc389ed22f845c06 100644 (file)
@@ -11,8 +11,6 @@ import (
        "internal/bytealg"
        "sync"
        "syscall"
-
-       "golang.org/x/net/dns/dnsmessage"
 )
 
 var onceReadProtocols sync.Once
@@ -55,26 +53,6 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
        return lookupProtocolMap(name)
 }
 
-func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
-       // Calling Dial here is scary -- we have to be sure not to
-       // dial a name that will require a DNS lookup, or Dial will
-       // call back here to translate it. The DNS config parser has
-       // already checked that all the cfg.servers are IP
-       // addresses, which Dial will use without a DNS lookup.
-       var c Conn
-       var err error
-       if r != nil && r.Dial != nil {
-               c, err = r.Dial(ctx, network, server)
-       } else {
-               var d Dialer
-               c, err = d.DialContext(ctx, network, server)
-       }
-       if err != nil {
-               return nil, mapErr(err)
-       }
-       return c, nil
-}
-
 func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        order := systemConf().hostLookupOrder(r, host)
        if !r.preferGo() && order == hostLookupCgo {
@@ -129,194 +107,19 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error)
 }
 
 func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
-       var target string
-       if service == "" && proto == "" {
-               target = name
-       } else {
-               target = "_" + service + "._" + proto + "." + name
-       }
-       p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
-       if err != nil {
-               return "", nil, err
-       }
-       var srvs []*SRV
-       var cname dnsmessage.Name
-       for {
-               h, err := p.AnswerHeader()
-               if err == dnsmessage.ErrSectionDone {
-                       break
-               }
-               if err != nil {
-                       return "", nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               if h.Type != dnsmessage.TypeSRV {
-                       if err := p.SkipAnswer(); err != nil {
-                               return "", nil, &DNSError{
-                                       Err:    "cannot unmarshal DNS message",
-                                       Name:   name,
-                                       Server: server,
-                               }
-                       }
-                       continue
-               }
-               if cname.Length == 0 && h.Name.Length != 0 {
-                       cname = h.Name
-               }
-               srv, err := p.SRVResource()
-               if err != nil {
-                       return "", nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
-       }
-       byPriorityWeight(srvs).sort()
-       return cname.String(), srvs, nil
+       return r.goLookupSRV(ctx, service, proto, name)
 }
 
 func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
-       p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
-       if err != nil {
-               return nil, err
-       }
-       var mxs []*MX
-       for {
-               h, err := p.AnswerHeader()
-               if err == dnsmessage.ErrSectionDone {
-                       break
-               }
-               if err != nil {
-                       return nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               if h.Type != dnsmessage.TypeMX {
-                       if err := p.SkipAnswer(); err != nil {
-                               return nil, &DNSError{
-                                       Err:    "cannot unmarshal DNS message",
-                                       Name:   name,
-                                       Server: server,
-                               }
-                       }
-                       continue
-               }
-               mx, err := p.MXResource()
-               if err != nil {
-                       return nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
-
-       }
-       byPref(mxs).sort()
-       return mxs, nil
+       return r.goLookupMX(ctx, name)
 }
 
 func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
-       p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
-       if err != nil {
-               return nil, err
-       }
-       var nss []*NS
-       for {
-               h, err := p.AnswerHeader()
-               if err == dnsmessage.ErrSectionDone {
-                       break
-               }
-               if err != nil {
-                       return nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               if h.Type != dnsmessage.TypeNS {
-                       if err := p.SkipAnswer(); err != nil {
-                               return nil, &DNSError{
-                                       Err:    "cannot unmarshal DNS message",
-                                       Name:   name,
-                                       Server: server,
-                               }
-                       }
-                       continue
-               }
-               ns, err := p.NSResource()
-               if err != nil {
-                       return nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               nss = append(nss, &NS{Host: ns.NS.String()})
-       }
-       return nss, nil
+       return r.goLookupNS(ctx, name)
 }
 
 func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
-       p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
-       if err != nil {
-               return nil, err
-       }
-       var txts []string
-       for {
-               h, err := p.AnswerHeader()
-               if err == dnsmessage.ErrSectionDone {
-                       break
-               }
-               if err != nil {
-                       return nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               if h.Type != dnsmessage.TypeTXT {
-                       if err := p.SkipAnswer(); err != nil {
-                               return nil, &DNSError{
-                                       Err:    "cannot unmarshal DNS message",
-                                       Name:   name,
-                                       Server: server,
-                               }
-                       }
-                       continue
-               }
-               txt, err := p.TXTResource()
-               if err != nil {
-                       return nil, &DNSError{
-                               Err:    "cannot unmarshal DNS message",
-                               Name:   name,
-                               Server: server,
-                       }
-               }
-               // Multiple strings in one TXT record need to be
-               // concatenated without separator to be consistent
-               // with previous Go resolver.
-               n := 0
-               for _, s := range txt.TXT {
-                       n += len(s)
-               }
-               txtJoin := make([]byte, 0, n)
-               for _, s := range txt.TXT {
-                       txtJoin = append(txtJoin, s...)
-               }
-               if len(txts) == 0 {
-                       txts = make([]string, 0, 1)
-               }
-               txts = append(txts, string(txtJoin))
-       }
-       return txts, nil
+       return r.goLookupTXT(ctx, name)
 }
 
 func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
index 27e5f86910e0fcff75b14b80731b4bebb46d21d4..051f47da392c3ecf426c27bd4c176ce737ec7c8c 100644 (file)
@@ -82,7 +82,19 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error
        return addrs, nil
 }
 
+// preferGoOverWindows reports whether the resolver should use the
+// pure Go implementation rather than making win32 calls to ask the
+// kernel for its answer.
+func (r *Resolver) preferGoOverWindows() bool {
+       conf := systemConf()
+       order := conf.hostLookupOrder(r, "") // name is unused
+       return order != hostLookupCgo
+}
+
 func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
+       if r.preferGoOverWindows() {
+               return r.goLookupIP(ctx, network, name)
+       }
        // TODO(bradfitz,brainman): use ctx more. See TODO below.
 
        var family int32 = syscall.AF_UNSPEC
@@ -169,7 +181,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
 }
 
 func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
-       if r.preferGo() {
+       if r.preferGoOverWindows() {
                return lookupPortMap(network, service)
        }
 
@@ -217,12 +229,15 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
        return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
 }
 
-func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
+func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
+       if r.preferGoOverWindows() {
+               return r.goLookupCNAME(ctx, name)
+       }
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
-       var r *syscall.DNSRecord
-       e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
+       var rec *syscall.DNSRecord
+       e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil)
        // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
        if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
                // if there are no aliases, the canonical name is the input name
@@ -231,14 +246,17 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
        if e != nil {
                return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
        }
-       defer syscall.DnsRecordListFree(r, 1)
+       defer syscall.DnsRecordListFree(rec, 1)
 
-       resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
+       resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), rec)
        cname := windows.UTF16PtrToString(resolved)
        return absDomainName(cname), nil
 }
 
-func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+       if r.preferGoOverWindows() {
+               return r.goLookupSRV(ctx, service, proto, name)
+       }
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
@@ -248,15 +266,15 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st
        } else {
                target = "_" + service + "._" + proto + "." + name
        }
-       var r *syscall.DNSRecord
-       e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
+       var rec *syscall.DNSRecord
+       e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil)
        if e != nil {
                return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
        }
-       defer syscall.DnsRecordListFree(r, 1)
+       defer syscall.DnsRecordListFree(rec, 1)
 
        srvs := make([]*SRV, 0, 10)
-       for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
+       for _, p := range validRecs(rec, syscall.DNS_TYPE_SRV, target) {
                v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
                srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
        }
@@ -264,19 +282,22 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st
        return absDomainName(target), srvs, nil
 }
 
-func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
+func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
+       if r.preferGoOverWindows() {
+               return r.goLookupMX(ctx, name)
+       }
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
-       var r *syscall.DNSRecord
-       e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
+       var rec *syscall.DNSRecord
+       e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
        if e != nil {
                return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
        }
-       defer syscall.DnsRecordListFree(r, 1)
+       defer syscall.DnsRecordListFree(rec, 1)
 
        mxs := make([]*MX, 0, 10)
-       for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
+       for _, p := range validRecs(rec, syscall.DNS_TYPE_MX, name) {
                v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
                mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
        }
@@ -284,38 +305,44 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
        return mxs, nil
 }
 
-func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
+func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
+       if r.preferGoOverWindows() {
+               return r.goLookupNS(ctx, name)
+       }
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
-       var r *syscall.DNSRecord
-       e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
+       var rec *syscall.DNSRecord
+       e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
        if e != nil {
                return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
        }
-       defer syscall.DnsRecordListFree(r, 1)
+       defer syscall.DnsRecordListFree(rec, 1)
 
        nss := make([]*NS, 0, 10)
-       for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
+       for _, p := range validRecs(rec, syscall.DNS_TYPE_NS, name) {
                v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
                nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
        }
        return nss, nil
 }
 
-func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
+func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
+       if r.preferGoOverWindows() {
+               return r.lookupTXT(ctx, name)
+       }
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
-       var r *syscall.DNSRecord
-       e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
+       var rec *syscall.DNSRecord
+       e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
        if e != nil {
                return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
        }
-       defer syscall.DnsRecordListFree(r, 1)
+       defer syscall.DnsRecordListFree(rec, 1)
 
        txts := make([]string, 0, 10)
-       for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
+       for _, p := range validRecs(rec, syscall.DNS_TYPE_TEXT, name) {
                d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
                s := ""
                for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] {
@@ -326,7 +353,11 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
        return txts, nil
 }
 
-func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
+func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
+       if r.preferGoOverWindows() {
+               return r.goLookupPTR(ctx, addr)
+       }
+
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
        defer releaseThread()
@@ -334,15 +365,15 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error)
        if err != nil {
                return nil, err
        }
-       var r *syscall.DNSRecord
-       e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
+       var rec *syscall.DNSRecord
+       e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil)
        if e != nil {
                return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
        }
-       defer syscall.DnsRecordListFree(r, 1)
+       defer syscall.DnsRecordListFree(rec, 1)
 
        ptrs := make([]string, 0, 10)
-       for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
+       for _, p := range validRecs(rec, syscall.DNS_TYPE_PTR, arpa) {
                v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
                ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
        }
index 759d5d8aa937a96613b9d4c8ca1093ca0ef79ed7..ff56c31c5634370e4a895e6bb2560fe9d2c75388 100644 (file)
@@ -61,7 +61,7 @@ The resolver decision can be overridden by setting the netdns value of the
 GODEBUG environment variable (see package runtime) to go or cgo, as in:
 
        export GODEBUG=netdns=go    # force pure Go resolver
-       export GODEBUG=netdns=cgo   # force cgo resolver
+       export GODEBUG=netdns=cgo   # force native resolver (cgo, win32)
 
 The decision can also be forced while building the Go source tree
 by setting the netgo or netcgo build tag.
@@ -73,7 +73,8 @@ join the two settings by a plus sign, as in GODEBUG=netdns=go+1.
 
 On Plan 9, the resolver always accesses /net/cs and /net/dns.
 
-On Windows, the resolver always uses C library functions, such as GetAddrInfo and DnsQuery.
+On Windows, in Go 1.18.x and earlier, the resolver always used C
+library functions, such as GetAddrInfo and DnsQuery.
 */
 package net
 
index ee5644c67f087f77c021aa747a61c5f368b18b1e..6d07d6297a4c56fcdae9de3ea04f8e13bba8f3fd 100644 (file)
@@ -16,6 +16,8 @@ import (
        "sync"
        "syscall"
        "time"
+
+       "golang.org/x/net/dns/dnsmessage"
 )
 
 var listenersMu sync.Mutex
@@ -314,3 +316,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
 func (fd *netFD) dup() (f *os.File, err error) {
        return nil, syscall.ENOSYS
 }
+
+func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
+       panic("unreachable")
+}
diff --git a/src/net/netgo.go b/src/net/netgo.go
new file mode 100644 (file)
index 0000000..f91c91b
--- /dev/null
@@ -0,0 +1,9 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build netgo
+
+package net
+
+func init() { netGo = true }
index 5df71dc268fad5bf98e7cc16242563477eb9a24b..c4c608fb61864760be09654e17c4df961fcd6cdc 100644 (file)
@@ -2,8 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build unix
-
 package net
 
 import (
diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go
new file mode 100644 (file)
index 0000000..034c636
--- /dev/null
@@ -0,0 +1,328 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js
+
+// Test that Resolver.Dial can be a func returning an in-memory net.Conn
+// speaking DNS.
+
+package net
+
+import (
+       "bytes"
+       "context"
+       "errors"
+       "fmt"
+       "reflect"
+       "sort"
+       "testing"
+       "time"
+
+       "golang.org/x/net/dns/dnsmessage"
+)
+
+func TestResolverDialFunc(t *testing.T) {
+       r := &Resolver{
+               PreferGo: true,
+               Dial: newResolverDialFunc(&resolverDialHandler{
+                       StartDial: func(network, address string) error {
+                               t.Logf("StartDial(%q, %q) ...", network, address)
+                               return nil
+                       },
+                       Question: func(h dnsmessage.Header, q dnsmessage.Question) {
+                               t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
+                                       q.Name.String(), q.Type, q.Class)
+                       },
+                       // TODO: add test without HandleA* hooks specified at all, that Go
+                       // doesn't issue retries; map to something terminal.
+                       HandleA: func(w AWriter, name string) error {
+                               w.AddIP([4]byte{1, 2, 3, 4})
+                               w.AddIP([4]byte{5, 6, 7, 8})
+                               return nil
+                       },
+                       HandleAAAA: func(w AAAAWriter, name string) error {
+                               w.AddIP([16]byte{1: 1, 15: 15})
+                               w.AddIP([16]byte{2: 2, 14: 14})
+                               return nil
+                       },
+                       HandleSRV: func(w SRVWriter, name string) error {
+                               w.AddSRV(1, 2, 80, "foo.bar.")
+                               w.AddSRV(2, 3, 81, "bar.baz.")
+                               return nil
+                       },
+               }),
+       }
+       ctx := context.Background()
+       const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
+
+       t.Run("LookupIP", func(t *testing.T) {
+               ips, err := r.LookupIP(ctx, "ip", fakeDomain)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
+                       t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
+               }
+       })
+
+       t.Run("LookupSRV", func(t *testing.T) {
+               _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               want := []*SRV{
+                       {
+                               Target:   "foo.bar.",
+                               Port:     80,
+                               Priority: 1,
+                               Weight:   2,
+                       },
+                       {
+                               Target:   "bar.baz.",
+                               Port:     81,
+                               Priority: 2,
+                               Weight:   3,
+                       },
+               }
+               if !reflect.DeepEqual(got, want) {
+                       t.Errorf("wrong result. got:")
+                       for _, r := range got {
+                               t.Logf("  - %+v", r)
+                       }
+               }
+       })
+}
+
+func sortedIPStrings(ips []IP) []string {
+       ret := make([]string, len(ips))
+       for i, ip := range ips {
+               ret[i] = ip.String()
+       }
+       sort.Strings(ret)
+       return ret
+}
+
+func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
+       return func(ctx context.Context, network, address string) (Conn, error) {
+               a := &resolverFuncConn{
+                       h:       h,
+                       network: network,
+                       address: address,
+                       ttl:     10, // 10 second default if unset
+               }
+               if h.StartDial != nil {
+                       if err := h.StartDial(network, address); err != nil {
+                               return nil, err
+                       }
+               }
+               return a, nil
+       }
+}
+
+type resolverDialHandler struct {
+       // StartDial, if non-nil, is called when Go first calls Resolver.Dial.
+       // Any error returned aborts the dial and is returned unwrapped.
+       StartDial func(network, address string) error
+
+       Question func(dnsmessage.Header, dnsmessage.Question)
+
+       // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
+       // A nil error means success.
+       HandleA    func(w AWriter, name string) error
+       HandleAAAA func(w AAAAWriter, name string) error
+       HandleSRV  func(w SRVWriter, name string) error
+}
+
+type ResponseWriter struct{ a *resolverFuncConn }
+
+func (w ResponseWriter) header() dnsmessage.ResourceHeader {
+       q := w.a.q
+       return dnsmessage.ResourceHeader{
+               Name:  q.Name,
+               Type:  q.Type,
+               Class: q.Class,
+               TTL:   w.a.ttl,
+       }
+}
+
+// SetTTL sets the TTL for subsequent written resources.
+// Once a resource has been written, SetTTL calls are no-ops.
+// That is, it can only be called at most once, before anything
+// else is written.
+func (w ResponseWriter) SetTTL(seconds uint32) {
+       // ... intention is last one wins and mutates all previously
+       // written records too, but that's a little annoying.
+       // But it's also annoying if the requirement is it needs to be set
+       // last.
+       // And it's also annoying if it's possible for users to set
+       // different TTLs per Answer.
+       if w.a.wrote {
+               return
+       }
+       w.a.ttl = seconds
+
+}
+
+type AWriter struct{ ResponseWriter }
+
+func (w AWriter) AddIP(v4 [4]byte) {
+       w.a.wrote = true
+       err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
+       if err != nil {
+               panic(err)
+       }
+}
+
+type AAAAWriter struct{ ResponseWriter }
+
+func (w AAAAWriter) AddIP(v6 [16]byte) {
+       w.a.wrote = true
+       err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
+       if err != nil {
+               panic(err)
+       }
+}
+
+type SRVWriter struct{ ResponseWriter }
+
+// AddSRV adds a SRV record. The target name must end in a period and
+// be 63 bytes or fewer.
+func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
+       targetName, err := dnsmessage.NewName(target)
+       if err != nil {
+               return err
+       }
+       w.a.wrote = true
+       err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
+               Priority: priority,
+               Weight:   weight,
+               Port:     port,
+               Target:   targetName,
+       })
+       if err != nil {
+               panic(err) // internal fault, not user
+       }
+       return nil
+}
+
+var (
+       ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
+       ErrRefused  = errors.New("refused")             // maps to RCode5, REFUSED
+)
+
+type resolverFuncConn struct {
+       h       *resolverDialHandler
+       ctx     context.Context
+       network string
+       address string
+       builder *dnsmessage.Builder
+       q       dnsmessage.Question
+       ttl     uint32
+       wrote   bool
+
+       rbuf bytes.Buffer
+}
+
+func (*resolverFuncConn) Close() error                       { return nil }
+func (*resolverFuncConn) LocalAddr() Addr                    { return someaddr{} }
+func (*resolverFuncConn) RemoteAddr() Addr                   { return someaddr{} }
+func (*resolverFuncConn) SetDeadline(t time.Time) error      { return nil }
+func (*resolverFuncConn) SetReadDeadline(t time.Time) error  { return nil }
+func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
+       return a.rbuf.Read(p)
+}
+
+func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
+       if len(packet) < 2 {
+               return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
+       }
+       reqLen := int(packet[0])<<8 | int(packet[1])
+       req := packet[2:]
+       if len(req) != reqLen {
+               return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
+       }
+
+       var parser dnsmessage.Parser
+       h, err := parser.Start(req)
+       if err != nil {
+               // TODO: hook
+               return 0, err
+       }
+       q, err := parser.Question()
+       hadQ := (err == nil)
+       if err == nil && a.h.Question != nil {
+               a.h.Question(h, q)
+       }
+       if err != nil && err != dnsmessage.ErrSectionDone {
+               return 0, err
+       }
+
+       resh := h
+       resh.Response = true
+       resh.Authoritative = true
+       if hadQ {
+               resh.RCode = dnsmessage.RCodeSuccess
+       } else {
+               resh.RCode = dnsmessage.RCodeNotImplemented
+       }
+       a.rbuf.Grow(514)
+       a.rbuf.WriteByte('X') // reserved header for beu16 length
+       a.rbuf.WriteByte('Y') // reserved header for beu16 length
+       builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
+       a.builder = &builder
+       if hadQ {
+               a.q = q
+               a.builder.StartQuestions()
+               err := a.builder.Question(q)
+               if err != nil {
+                       return 0, fmt.Errorf("Question: %w", err)
+               }
+               a.builder.StartAnswers()
+               switch q.Type {
+               case dnsmessage.TypeA:
+                       if a.h.HandleA != nil {
+                               resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
+                       }
+               case dnsmessage.TypeAAAA:
+                       if a.h.HandleAAAA != nil {
+                               resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
+                       }
+               case dnsmessage.TypeSRV:
+                       if a.h.HandleSRV != nil {
+                               resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
+                       }
+               }
+       }
+       tcpRes, err := builder.Finish()
+       if err != nil {
+               return 0, fmt.Errorf("Finish: %w", err)
+       }
+
+       n = len(tcpRes) - 2
+       tcpRes[0] = byte(n >> 8)
+       tcpRes[1] = byte(n)
+       a.rbuf.Write(tcpRes[2:])
+
+       return len(packet), nil
+}
+
+type someaddr struct{}
+
+func (someaddr) Network() string { return "unused" }
+func (someaddr) String() string  { return "unused-someaddr" }
+
+func mapRCode(err error) dnsmessage.RCode {
+       switch err {
+       case nil:
+               return dnsmessage.RCodeSuccess
+       case ErrNotExist:
+               return dnsmessage.RCodeNameError
+       case ErrRefused:
+               return dnsmessage.RCodeRefused
+       default:
+               return dnsmessage.RCodeServerFailure
+       }
+}