]> Cypherpunks repositories - gostls13.git/commitdiff
net: keep waiting for valid DNS response until timeout
authorMatthew Dempsky <mdempsky@google.com>
Sat, 16 Apr 2016 02:19:58 +0000 (19:19 -0700)
committerMatthew Dempsky <mdempsky@google.com>
Fri, 22 Apr 2016 22:16:08 +0000 (22:16 +0000)
Prevents denial of service attacks from bogus UDP packets.

Fixes #13281.

Change-Id: Ifb51b17a1b0807bfd27b144d6037431701184e7b
Reviewed-on: https://go-review.googlesource.com/22126
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/dnsmsg.go
src/net/dnsmsg_test.go

index 5ae21012e3cf150a3c4c081b95fec923d2dda3df..6a1fdfccb851874e6b949b7a691e4a95f70b722c 100644 (file)
@@ -38,46 +38,67 @@ type dnsConn interface {
 
        SetDeadline(time.Time) error
 
-       // readDNSResponse reads a DNS response message from the DNS
-       // transport endpoint and returns the received DNS response
-       // message.
-       readDNSResponse() (*dnsMsg, error)
-
-       // writeDNSQuery writes a DNS query message to the DNS
-       // connection endpoint.
-       writeDNSQuery(*dnsMsg) error
+       // dnsRoundTrip executes a single DNS transaction, returning a
+       // DNS response message for the provided DNS query message.
+       dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
 }
 
-func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
-       b := make([]byte, 512) // see RFC 1035
-       n, err := c.Read(b)
-       if err != nil {
+func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
+       return dnsRoundTripUDP(c, query)
+}
+
+// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
+// "UDP usage" transport mechanism. c should be a packet-oriented connection,
+// such as a *UDPConn.
+func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+       b, ok := query.Pack()
+       if !ok {
+               return nil, errors.New("cannot marshal DNS message")
+       }
+       if _, err := c.Write(b); err != nil {
                return nil, err
        }
-       msg := &dnsMsg{}
-       if !msg.Unpack(b[:n]) {
-               return nil, errors.New("cannot unmarshal DNS message")
+
+       b = make([]byte, 512) // see RFC 1035
+       for {
+               n, err := c.Read(b)
+               if err != nil {
+                       return nil, err
+               }
+               resp := &dnsMsg{}
+               if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
+                       // Ignore invalid responses as they may be malicious
+                       // forgery attempts. Instead continue waiting until
+                       // timeout. See golang.org/issue/13281.
+                       continue
+               }
+               return resp, nil
        }
-       return msg, nil
 }
 
-func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
-       b, ok := msg.Pack()
+func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
+       return dnsRoundTripTCP(c, out)
+}
+
+// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
+// "TCP usage" transport mechanism. c should be a stream-oriented connection,
+// such as a *TCPConn.
+func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+       b, ok := query.Pack()
        if !ok {
-               return errors.New("cannot marshal DNS message")
+               return nil, errors.New("cannot marshal DNS message")
        }
+       l := len(b)
+       b = append([]byte{byte(l >> 8), byte(l)}, b...)
        if _, err := c.Write(b); err != nil {
-               return err
+               return nil, err
        }
-       return nil
-}
 
-func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
-       b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+       b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
        if _, err := io.ReadFull(c, b[:2]); err != nil {
                return nil, err
        }
-       l := int(b[0])<<8 | int(b[1])
+       l = int(b[0])<<8 | int(b[1])
        if l > len(b) {
                b = make([]byte, l)
        }
@@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
        if err != nil {
                return nil, err
        }
-       msg := &dnsMsg{}
-       if !msg.Unpack(b[:n]) {
+       resp := &dnsMsg{}
+       if !resp.Unpack(b[:n]) {
                return nil, errors.New("cannot unmarshal DNS message")
        }
-       return msg, nil
-}
-
-func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
-       b, ok := msg.Pack()
-       if !ok {
-               return errors.New("cannot marshal DNS message")
+       if !resp.IsResponseTo(query) {
+               return nil, errors.New("invalid DNS response")
        }
-       l := uint16(len(b))
-       b = append([]byte{byte(l >> 8), byte(l)}, b...)
-       if _, err := c.Write(b); err != nil {
-               return err
-       }
-       return nil
+       return resp, nil
 }
 
 func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
@@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
                        c.SetDeadline(d)
                }
                out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
-               if err := c.writeDNSQuery(&out); err != nil {
-                       return nil, mapErr(err)
-               }
-               in, err := c.readDNSResponse()
+               in, err := c.dnsRoundTrip(&out)
                if err != nil {
                        return nil, mapErr(err)
                }
-               if in.id != out.id {
-                       return nil, errors.New("DNS message ID mismatch")
-               }
                if in.truncated { // see RFC 5966
                        continue
                }
index 761fb23f142778c90221ede64f3f7be0f73059a3..0b78adb853835c37434f51123f6f9af7da114af1 100644 (file)
@@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
 }
 
 type fakeDNSConn struct {
-       // last query
-       qmu sync.Mutex // guards q
-       q   *dnsMsg
        // reply handler
        rh func(*dnsMsg) (*dnsMsg, error)
 }
@@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error {
        return nil
 }
 
-func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error {
-       f.qmu.Lock()
-       defer f.qmu.Unlock()
-       f.q = q
-       return nil
+func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
+       return f.rh(q)
 }
 
