]> Cypherpunks repositories - gostls13.git/commitdiff
net: add support for /etc/hosts aliases using go resolver
authorMateusz Poliwczak <mpoliwczak34@gmail.com>
Thu, 10 Nov 2022 18:07:57 +0000 (18:07 +0000)
committerGopher Robot <gobot@golang.org>
Thu, 10 Nov 2022 18:29:14 +0000 (18:29 +0000)
It adds support for /etc/hosts aliases and fixes the difference between the glibc cgo and the go DNS resolver.
Examples: https://pastebin.com/Fv6UcAVr

Fixes #44741

Change-Id: I98c484fced900731fbad800278b296028a45f044
GitHub-Last-Rev: 3d47e44f11c350df906d0c986e41891dd6e8d929
GitHub-Pull-Request: golang/go#51004
Reviewed-on: https://go-review.googlesource.com/c/go/+/382996
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/hosts.go
src/net/hosts_test.go
src/net/testdata/aliases [new file with mode: 0644]

index aed6a337dec2c1728464431522532bee3d64296e..7cb30c0402ee2221cdd53f4d81264f99066e6166 100644 (file)
@@ -559,7 +559,7 @@ func (r *Resolver) goLookupHost(ctx context.Context, name string) (addrs []strin
 func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
        if order == hostLookupFilesDNS || order == hostLookupFiles {
                // Use entries from /etc/hosts if they match.
-               addrs = lookupStaticHost(name)
+               addrs, _ = lookupStaticHost(name)
                if len(addrs) > 0 || order == hostLookupFiles {
                        return
                }
@@ -576,8 +576,9 @@ func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hos
 }
 
 // lookup entries from /etc/hosts
-func goLookupIPFiles(name string) (addrs []IPAddr) {
-       for _, haddr := range lookupStaticHost(name) {
+func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) {
+       addr, canonical := lookupStaticHost(name)
+       for _, haddr := range addr {
                haddr, zone := splitHostZone(haddr)
                if ip := ParseIP(haddr); ip != nil {
                        addr := IPAddr{IP: ip, Zone: zone}
@@ -585,7 +586,7 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
                }
        }
        sortByRFC6724(addrs)
-       return
+       return addrs, canonical
 }
 
 // goLookupIP is the native Go implementation of LookupIP.
@@ -598,11 +599,23 @@ func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs
 
 func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
        if order == hostLookupFilesDNS || order == hostLookupFiles {
-               addrs = goLookupIPFiles(name)
-               if len(addrs) > 0 || order == hostLookupFiles {
-                       return addrs, dnsmessage.Name{}, nil
+               var canonical string
+               addrs, canonical = goLookupIPFiles(name)
+
+               if len(addrs) > 0 {
+                       var err error
+                       cname, err = dnsmessage.NewName(canonical)
+                       if err != nil {
+                               return nil, dnsmessage.Name{}, err
+                       }
+                       return addrs, cname, nil
+               }
+
+               if order == hostLookupFiles {
+                       return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
                }
        }
+
        if !isDomainName(name) {
                // See comment in func lookup above about use of errNoSuchHost.
                return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
@@ -776,9 +789,18 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
        sortByRFC6724(addrs)
        if len(addrs) == 0 && !(network == "CNAME" && cname.Length > 0) {
                if order == hostLookupDNSFiles {
-                       addrs = goLookupIPFiles(name)
+                       var canonical string
+                       addrs, canonical = goLookupIPFiles(name)
+                       if len(addrs) > 0 {
+                               var err error
+                               cname, err = dnsmessage.NewName(canonical)
+                               if err != nil {
+                                       return nil, dnsmessage.Name{}, err
+                               }
+                               return addrs, cname, nil
+                       }
                }
-               if len(addrs) == 0 && lastErr != nil {
+               if lastErr != nil {
                        return nil, dnsmessage.Name{}, lastErr
                }
        }
index 63d3c51163472eae29d600cd10d89e7bae45152d..a9a55671c24550c7da4652ce642c488b771b0b37 100644 (file)
@@ -2170,6 +2170,56 @@ func TestRootNS(t *testing.T) {
        }
 }
 
+func TestGoLookupIPCNAMEOrderHostsAliasesFilesOnlyMode(t *testing.T) {
+       defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+       testHookHostsPath = "testdata/aliases"
+       mode := hostLookupFiles
+
+       for _, v := range lookupStaticHostAliasesTest {
+               testGoLookupIPCNAMEOrderHostsAliases(t, mode, v.lookup, absDomainName(v.res))
+       }
+}
+
+func TestGoLookupIPCNAMEOrderHostsAliasesFilesDNSMode(t *testing.T) {
+       defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+       testHookHostsPath = "testdata/aliases"
+       mode := hostLookupFilesDNS
+
+       for _, v := range lookupStaticHostAliasesTest {
+               testGoLookupIPCNAMEOrderHostsAliases(t, mode, v.lookup, absDomainName(v.res))
+       }
+}
+
+var goLookupIPCNAMEOrderDNSFilesModeTests = []struct {
+       lookup, res string
+}{
+       // 127.0.1.1
+       {"invalid.invalid", "invalid.test"},
+}
+
+func TestGoLookupIPCNAMEOrderHostsAliasesDNSFilesMode(t *testing.T) {
+       defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+       testHookHostsPath = "testdata/aliases"
+       mode := hostLookupDNSFiles
+
+       for _, v := range goLookupIPCNAMEOrderDNSFilesModeTests {
+               testGoLookupIPCNAMEOrderHostsAliases(t, mode, v.lookup, absDomainName(v.res))
+       }
+}
+
+func testGoLookupIPCNAMEOrderHostsAliases(t *testing.T, mode hostLookupOrder, lookup, lookupRes string) {
+       ins := []string{lookup, absDomainName(lookup), strings.ToLower(lookup), strings.ToUpper(lookup)}
+       for _, in := range ins {
+               _, res, err := goResolver.goLookupIPCNAMEOrder(context.Background(), "ip", in, mode)
+               if err != nil {
+                       t.Errorf("expected err == nil, but got error: %v", err)
+               }
+               if res.String() != lookupRes {
+                       t.Errorf("goLookupIPCNAMEOrder(%v): got %v, want %v", in, res, lookupRes)
+               }
+       }
+}
+
 // Test that we advertise support for a larger DNS packet size.
 // This isn't a great test as it just tests the dnsmessage package
 // against itself.
index e604031920b1fe55a495679f64b740683161fdfd..2ba85365694a009500806706df4c9e3c8d487e2d 100644 (file)
@@ -28,6 +28,11 @@ func parseLiteralIP(addr string) string {
        return ip.String() + "%" + zone
 }
 
+type byName struct {
+       addrs         []string
+       canonicalName string
+}
+
 // hosts contains known host entries.
 var hosts struct {
        sync.Mutex
@@ -36,7 +41,7 @@ var hosts struct {
        // name. It would be part of DNS labels, a FQDN or an absolute
        // FQDN.
        // For now the key is converted to lower case for convenience.
-       byName map[string][]string
+       byName map[string]byName
 
        // Key for the list of host names must be a literal IP address
        // including IPv6 address with zone identifier.
@@ -62,8 +67,9 @@ func readHosts() {
                return
        }
 
-       hs := make(map[string][]string)
+       hs := make(map[string]byName)
        is := make(map[string][]string)
+
        var file *file
        if file, _ = open(hp); file == nil {
                return
@@ -81,13 +87,32 @@ func readHosts() {
                if addr == "" {
                        continue
                }
+
+               var canonical string
                for i := 1; i < len(f); i++ {
                        name := absDomainName(f[i])
                        h := []byte(f[i])
                        lowerASCIIBytes(h)
                        key := absDomainName(string(h))
-                       hs[key] = append(hs[key], addr)
+
+                       if i == 1 {
+                               canonical = key
+                       }
+
                        is[addr] = append(is[addr], name)
+
+                       if v,ok := hs[key]; ok {
+                               hs[key] = byName{
+                                       addrs:         append(v.addrs, addr),
+                                       canonicalName: v.canonicalName,
+                               }
+                               continue
+                       }
+
+                       hs[key] = byName{
+                               addrs:         []string{addr},
+                               canonicalName: canonical,
+                       }
                }
        }
        // Update the data cache.
@@ -100,8 +125,8 @@ func readHosts() {
        file.close()
 }
 
-// lookupStaticHost looks up the addresses for the given host from /etc/hosts.
-func lookupStaticHost(host string) []string {
+// lookupStaticHost looks up the addresses and the cannonical name for the given host from /etc/hosts.
+func lookupStaticHost(host string) ([]string, string) {
        hosts.Lock()
        defer hosts.Unlock()
        readHosts()
@@ -111,13 +136,13 @@ func lookupStaticHost(host string) []string {
                        lowerASCIIBytes(lowerHost)
                        host = string(lowerHost)
                }
-               if ips, ok := hosts.byName[absDomainName(host)]; ok {
-                       ipsCp := make([]string, len(ips))
-                       copy(ipsCp, ips)
-                       return ipsCp
+               if byName, ok := hosts.byName[absDomainName(host)]; ok {
+                       ipsCp := make([]string, len(byName.addrs))
+                       copy(ipsCp, byName.addrs)
+                       return ipsCp, byName.canonicalName
                }
        }
-       return nil
+       return nil, ""
 }
 
 // lookupStaticAddr looks up the hosts for the given address from /etc/hosts.
index 72919140e9c79f032b493e2642b863c3b026755b..b3f189e641373a7eca430db401cd0db9772365c7 100644 (file)
@@ -72,7 +72,7 @@ func TestLookupStaticHost(t *testing.T) {
 func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) {
        ins := []string{ent.in, absDomainName(ent.in), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
        for _, in := range ins {
-               addrs := lookupStaticHost(in)
+               addrs, _ := lookupStaticHost(in)
                if !reflect.DeepEqual(addrs, ent.out) {
                        t.Errorf("%s, lookupStaticHost(%s) = %v; want %v", hostsPath, in, addrs, ent.out)
                }
@@ -157,7 +157,7 @@ func TestHostCacheModification(t *testing.T) {
        ent := staticHostEntry{"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}}
        testStaticHost(t, testHookHostsPath, ent)
        // Modify the addresses return by lookupStaticHost.
-       addrs := lookupStaticHost(ent.in)
+       addrs, _ := lookupStaticHost(ent.in)
        for i := range addrs {
                addrs[i] += "junk"
        }
@@ -173,3 +173,42 @@ func TestHostCacheModification(t *testing.T) {
        }
        testStaticAddr(t, testHookHostsPath, ent)
 }
+
+var lookupStaticHostAliasesTest = []struct {
+       lookup, res string
+}{
+       // 127.0.0.1
+       {"test", "test"},
+       // 127.0.0.2
+       {"test2.example.com", "test2.example.com"},
+       {"2.test", "test2.example.com"},
+       // 127.0.0.3
+       {"test3.example.com", "3.test"},
+       {"3.test", "3.test"},
+       // 127.0.0.4
+       {"example.com", "example.com"},
+       // 127.0.0.5
+       {"test5.example.com", "test4.example.com"},
+       {"5.test", "test4.example.com"},
+       {"4.test", "test4.example.com"},
+       {"test4.example.com", "test4.example.com"},
+}
+
+func TestLookupStaticHostAliases(t *testing.T) {
+       defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+
+       testHookHostsPath = "testdata/aliases"
+       for _, ent := range lookupStaticHostAliasesTest {
+               testLookupStaticHostAliases(t, ent.lookup, absDomainName(ent.res))
+       }
+}
+
+func testLookupStaticHostAliases(t *testing.T, lookup, lookupRes string) {
+       ins := []string{lookup, absDomainName(lookup), strings.ToLower(lookup), strings.ToUpper(lookup)}
+       for _, in := range ins {
+               _, res := lookupStaticHost(in)
+               if res != lookupRes {
+                       t.Errorf("lookupStaticHost(%v): got %v, want %v", in, res, lookupRes)
+               }
+       }
+}
diff --git a/src/net/testdata/aliases b/src/net/testdata/aliases
new file mode 100644 (file)
index 0000000..9330ba0
--- /dev/null
@@ -0,0 +1,8 @@
+127.0.0.1 test
+127.0.0.2 test2.example.com 2.test
+127.0.0.3 3.test test3.example.com
+127.0.0.4 example.com
+127.0.0.5 test4.example.com 4.test 5.test test5.example.com
+
+# must be a non resolvable domain on the internet
+127.0.1.1 invalid.test invalid.invalid