]> Cypherpunks repositories - gostls13.git/commitdiff
net: detect changes to /etc/resolv.conf.
authorGuillaume J. Charmes <guillaume@charmes.net>
Thu, 15 May 2014 00:11:00 +0000 (17:11 -0700)
committerIan Lance Taylor <iant@golang.org>
Thu, 15 May 2014 00:11:00 +0000 (17:11 -0700)
Implement the changes as suggested by rsc.
Fixes #6670.

LGTM=josharian, iant
R=golang-codereviews, iant, josharian, mikioh.mikioh, alex, gobot
CC=golang-codereviews, rsc
https://golang.org/cl/83690045

src/pkg/net/dnsclient_unix.go
src/pkg/net/dnsclient_unix_test.go

index 2211e2190c7af304657fc8b9af880c54e5ce585a..3713efd0e3c9f27b3d95596d2ed3f00c0b82649c 100644 (file)
@@ -8,7 +8,6 @@
 // Has to be linked into package net for Dial.
 
 // TODO(rsc):
-//     Check periodically whether /etc/resolv.conf has changed.
 //     Could potentially handle many outstanding lookups faster.
 //     Could have a small cache.
 //     Random UDP source port (net.Dial should do that for us).
@@ -19,6 +18,7 @@ package net
 import (
        "io"
        "math/rand"
+       "os"
        "sync"
        "time"
 )
@@ -156,33 +156,90 @@ func convertRR_AAAA(records []dnsRR) []IP {
        return addrs
 }
 
-var cfg *dnsConfig
-var dnserr error
+var cfg struct {
+       ch        chan struct{}
+       mu        sync.RWMutex // protects dnsConfig and dnserr
+       dnsConfig *dnsConfig
+       dnserr    error
+}
+var onceLoadConfig sync.Once
 
 // Assume dns config file is /etc/resolv.conf here
-func loadConfig() { cfg, dnserr = dnsReadConfig("/etc/resolv.conf") }
+func loadDefaultConfig() {
+       loadConfig("/etc/resolv.conf", 5*time.Second, nil)
+}
 
-var onceLoadConfig sync.Once
+func loadConfig(resolvConfPath string, reloadTime time.Duration, quit <-chan chan struct{}) {
+       var mtime time.Time
+       cfg.ch = make(chan struct{}, 1)
+       if fi, err := os.Stat(resolvConfPath); err != nil {
+               cfg.dnserr = err
+       } else {
+               mtime = fi.ModTime()
+               cfg.dnsConfig, cfg.dnserr = dnsReadConfig(resolvConfPath)
+       }
+       go func() {
+               for {
+                       time.Sleep(reloadTime)
+                       select {
+                       case qresp := <-quit:
+                               qresp <- struct{}{}
+                               return
+                       case <-cfg.ch:
+                       }
+
+                       // In case of error, we keep the previous config
+                       fi, err := os.Stat(resolvConfPath)
+                       if err != nil {
+                               continue
+                       }
+                       // If the resolv.conf mtime didn't change, do not reload
+                       m := fi.ModTime()
+                       if m.Equal(mtime) {
+                               continue
+                       }
+                       mtime = m
+                       // In case of error, we keep the previous config
+                       ncfg, err := dnsReadConfig(resolvConfPath)
+                       if err != nil || len(ncfg.servers) == 0 {
+                               continue
+                       }
+                       cfg.mu.Lock()
+                       cfg.dnsConfig = ncfg
+                       cfg.dnserr = nil
+                       cfg.mu.Unlock()
+               }
+       }()
+}
 
 func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
        if !isDomainName(name) {
                return name, nil, &DNSError{Err: "invalid domain name", Name: name}
        }