-func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) {
-       f.qmu.Lock()
-       q := f.q
-       f.qmu.Unlock()
-       return f.rh(q)
+// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
+func TestIgnoreDNSForgeries(t *testing.T) {
+       const TestAddr uint32 = 0x80420001
+
+       c, s := Pipe()
+       go func() {
+               b := make([]byte, 512)
+               n, err := s.Read(b)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               msg := &dnsMsg{}
+               if !msg.Unpack(b[:n]) {
+                       t.Fatal("invalid DNS query")
+               }
+
+               s.Write([]byte("garbage DNS response packet"))
+
+               msg.response = true
+               msg.id++ // make invalid ID
+               b, ok := msg.Pack()
+               if !ok {
+                       t.Fatal("failed to pack DNS response")
+               }
+               s.Write(b)
+
+               msg.id-- // restore original ID
+               msg.answer = []dnsRR{
+                       &dnsRR_A{
+                               Hdr: dnsRR_Header{
+                                       Name:     "www.example.com.",
+                                       Rrtype:   dnsTypeA,
+                                       Class:    dnsClassINET,
+                                       Rdlength: 4,
+                               },
+                               A: TestAddr,
+                       },
+               }
+
+               b, ok = msg.Pack()
+               if !ok {
+                       t.Fatal("failed to pack DNS response")
+               }
+               s.Write(b)
+       }()
+
+       msg := &dnsMsg{
+               dnsMsgHdr: dnsMsgHdr{
+                       id: 42,
+               },
+               question: []dnsQuestion{
+                       {
+                               Name:   "www.example.com.",
+                               Qtype:  dnsTypeA,
+                               Qclass: dnsClassINET,
+                       },
+               },
+       }
+
+       resp, err := dnsRoundTripUDP(c, msg)
+       if err != nil {
+               t.Fatalf("dnsRoundTripUDP failed: %v", err)
+       }
+
+       if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
+               t.Error("got address %v, want %v", got, TestAddr)
+       }
 }
index c01381f1903e7737395618df221473415664bce1..5e339c5fbf058d3bbaccf30bc8a30d30d4368ddf 100644 (file)
@@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string {
        }
        return s
 }
+
+// IsResponseTo reports whether m is an acceptable response to query.
+func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
+       if !m.response {
+               return false
+       }
+       if m.id != query.id {
+               return false
+       }
+       if len(m.question) != len(query.question) {
+               return false
+       }
+       for i, q := range m.question {
+               q2 := query.question[i]
+               if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
+                       return false
+               }
+       }
+       return true
+}
index 841c32fa84baddb202797c106673e496bfb20cc0..25bd98cff7695055b30333cb9751db717331f977 100644 (file)
@@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
        }
 }
 
+func TestIsResponseTo(t *testing.T) {
+       // Sample DNS query.
+       query := dnsMsg{
+               dnsMsgHdr: dnsMsgHdr{
+                       id: 42,
+               },
+               question: []dnsQuestion{
+                       {
+                               Name:   "www.example.com.",
+                               Qtype:  dnsTypeA,
+                               Qclass: dnsClassINET,
+                       },
+               },
+       }
+
+       resp := query
+       resp.response = true
+       if !resp.IsResponseTo(&query) {
+               t.Error("got false, want true")
+       }
+
+       badResponses := []dnsMsg{
+               // Different ID.
+               {
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       43,
+                               response: true,
+                       },
+                       question: []dnsQuestion{
+                               {
+                                       Name:   "www.example.com.",
+                                       Qtype:  dnsTypeA,
+                                       Qclass: dnsClassINET,
+                               },
+                       },
+               },
+
+               // Different query name.
+               {
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       42,
+                               response: true,
+                       },
+                       question: []dnsQuestion{
+                               {
+                                       Name:   "www.google.com.",
+                                       Qtype:  dnsTypeA,
+                                       Qclass: dnsClassINET,
+                               },
+                       },
+               },
+
+               // Different query type.
+               {
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       42,
+                               response: true,
+                       },
+                       question: []dnsQuestion{
+                               {
+                                       Name:   "www.example.com.",
+                                       Qtype:  dnsTypeAAAA,
+                                       Qclass: dnsClassINET,
+                               },
+                       },
+               },
+
+               // Different query class.
+               {
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       42,
+                               response: true,
+                       },
+                       question: []dnsQuestion{
+                               {
+                                       Name:   "www.example.com.",
+                                       Qtype:  dnsTypeA,
+                                       Qclass: dnsClassCSNET,
+                               },
+                       },
+               },
+
+               // No questions.
+               {
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       42,
+                               response: true,
+                       },
+               },
+
+               // Extra questions.
+               {
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:       42,
+                               response: true,
+                       },
+                       question: []dnsQuestion{
+                               {
+                                       Name:   "www.example.com.",
+                                       Qtype:  dnsTypeA,
+                                       Qclass: dnsClassINET,
+                               },
+                               {
+                                       Name:   "www.golang.org.",
+                                       Qtype:  dnsTypeAAAA,
+                                       Qclass: dnsClassINET,
+                               },
+                       },
+               },
+       }
+
+       for i := range badResponses {
+               if badResponses[i].IsResponseTo(&query) {
+                       t.Error("%v: got true, want false", i)
+               }
+       }
+}
+
 // Valid DNS SRV reply
 const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
        "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +