]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix CNAME resolving on Windows
authorEgon Elbre <egonelbre@gmail.com>
Fri, 15 Aug 2014 06:37:19 +0000 (16:37 +1000)
committerAlex Brainman <alex.brainman@gmail.com>
Fri, 15 Aug 2014 06:37:19 +0000 (16:37 +1000)
Fixes #8492

LGTM=alex.brainman
R=golang-codereviews, alex.brainman
CC=golang-codereviews
https://golang.org/cl/122200043

src/pkg/net/lookup_windows.go
src/pkg/net/lookup_windows_test.go [new file with mode: 0644]
src/pkg/syscall/syscall_windows.go
src/pkg/syscall/zsyscall_windows_386.go
src/pkg/syscall/zsyscall_windows_amd64.go
src/pkg/syscall/ztypes_windows.go

index 130364231d49137e34ec6153aef5e513444c3fac..6a925b0a7adad630e569ce237905c12400730747 100644 (file)
@@ -210,14 +210,21 @@ func lookupCNAME(name string) (cname string, err error) {
        defer releaseThread()
        var r *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
+       // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
+       if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
+               // if there are no aliases, the canonical name is the input name
+               if name == "" || name[len(name)-1] != '.' {
+                       return name + ".", nil
+               }
+               return name, nil
+       }
        if e != nil {
                return "", os.NewSyscallError("LookupCNAME", e)
        }
        defer syscall.DnsRecordListFree(r, 1)
-       if r != nil && r.Type == syscall.DNS_TYPE_CNAME {
-               v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0]))
-               cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."
-       }
+
+       resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
+       cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
        return
 }
 
@@ -236,8 +243,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
                return "", nil, os.NewSyscallError("LookupSRV", e)
        }
        defer syscall.DnsRecordListFree(r, 1)
+
        addrs = make([]*SRV, 0, 10)
-       for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next {
+       for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
                v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
                addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
        }
@@ -254,8 +262,9 @@ func lookupMX(name string) (mx []*MX, err error) {
                return nil, os.NewSyscallError("LookupMX", e)
        }
        defer syscall.DnsRecordListFree(r, 1)
+
        mx = make([]*MX, 0, 10)
-       for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next {
+       for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
                v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
                mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
        }
@@ -272,8 +281,9 @@ func lookupNS(name string) (ns []*NS, err error) {
                return nil, os.NewSyscallError("LookupNS", e)
        }
        defer syscall.DnsRecordListFree(r, 1)
+
        ns = make([]*NS, 0, 10)
-       for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next {
+       for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
                v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
                ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
        }
@@ -289,9 +299,10 @@ func lookupTXT(name string) (txt []string, err error) {
                return nil, os.NewSyscallError("LookupTXT", e)
        }
        defer syscall.DnsRecordListFree(r, 1)
+
        txt = make([]string, 0, 10)
-       if r != nil && r.Type == syscall.DNS_TYPE_TEXT {
-               d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0]))
+       for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
+               d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
                for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
                        s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
                        txt = append(txt, s)
@@ -313,10 +324,58 @@ func lookupAddr(addr string) (name []string, err error) {
                return nil, os.NewSyscallError("LookupAddr", e)
        }
        defer syscall.DnsRecordListFree(r, 1)
+
        name = make([]string, 0, 10)
-       for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next {
+       for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
                v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
                name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
        }
        return name, nil
 }
