]> Cypherpunks repositories - gostls13.git/commitdiff
net: allow Resolver to use a custom dialer
authorMatt Harden <matt.harden@gmail.com>
Mon, 20 Feb 2017 13:58:55 +0000 (05:58 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 12 May 2017 18:08:12 +0000 (18:08 +0000)
In some cases it is desirable to customize the way the DNS server is
contacted, for instance to use a specific LocalAddr. While most
operating-system level resolvers do not allow this, we have the
opportunity to do so with the Go resolver. Most of the code was
already in place to allow tests to override the dialer. This exposes
that functionality, and as a side effect eliminates the need for a
testing hook.

Fixes #17404

Change-Id: I1c5e570f8edbcf630090f8ec6feb52e379e3e5c0
Reviewed-on: https://go-review.googlesource.com/37260
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/go/build/deps_test.go
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/lookup.go
src/net/lookup_unix.go

index 2cceb5a2e28735ad4ffe94acc2513dad54184f29..ec8dd0678811fa3393f36cb6e2da73431e50c756 100644 (file)
@@ -304,7 +304,7 @@ var pkgDeps = map[string][]string{
        // do networking portably, it must have a small dependency set: just L0+basic os.
        "net": {
                "L0", "CGO",
-               "context", "math/rand", "os", "sort", "syscall", "time",
+               "context", "math/rand", "os", "reflect", "sort", "syscall", "time",
                "internal/nettrace", "internal/poll",
                "internal/syscall/windows", "internal/singleflight", "internal/race",
                "golang_org/x/net/lif", "golang_org/x/net/route",
index 6613dc7593c4710c2ee8c6af616caa757b13d941..75d70d3989377c5e160916d8109c4f2d220f96ca 100644 (file)
@@ -25,13 +25,6 @@ import (
        "time"
 )
 
-// A dnsDialer provides dialing suitable for DNS queries.
-type dnsDialer interface {
-       dialDNS(ctx context.Context, network, addr string) (dnsConn, error)
-}
-
-var testHookDNSDialer = func() dnsDialer { return &Dialer{} }
-
 // A dnsConn represents a DNS transport endpoint.
 type dnsConn interface {
        io.Closer
@@ -116,33 +109,8 @@ func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
        return resp, nil
 }
 
-func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
-       switch network {
-       case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
-       default:
-               return nil, UnknownNetworkError(network)
-       }
-       // Calling Dial here is scary -- we have to be sure not to
-       // dial a name that will require a DNS lookup, or Dial will
-       // call back here to translate it. The DNS config parser has
-       // already checked that all the cfg.servers are IP
-       // addresses, which Dial will use without a DNS lookup.
-       c, err := d.DialContext(ctx, network, server)
-       if err != nil {
-               return nil, mapErr(err)
-       }
-       switch network {
-       case "tcp", "tcp4", "tcp6":
-               return c.(*TCPConn), nil
-       case "udp", "udp4", "udp6":
-               return c.(*UDPConn), nil
-       }
-       panic("unreachable")
-}
-
 // exchange sends a query on the connection and hopes for a response.
-func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
-       d := testHookDNSDialer()
+func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
        out := dnsMsg{
                dnsMsgHdr: dnsMsgHdr{
                        recursion_desired: true,
@@ -158,7 +126,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti
                ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
                defer cancel()
 
-               c, err := d.dialDNS(ctx, network, server)
+               c, err := r.dial(ctx, network, server)
                if err != nil {
                        return nil, err
                }
@@ -181,7 +149,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti
 
 // Do a lookup for a single name, which must be rooted
 // (otherwise answer will not find the answers).
-func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
+func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
        var lastErr error
        serverOffset := cfg.serverOffset()
        sLen := uint32(len(cfg.servers))
@@ -190,7 +158,7 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16)
                for j := uint32(0); j < sLen; j++ {
                        server := cfg.servers[(serverOffset+j)%sLen]
 
-                       msg, err := exchange(ctx, server, name, qtype, cfg.timeout)
+                       msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout)
                        if err != nil {
                                lastErr = &DNSError{
                                        Err:    err.Error(),
@@ -333,7 +301,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname
        conf := resolvConf.dnsConfig
        resolvConf.mu.RUnlock()
        for _, fqdn := range conf.nameList(name) {
-               cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype)
+               cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype)
                if err == nil {
                        break
                }
@@ -512,7 +480,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
        for _, fqdn := range conf.nameList(name) {
                for _, qtype := range qtypes {
                        go func(qtype uint16) {
-                               cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype)
+                               cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
                                lane <- racer{cname, rrs, err}
                        }(qtype)
                }
index a23e5f622225bea52b6aef431db9ae3a39699040..d0ac4302b105e00d8ac781ce37122597eb14d468 100644 (file)
@@ -10,7 +10,6 @@ import (
        "context"
        "fmt"
        "internal/poll"
-       "internal/testenv"
        "io/ioutil"
        "os"
        "path"
@@ -21,6 +20,8 @@ import (
        "time"
 )
 
+var goResolver = Resolver{PreferGo: true}
+
 // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
 const TestAddr uint32 = 0xc0000201
 
@@ -41,18 +42,30 @@ var dnsTransportFallbackTests = []struct {
 }
 
 func TestDNSTransportFallback(t *testing.T) {
-       testenv.MustHaveExternalNetwork(t)
-
+       fake := fakeDNSServer{
+               rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) {
+                       r := &dnsMsg{
+                               dnsMsgHdr: dnsMsgHdr{
+                                       rcode: dnsRcodeSuccess,
+                               },
+                       }
+                       if n == "udp" {
+                               r.truncated = true
+                       }
+                       return r, nil
+               },
+       }
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
        for _, tt := range dnsTransportFallbackTests {
                ctx, cancel := context.WithCancel(context.Background())
                defer cancel()
-               msg, err := exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
+               msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
                if err != nil {
                        t.Error(err)
                        continue
                }
                switch msg.rcode {
-               case tt.rcode, dnsRcodeServerFailure:
+               case tt.rcode:
                default:
                        t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
                        continue
@@ -82,13 +95,28 @@ var specialDomainNameTests = []struct {
 }
 
 func TestSpecialDomainName(t *testing.T) {
-       testenv.MustHaveExternalNetwork(t)
+       fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+               r := &dnsMsg{
+                       dnsMsgHdr: dnsMsgHdr{
+                               id: q.id,
+                       },
+               }
 
+               switch q.question[0].Name {
+               case "example.com.":
+                       r.rcode = dnsRcodeSuccess
+               default:
+                       r.rcode = dnsRcodeNameError
+               }
+
+               return r, nil
+       }}
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
        server := "8.8.8.8:53"
        for _, tt := range specialDomainNameTests {
                ctx, cancel := context.WithCancel(context.Background())
                defer cancel()
-               msg, err := exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
+               msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
                if err != nil {
                        t.Error(err)
                        continue
@@ -143,15 +171,40 @@ func TestAvoidDNSName(t *testing.T) {
        }
 }
 
+var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+       r := &dnsMsg{
+               dnsMsgHdr: dnsMsgHdr{
+                       id:       q.id,
+                       response: true,
+               },
+               question: q.question,
+       }
+       if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
+               r.answer = []dnsRR{
+                       &dnsRR_A{
+                               Hdr: dnsRR_Header{
+                                       Name:     q.question[0].Name,
+                                       Rrtype:   dnsTypeA,
+                                       Class:    dnsClassINET,
+                                       Rdlength: 4,
+                               },
+                               A: TestAddr,
+                       },
+               }
+       }
+       return r, nil
+}}
+
 // Issue 13705: don't try to resolve onion addresses, etc
 func TestLookupTorOnion(t *testing.T) {
-       addrs, err := DefaultResolver.goLookupIP(context.Background(), "foo.onion")
-       if len(addrs) > 0 {
-               t.Errorf("unexpected addresses: %v", addrs)
-       }
+       r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
+       addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
        if err != nil {
                t.Fatalf("lookup = %v; want nil", err)
        }
+       if len(addrs) > 0 {
+               t.Errorf("unexpected addresses: %v", addrs)
+       }
 }
 
 type resolvConfTest struct {
@@ -241,7 +294,7 @@ var updateResolvConfTests = []struct {
 }
 
 func TestUpdateResolvConf(t *testing.T) {
-       testenv.MustHaveExternalNetwork(t)
+       r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
 
        conf, err := newResolvConfTest()
        if err != nil {
@@ -261,7 +314,7 @@ func TestUpdateResolvConf(t *testing.T) {
                        for j := 0; j < N; j++ {
                                go func(name string) {
                                        defer wg.Done()
-                                       ips, err := DefaultResolver.goLookupIP(context.Background(), name)
+                                       ips, err := r.LookupIPAddr(context.Background(), name)
                                        if err != nil {
                                                t.Error(err)
                                                return
@@ -396,7 +449,60 @@ var goLookupIPWithResolverConfigTests = []struct {
 }
 
 func TestGoLookupIPWithResolverConfig(t *testing.T) {
-       testenv.MustHaveExternalNetwork(t)
+       fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+               switch s {
+               case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
+                       break
+               default:
+                       time.Sleep(10 * time.Millisecond)
+                       return nil, poll.ErrTimeout
+               }
+               r := &dnsMsg{
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       q.id,
+                               response: true,
+                       },
+                       question: q.question,
+               }
+               for _, question := range q.question {
+                       switch question.Qtype {
+                       case dnsTypeA:
+                               switch question.Name {
+                               case "hostname.as112.net.":
+                                       break
+                               case "ipv4.google.com.":
+                                       r.answer = append(r.answer, &dnsRR_A{
+                                               Hdr: dnsRR_Header{
+                                                       Name:     q.question[0].Name,
+                                                       Rrtype:   dnsTypeA,
+                                                       Class:    dnsClassINET,
+                                                       Rdlength: 4,
+                                               },
+                                               A: TestAddr,
+                                       })
+                               default:
+
+                               }
+                       case dnsTypeAAAA:
+                               switch question.Name {
+                               case "hostname.as112.net.":
+                                       break
+                               case "ipv6.google.com.":
+                                       r.answer = append(r.answer, &dnsRR_AAAA{
+                                               Hdr: dnsRR_Header{
+                                                       Name:     q.question[0].Name,
+                                                       Rrtype:   dnsTypeAAAA,
+                                                       Class:    dnsClassINET,
+                                                       Rdlength: 16,
+                                               },
+                                               AAAA: TestAddr6,
+                                       })
+                               }
+                       }
+               }
+               return r, nil
+       }}
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
 
        conf, err := newResolvConfTest()
        if err != nil {
@@ -409,14 +515,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
                        t.Error(err)
                        continue
                }
-               addrs, err := DefaultResolver.goLookupIP(context.Background(), tt.name)
+               addrs, err := r.LookupIPAddr(context.Background(), tt.name)
                if err != nil {
-                       // This test uses external network connectivity.
-                       // We need to take care with errors on both
-                       // DNS message exchange layer and DNS
-                       // transport layer because goLookupIP may fail
-                       // when the IP connectivity on node under test
-                       // gets lost during its run.
                        if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
                                t.Errorf("got %v; want %v", err, tt.error)
                        }
@@ -441,7 +541,17 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
 
 // Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
 func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
-       testenv.MustHaveExternalNetwork(t)
+       fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
+               r := &dnsMsg{
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       q.id,
+                               response: true,
+                       },
+                       question: q.question,
+               }
+               return r, nil
+       }}
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
 
        // Add a config that simulates no dns servers being available.
        conf, err := newResolvConfTest()
@@ -459,14 +569,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
                name := fmt.Sprintf("order %v", order)
 
                // First ensure that we get an error when contacting a non-existent host.
-               _, _, err := DefaultResolver.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
+               _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
                if err == nil {
                        t.Errorf("%s: expected error while looking up name not in hosts file", name)
                        continue
                }
 
                // Now check that we get an address when the name appears in the hosts file.
-               addrs, _, err := DefaultResolver.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
+               addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
                if err != nil {
                        t.Errorf("%s: expected to successfully lookup host entry", name)
                        continue
@@ -489,9 +599,6 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
 func TestErrorForOriginalNameWhenSearching(t *testing.T) {
        const fqdn = "doesnotexist.domain"
 
-       origTestHookDNSDialer := testHookDNSDialer
-       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
-
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -502,10 +609,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
                t.Fatal(err)
        }
 
-       d := &fakeDNSDialer{}
-       testHookDNSDialer = func() dnsDialer { return d }
-
-       d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+       fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
                                id: q.id,
@@ -520,7 +624,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
                }
 
                return r, nil
-       }
+       }}
 
        cases := []struct {
                strictErrors bool
@@ -530,8 +634,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
                {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}},
        }
        for _, tt := range cases {
-               r := Resolver{StrictErrors: tt.strictErrors}
-               _, err = r.goLookupIP(context.Background(), fqdn)
+               r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
+               _, err = r.LookupIPAddr(context.Background(), fqdn)
                if err == nil {
                        t.Fatal("expected an error")
                }
@@ -545,9 +649,6 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
 
 // Issue 15434. If a name server gives a lame referral, continue to the next.
 func TestIgnoreLameReferrals(t *testing.T) {
-       origTestHookDNSDialer := testHookDNSDialer
-       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
-
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -559,10 +660,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
                t.Fatal(err)
        }
 
-       d := &fakeDNSDialer{}
-       testHookDNSDialer = func() dnsDialer { return d }
-
-       d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+       fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                t.Log(s, q)
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
@@ -590,9 +688,10 @@ func TestIgnoreLameReferrals(t *testing.T) {
                }
 
                return r, nil
-       }
+       }}
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
 