-       onceLoadConfig.Do(loadConfig)
-       if dnserr != nil || cfg == nil {
-               err = dnserr
+       onceLoadConfig.Do(loadDefaultConfig)
+
+       select {
+       case cfg.ch <- struct{}{}:
+       default:
+       }
+
+       cfg.mu.RLock()
+       defer cfg.mu.RUnlock()
+
+       if cfg.dnserr != nil || cfg.dnsConfig == nil {
+               err = cfg.dnserr
                return
        }
        // If name is rooted (trailing dot) or has enough dots,
        // try it by itself first.
        rooted := len(name) > 0 && name[len(name)-1] == '.'
-       if rooted || count(name, '.') >= cfg.ndots {
+       if rooted || count(name, '.') >= cfg.dnsConfig.ndots {
                rname := name
                if !rooted {
                        rname += "."
                }
                // Can try as ordinary name.
-               cname, addrs, err = tryOneName(cfg, rname, qtype)
+               cname, addrs, err = tryOneName(cfg.dnsConfig, rname, qtype)
                if err == nil {
                        return
                }
@@ -192,12 +249,12 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error)
        }
 
        // Otherwise, try suffixes.
-       for i := 0; i < len(cfg.search); i++ {
-               rname := name + "." + cfg.search[i]
+       for i := 0; i < len(cfg.dnsConfig.search); i++ {
+               rname := name + "." + cfg.dnsConfig.search[i]
                if rname[len(rname)-1] != '.' {
                        rname += "."
                }
-               cname, addrs, err = tryOneName(cfg, rname, qtype)
+               cname, addrs, err = tryOneName(cfg.dnsConfig, rname, qtype)
                if err == nil {
                        return
                }
@@ -208,7 +265,7 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error)
        if !rooted {
                rname += "."
        }
-       cname, addrs, err = tryOneName(cfg, rname, qtype)
+       cname, addrs, err = tryOneName(cfg.dnsConfig, rname, qtype)
        if err == nil {
                return
        }
@@ -233,11 +290,6 @@ func goLookupHost(name string) (addrs []string, err error) {
        if len(addrs) > 0 {
                return
        }
-       onceLoadConfig.Do(loadConfig)
-       if dnserr != nil || cfg == nil {
-               err = dnserr
-               return
-       }
        ips, err := goLookupIP(name)
        if err != nil {
                return
@@ -268,11 +320,6 @@ func goLookupIP(name string) (addrs []IP, err error) {
                        return
                }
        }
-       onceLoadConfig.Do(loadConfig)
-       if dnserr != nil || cfg == nil {
-               err = dnserr
-               return
-       }
        var records []dnsRR
        var cname string
        var err4, err6 error
@@ -308,11 +355,6 @@ func goLookupIP(name string) (addrs []IP, err error) {
 // depending on our lookup code, so that Go and C get the same
 // answers.
 func goLookupCNAME(name string) (cname string, err error) {
-       onceLoadConfig.Do(loadConfig)
-       if dnserr != nil || cfg == nil {
-               err = dnserr
-               return
-       }
        _, rr, err := lookup(name, dnsTypeCNAME)
        if err != nil {
                return
index a2fdda35656c2ce91df3808d7f79d34833692113..2350142d61052e06f40644baa3200016df3bb27a 100644 (file)
@@ -7,7 +7,13 @@
 package net
 
 import (
+       "io"
+       "io/ioutil"
+       "os"
+       "path"
+       "reflect"
        "testing"
+       "time"
 )
 
 func TestTCPLookup(t *testing.T) {
@@ -25,3 +31,129 @@ func TestTCPLookup(t *testing.T) {
                t.Fatalf("exchange failed: %v", err)
        }
 }
+
+type resolvConfTest struct {
+       *testing.T
+       dir     string
+       path    string
+       started bool
+       quitc   chan chan struct{}
+}
+
+func newResolvConfTest(t *testing.T) *resolvConfTest {
+       dir, err := ioutil.TempDir("", "resolvConfTest")
+       if err != nil {
+               t.Fatalf("could not create temp dir: %v", err)
+       }
+
+       // Disable the default loadConfig
+       onceLoadConfig.Do(func() {})
+
+       r := &resolvConfTest{
+               T:     t,
+               dir:   dir,
+               path:  path.Join(dir, "resolv.conf"),
+               quitc: make(chan chan struct{}),
+       }
+
+       return r
+}
+
+func (r *resolvConfTest) Start() {
+       loadConfig(r.path, 100*time.Millisecond, r.quitc)
+       r.started = true
+}
+
+func (r *resolvConfTest) SetConf(s string) {
+       // Make sure the file mtime will be different once we're done here,
+       // even on systems with coarse (1s) mtime resolution.
+       time.Sleep(time.Second)
+
+       f, err := os.OpenFile(r.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+       if err != nil {
+               r.Fatalf("failed to create temp file %s: %v", r.path, err)
+       }
+       if _, err := io.WriteString(f, s); err != nil {
+               f.Close()
+               r.Fatalf("failed to write temp file: %v", err)
+       }
+       f.Close()
+
+       if r.started {
+               cfg.ch <- struct{}{} // fill buffer
+               cfg.ch <- struct{}{} // wait for reload to begin
+               cfg.ch <- struct{}{} // wait for reload to complete
+       }
+}
+
+func (r *resolvConfTest) WantServers(want []string) {
+       cfg.mu.RLock()
+       defer cfg.mu.RUnlock()
+       if got := cfg.dnsConfig.servers; !reflect.DeepEqual(got, want) {
+               r.Fatalf("Unexpected dns server loaded, got %v want %v", got, want)
+       }
+}
+
+func (r *resolvConfTest) Close() {
+       resp := make(chan struct{})
+       r.quitc <- resp
+       <-resp
+       if err := os.RemoveAll(r.dir); err != nil {
+               r.Logf("failed to remove temp dir %s: %v", r.dir, err)
+       }
+}
+
+func TestReloadResolvConfFail(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+
+       r := newResolvConfTest(t)
+       defer r.Close()
+
+       // resolv.conf.tmp does not exist yet
+       r.Start()
+       if _, err := goLookupIP("golang.org"); err == nil {
+               t.Fatal("goLookupIP(missing) succeeded")
+       }
+
+       r.SetConf("nameserver 8.8.8.8")
+       if _, err := goLookupIP("golang.org"); err != nil {
+               t.Fatalf("goLookupIP(missing; good) failed: %v", err)
+       }
+
+       // Using a bad resolv.conf while we had a good
+       // one before should not update the config
+       r.SetConf("")
+       if _, err := goLookupIP("golang.org"); err != nil {
+               t.Fatalf("goLookupIP(missing; good; bad) failed: %v", err)
+       }
+}
+
+func TestReloadResolvConfChange(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+
+       r := newResolvConfTest(t)
+       defer r.Close()
+
+       r.SetConf("nameserver 8.8.8.8")
+       r.Start()
+
+       if _, err := goLookupIP("golang.org"); err != nil {
+               t.Fatalf("goLookupIP(good) failed: %v", err)
+       }
+       r.WantServers([]string{"[8.8.8.8]"})
+
+       // Using a bad resolv.conf when we had a good one
+       // before should not update the config
+       r.SetConf("")
+       if _, err := goLookupIP("golang.org"); err != nil {
+               t.Fatalf("goLookupIP(good; bad) failed: %v", err)
+       }
+
+       // A new good config should get picked up
+       r.SetConf("nameserver 8.8.4.4")
+       r.WantServers([]string{"[8.8.4.4]"})
+}