From d8a7990ffad9aebfb7261df7afb3049da4a09986 Mon Sep 17 00:00:00 2001 From: Ben Burkert Date: Thu, 8 Jun 2017 13:19:28 -0700 Subject: [PATCH] net: support all PacketConn and Conn returned by Resolver.Dial Allow the Resolver.Dial func to return instances of Conn other than *TCPConn and *UDPConn. If the Conn is also a PacketConn, assume DNS messages transmitted over the Conn adhere to section 4.2.1. "UDP usage". Otherwise, follow section 4.2.2. "TCP usage". Provides a hook mechanism so that DNS queries generated by the net package may be answered or modified before being sent to over the network. Updates #19910 Change-Id: Ib089a28ad4a1848bbeaf624ae889f1e82d56655b Reviewed-on: https://go-review.googlesource.com/45153 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/net/dnsclient_unix.go | 24 +++++------ src/net/dnsclient_unix_test.go | 78 ++++++++++++++++++++++++++++------ src/net/lookup.go | 8 ++-- src/net/lookup_unix.go | 10 ++--- 4 files changed, 86 insertions(+), 34 deletions(-) diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 75d70d3989..acbf6c3b2a 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -36,14 +36,14 @@ type dnsConn interface { dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) } -func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { - return dnsRoundTripUDP(c, query) +// dnsPacketConn implements the dnsConn interface for RFC 1035's +// "UDP usage" transport mechanism. Conn is a packet-oriented connection, +// such as a *UDPConn. +type dnsPacketConn struct { + Conn } -// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's -// "UDP usage" transport mechanism. c should be a packet-oriented connection, -// such as a *UDPConn. -func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { +func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { b, ok := query.Pack() if !ok { return nil, errors.New("cannot marshal DNS message") @@ -69,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { } } -func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) { - return dnsRoundTripTCP(c, out) +// dnsStreamConn implements the dnsConn interface for RFC 1035's +// "TCP usage" transport mechanism. Conn is a stream-oriented connection, +// such as a *TCPConn. +type dnsStreamConn struct { + Conn } -// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's -// "TCP usage" transport mechanism. c should be a stream-oriented connection, -// such as a *TCPConn. -func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { +func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { b, ok := query.Pack() if !ok { return nil, errors.New("cannot marshal DNS message") diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index d0ac4302b1..73b628c1b5 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -8,6 +8,7 @@ package net import ( "context" + "errors" "fmt" "internal/poll" "io/ioutil" @@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct { func TestDNSTransportFallback(t *testing.T) { fake := fakeDNSServer{ - rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) { + rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ - rcode: dnsRcodeSuccess, + id: q.id, + response: true, + rcode: dnsRcodeSuccess, }, + question: q.question, } if n == "udp" { r.truncated = true @@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) { fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ - id: q.id, + id: q.id, + response: true, }, + question: q.question, } switch q.question[0].Name { @@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ - id: q.id, + id: q.id, + response: true, }, + question: q.question, } switch q.question[0].Name { @@ -751,7 +759,7 @@ type fakeDNSServer struct { } func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { - return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil + return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil } type fakeDNSConn struct { @@ -759,6 +767,7 @@ type fakeDNSConn struct { server *fakeDNSServer n string s string + q *dnsMsg t time.Time } @@ -766,15 +775,45 @@ func (f *fakeDNSConn) Close() error { return nil } +func (f *fakeDNSConn) Read(b []byte) (int, error) { + resp, err := f.server.rh(f.n, f.s, f.q, f.t) + if err != nil { + return 0, err + } + + bb, ok := resp.Pack() + if !ok { + return 0, errors.New("cannot marshal DNS message") + } + if len(b) < len(bb) { + return 0, errors.New("read would fragment DNS message") + } + + copy(b, bb) + return len(bb), nil +} + +func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) { + return 0, nil, nil +} + +func (f *fakeDNSConn) Write(b []byte) (int, error) { + f.q = new(dnsMsg) + if !f.q.Unpack(b) { + return 0, errors.New("cannot unmarshal DNS message") + } + return len(b), nil +} + +func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) { + return 0, nil +} + func (f *fakeDNSConn) SetDeadline(t time.Time) error { f.t = t return nil } -func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) { - return f.server.rh(f.n, f.s, q, f.t) -} - // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). func TestIgnoreDNSForgeries(t *testing.T) { c, s := Pipe() @@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) { }, } - resp, err := dnsRoundTripUDP(c, msg) + dc := &dnsPacketConn{c} + resp, err := dc.dnsRoundTrip(msg) if err != nil { t.Fatalf("dnsRoundTripUDP failed: %v", err) } @@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) { case resolveOpError: return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} case resolveServfail: - return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeServerFailure}}, nil + return &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + rcode: dnsRcodeServerFailure, + }, + question: q.question, + }, nil case resolveTimeout: return nil, poll.ErrTimeout default: @@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) { switch q.question[0].Name { case searchX, name + ".": // Return NXDOMAIN to utilize the search list. - return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeNameError}}, nil + return &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + rcode: dnsRcodeNameError, + }, + question: q.question, + }, nil case searchY: // Return records below. default: diff --git a/src/net/lookup.go b/src/net/lookup.go index 818f91c3dc..abc56de533 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -111,9 +111,11 @@ type Resolver struct { // Go's built-in DNS resolver to make TCP and UDP connections // to DNS services. The provided addr will always be an IP // address and not a hostname. - // The Conn returned must be a *TCPConn or *UDPConn as - // requested by the network parameter. If nil, the default - // dialer is used. + // If the Conn returned is also a PacketConn, sent and received DNS + // messages must adhere to section 4.2.1. "UDP usage" of RFC 1035. + // Otherwise, DNS messages transmitted over Conn must adhere to section + // 4.2.2. "TCP usage". + // If nil, the default dialer is used. Dial func(ctx context.Context, network, addr string) (Conn, error) // TODO(bradfitz): optional interface impl override hook diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index a485d706a5..896836393b 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -8,8 +8,6 @@ package net import ( "context" - "errors" - "reflect" "sync" ) @@ -70,12 +68,10 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e if err != nil { return nil, mapErr(err) } - dc, ok := c.(dnsConn) - if !ok { - c.Close() - return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String()) + if _, ok := c.(PacketConn); ok { + return &dnsPacketConn{c}, nil } - return dc, nil + return &dnsStreamConn{c}, nil } func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { -- 2.48.1