From cbd72318b964bde9d95102571cd22d1919dbd37f Mon Sep 17 00:00:00 2001 From: Dan Peterson Date: Sat, 2 Apr 2016 18:07:24 -0300 Subject: [PATCH] net: search domain from hostname if no search directives Fixes #14897 Change-Id: Iffe7462983a5623a37aa0dc6f74c8c70e10c3244 Reviewed-on: https://go-review.googlesource.com/21464 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Matthew Dempsky --- src/net/dnsconfig_unix.go | 29 +++++++++++++++--- src/net/dnsconfig_unix_test.go | 54 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/net/dnsconfig_unix.go b/src/net/dnsconfig_unix.go index 181d47b36d..9893cb7e63 100644 --- a/src/net/dnsconfig_unix.go +++ b/src/net/dnsconfig_unix.go @@ -8,9 +8,15 @@ package net -import "time" +import ( + "os" + "time" +) -var defaultNS = []string{"127.0.0.1", "::1"} +var ( + defaultNS = []string{"127.0.0.1", "::1"} + getHostname = os.Hostname // variable for testing +) type dnsConfig struct { servers []string // servers to use @@ -26,8 +32,6 @@ type dnsConfig struct { } // See resolv.conf(5) on a Linux machine. -// TODO(rsc): Supposed to call uname() and chop the beginning -// of the host name to get the default search domain. func dnsReadConfig(filename string) *dnsConfig { conf := &dnsConfig{ ndots: 1, @@ -37,6 +41,7 @@ func dnsReadConfig(filename string) *dnsConfig { file, err := open(filename) if err != nil { conf.servers = defaultNS + conf.search = dnsDefaultSearch() conf.err = err return conf } @@ -45,6 +50,7 @@ func dnsReadConfig(filename string) *dnsConfig { conf.mtime = fi.ModTime() } else { conf.servers = defaultNS + conf.search = dnsDefaultSearch() conf.err = err return conf } @@ -122,9 +128,24 @@ func dnsReadConfig(filename string) *dnsConfig { if len(conf.servers) == 0 { conf.servers = defaultNS } + if len(conf.search) == 0 { + conf.search = dnsDefaultSearch() + } return conf } +func dnsDefaultSearch() []string { + hn, err := getHostname() + if err != nil { + // best effort + return nil + } + if i := byteIndex(hn, '.'); i >= 0 && i < len(hn)-1 { + return []string{hn[i+1:]} + } + return nil +} + func hasPrefix(s, prefix string) bool { return len(s) >= len(prefix) && s[:len(prefix)] == prefix } diff --git a/src/net/dnsconfig_unix_test.go b/src/net/dnsconfig_unix_test.go index 849c0da93b..f9ef79cba8 100644 --- a/src/net/dnsconfig_unix_test.go +++ b/src/net/dnsconfig_unix_test.go @@ -7,6 +7,7 @@ package net import ( + "errors" "os" "reflect" "testing" @@ -56,6 +57,7 @@ var dnsReadConfigTests = []struct { ndots: 1, timeout: 5, attempts: 2, + search: []string{"domain.local"}, }, }, { @@ -72,6 +74,10 @@ var dnsReadConfigTests = []struct { } func TestDNSReadConfig(t *testing.T) { + origGetHostname := getHostname + defer func() { getHostname = origGetHostname }() + getHostname = func() (string, error) { return "host.domain.local", nil } + for _, tt := range dnsReadConfigTests { conf := dnsReadConfig(tt.name) if conf.err != nil { @@ -85,6 +91,10 @@ func TestDNSReadConfig(t *testing.T) { } func TestDNSReadMissingFile(t *testing.T) { + origGetHostname := getHostname + defer func() { getHostname = origGetHostname }() + getHostname = func() (string, error) { return "host.domain.local", nil } + conf := dnsReadConfig("a-nonexistent-file") if !os.IsNotExist(conf.err) { t.Errorf("missing resolv.conf:\ngot: %v\nwant: %v", conf.err, os.ErrNotExist) @@ -95,8 +105,52 @@ func TestDNSReadMissingFile(t *testing.T) { ndots: 1, timeout: 5, attempts: 2, + search: []string{"domain.local"}, } if !reflect.DeepEqual(conf, want) { t.Errorf("missing resolv.conf:\ngot: %+v\nwant: %+v", conf, want) } } + +var dnsDefaultSearchTests = []struct { + name string + err error + want []string +}{ + { + name: "host.long.domain.local", + want: []string{"long.domain.local"}, + }, + { + name: "host.local", + want: []string{"local"}, + }, + { + name: "host", + want: nil, + }, + { + name: "host.domain.local", + err: errors.New("errored"), + want: nil, + }, + { + // ensures we don't return []string{""} + // which causes duplicate lookups + name: "foo.", + want: nil, + }, +} + +func TestDNSDefaultSearch(t *testing.T) { + origGetHostname := getHostname + defer func() { getHostname = origGetHostname }() + + for _, tt := range dnsDefaultSearchTests { + getHostname = func() (string, error) { return tt.name, tt.err } + got := dnsDefaultSearch() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("dnsDefaultSearch with hostname %q and error %+v = %q, wanted %q", tt.name, tt.err, got, tt.want) + } + } +} -- 2.48.1