]> Cypherpunks repositories - gostls13.git/commitdiff
net: search domain from hostname if no search directives
authorDan Peterson <dpiddy@gmail.com>
Sat, 2 Apr 2016 21:07:24 +0000 (18:07 -0300)
committerMatthew Dempsky <mdempsky@google.com>
Wed, 27 Apr 2016 21:14:32 +0000 (21:14 +0000)
Fixes #14897

Change-Id: Iffe7462983a5623a37aa0dc6f74c8c70e10c3244
Reviewed-on: https://go-review.googlesource.com/21464
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/net/dnsconfig_unix.go
src/net/dnsconfig_unix_test.go

index 181d47b36d0e3b1ad7752096062e3212ab7fde47..9893cb7e634ad022e60f94ff1177cbf7b04668e0 100644 (file)
@@ -8,9 +8,15 @@
 
 package net
 
-import "time"
+import (
+       "os"
+       "time"
+)
 
-var defaultNS = []string{"127.0.0.1", "::1"}
+var (
+       defaultNS   = []string{"127.0.0.1", "::1"}
+       getHostname = os.Hostname // variable for testing
+)
 
 type dnsConfig struct {
        servers    []string  // servers to use
@@ -26,8 +32,6 @@ type dnsConfig struct {
 }
 
 // See resolv.conf(5) on a Linux machine.
-// TODO(rsc): Supposed to call uname() and chop the beginning
-// of the host name to get the default search domain.
 func dnsReadConfig(filename string) *dnsConfig {
        conf := &dnsConfig{
                ndots:    1,
@@ -37,6 +41,7 @@ func dnsReadConfig(filename string) *dnsConfig {
        file, err := open(filename)
        if err != nil {
                conf.servers = defaultNS
+               conf.search = dnsDefaultSearch()
                conf.err = err
                return conf
        }
@@ -45,6 +50,7 @@ func dnsReadConfig(filename string) *dnsConfig {
                conf.mtime = fi.ModTime()
        } else {
                conf.servers = defaultNS
+               conf.search = dnsDefaultSearch()
                conf.err = err
                return conf
        }
@@ -122,9 +128,24 @@ func dnsReadConfig(filename string) *dnsConfig {
        if len(conf.servers) == 0 {
                conf.servers = defaultNS
        }
+       if len(conf.search) == 0 {
+               conf.search = dnsDefaultSearch()
+       }
        return conf
 }
 
+func dnsDefaultSearch() []string {
+       hn, err := getHostname()
+       if err != nil {
+               // best effort
+               return nil
+       }
+       if i := byteIndex(hn, '.'); i >= 0 && i < len(hn)-1 {
+               return []string{hn[i+1:]}
+       }
+       return nil
+}
+
 func hasPrefix(s, prefix string) bool {
        return len(s) >= len(prefix) && s[:len(prefix)] == prefix
 }
index 849c0da93bafb84d89fd3db3b060907a88242e01..f9ef79cba86a5550bc842e2a4b871748cbec6cba 100644 (file)
@@ -7,6 +7,7 @@
 package net
 
 import (
+       "errors"
        "os"
        "reflect"
        "testing"
@@ -56,6 +57,7 @@ var dnsReadConfigTests = []struct {
                        ndots:    1,
                        timeout:  5,
                        attempts: 2,
+                       search:   []string{"domain.local"},
                },
        },
        {
@@ -72,6 +74,10 @@ var dnsReadConfigTests = []struct {
 }
 
 func TestDNSReadConfig(t *testing.T) {
+       origGetHostname := getHostname
+       defer func() { getHostname = origGetHostname }()
+       getHostname = func() (string, error) { return "host.domain.local", nil }
+
        for _, tt := range dnsReadConfigTests {
                conf := dnsReadConfig(tt.name)
                if conf.err != nil {
@@ -85,6 +91,10 @@ func TestDNSReadConfig(t *testing.T) {
 }
 
 func TestDNSReadMissingFile(t *testing.T) {
+       origGetHostname := getHostname
+       defer func() { getHostname = origGetHostname }()
+       getHostname = func() (string, error) { return "host.domain.local", nil }
+
        conf := dnsReadConfig("a-nonexistent-file")
        if !os.IsNotExist(conf.err) {
                t.Errorf("missing resolv.conf:\ngot: %v\nwant: %v", conf.err, os.ErrNotExist)
@@ -95,8 +105,52 @@ func TestDNSReadMissingFile(t *testing.T) {
                ndots:    1,
                timeout:  5,
                attempts: 2,
+               search:   []string{"domain.local"},
        }
        if !reflect.DeepEqual(conf, want) {
                t.Errorf("missing resolv.conf:\ngot: %+v\nwant: %+v", conf, want)
        }
 }
+
+var dnsDefaultSearchTests = []struct {
+       name string
+       err  error
+       want []string
+}{
+       {
+               name: "host.long.domain.local",
+               want: []string{"long.domain.local"},
+       },
+       {
+               name: "host.local",
+               want: []string{"local"},
+       },
+       {
+               name: "host",
+               want: nil,
+       },
+       {
+               name: "host.domain.local",
+               err:  errors.New("errored"),
+               want: nil,
+       },
+       {
+               // ensures we don't return []string{""}
+               // which causes duplicate lookups
+               name: "foo.",
+               want: nil,
+       },
+}
+
+func TestDNSDefaultSearch(t *testing.T) {
+       origGetHostname := getHostname
+       defer func() { getHostname = origGetHostname }()
+
+       for _, tt := range dnsDefaultSearchTests {
+               getHostname = func() (string, error) { return tt.name, tt.err }
+               got := dnsDefaultSearch()
+               if !reflect.DeepEqual(got, tt.want) {
+                       t.Errorf("dnsDefaultSearch with hostname %q and error %+v = %q, wanted %q", tt.name, tt.err, got, tt.want)
+               }
+       }
+}