SetDeadline(time.Time) error
- // readDNSResponse reads a DNS response message from the DNS
- // transport endpoint and returns the received DNS response
- // message.
- readDNSResponse() (*dnsMsg, error)
-
- // writeDNSQuery writes a DNS query message to the DNS
- // connection endpoint.
- writeDNSQuery(*dnsMsg) error
+ // dnsRoundTrip executes a single DNS transaction, returning a
+ // DNS response message for the provided DNS query message.
+ dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
}
-func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
- b := make([]byte, 512) // see RFC 1035
- n, err := c.Read(b)
- if err != nil {
+func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
+ return dnsRoundTripUDP(c, query)
+}
+
+// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
+// "UDP usage" transport mechanism. c should be a packet-oriented connection,
+// such as a *UDPConn.
+func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+ b, ok := query.Pack()
+ if !ok {
+ return nil, errors.New("cannot marshal DNS message")
+ }
+ if _, err := c.Write(b); err != nil {
return nil, err
}
- msg := &dnsMsg{}
- if !msg.Unpack(b[:n]) {
- return nil, errors.New("cannot unmarshal DNS message")
+
+ b = make([]byte, 512) // see RFC 1035
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return nil, err
+ }
+ resp := &dnsMsg{}
+ if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
+ // Ignore invalid responses as they may be malicious
+ // forgery attempts. Instead continue waiting until
+ // timeout. See golang.org/issue/13281.
+ continue
+ }
+ return resp, nil
}
- return msg, nil
}
-func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
- b, ok := msg.Pack()
+func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
+ return dnsRoundTripTCP(c, out)
+}
+
+// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
+// "TCP usage" transport mechanism. c should be a stream-oriented connection,
+// such as a *TCPConn.
+func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+ b, ok := query.Pack()
if !ok {
- return errors.New("cannot marshal DNS message")
+ return nil, errors.New("cannot marshal DNS message")
}
+ l := len(b)
+ b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
- return err
+ return nil, err
}
- return nil
-}
-func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
- b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+ b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err
}
- l := int(b[0])<<8 | int(b[1])
+ l = int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
if err != nil {
return nil, err
}
- msg := &dnsMsg{}
- if !msg.Unpack(b[:n]) {
+ resp := &dnsMsg{}
+ if !resp.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
- return msg, nil
-}
-
-func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
- b, ok := msg.Pack()
- if !ok {
- return errors.New("cannot marshal DNS message")
+ if !resp.IsResponseTo(query) {
+ return nil, errors.New("invalid DNS response")
}
- l := uint16(len(b))
- b = append([]byte{byte(l >> 8), byte(l)}, b...)
- if _, err := c.Write(b); err != nil {
- return err
- }
- return nil
+ return resp, nil
}
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
c.SetDeadline(d)
}
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
- if err := c.writeDNSQuery(&out); err != nil {
- return nil, mapErr(err)
- }
- in, err := c.readDNSResponse()
+ in, err := c.dnsRoundTrip(&out)
if err != nil {
return nil, mapErr(err)
}
- if in.id != out.id {
- return nil, errors.New("DNS message ID mismatch")
- }
if in.truncated { // see RFC 5966
continue
}
}
type fakeDNSConn struct {
- // last query
- qmu sync.Mutex // guards q
- q *dnsMsg
// reply handler
rh func(*dnsMsg) (*dnsMsg, error)
}
return nil
}
-func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error {
- f.qmu.Lock()
- defer f.qmu.Unlock()
- f.q = q
- return nil
+func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
+ return f.rh(q)
}
-func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) {
- f.qmu.Lock()
- q := f.q
- f.qmu.Unlock()
- return f.rh(q)
+// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
+func TestIgnoreDNSForgeries(t *testing.T) {
+ const TestAddr uint32 = 0x80420001
+
+ c, s := Pipe()
+ go func() {
+ b := make([]byte, 512)
+ n, err := s.Read(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ msg := &dnsMsg{}
+ if !msg.Unpack(b[:n]) {
+ t.Fatal("invalid DNS query")
+ }
+
+ s.Write([]byte("garbage DNS response packet"))
+
+ msg.response = true
+ msg.id++ // make invalid ID
+ b, ok := msg.Pack()
+ if !ok {
+ t.Fatal("failed to pack DNS response")
+ }
+ s.Write(b)
+
+ msg.id-- // restore original ID
+ msg.answer = []dnsRR{
+ &dnsRR_A{
+ Hdr: dnsRR_Header{
+ Name: "www.example.com.",
+ Rrtype: dnsTypeA,
+ Class: dnsClassINET,
+ Rdlength: 4,
+ },
+ A: TestAddr,
+ },
+ }
+
+ b, ok = msg.Pack()
+ if !ok {
+ t.Fatal("failed to pack DNS response")
+ }
+ s.Write(b)
+ }()
+
+ msg := &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.example.com.",
+ Qtype: dnsTypeA,
+ Qclass: dnsClassINET,
+ },
+ },
+ }
+
+ resp, err := dnsRoundTripUDP(c, msg)
+ if err != nil {
+ t.Fatalf("dnsRoundTripUDP failed: %v", err)
+ }
+
+ if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
+ t.Error("got address %v, want %v", got, TestAddr)
+ }
}
}
}
+func TestIsResponseTo(t *testing.T) {
+ // Sample DNS query.
+ query := dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.example.com.",
+ Qtype: dnsTypeA,
+ Qclass: dnsClassINET,
+ },
+ },
+ }
+
+ resp := query
+ resp.response = true
+ if !resp.IsResponseTo(&query) {
+ t.Error("got false, want true")
+ }
+
+ badResponses := []dnsMsg{
+ // Different ID.
+ {
+ dnsMsgHdr: dnsMsgHdr{
+ id: 43,
+ response: true,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.example.com.",
+ Qtype: dnsTypeA,
+ Qclass: dnsClassINET,
+ },
+ },
+ },
+
+ // Different query name.
+ {
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ response: true,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.google.com.",
+ Qtype: dnsTypeA,
+ Qclass: dnsClassINET,
+ },
+ },
+ },
+
+ // Different query type.
+ {
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ response: true,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.example.com.",
+ Qtype: dnsTypeAAAA,
+ Qclass: dnsClassINET,
+ },
+ },
+ },
+
+ // Different query class.
+ {
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ response: true,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.example.com.",
+ Qtype: dnsTypeA,
+ Qclass: dnsClassCSNET,
+ },
+ },
+ },
+
+ // No questions.
+ {
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ response: true,
+ },
+ },
+
+ // Extra questions.
+ {
+ dnsMsgHdr: dnsMsgHdr{
+ id: 42,
+ response: true,
+ },
+ question: []dnsQuestion{
+ {
+ Name: "www.example.com.",
+ Qtype: dnsTypeA,
+ Qclass: dnsClassINET,
+ },
+ {
+ Name: "www.golang.org.",
+ Qtype: dnsTypeAAAA,
+ Qclass: dnsClassINET,
+ },
+ },
+ },
+ }
+
+ for i := range badResponses {
+ if badResponses[i].IsResponseTo(&query) {
+ t.Error("%v: got true, want false", i)
+ }
+ }
+}
+
// Valid DNS SRV reply
const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +