]> Cypherpunks repositories - gostls13.git/commitdiff
net: separate DNS transport from DNS query-response interaction
authorMikio Hara <mikioh.mikioh@gmail.com>
Wed, 6 Aug 2014 00:58:47 +0000 (09:58 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Wed, 6 Aug 2014 00:58:47 +0000 (09:58 +0900)
Before fixing issue 6579 this CL separates DNS transport from
DNS message interaction to make it easier to add builtin DNS
resolver control logic.

Update #6579

LGTM=alex, kevlar
R=golang-codereviews, alex, gobot, iant, minux, kevlar
CC=golang-codereviews
https://golang.org/cl/101220044

src/pkg/net/dnsclient_unix.go
src/pkg/net/dnsclient_unix_test.go

index 3713efd0e3c9f27b3d95596d2ed3f00c0b82649c..eb4b5900de8fe336af3c92f46680af8e26940039 100644 (file)
@@ -16,6 +16,7 @@
 package net
 
 import (
+       "errors"
        "io"
        "math/rand"
        "os"
@@ -23,118 +24,187 @@ import (
        "time"
 )
 
-// Send a request on the connection and hope for a reply.
-// Up to cfg.attempts attempts.
-func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) {
-       _, useTCP := c.(*TCPConn)
-       if len(name) >= 256 {
-               return nil, &DNSError{Err: "name too long", Name: name}
+// A dnsConn represents a DNS transport endpoint.
+type dnsConn interface {
+       Conn
+
+       // readDNSResponse reads a DNS response message from the DNS
+       // transport endpoint and returns the received DNS response
+       // message.
+       readDNSResponse() (*dnsMsg, error)
+
+       // writeDNSQuery writes a DNS query message to the DNS
+       // connection endpoint.
+       writeDNSQuery(*dnsMsg) error
+}
+
+func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
+       b := make([]byte, 512) // see RFC 1035
+       n, err := c.Read(b)
+       if err != nil {
+               return nil, err
+       }
+       msg := &dnsMsg{}
+       if !msg.Unpack(b[:n]) {
+               return nil, errors.New("cannot unmarshal DNS message")
+       }
+       return msg, nil
+}
+
+func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
+       b, ok := msg.Pack()
+       if !ok {
+               return errors.New("cannot marshal DNS message")
+       }
+       if _, err := c.Write(b); err != nil {
+               return err
+       }
+       return nil
+}
+
+func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
+       b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+       if _, err := io.ReadFull(c, b[:2]); err != nil {
+               return nil, err
+       }
+       l := int(b[0])<<8 | int(b[1])
+       if l > len(b) {
+               b = make([]byte, l)
+       }
+       n, err := io.ReadFull(c, b[:l])
+       if err != nil {
+               return nil, err
        }
-       out := new(dnsMsg)
-       out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
-       out.question = []dnsQuestion{
-               {name, qtype, dnsClassINET},
+       msg := &dnsMsg{}
+       if !msg.Unpack(b[:n]) {
+               return nil, errors.New("cannot unmarshal DNS message")
        }
-       out.recursion_desired = true
-       msg, ok := out.Pack()
+       return msg, nil
+}
+
+func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
+       b, ok := msg.Pack()
        if !ok {
-               return nil, &DNSError{Err: "internal error - cannot pack message", Name: name}
+               return errors.New("cannot marshal DNS message")
+       }
+       l := uint16(len(b))
+       b = append([]byte{byte(l >> 8), byte(l)}, b...)
+       if _, err := c.Write(b); err != nil {
+               return err
        }
-       if useTCP {
-               mlen := uint16(len(msg))
-               msg = append([]byte{byte(mlen >> 8), byte(mlen)}, msg...)
+       return nil
+}
+
+func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
+       switch network {
+       case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+       default:
+               return nil, UnknownNetworkError(network)
        }
-       for attempt := 0; attempt < cfg.attempts; attempt++ {
-               n, err := c.Write(msg)
+       // Calling Dial here is scary -- we have to be sure not to
+       // dial a name that will require a DNS lookup, or Dial will
+       // call back here to translate it. The DNS config parser has
+       // already checked that all the cfg.servers[i] are IP
+       // addresses, which Dial will use without a DNS lookup.
+       c, err := d.Dial(network, server)
+       if err != nil {
+               return nil, err
+       }
+       switch network {
+       case "tcp", "tcp4", "tcp6":
+               return c.(*TCPConn), nil
+       case "udp", "udp4", "udp6":
+               return c.(*UDPConn), nil
+       }
+       panic("unreachable")
+}
+
+// exchange sends a query on the connection and hopes for a response.
+func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
+       d := Dialer{Timeout: timeout}
+       out := dnsMsg{
+               dnsMsgHdr: dnsMsgHdr{
+                       recursion_desired: true,
+               },
+               question: []dnsQuestion{
+                       {name, qtype, dnsClassINET},
+               },
+       }
+       for _, network := range []string{"udp", "tcp"} {
+               c, err := d.dialDNS(network, server)
                if err != nil {
                        return nil, err
                }
-
-               if cfg.timeout == 0 {
-                       c.SetReadDeadline(noDeadline)
-               } else {
-                       c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second))
+               defer c.Close()
+               if timeout > 0 {
+                       c.SetDeadline(time.Now().Add(timeout))
                }
-               buf := make([]byte, 2000)
-               if useTCP {
-                       n, err = io.ReadFull(c, buf[:2])
-                       if err != nil {
-                               if e, ok := err.(Error); ok && e.Timeout() {
-                                       continue
-                               }
-                       }
-                       mlen := int(buf[0])<<8 | int(buf[1])
-                       if mlen > len(buf) {
-                               buf = make([]byte, mlen)
-                       }
-                       n, err = io.ReadFull(c, buf[:mlen])
-               } else {
-                       n, err = c.Read(buf)
+               out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
+               if err := c.writeDNSQuery(&out); err != nil {
+                       return nil, err
                }
+               in, err := c.readDNSResponse()
                if err != nil {
-                       if e, ok := err.(Error); ok && e.Timeout() {
-                               continue
-                       }
                        return nil, err
                }
-               buf = buf[:n]
-               in := new(dnsMsg)
-               if !in.Unpack(buf) || in.id != out.id {
+               if in.id != out.id {
+                       return nil, errors.New("DNS message ID mismatch")
+               }
+               if in.truncated { // see RFC 5966
                        continue
                }
                return in, nil
        }
-       var server string
-       if a := c.RemoteAddr(); a != nil {
-               server = a.String()
-       }
-       return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true}
+       return nil, errors.New("no answer from DNS server")
 }
 
 // Do a lookup for a single name, which must be rooted
 // (otherwise answer will not find the answers).
-func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
+func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
        if len(cfg.servers) == 0 {
                return "", nil, &DNSError{Err: "no DNS servers", Name: name}
        }
-       for i := 0; i < len(cfg.servers); i++ {
-               // Calling Dial here is scary -- we have to be sure
-               // not to dial a name that will require a DNS lookup,
-               // or Dial will call back here to translate it.
-               // The DNS config parser has already checked that
-               // all the cfg.servers[i] are IP addresses, which
-               // Dial will use without a DNS lookup.
-               server := cfg.servers[i] + ":53"
-               c, cerr := Dial("udp", server)
-               if cerr != nil {
-                       err = cerr
-                       continue
-               }
-               msg, merr := exchange(cfg, c, name, qtype)
-               c.Close()
-               if merr != nil {
-                       err = merr
-                       continue
+       if len(name) >= 256 {
+               return "", nil, &DNSError{Err: "DNS name too long", Name: name}
+       }
+       timeout := time.Duration(cfg.timeout) * time.Second
+       var lastErr error
+       for _, server := range cfg.servers {
+               server += ":53"
+               lastErr = &DNSError{
+                       Err:       "no answer from DNS server",
+                       Name:      name,
+                       Server:    server,
+                       IsTimeout: true,
                }
-               if msg.truncated { // see RFC 5966
-                       c, cerr = Dial("tcp", server)
-                       if cerr != nil {
-                               err = cerr
-                               continue
+               for i := 0; i < cfg.attempts; i++ {
+                       msg, err := exchange(server, name, qtype, timeout)
+                       if err != nil {
+                               if nerr, ok := err.(Error); ok && nerr.Timeout() {
+                                       lastErr = &DNSError{
+                                               Err:       nerr.Error(),
+                                               Name:      name,
+                                               Server:    server,
+                                               IsTimeout: true,
+                                       }
+                                       continue
+
+                               }
+                               lastErr = &DNSError{
+                                       Err:    err.Error(),
+                                       Name:   name,
+                                       Server: server,
+                               }
+                               break
                        }
-                       msg, merr = exchange(cfg, c, name, qtype)
-                       c.Close()
-                       if merr != nil {
-                               err = merr
-                               continue
+                       cname, addrs, err := answer(name, server, msg, qtype)
+                       if err == nil || err.(*DNSError).Err == noSuchHost {
+                               return cname, addrs, err
                        }
-               }
-               cname, addrs, err = answer(name, server, msg, qtype)
-               if err == nil || err.(*DNSError).Err == noSuchHost {
-                       break
+                       lastErr = err
                }
        }
-       return
+       return "", nil, lastErr
 }
 
 func convertRR_A(records []dnsRR) []IP {
index 2350142d61052e06f40644baa3200016df3bb27a..39d82d9961d34c748a5c77fa1a655be924a4bf92 100644 (file)
@@ -16,19 +16,79 @@ import (
        "time"
 )
 
-func TestTCPLookup(t *testing.T) {
+var dnsTransportFallbackTests = []struct {
+       server  string
+       name    string
+       qtype   uint16
+       timeout int
+       rcode   int
+}{
+       // Querying "com." with qtype=255 usually makes an answer
+       // which requires more than 512 bytes.
+       {"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
+       {"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
+}
+
+func TestDNSTransportFallback(t *testing.T) {
        if testing.Short() || !*testExternal {
                t.Skip("skipping test to avoid external network")
        }
-       c, err := Dial("tcp", "8.8.8.8:53")
-       if err != nil {
-               t.Fatalf("Dial failed: %v", err)
+
+       for _, tt := range dnsTransportFallbackTests {
+               timeout := time.Duration(tt.timeout) * time.Second
+               msg, err := exchange(tt.server, tt.name, tt.qtype, timeout)
+               if err != nil {
+                       t.Error(err)
+                       continue
+               }
+               switch msg.rcode {
+               case tt.rcode, dnsRcodeServerFailure:
+               default:
+                       t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
+                       continue
+               }
        }
-       defer c.Close()
-       cfg := &dnsConfig{timeout: 10, attempts: 3}
-       _, err = exchange(cfg, c, "com.", dnsTypeALL)
-       if err != nil {
-               t.Fatalf("exchange failed: %v", err)
+}
+
+// See RFC 6761 for further information about the reserved, pseudo
+// domain names.
+var specialDomainNameTests = []struct {
+       name  string
+       qtype uint16
+       rcode int
+}{
+       // Name resoltion APIs and libraries should not recongnize the
+       // followings as special.
+       {"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
+       {"test.", dnsTypeALL, dnsRcodeNameError},
+       {"example.com.", dnsTypeALL, dnsRcodeSuccess},
+
+       // Name resoltion APIs and libraries should recongnize the
+       // followings as special and should not send any queries.
+       // Though, we test those names here for verifying nagative
+       // answers at DNS query-response interaction level.
+       {"localhost.", dnsTypeALL, dnsRcodeNameError},
+       {"invalid.", dnsTypeALL, dnsRcodeNameError},
+}
+
+func TestSpecialDomainName(t *testing.T) {
+       if testing.Short() || !*testExternal {
+               t.Skip("skipping test to avoid external network")
+       }
+
+       server := "8.8.8.8:53"
+       for _, tt := range specialDomainNameTests {
+               msg, err := exchange(server, tt.name, tt.qtype, 0)
+               if err != nil {
+                       t.Error(err)
+                       continue
+               }
+               switch msg.rcode {
+               case tt.rcode, dnsRcodeServerFailure:
+               default:
+                       t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
+                       continue
+               }
        }
 }