]> Cypherpunks repositories - gostls13.git/commitdiff
net: support all PacketConn and Conn returned by Resolver.Dial
authorBen Burkert <ben@benburkert.com>
Thu, 8 Jun 2017 20:19:28 +0000 (13:19 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 8 Jun 2017 21:53:49 +0000 (21:53 +0000)
Allow the Resolver.Dial func to return instances of Conn other than
*TCPConn and *UDPConn. If the Conn is also a PacketConn, assume DNS
messages transmitted over the Conn adhere to section 4.2.1. "UDP usage".
Otherwise, follow section 4.2.2. "TCP usage".

Provides a hook mechanism so that DNS queries generated by the net
package may be answered or modified before being sent to over the
network.

Updates #19910

Change-Id: Ib089a28ad4a1848bbeaf624ae889f1e82d56655b
Reviewed-on: https://go-review.googlesource.com/45153
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/lookup.go
src/net/lookup_unix.go

index 75d70d3989377c5e160916d8109c4f2d220f96ca..acbf6c3b2ad99cec3d73377e904f83de1ab6ec76 100644 (file)
@@ -36,14 +36,14 @@ type dnsConn interface {
        dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
 }
 
-func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
-       return dnsRoundTripUDP(c, query)
+// dnsPacketConn implements the dnsConn interface for RFC 1035's
+// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
+// such as a *UDPConn.
+type dnsPacketConn struct {
+       Conn
 }
 
-// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
-// "UDP usage" transport mechanism. c should be a packet-oriented connection,
-// such as a *UDPConn.
-func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
        b, ok := query.Pack()
        if !ok {
                return nil, errors.New("cannot marshal DNS message")
@@ -69,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
        }
 }
 
-func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
-       return dnsRoundTripTCP(c, out)
+// dnsStreamConn implements the dnsConn interface for RFC 1035's
+// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
+// such as a *TCPConn.
+type dnsStreamConn struct {
+       Conn
 }
 
-// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
-// "TCP usage" transport mechanism. c should be a stream-oriented connection,
-// such as a *TCPConn.
-func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
        b, ok := query.Pack()
        if !ok {
                return nil, errors.New("cannot marshal DNS message")
index d0ac4302b105e00d8ac781ce37122597eb14d468..73b628c1b5f328d50dfafa4bd6ecfb19432e62d1 100644 (file)
@@ -8,6 +8,7 @@ package net
 
 import (
        "context"
+       "errors"
        "fmt"
        "internal/poll"
        "io/ioutil"
@@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct {
 
 func TestDNSTransportFallback(t *testing.T) {
        fake := fakeDNSServer{
-               rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) {
+               rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                        r := &dnsMsg{
                                dnsMsgHdr: dnsMsgHdr{
-                                       rcode: dnsRcodeSuccess,
+                                       id:       q.id,
+                                       response: true,
+                                       rcode:    dnsRcodeSuccess,
                                },
+                               question: q.question,
                        }
                        if n == "udp" {
                                r.truncated = true
@@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) {
        fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
-                               id: q.id,
+                               id:       q.id,
+                               response: true,
                        },
+                       question: q.question,
                }
 
                switch q.question[0].Name {
@@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
        fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
-                               id: q.id,
+                               id:       q.id,
+                               response: true,
                        },
+                       question: q.question,
                }
 
                switch q.question[0].Name {
@@ -751,7 +759,7 @@ type fakeDNSServer struct {
 }
 
 func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
-       return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil
+       return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
 }
 
 type fakeDNSConn struct {
@@ -759,6 +767,7 @@ type fakeDNSConn struct {
        server *fakeDNSServer
        n      string
        s      string
+       q      *dnsMsg
        t      time.Time
 }
 
@@ -766,15 +775,45 @@ func (f *fakeDNSConn) Close() error {
        return nil
 }
 
+func (f *fakeDNSConn) Read(b []byte) (int, error) {
+       resp, err := f.server.rh(f.n, f.s, f.q, f.t)
+       if err != nil {
+               return 0, err
+       }
+
+       bb, ok := resp.Pack()
+       if !ok {
+               return 0, errors.New("cannot marshal DNS message")
+       }
+       if len(b) < len(bb) {
+               return 0, errors.New("read would fragment DNS message")
+       }
+
+       copy(b, bb)
+       return len(bb), nil
+}
+
+func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
+       return 0, nil, nil
+}
+
+func (f *fakeDNSConn) Write(b []byte) (int, error) {
+       f.q = new(dnsMsg)
+       if !f.q.Unpack(b) {
+               return 0, errors.New("cannot unmarshal DNS message")
+       }
+       return len(b), nil
+}
+
+func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
+       return 0, nil
+}
+
 func (f *fakeDNSConn) SetDeadline(t time.Time) error {
        f.t = t
        return nil
 }
 
-func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
-       return f.server.rh(f.n, f.s, q, f.t)
-}
-
 // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
 func TestIgnoreDNSForgeries(t *testing.T) {
        c, s := Pipe()
@@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
                },
        }
 
-       resp, err := dnsRoundTripUDP(c, msg)
+       dc := &dnsPacketConn{c}
+       resp, err := dc.dnsRoundTrip(msg)
        if err != nil {
                t.Fatalf("dnsRoundTripUDP failed: %v", err)
        }
@@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
                        case resolveOpError:
                                return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
                        case resolveServfail:
-                               return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeServerFailure}}, nil
+                               return &dnsMsg{
+                                       dnsMsgHdr: dnsMsgHdr{
+                                               id:       q.id,
+                                               response: true,
+                                               rcode:    dnsRcodeServerFailure,
+                                       },
+                                       question: q.question,
+                               }, nil
                        case resolveTimeout:
                                return nil, poll.ErrTimeout
                        default:
@@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
                        switch q.question[0].Name {
                        case searchX, name + ".":
                                // Return NXDOMAIN to utilize the search list.
-                               return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeNameError}}, nil
+                               return &dnsMsg{
+                                       dnsMsgHdr: dnsMsgHdr{
+                                               id:       q.id,
+                                               response: true,
+                                               rcode:    dnsRcodeNameError,
+                                       },
+                                       question: q.question,
+                               }, nil
                        case searchY:
                                // Return records below.
                        default:
index 818f91c3dc75c62f05a57cc222a81a78570ca06d..abc56de533864c603d8747f3224f01c886e26dbe 100644 (file)
@@ -111,9 +111,11 @@ type Resolver struct {
        // Go's built-in DNS resolver to make TCP and UDP connections
        // to DNS services. The provided addr will always be an IP
        // address and not a hostname.
-       // The Conn returned must be a *TCPConn or *UDPConn as
-       // requested by the network parameter. If nil, the default
-       // dialer is used.
+       // If the Conn returned is also a PacketConn, sent and received DNS
+       // messages must adhere to section 4.2.1. "UDP usage" of RFC 1035.
+       // Otherwise, DNS messages transmitted over Conn must adhere to section
+       // 4.2.2. "TCP usage".
+       // If nil, the default dialer is used.
        Dial func(ctx context.Context, network, addr string) (Conn, error)
 
        // TODO(bradfitz): optional interface impl override hook
index a485d706a585ad94b5e79a5d0665df8adfa9aac6..896836393b0b1bfaef359a5efc1c095bcd6818f9 100644 (file)
@@ -8,8 +8,6 @@ package net
 
 import (
        "context"
-       "errors"
-       "reflect"
        "sync"
 )
 
@@ -70,12 +68,10 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
        if err != nil {
                return nil, mapErr(err)
        }
-       dc, ok := c.(dnsConn)
-       if !ok {
-               c.Close()
-               return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String())
+       if _, ok := c.(PacketConn); ok {
+               return &dnsPacketConn{c}, nil
        }
-       return dc, nil
+       return &dnsStreamConn{c}, nil
 }
 
 func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {