]> Cypherpunks repositories - gostls13.git/commitdiff
net: clean up builtin DNS stub resolver, fix tests
authorMikio Hara <mikioh.mikioh@gmail.com>
Thu, 11 Jun 2015 03:46:01 +0000 (12:46 +0900)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 13 Jul 2015 19:29:25 +0000 (19:29 +0000)
This change does clean up as preparation for fixing #11081.

- renames cfg to resolvConf for clarification
- adds a new type resolverConfig and its methods: init, update,
  tryAcquireSema, releaseSema for mutual exclusion of resolv.conf data
- deflakes, simplifies tests for resolv.conf data; previously the tests
  sometimes left some garbage in the data

Change-Id: I277ced853fddc3791dde40ab54dbd5c78114b78c
Reviewed-on: https://go-review.googlesource.com/10931
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>

src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go

index 8a1745f3cb9b091d5c6c49de3c3718d306193009..8f636055ab936f2020a0b8c7257d9fb74a9b75ef 100644 (file)
@@ -195,91 +195,106 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err
        return "", nil, lastErr
 }
 
-func convertRR_A(records []dnsRR) []IP {
-       addrs := make([]IP, len(records))
-       for i, rr := range records {
-               a := rr.(*dnsRR_A).A
-               addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a))
+// addrRecordList converts and returns a list of IP addresses from DNS
+// address records (both A and AAAA). Other record types are ignored.
+func addrRecordList(rrs []dnsRR) []IPAddr {
+       addrs := make([]IPAddr, 0, 4)
+       for _, rr := range rrs {
+               switch rr := rr.(type) {
+               case *dnsRR_A:
+                       addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
+               case *dnsRR_AAAA:
+                       ip := make(IP, IPv6len)
+                       copy(ip, rr.AAAA[:])
+                       addrs = append(addrs, IPAddr{IP: ip})
+               }
        }
        return addrs
 }
 
