]> Cypherpunks repositories - gostls13.git/commitdiff
net: don't crash on unexpected DNS SRV responses
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 3 May 2011 14:10:48 +0000 (07:10 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 3 May 2011 14:10:48 +0000 (07:10 -0700)
Fixes #1350

R=rsc, bradfitzwork
CC=golang-dev
https://golang.org/cl/4432089

src/pkg/net/dnsclient.go
src/pkg/net/dnsmsg.go
src/pkg/net/dnsmsg_test.go [new file with mode: 0644]

index 89f2409bf6548ad66728e457141a8ff9ab5a93e1..3466003fab8eae531237df4c91a96ec8e16809cc 100644 (file)
@@ -121,15 +121,19 @@ func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs
 Cname:
        for cnameloop := 0; cnameloop < 10; cnameloop++ {
                addrs = addrs[0:0]
-               for i := 0; i < len(dns.answer); i++ {
-                       rr := dns.answer[i]
+               for _, rr := range dns.answer {
+                       if _, justHeader := rr.(*dnsRR_Header); justHeader {
+                               // Corrupt record: we only have a
+                               // header. That header might say it's
+                               // of type qtype, but we don't
+                               // actually have it. Skip.
+                               continue
+                       }
                        h := rr.Header()
                        if h.Class == dnsClassINET && h.Name == name {
                                switch h.Rrtype {
                                case qtype:
-                                       n := len(addrs)
-                                       addrs = addrs[0 : n+1]
-                                       addrs[n] = rr
+                                       addrs = append(addrs, rr)
                                case dnsTypeCNAME:
                                        // redirect to cname
                                        name = rr.(*dnsRR_CNAME).Cname
@@ -181,8 +185,7 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs
 
 func convertRR_A(records []dnsRR) []IP {
        addrs := make([]IP, len(records))
-       for i := 0; i < len(records); i++ {
-               rr := records[i]
+       for i, rr := range records {
                a := rr.(*dnsRR_A).A
                addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a))
        }
@@ -191,8 +194,7 @@ func convertRR_A(records []dnsRR) []IP {
 
 func convertRR_AAAA(records []dnsRR) []IP {
        addrs := make([]IP, len(records))
-       for i := 0; i < len(records); i++ {
-               rr := records[i]
+       for i, rr := range records {
                a := make(IP, 16)
                copy(a, rr.(*dnsRR_AAAA).AAAA[:])
                addrs[i] = a
@@ -384,9 +386,7 @@ func goLookupCNAME(name string) (cname string, err os.Error) {
        if err != nil {
                return
        }
-       if len(rr) >= 0 {
-               cname = rr[0].(*dnsRR_CNAME).Cname
-       }
+       cname = rr[0].(*dnsRR_CNAME).Cname
        return
 }
 
@@ -410,8 +410,8 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.
                return
        }
        addrs = make([]*SRV, len(records))
-       for i := 0; i < len(records); i++ {
-               r := records[i].(*dnsRR_SRV)
+       for i, rr := range records {
+               r := rr.(*dnsRR_SRV)
                addrs[i] = &SRV{r.Target, r.Port, r.Priority, r.Weight}
        }
        return
index 7b8e5c6d3f366d4adf4c11818e2a79d03c7773f8..731efe26a444be7c81d4f8ce5410f2c09cb1e4d5 100644 (file)
@@ -715,24 +715,35 @@ func (dns *dnsMsg) Unpack(msg []byte) bool {
 
        // Arrays.
        dns.question = make([]dnsQuestion, dh.Qdcount)
-       dns.answer = make([]dnsRR, dh.Ancount)
-       dns.ns = make([]dnsRR, dh.Nscount)
-       dns.extra = make([]dnsRR, dh.Arcount)
+       dns.answer = make([]dnsRR, 0, dh.Ancount)
+       dns.ns = make([]dnsRR, 0, dh.Nscount)
+       dns.extra = make([]dnsRR, 0, dh.Arcount)
+
+       var rec dnsRR
 
        for i := 0; i < len(dns.question); i++ {
                off, ok = unpackStruct(&dns.question[i], msg, off)
        }
-       for i := 0; i < len(dns.answer); i++ {
-               dns.answer[i], off, ok = unpackRR(msg, off)
-       }
-       for i := 0; i < len(dns.ns); i++ {
-               dns.ns[i], off, ok = unpackRR(msg, off)
+       for i := 0; i < int(dh.Ancount); i++ {
+               rec, off, ok = unpackRR(msg, off)
+               if !ok {
+                       return false
+               }
+               dns.answer = append(dns.answer, rec)
        }
-       for i := 0; i < len(dns.extra); i++ {
-               dns.extra[i], off, ok = unpackRR(msg, off)
+       for i := 0; i < int(dh.Nscount); i++ {
+               rec, off, ok = unpackRR(msg, off)
+               if !ok {
+                       return false
+               }
+               dns.ns = append(dns.ns, rec)
        }
-       if !ok {
-               return false
+       for i := 0; i < int(dh.Arcount); i++ {
+               rec, off, ok = unpackRR(msg, off)
+               if !ok {
+                       return false
+               }
+               dns.extra = append(dns.extra, rec)
        }
        //      if off != len(msg) {
        //              println("extra bytes in dns packet", off, "<", len(msg));
diff --git a/src/pkg/net/dnsmsg_test.go b/src/pkg/net/dnsmsg_test.go
new file mode 100644 (file)
index 0000000..06152a0
--- /dev/null
@@ -0,0 +1,100 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "encoding/hex"
+       "testing"
+)
+
+func TestDNSParseSRVReply(t *testing.T) {
+       data, err := hex.DecodeString(dnsSRVReply)
+       if err != nil {
+               t.Fatal(err)
+       }
+       msg := new(dnsMsg)
+       ok := msg.Unpack(data)
+       if !ok {
+               t.Fatalf("unpacking packet failed")
+       }
+       if g, e := len(msg.answer), 5; g != e {
+               t.Errorf("len(msg.answer) = %d; want %d", g, e)
+       }
+       for idx, rr := range msg.answer {
+               if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
+                       t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
+               }
+               if _, ok := rr.(*dnsRR_SRV); !ok {
+                       t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
+               }
+       }
+       _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV))
+       if err != nil {
+               t.Fatalf("answer: %v", err)
+       }
+       if g, e := len(addrs), 5; g != e {
+               t.Errorf("len(addrs) = %d; want %d", g, e)
+               t.Logf("addrs = %#v", addrs)
+       }
+}
+
+func TestDNSParseCorruptSRVReply(t *testing.T) {
+       data, err := hex.DecodeString(dnsSRVCorruptReply)
+       if err != nil {
+               t.Fatal(err)
+       }
+       msg := new(dnsMsg)
+       ok := msg.Unpack(data)
+       if !ok {
+               t.Fatalf("unpacking packet failed")
+       }
+       if g, e := len(msg.answer), 5; g != e {
+               t.Errorf("len(msg.answer) = %d; want %d", g, e)
+       }
+       for idx, rr := range msg.answer {
+               if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
+                       t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
+               }
+               if idx == 4 {
+                       if _, ok := rr.(*dnsRR_Header); !ok {
+                               t.Errorf("answer[%d] = %T; want *dnsRR_Header", idx, rr)
+                       }
+               } else {
+                       if _, ok := rr.(*dnsRR_SRV); !ok {
+                               t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
+                       }
+               }
+       }
+       _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV))
+       if err != nil {
+               t.Fatalf("answer: %v", err)
+       }
+       if g, e := len(addrs), 4; g != e {
+               t.Errorf("len(addrs) = %d; want %d", g, e)
+               t.Logf("addrs = %#v", addrs)
+       }
+}
+
+// Valid DNS SRV reply
+const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
+       "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
+       "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
+       "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
+       "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
+       "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
+       "72016c06676f6f676c6503636f6d00c00c002100010000012c00210014000014950c78" +
+       "6d70702d73657276657231016c06676f6f676c6503636f6d00"
+
+// Corrupt DNS SRV reply, with its final RR having a bogus length
+// (perhaps it was truncated, or it's malicious) The mutation is the
+// capital "FF" below, instead of the proper "21".
+const dnsSRVCorruptReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
+       "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
+       "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
+       "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
+       "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
+       "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
+       "72016c06676f6f676c6503636f6d00c00c002100010000012c00FF0014000014950c78" +
+       "6d70702d73657276657231016c06676f6f676c6503636f6d00"