-       addrs, err := DefaultResolver.goLookupIP(context.Background(), "www.golang.org")
+       addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
        if err != nil {
                t.Fatal(err)
        }
@@ -611,7 +710,7 @@ func BenchmarkGoLookupIP(b *testing.B) {
        ctx := context.Background()
 
        for i := 0; i < b.N; i++ {
-               DefaultResolver.goLookupIP(ctx, "www.example.com")
+               goResolver.LookupIPAddr(ctx, "www.example.com")
        }
 }
 
@@ -620,7 +719,7 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
        ctx := context.Background()
 
        for i := 0; i < b.N; i++ {
-               DefaultResolver.goLookupIP(ctx, "some.nonexistent")
+               goResolver.LookupIPAddr(ctx, "some.nonexistent")
        }
 }
 
@@ -643,23 +742,24 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
        ctx := context.Background()
 
        for i := 0; i < b.N; i++ {
-               DefaultResolver.goLookupIP(ctx, "www.example.com")
+               goResolver.LookupIPAddr(ctx, "www.example.com")
        }
 }
 
-type fakeDNSDialer struct {
-       // reply handler
-       rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
+type fakeDNSServer struct {
+       rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
 }
 
-func (f *fakeDNSDialer) dialDNS(_ context.Context, n, s string) (dnsConn, error) {
-       return &fakeDNSConn{f.rh, s, time.Time{}}, nil
+func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
+       return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil
 }
 
 type fakeDNSConn struct {
-       rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
-       s  string
-       t  time.Time
+       Conn
+       server *fakeDNSServer
+       n      string
+       s      string
+       t      time.Time
 }
 
 func (f *fakeDNSConn) Close() error {
@@ -672,7 +772,7 @@ func (f *fakeDNSConn) SetDeadline(t time.Time) error {
 }
 
 func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
-       return f.rh(f.s, q, f.t)
+       return f.server.rh(f.n, f.s, q, f.t)
 }
 
 // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
@@ -749,9 +849,6 @@ func TestIgnoreDNSForgeries(t *testing.T) {
 
 // Issue 16865. If a name server times out, continue to the next.
 func TestRetryTimeout(t *testing.T) {
-       origTestHookDNSDialer := testHookDNSDialer
-       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
-
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -766,12 +863,9 @@ func TestRetryTimeout(t *testing.T) {
                t.Fatal(err)
        }
 
-       d := &fakeDNSDialer{}
-       testHookDNSDialer = func() dnsDialer { return d }
-
        var deadline0 time.Time
 
-       d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+       fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
                t.Log(s, q, deadline)
 
                if deadline.IsZero() {
@@ -789,9 +883,10 @@ func TestRetryTimeout(t *testing.T) {
                }
 
                return mockTXTResponse(q), nil
-       }
+       }}
+       r := &Resolver{PreferGo: true, Dial: fake.DialContext}
 
-       _, err = LookupTXT("www.golang.org")
+       _, err = r.LookupTXT(context.Background(), "www.golang.org")
        if err != nil {
                t.Fatal(err)
        }
@@ -810,9 +905,6 @@ func TestRotate(t *testing.T) {
 }
 
 func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
-       origTestHookDNSDialer := testHookDNSDialer
-       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
-
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -831,18 +923,16 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
                t.Fatal(err)
        }
 
-       d := &fakeDNSDialer{}
-       testHookDNSDialer = func() dnsDialer { return d }
-
        var usedServers []string
-       d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+       fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
                usedServers = append(usedServers, s)
                return mockTXTResponse(q), nil
-       }
+       }}
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
 
        // len(nameservers) + 1 to allow rotation to get back to start
        for i := 0; i < len(nameservers)+1; i++ {
-               if _, err := LookupTXT("www.golang.org"); err != nil {
+               if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
                        t.Fatal(err)
                }
        }
@@ -878,9 +968,6 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg {
 // Issue 17448. With StrictErrors enabled, temporary errors should make
 // LookupIP fail rather than return a partial result.
 func TestStrictErrorsLookupIP(t *testing.T) {
-       origTestHookDNSDialer := testHookDNSDialer
-       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
-
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -1017,10 +1104,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
        }
 
        for i, tt := range cases {
-               d := &fakeDNSDialer{}
-               testHookDNSDialer = func() dnsDialer { return d }
-
-               d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+               fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
                        t.Log(s, q)
 
                        switch tt.resolveWhich(&q.question[0]) {
@@ -1082,11 +1166,11 @@ func TestStrictErrorsLookupIP(t *testing.T) {
                                return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
                        }
                        return r, nil
-               }
+               }}
 
                for _, strict := range []bool{true, false} {
-                       r := Resolver{StrictErrors: strict}
-                       ips, err := r.goLookupIP(context.Background(), name)
+                       r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
+                       ips, err := r.LookupIPAddr(context.Background(), name)
 
                        var wantErr error
                        if strict {
@@ -1118,9 +1202,6 @@ func TestStrictErrorsLookupIP(t *testing.T) {
 // Issue 17448. With StrictErrors enabled, temporary errors should make
 // LookupTXT stop walking the search list.
 func TestStrictErrorsLookupTXT(t *testing.T) {
-       origTestHookDNSDialer := testHookDNSDialer
-       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
-
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -1141,10 +1222,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
        const searchY = "test.y.golang.org."
        const txt = "Hello World"
 
-       d := &fakeDNSDialer{}
-       testHookDNSDialer = func() dnsDialer { return d }
-
-       d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+       fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
                t.Log(s, q)
 
                switch q.question[0].Name {
@@ -1155,10 +1233,10 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
                default:
                        return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
                }
-       }
+       }}
 
        for _, strict := range []bool{true, false} {
-               r := Resolver{StrictErrors: strict}
+               r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
                _, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
                var wantErr error
                var wantRRs int
index 463b374aff2166615ecbe010866e4906882e7088..818f91c3dc75c62f05a57cc222a81a78570ca06d 100644 (file)
@@ -107,6 +107,15 @@ type Resolver struct {
        // with resolvers that process AAAA queries incorrectly.
        StrictErrors bool
 
+       // Dial optionally specifies an alternate dialer for use by
+       // Go's built-in DNS resolver to make TCP and UDP connections
+       // to DNS services. The provided addr will always be an IP
+       // address and not a hostname.
+       // The Conn returned must be a *TCPConn or *UDPConn as
+       // requested by the network parameter. If nil, the default
+       // dialer is used.
+       Dial func(ctx context.Context, network, addr string) (Conn, error)
+
        // TODO(bradfitz): optional interface impl override hook
        // TODO(bradfitz): Timeout time.Duration?
 }
index 158cc94a99dfa77539b8b65096d42ebaac5e19d8..a485d706a585ad94b5e79a5d0665df8adfa9aac6 100644 (file)
@@ -8,6 +8,8 @@ package net
 
 import (
        "context"
+       "errors"
+       "reflect"
        "sync"
 )
 
@@ -51,6 +53,31 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
        return lookupProtocolMap(name)
 }
 
+func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) {
+       // Calling Dial here is scary -- we have to be sure not to
+       // dial a name that will require a DNS lookup, or Dial will
+       // call back here to translate it. The DNS config parser has
+       // already checked that all the cfg.servers are IP
+       // addresses, which Dial will use without a DNS lookup.
+       var c Conn
+       var err error
+       if r.Dial != nil {
+               c, err = r.Dial(ctx, network, server)
+       } else {
+               var d Dialer
+               c, err = d.DialContext(ctx, network, server)
+       }
+       if err != nil {
+               return nil, mapErr(err)
+       }
+       dc, ok := c.(dnsConn)
+       if !ok {
+               c.Close()
+               return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String())
+       }
+       return dc, nil
+}
+
 func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        order := systemConf().hostLookupOrder(host)
        if !r.PreferGo && order == hostLookupCgo {