+
+const dnsSectionMask = 0x0003
+
+// returns only results applicable to name and resolves CNAME entries
+func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
+       cname := syscall.StringToUTF16Ptr(name)
+       if dnstype != syscall.DNS_TYPE_CNAME {
+               cname = resolveCNAME(cname, r)
+       }
+       rec := make([]*syscall.DNSRecord, 0, 10)
+       for p := r; p != nil; p = p.Next {
+               if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
+                       continue
+               }
+               if p.Type != dnstype {
+                       continue
+               }
+               if !syscall.DnsNameCompare(cname, p.Name) {
+                       continue
+               }
+               rec = append(rec, p)
+       }
+       return rec
+}
+
+// returns the last CNAME in chain
+func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
+       // limit cname resolving to 10 in case of a infinite CNAME loop
+Cname:
+       for cnameloop := 0; cnameloop < 10; cnameloop++ {
+               for p := r; p != nil; p = p.Next {
+                       if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
+                               continue
+                       }
+                       if p.Type != syscall.DNS_TYPE_CNAME {
+                               continue
+                       }
+                       if !syscall.DnsNameCompare(name, p.Name) {
+                               continue
+                       }
+                       name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
+                       continue Cname
+               }
+               break
+       }
+       return name
+}
diff --git a/src/pkg/net/lookup_windows_test.go b/src/pkg/net/lookup_windows_test.go
new file mode 100644 (file)
index 0000000..7495b5b
--- /dev/null
@@ -0,0 +1,243 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "bytes"
+       "encoding/json"
+       "errors"
+       "os/exec"
+       "reflect"
+       "regexp"
+       "sort"
+       "strconv"
+       "strings"
+       "testing"
+)
+
+var nslookupTestServers = []string{"mail.golang.com", "gmail.com"}
+
+func toJson(v interface{}) string {
+       data, _ := json.Marshal(v)
+       return string(data)
+}
+
+func TestLookupMX(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+       for _, server := range nslookupTestServers {
+               mx, err := LookupMX(server)
+               if err != nil {
+                       t.Errorf("failed %s: %s", server, err)
+                       continue
+               }
+               if len(mx) == 0 {
+                       t.Errorf("no results")
+                       continue
+               }
+               expected, err := nslookupMX(server)
+               if err != nil {
+                       t.Logf("skipping failed nslookup %s test: %s", server, err)
+               }
+               sort.Sort(byPrefAndHost(expected))
+               sort.Sort(byPrefAndHost(mx))
+               if !reflect.DeepEqual(expected, mx) {
+                       t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx))
+               }
+       }
+}
+
+func TestLookupCNAME(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+       for _, server := range nslookupTestServers {
+               cname, err := LookupCNAME(server)
+               if err != nil {
+                       t.Errorf("failed %s: %s", server, err)
+                       continue
+               }
+               if cname == "" {
+                       t.Errorf("no result %s", server)
+               }
+               expected, err := nslookupCNAME(server)
+               if err != nil {
+                       t.Logf("skipping failed nslookup %s test: %s", server, err)
+                       continue
+               }
+               if expected != cname {
+                       t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname)
+               }
+       }
+}
+
+func TestLookupNS(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+       for _, server := range nslookupTestServers {
+               ns, err := LookupNS(server)
+               if err != nil {
+                       t.Errorf("failed %s: %s", server, err)
+                       continue
+               }
+               if len(ns) == 0 {
+                       t.Errorf("no results")
+                       continue
+               }
+               expected, err := nslookupNS(server)
+               if err != nil {
+                       t.Logf("skipping failed nslookup %s test: %s", server, err)
+                       continue
+               }
+               sort.Sort(byHost(expected))
+               sort.Sort(byHost(ns))
+               if !reflect.DeepEqual(expected, ns) {
+                       t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns)
+               }
+       }
+}
+
+func TestLookupTXT(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+       for _, server := range nslookupTestServers {
+               txt, err := LookupTXT(server)
+               if err != nil {
+                       t.Errorf("failed %s: %s", server, err)
+                       continue
+               }
+               if len(txt) == 0 {
+                       t.Errorf("no results")
+                       continue
+               }
+               expected, err := nslookupTXT(server)
+               if err != nil {
+                       t.Logf("skipping failed nslookup %s test: %s", server, err)
+                       continue
+               }
+               sort.Strings(expected)
+               sort.Strings(txt)
+               if !reflect.DeepEqual(expected, txt) {
+                       t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt))
+               }
+       }
+}
+
+type byPrefAndHost []*MX
+
+func (s byPrefAndHost) Len() int { return len(s) }
+func (s byPrefAndHost) Less(i, j int) bool {
+       if s[i].Pref != s[j].Pref {
+               return s[i].Pref < s[j].Pref
+       }
+       return s[i].Host < s[j].Host
+}
+func (s byPrefAndHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+type byHost []*NS
+
+func (s byHost) Len() int           { return len(s) }
+func (s byHost) Less(i, j int) bool { return s[i].Host < s[j].Host }
+func (s byHost) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
+
+func fqdn(s string) string {
+       if len(s) == 0 || s[len(s)-1] != '.' {
+               return s + "."
+       }
+       return s
+}
+
+func nslookup(qtype, name string) (string, error) {
+       var out bytes.Buffer
+       var err bytes.Buffer
+       cmd := exec.Command("nslookup", "-querytype="+qtype, name)
+       cmd.Stdout = &out
+       cmd.Stderr = &err
+       if err := cmd.Run(); err != nil {
+               return "", err
+       }
+       r := strings.Replace(out.String(), "\r\n", "\n", -1)
+       // nslookup stderr output contains also debug information such as
+       // "Non-authoritative answer" and it doesn't return the correct errcode
+       if strings.Contains(err.String(), "can't find") {
+               return r, errors.New(err.String())
+       }
+       return r, nil
+}
+
+func nslookupMX(name string) (mx []*MX, err error) {
+       var r string
+       if r, err = nslookup("mx", name); err != nil {
+               return
+       }
+       mx = make([]*MX, 0, 10)
+       // linux nslookup syntax
+       // golang.org      mail exchanger = 2 alt1.aspmx.l.google.com.
+       rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
+       for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+               pref, _ := strconv.Atoi(ans[2])
+               mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)})
+       }
+       // windows nslookup syntax
+       // gmail.com       MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
+       rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
+       for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+               pref, _ := strconv.Atoi(ans[2])
+               mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)})
+       }
+       return
+}
+
+func nslookupNS(name string) (ns []*NS, err error) {
+       var r string
+       if r, err = nslookup("ns", name); err != nil {
+               return
+       }
+       ns = make([]*NS, 0, 10)
+       // golang.org      nameserver = ns1.google.com.
+       rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
+       for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+               ns = append(ns, &NS{fqdn(ans[2])})
+       }
+       return
+}
+
+func nslookupCNAME(name string) (cname string, err error) {
+       var r string
+       if r, err = nslookup("cname", name); err != nil {
+               return
+       }
+       // mail.golang.com canonical name = golang.org.
+       rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+canonical name\s*=\s*([a-z0-9.\-]+)$`)
+       // assumes the last CNAME is the correct one
+       last := name
+       for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+               last = ans[2]
+       }
+       return fqdn(last), nil
+}
+
+func nslookupTXT(name string) (txt []string, err error) {
+       var r string
+       if r, err = nslookup("txt", name); err != nil {
+               return
+       }
+       txt = make([]string, 0, 10)
+       // linux
+       // golang.org      text = "v=spf1 redirect=_spf.google.com"
+
+       // windows
+       // golang.org      text =
+       //
+       //    "v=spf1 redirect=_spf.google.com"
+       rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+text\s*=\s*"(.*)"$`)
+       for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+               txt = append(txt, ans[2])
+       }
+       return
+}
index 1fe1ae0fab464efa0d6e80fc2e021caa11cc3f2d..32a7aed00188343d75203e77ca9b6241777f5e3a 100644 (file)
@@ -549,6 +549,7 @@ const socket_error = uintptr(^uint32(0))
 //sys  GetProtoByName(name string) (p *Protoent, err error) [failretval==nil] = ws2_32.getprotobyname
 //sys  DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) = dnsapi.DnsQuery_W
 //sys  DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree
