dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
}
-func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
- return dnsRoundTripUDP(c, query)
+// dnsPacketConn implements the dnsConn interface for RFC 1035's
+// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
+// such as a *UDPConn.
+type dnsPacketConn struct {
+ Conn
}
-// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
-// "UDP usage" transport mechanism. c should be a packet-oriented connection,
-// such as a *UDPConn.
-func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
}
-func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
- return dnsRoundTripTCP(c, out)
+// dnsStreamConn implements the dnsConn interface for RFC 1035's
+// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
+// such as a *TCPConn.
+type dnsStreamConn struct {
+ Conn
}
-// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
-// "TCP usage" transport mechanism. c should be a stream-oriented connection,
-// such as a *TCPConn.
-func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
import (
"context"
+ "errors"
"fmt"
"internal/poll"
"io/ioutil"
func TestDNSTransportFallback(t *testing.T) {
fake := fakeDNSServer{
- rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) {
+ rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
- rcode: dnsRcodeSuccess,
+ id: q.id,
+ response: true,
+ rcode: dnsRcodeSuccess,
},
+ question: q.question,
}
if n == "udp" {
r.truncated = true
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
- id: q.id,
+ id: q.id,
+ response: true,
},
+ question: q.question,
}
switch q.question[0].Name {
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
- id: q.id,
+ id: q.id,
+ response: true,
},
+ question: q.question,
}
switch q.question[0].Name {
}
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
- return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil
+ return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
}
type fakeDNSConn struct {
server *fakeDNSServer
n string
s string
+ q *dnsMsg
t time.Time
}
return nil
}
+func (f *fakeDNSConn) Read(b []byte) (int, error) {
+ resp, err := f.server.rh(f.n, f.s, f.q, f.t)
+ if err != nil {
+ return 0, err
+ }
+
+ bb, ok := resp.Pack()
+ if !ok {
+ return 0, errors.New("cannot marshal DNS message")
+ }
+ if len(b) < len(bb) {
+ return 0, errors.New("read would fragment DNS message")
+ }
+
+ copy(b, bb)
+ return len(bb), nil
+}
+
+func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
+ return 0, nil, nil
+}
+
+func (f *fakeDNSConn) Write(b []byte) (int, error) {
+ f.q = new(dnsMsg)
+ if !f.q.Unpack(b) {
+ return 0, errors.New("cannot unmarshal DNS message")
+ }
+ return len(b), nil
+}
+
+func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
+ return 0, nil
+}
+
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
f.t = t
return nil
}
-func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
- return f.server.rh(f.n, f.s, q, f.t)
-}
-
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) {
c, s := Pipe()
},
}
- resp, err := dnsRoundTripUDP(c, msg)
+ dc := &dnsPacketConn{c}
+ resp, err := dc.dnsRoundTrip(msg)
if err != nil {
t.Fatalf("dnsRoundTripUDP failed: %v", err)
}
case resolveOpError:
return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
case resolveServfail:
- return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeServerFailure}}, nil
+ return &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ id: q.id,
+ response: true,
+ rcode: dnsRcodeServerFailure,
+ },
+ question: q.question,
+ }, nil
case resolveTimeout:
return nil, poll.ErrTimeout
default:
switch q.question[0].Name {
case searchX, name + ".":
// Return NXDOMAIN to utilize the search list.
- return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeNameError}}, nil
+ return &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ id: q.id,
+ response: true,
+ rcode: dnsRcodeNameError,
+ },
+ question: q.question,
+ }, nil
case searchY:
// Return records below.
default:
// Go's built-in DNS resolver to make TCP and UDP connections
// to DNS services. The provided addr will always be an IP
// address and not a hostname.
- // The Conn returned must be a *TCPConn or *UDPConn as
- // requested by the network parameter. If nil, the default
- // dialer is used.
+ // If the Conn returned is also a PacketConn, sent and received DNS
+ // messages must adhere to section 4.2.1. "UDP usage" of RFC 1035.
+ // Otherwise, DNS messages transmitted over Conn must adhere to section
+ // 4.2.2. "TCP usage".
+ // If nil, the default dialer is used.
Dial func(ctx context.Context, network, addr string) (Conn, error)
// TODO(bradfitz): optional interface impl override hook