-func convertRR_AAAA(records []dnsRR) []IP {
-       addrs := make([]IP, len(records))
-       for i, rr := range records {
-               a := make(IP, IPv6len)
-               copy(a, rr.(*dnsRR_AAAA).AAAA[:])
-               addrs[i] = a
-       }
-       return addrs
-}
+// A resolverConfig represents a DNS stub resolver configuration.
+type resolverConfig struct {
+       initOnce sync.Once // guards init of resolverConfig
 
-// cfg is used for the storage and reparsing of /etc/resolv.conf
-var cfg struct {
-       // ch is used as a semaphore that only allows one lookup at a time to
-       // recheck resolv.conf.  It acts as guard for lastChecked and modTime.
-       ch          chan struct{}
-       lastChecked time.Time // last time resolv.conf was checked
-       modTime     time.Time // time of resolv.conf modification
+       // ch is used as a semaphore that only allows one lookup at a
+       // time to recheck resolv.conf.
+       ch          chan struct{} // guards lastChecked and modTime
+       lastChecked time.Time     // last time resolv.conf was checked
+       modTime     time.Time     // time of resolv.conf modification
 
        mu        sync.RWMutex // protects dnsConfig
        dnsConfig *dnsConfig   // parsed resolv.conf structure used in lookups
 }
 
-var onceLoadConfig sync.Once
+var resolvConf resolverConfig
 
-func initCfg() {
+// init initializes conf and is only called via conf.initOnce.
+func (conf *resolverConfig) init() {
        // Set dnsConfig, modTime, and lastChecked so we don't parse
        // resolv.conf twice the first time.
-       cfg.dnsConfig = systemConf().resolv
-       if cfg.dnsConfig == nil {
-               cfg.dnsConfig = dnsReadConfig("/etc/resolv.conf")
+       conf.dnsConfig = systemConf().resolv
+       if conf.dnsConfig == nil {
+               conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
        }
 
        if fi, err := os.Stat("/etc/resolv.conf"); err == nil {
-               cfg.modTime = fi.ModTime()
+               conf.modTime = fi.ModTime()
        }
-       cfg.lastChecked = time.Now()
+       conf.lastChecked = time.Now()
 
-       // Prepare ch so that only one loadConfig may run at once
-       cfg.ch = make(chan struct{}, 1)
-       cfg.ch <- struct{}{}
+       // Prepare ch so that only one update of resolverConfig may
+       // run at once.
+       conf.ch = make(chan struct{}, 1)
 }
 
-func loadConfig(resolvConfPath string) {
-       onceLoadConfig.Do(initCfg)
+// tryUpdate tries to update conf with the named resolv.conf file.
+// The name variable only exists for testing. It is otherwise always
+// "/etc/resolv.conf".
+func (conf *resolverConfig) tryUpdate(name string) {
+       conf.initOnce.Do(conf.init)
 
-       // ensure only one loadConfig at a time checks /etc/resolv.conf
-       select {
-       case <-cfg.ch:
-               defer func() { cfg.ch <- struct{}{} }()
-       default:
+       // Ensure only one update at a time checks resolv.conf.
+       if !conf.tryAcquireSema() {
                return
        }
+       defer conf.releaseSema()
 
        now := time.Now()
-       if cfg.lastChecked.After(now.Add(-5 * time.Second)) {
+       if conf.lastChecked.After(now.Add(-5 * time.Second)) {
                return
        }
-       cfg.lastChecked = now
+       conf.lastChecked = now
 
-       if fi, err := os.Stat(resolvConfPath); err == nil {
-               if fi.ModTime().Equal(cfg.modTime) {
+       if fi, err := os.Stat(name); err == nil {
+               if fi.ModTime().Equal(conf.modTime) {
                        return
                }
-               cfg.modTime = fi.ModTime()
+               conf.modTime = fi.ModTime()
        } else {
                // If modTime wasn't set prior, assume nothing has changed.
-               if cfg.modTime.IsZero() {
+               if conf.modTime.IsZero() {
                        return
                }
-               cfg.modTime = time.Time{}
+               conf.modTime = time.Time{}
+       }
+
+       dnsConf := dnsReadConfig(name)
+       conf.mu.Lock()
+       conf.dnsConfig = dnsConf
+       conf.mu.Unlock()
+}
+
+func (conf *resolverConfig) tryAcquireSema() bool {
+       select {
+       case conf.ch <- struct{}{}:
+               return true
+       default:
+               return false
        }
+}
 
-       ncfg := dnsReadConfig(resolvConfPath)
-       cfg.mu.Lock()
-       cfg.dnsConfig = ncfg
-       cfg.mu.Unlock()
+func (conf *resolverConfig) releaseSema() {
+       <-conf.ch
 }
 
 func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
@@ -287,10 +302,10 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
                return name, nil, &DNSError{Err: "invalid domain name", Name: name}
        }
 
-       loadConfig("/etc/resolv.conf")
-       cfg.mu.RLock()
-       resolv := cfg.dnsConfig
-       cfg.mu.RUnlock()
+       resolvConf.tryUpdate("/etc/resolv.conf")
+       resolvConf.mu.RLock()
+       resolv := resolvConf.dnsConfig
+       resolvConf.mu.RUnlock()
 
        // If name is rooted (trailing dot) or has enough dots,
        // try it by itself first.
@@ -441,18 +456,7 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
                        lastErr = racer.error
                        continue
                }
-               switch racer.qtype {
-               case dnsTypeA:
-                       for _, ip := range convertRR_A(racer.rrs) {
-                               addr := IPAddr{IP: ip}
-                               addrs = append(addrs, addr)
-                       }
-               case dnsTypeAAAA:
-                       for _, ip := range convertRR_AAAA(racer.rrs) {
-                               addr := IPAddr{IP: ip}
-                               addrs = append(addrs, addr)
-                       }
-               }
+               addrs = append(addrs, addrRecordList(racer.rrs)...)
        }
        if len(addrs) == 0 {
                if lastErr != nil {
@@ -472,11 +476,11 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
 // depending on our lookup code, so that Go and C get the same
 // answers.
 func goLookupCNAME(name string) (cname string, err error) {
-       _, rr, err := lookup(name, dnsTypeCNAME)
+       _, rrs, err := lookup(name, dnsTypeCNAME)
        if err != nil {
                return
        }
-       cname = rr[0].(*dnsRR_CNAME).Cname
+       cname = rrs[0].(*dnsRR_CNAME).Cname
        return
 }
 
index 06c9ad31344ffe743fe18372fbb782aa881e7561..c6bfc67abc24d573b16b7a4c143304e399cf0549 100644 (file)
@@ -7,11 +7,13 @@
 package net
 
 import (
-       "io"
+       "fmt"
        "io/ioutil"
        "os"
        "path"
        "reflect"
+       "strings"
+       "sync"
        "testing"
        "time"
 )
@@ -93,117 +95,133 @@ func TestSpecialDomainName(t *testing.T) {
 }
 
 type resolvConfTest struct {
-       *testing.T
        dir  string
        path string
+       *resolverConfig
 }
 
-func newResolvConfTest(t *testing.T) *resolvConfTest {
+func newResolvConfTest() (*resolvConfTest, error) {
        dir, err := ioutil.TempDir("", "go-resolvconftest")
        if err != nil {
-               t.Fatal(err)
+               return nil, err
        }
-
-       r := &resolvConfTest{
-               T:    t,
-               dir:  dir,
-               path: path.Join(dir, "resolv.conf"),
+       conf := &resolvConfTest{
+               dir:            dir,
+               path:           path.Join(dir, "resolv.conf"),
+               resolverConfig: &resolvConf,
        }
-
-       return r
+       conf.initOnce.Do(conf.init)
+       return conf, nil
 }
 
-func (r *resolvConfTest) SetConf(s string) {
-       // Make sure the file mtime will be different once we're done here,
-       // even on systems with coarse (1s) mtime resolution.
-       time.Sleep(time.Second)
-
-       f, err := os.OpenFile(r.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
+       f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
        if err != nil {
-               r.Fatalf("failed to create temp file %s: %v", r.path, err)
+               return err
        }
-       if _, err := io.WriteString(f, s); err != nil {
+       if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
                f.Close()
-               r.Fatalf("failed to write temp file: %v", err)
+               return err
        }
        f.Close()
-       cfg.lastChecked = time.Time{}
-       loadConfig(r.path)
-}
-
-func (r *resolvConfTest) WantServers(want []string) {
-       cfg.mu.RLock()
-       defer cfg.mu.RUnlock()
-       if got := cfg.dnsConfig.servers; !reflect.DeepEqual(got, want) {
-               r.Fatalf("unexpected dns server loaded, got %v want %v", got, want)
+       if err := conf.forceUpdate(conf.path); err != nil {
+               return err
        }
+       return nil
 }
 
-func (r *resolvConfTest) Close() {
-       if err := os.RemoveAll(r.dir); err != nil {
-               r.Logf("failed to remove temp dir %s: %v", r.dir, err)
+func (conf *resolvConfTest) forceUpdate(name string) error {
+       dnsConf := dnsReadConfig(name)
+       conf.mu.Lock()
+       conf.dnsConfig = dnsConf
+       conf.mu.Unlock()
+       for i := 0; i < 5; i++ {
+               if conf.tryAcquireSema() {
+                       conf.lastChecked = time.Time{}
+                       conf.releaseSema()
+                       return nil
+               }
        }
+       return fmt.Errorf("tryAcquireSema for %s failed", name)
 }
 
-func TestReloadResolvConfFail(t *testing.T) {
-       if testing.Short() || !*testExternal {
-               t.Skip("avoid external network")
-       }
-
-       r := newResolvConfTest(t)
-       defer r.Close()
-
-       r.SetConf("nameserver 8.8.8.8")
-
-       if _, err := goLookupIP("golang.org"); err != nil {
-               t.Fatal(err)
-       }
-
-       // Using an empty resolv.conf should use localhost as servers
-       r.SetConf("")
+func (conf *resolvConfTest) servers() []string {
+       conf.mu.RLock()
+       servers := conf.dnsConfig.servers
+       conf.mu.RUnlock()
+       return servers
+}
 
-       if len(cfg.dnsConfig.servers) != len(defaultNS) {
-               t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
-       }
+func (conf *resolvConfTest) teardown() error {
+       err := conf.forceUpdate("/etc/resolv.conf")
+       os.RemoveAll(conf.dir)
+       return err
+}
 
-       for i := range cfg.dnsConfig.servers {
-               if cfg.dnsConfig.servers[i] != defaultNS[i] {
-                       t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
-               }
-       }
+var updateResolvConfTests = []struct {
+       name    string   // query name
+       lines   []string // resolver configuration lines
+       servers []string // expected name servers
+}{
+       {
+               name:    "golang.org",
+               lines:   []string{"nameserver 8.8.8.8"},
+               servers: []string{"8.8.8.8"},
+       },
+       {
+               name:    "",
+               lines:   nil, // an empty resolv.conf should use defaultNS as name servers
+               servers: defaultNS,
+       },
+       {
+               name:    "www.example.com",
+               lines:   []string{"nameserver 8.8.4.4"},
+               servers: []string{"8.8.4.4"},
+       },
 }
 
-func TestReloadResolvConfChange(t *testing.T) {
+func TestUpdateResolvConf(t *testing.T) {
        if testing.Short() || !*testExternal {
                t.Skip("avoid external network")
        }
 
-       r := newResolvConfTest(t)
-       defer r.Close()
-
-       r.SetConf("nameserver 8.8.8.8")
-
-       if _, err := goLookupIP("golang.org"); err != nil {
+       conf, err := newResolvConfTest()
+       if err != nil {
                t.Fatal(err)
        }
-       r.WantServers([]string{"8.8.8.8"})
+       defer conf.teardown()
 
-       // Using an empty resolv.conf should use localhost as servers
-       r.SetConf("")
-
-       if len(cfg.dnsConfig.servers) != len(defaultNS) {
-               t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
-       }
-
-       for i := range cfg.dnsConfig.servers {
-               if cfg.dnsConfig.servers[i] != defaultNS[i] {
-                       t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
+       for i, tt := range updateResolvConfTests {
+               if err := conf.writeAndUpdate(tt.lines); err != nil {
+                       t.Error(err)
+                       continue
+               }
+               if tt.name != "" {
+                       var wg sync.WaitGroup
+                       const N = 10
+                       wg.Add(N)
+                       for j := 0; j < N; j++ {
+                               go func(name string) {
+                                       defer wg.Done()
+                                       ips, err := goLookupIP(name)
+                                       if err != nil {
+                                               t.Error(err)
+                                               return
+                                       }
+                                       if len(ips) == 0 {
+                                               t.Errorf("no records for %s", name)
+                                               return
+                                       }
+                               }(tt.name)
+                       }
+                       wg.Wait()
+               }
+               servers := conf.servers()
+               if !reflect.DeepEqual(servers, tt.servers) {
+                       t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
+                       continue
                }
        }
-
-       // A new good config should get picked up
-       r.SetConf("nameserver 8.8.4.4")
-       r.WantServers([]string{"8.8.4.4"})
 }
 
 func BenchmarkGoLookupIP(b *testing.B) {
@@ -225,14 +243,21 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
 func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
        testHookUninstaller.Do(uninstallTestHooks)
 
-       // This looks ugly but it's safe as long as benchmarks are run
-       // sequentially in package testing.
-       <-cfg.ch // keep config from being reloaded upon lookup
-       orig := cfg.dnsConfig
-       cfg.dnsConfig.servers = append([]string{"203.0.113.254"}, cfg.dnsConfig.servers...) // use TEST-NET-3 block, see RFC 5737
+       conf, err := newResolvConfTest()
+       if err != nil {
+               b.Fatal(err)
+       }
+       defer conf.teardown()
+
+       lines := []string{
+               "nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737
+               "nameserver 8.8.8.8",
+       }
+       if err := conf.writeAndUpdate(lines); err != nil {
+               b.Fatal(err)
+       }
+
        for i := 0; i < b.N; i++ {
                goLookupIP("www.example.com")
        }
-       cfg.dnsConfig = orig
-       cfg.ch <- struct{}{}
 }