+//sys  DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) = dnsapi.DnsNameCompare_W
 //sys  GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) = ws2_32.GetAddrInfoW
 //sys  FreeAddrInfoW(addrinfo *AddrinfoW) = ws2_32.FreeAddrInfoW
 //sys  GetIfEntry(pIfRow *MibIfRow) (errcode error) = iphlpapi.GetIfEntry
index d55211ee75488251b646a76a2547c679a088ebe5..1f44750b7f3c8b243922c6b77984a82b64ade8c8 100644 (file)
@@ -1,4 +1,4 @@
-// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_386.go
+// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go
 // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
 
 package syscall
@@ -139,6 +139,7 @@ var (
        procgetprotobyname                     = modws2_32.NewProc("getprotobyname")
        procDnsQuery_W                         = moddnsapi.NewProc("DnsQuery_W")
        procDnsRecordListFree                  = moddnsapi.NewProc("DnsRecordListFree")
+       procDnsNameCompare_W                   = moddnsapi.NewProc("DnsNameCompare_W")
        procGetAddrInfoW                       = modws2_32.NewProc("GetAddrInfoW")
        procFreeAddrInfoW                      = modws2_32.NewProc("FreeAddrInfoW")
        procGetIfEntry                         = modiphlpapi.NewProc("GetIfEntry")
@@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) {
        return
 }
 
+func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) {
+       r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0)
+       same = r0 != 0
+       return
+}
+
 func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) {
        r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0)
        if r0 != 0 {
index 47affab73da72df52251684c595feb1a64e4974e..1f44750b7f3c8b243922c6b77984a82b64ade8c8 100644 (file)
@@ -1,4 +1,4 @@
-// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_amd64.go
+// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go
 // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
 
 package syscall
@@ -139,6 +139,7 @@ var (
        procgetprotobyname                     = modws2_32.NewProc("getprotobyname")
        procDnsQuery_W                         = moddnsapi.NewProc("DnsQuery_W")
        procDnsRecordListFree                  = moddnsapi.NewProc("DnsRecordListFree")
+       procDnsNameCompare_W                   = moddnsapi.NewProc("DnsNameCompare_W")
        procGetAddrInfoW                       = modws2_32.NewProc("GetAddrInfoW")
        procFreeAddrInfoW                      = modws2_32.NewProc("FreeAddrInfoW")
        procGetIfEntry                         = modiphlpapi.NewProc("GetIfEntry")
@@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) {
        return
 }
 
+func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) {
+       r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0)
+       same = r0 != 0
+       return
+}
+
 func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) {
        r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0)
        if r0 != 0 {
index 8b3625f14678dfae12c1184942aaea4086f4b9c7..1363da01a88ba51b6d4202ada9945c2ec1da39bb 100644 (file)
@@ -689,6 +689,18 @@ const (
        DNS_TYPE_NBSTAT  = 0xff01
 )
 
+const (
+       DNS_INFO_NO_RECORDS = 0x251D
+)
+
+const (
+       // flags inside DNSRecord.Dw
+       DnsSectionQuestion   = 0x0000
+       DnsSectionAnswer     = 0x0001
+       DnsSectionAuthority  = 0x0002
+       DnsSectionAdditional = 0x0003
+)
+
 type DNSSRVData struct {
        Target   *uint16
        Priority uint16