]> Cypherpunks repositories - gostls13.git/commitdiff
net: filter bad names from Lookup functions instead of hard failing
authorRoland Shoemaker <roland@golang.org>
Fri, 2 Jul 2021 17:25:49 +0000 (10:25 -0700)
committerRoland Shoemaker <roland@golang.org>
Thu, 8 Jul 2021 17:53:43 +0000 (17:53 +0000)
Instead of hard failing on a single bad record, filter the bad records
and return anything valid. This only applies to the methods which can
return multiple records, LookupMX, LookupNS, LookupSRV, and LookupAddr.

When bad results are filtered out, also return an error, indicating
that this filtering has happened.

Updates #46241
Fixes #46979

Change-Id: I6493e0002beaf89f5a9795333a93605abd30d171
Reviewed-on: https://go-review.googlesource.com/c/go/+/332549
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
src/net/dnsclient_unix_test.go
src/net/lookup.go

index 59cdd2bf3e2867c08b0297448940f39eef26ca46..350ad5def797bc9b95aa4a9c1a56b58e8f5db7f2 100644 (file)
@@ -1846,6 +1846,17 @@ func TestCVE202133195(t *testing.T) {
                                                        Target: dnsmessage.MustNewName("<html>.golang.org."),
                                                },
                                        },
+                                       dnsmessage.Resource{
+                                               Header: dnsmessage.ResourceHeader{
+                                                       Name:   n,
+                                                       Type:   dnsmessage.TypeSRV,
+                                                       Class:  dnsmessage.ClassINET,
+                                                       Length: 4,
+                                               },
+                                               Body: &dnsmessage.SRVResource{
+                                                       Target: dnsmessage.MustNewName("good.golang.org."),
+                                               },
+                                       },
                                )
                        case dnsmessage.TypeMX:
                                r.Answers = append(r.Answers,
@@ -1860,6 +1871,17 @@ func TestCVE202133195(t *testing.T) {
                                                        MX: dnsmessage.MustNewName("<html>.golang.org."),
                                                },
                                        },
+                                       dnsmessage.Resource{
+                                               Header: dnsmessage.ResourceHeader{
+                                                       Name:   dnsmessage.MustNewName("good.golang.org."),
+                                                       Type:   dnsmessage.TypeMX,
+                                                       Class:  dnsmessage.ClassINET,
+                                                       Length: 4,
+                                               },
+                                               Body: &dnsmessage.MXResource{
+                                                       MX: dnsmessage.MustNewName("good.golang.org."),
+                                               },
+                                       },
                                )
                        case dnsmessage.TypeNS:
                                r.Answers = append(r.Answers,
@@ -1874,6 +1896,17 @@ func TestCVE202133195(t *testing.T) {
                                                        NS: dnsmessage.MustNewName("<html>.golang.org."),
                                                },
                                        },
+                                       dnsmessage.Resource{
+                                               Header: dnsmessage.ResourceHeader{
+                                                       Name:   dnsmessage.MustNewName("good.golang.org."),
+                                                       Type:   dnsmessage.TypeNS,
+                                                       Class:  dnsmessage.ClassINET,
+                                                       Length: 4,
+                                               },
+                                               Body: &dnsmessage.NSResource{
+                                                       NS: dnsmessage.MustNewName("good.golang.org."),
+                                               },
+                                       },
                                )
                        case dnsmessage.TypePTR:
                                r.Answers = append(r.Answers,
@@ -1888,6 +1921,17 @@ func TestCVE202133195(t *testing.T) {
                                                        PTR: dnsmessage.MustNewName("<html>.golang.org."),
                                                },
                                        },
+                                       dnsmessage.Resource{
+                                               Header: dnsmessage.ResourceHeader{
+                                                       Name:   dnsmessage.MustNewName("good.golang.org."),
+                                                       Type:   dnsmessage.TypePTR,
+                                                       Class:  dnsmessage.ClassINET,
+                                                       Length: 4,
+                                               },
+                                               Body: &dnsmessage.PTRResource{
+                                                       PTR: dnsmessage.MustNewName("good.golang.org."),
+                                               },
+                                       },
                                )
                        }
                        return r, nil
@@ -1903,59 +1947,139 @@ func TestCVE202133195(t *testing.T) {
        defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
        testHookHostsPath = "testdata/hosts"
 
-       _, err := r.LookupCNAME(context.Background(), "golang.org")
-       if expected := "lookup golang.org: CNAME target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("Resolver.LookupCNAME returned unexpected error, got %q, want %q", err, expected)
-       }
-       _, err = LookupCNAME("golang.org")
-       if expected := "lookup golang.org: CNAME target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("LookupCNAME returned unexpected error, got %q, want %q", err, expected)
-       }
-
-       _, _, err = r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
-       if expected := "lookup golang.org: SRV target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
-       }
-       _, _, err = LookupSRV("target", "tcp", "golang.org")
-       if expected := "lookup golang.org: SRV target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
-       }
-
-       _, _, err = r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
-       if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
-       }
-       _, _, err = LookupSRV("hdr", "tcp", "golang.org.")
-       if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
-       }
-
-       _, err = r.LookupMX(context.Background(), "golang.org")
-       if expected := "lookup golang.org: MX target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("Resolver.LookupMX returned unexpected error, got %q, want %q", err, expected)
-       }
-       _, err = LookupMX("golang.org")
-       if expected := "lookup golang.org: MX target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("LookupMX returned unexpected error, got %q, want %q", err, expected)
+       tests := []struct {
+               name string
+               f    func(*testing.T)
+       }{
+               {
+                       name: "CNAME",
+                       f: func(t *testing.T) {
+                               expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+                               _, err := r.LookupCNAME(context.Background(), "golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               _, err = LookupCNAME("golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                       },
+               },
+               {
+                       name: "SRV (bad record)",
+                       f: func(t *testing.T) {
+                               expected := []*SRV{
+                                       {
+                                               Target: "good.golang.org.",
+                                       },
+                               }
+                               expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+                               _, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                               _, records, err = LookupSRV("target", "tcp", "golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Errorf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                       },
+               },
+               {
+                       name: "SRV (bad header)",
+                       f: func(t *testing.T) {
+                               _, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
+                               if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+                                       t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
+                               }
+                               _, _, err = LookupSRV("hdr", "tcp", "golang.org.")
+                               if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+                                       t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
+                               }
+                       },
+               },
+               {
+                       name: "MX",
+                       f: func(t *testing.T) {
+                               expected := []*MX{
+                                       {
+                                               Host: "good.golang.org.",
+                                       },
+                               }
+                               expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+                               records, err := r.LookupMX(context.Background(), "golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                               records, err = LookupMX("golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                       },
+               },
+               {
+                       name: "NS",
+                       f: func(t *testing.T) {
+                               expected := []*NS{
+                                       {
+                                               Host: "good.golang.org.",
+                                       },
+                               }
+                               expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+                               records, err := r.LookupNS(context.Background(), "golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                               records, err = LookupNS("golang.org")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                       },
+               },
+               {
+                       name: "Addr",
+                       f: func(t *testing.T) {
+                               expected := []string{"good.golang.org."}
+                               expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"}
+                               records, err := r.LookupAddr(context.Background(), "192.0.2.42")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                               records, err = LookupAddr("192.0.2.42")
+                               if err.Error() != expectedErr.Error() {
+                                       t.Fatalf("unexpected error: %s", err)
+                               }
+                               if !reflect.DeepEqual(records, expected) {
+                                       t.Error("Unexpected record set")
+                               }
+                       },
+               },
        }
 
-       _, err = r.LookupNS(context.Background(), "golang.org")
-       if expected := "lookup golang.org: NS target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("Resolver.LookupNS returned unexpected error, got %q, want %q", err, expected)
-       }
-       _, err = LookupNS("golang.org")
-       if expected := "lookup golang.org: NS target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("LookupNS returned unexpected error, got %q, want %q", err, expected)
+       for _, tc := range tests {
+               t.Run(tc.name, tc.f)
        }
 
-       _, err = r.LookupAddr(context.Background(), "192.0.2.42")
-       if expected := "lookup 192.0.2.42: PTR target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("Resolver.LookupAddr returned unexpected error, got %q, want %q", err, expected)
-       }
-       _, err = LookupAddr("192.0.2.42")
-       if expected := "lookup 192.0.2.42: PTR target is invalid"; err == nil || err.Error() != expected {
-               t.Errorf("LookupAddr returned unexpected error, got %q, want %q", err, expected)
-       }
 }
 
 func TestNullMX(t *testing.T) {
index b5af3a0f86735c0c25c6baef69ca9f989a9396db..d350ef7fc03b0b6259a562e5f28b04a40157fdf0 100644 (file)
@@ -424,7 +424,7 @@ func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error)
                return "", err
        }
        if !isDomainName(cname) {
-               return "", &DNSError{Err: "CNAME target is invalid", Name: host}
+               return "", &DNSError{Err: errMalformedDNSRecordsDetail, Name: host}
        }
        return cname, nil
 }
@@ -440,7 +440,9 @@ func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error)
 // and proto are empty strings, LookupSRV looks up name directly.
 //
 // The returned service names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
        return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
 }
@@ -456,7 +458,9 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
 // and proto are empty strings, LookupSRV looks up name directly.
 //
 // The returned service names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
        cname, addrs, err := r.lookupSRV(ctx, service, proto, name)
        if err != nil {
@@ -465,21 +469,28 @@ func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (
        if cname != "" && !isDomainName(cname) {
                return "", nil, &DNSError{Err: "SRV header name is invalid", Name: name}
        }
+       filteredAddrs := make([]*SRV, 0, len(addrs))
        for _, addr := range addrs {
                if addr == nil {
                        continue
                }
                if !isDomainName(addr.Target) {
-                       return "", nil, &DNSError{Err: "SRV target is invalid", Name: name}
+                       continue
                }
+               filteredAddrs = append(filteredAddrs, addr)
+       }
+       if len(addrs) != len(filteredAddrs) {
+               return cname, filteredAddrs, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
        }
-       return cname, addrs, nil
+       return cname, filteredAddrs, nil
 }
 
 // LookupMX returns the DNS MX records for the given domain name sorted by preference.
 //
 // The returned mail server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 //
 // LookupMX uses context.Background internally; to specify the context, use
 // Resolver.LookupMX.
@@ -490,12 +501,15 @@ func LookupMX(name string) ([]*MX, error) {
 // LookupMX returns the DNS MX records for the given domain name sorted by preference.
 //
 // The returned mail server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
        records, err := r.lookupMX(ctx, name)
        if err != nil {
                return nil, err
        }
+       filteredMX := make([]*MX, 0, len(records))
        for _, mx := range records {
                if mx == nil {
                        continue
@@ -503,16 +517,22 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
                // Bypass the hostname validity check for targets which contain only a dot,
                // as this is used to represent a 'Null' MX record.
                if mx.Host != "." && !isDomainName(mx.Host) {
-                       return nil, &DNSError{Err: "MX target is invalid", Name: name}
+                       continue
                }
+               filteredMX = append(filteredMX, mx)
+       }
+       if len(records) != len(filteredMX) {
+               return filteredMX, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
        }
-       return records, nil
+       return filteredMX, nil
 }
 
 // LookupNS returns the DNS NS records for the given domain name.
 //
 // The returned name server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 //
 // LookupNS uses context.Background internally; to specify the context, use
 // Resolver.LookupNS.
@@ -523,21 +543,28 @@ func LookupNS(name string) ([]*NS, error) {
 // LookupNS returns the DNS NS records for the given domain name.
 //
 // The returned name server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
        records, err := r.lookupNS(ctx, name)
        if err != nil {
                return nil, err
        }
+       filteredNS := make([]*NS, 0, len(records))
        for _, ns := range records {
                if ns == nil {
                        continue
                }
                if !isDomainName(ns.Host) {
-                       return nil, &DNSError{Err: "NS target is invalid", Name: name}
+                       continue
                }
+               filteredNS = append(filteredNS, ns)
+       }
+       if len(records) != len(filteredNS) {
+               return filteredNS, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
        }
-       return records, nil
+       return filteredNS, nil
 }
 
 // LookupTXT returns the DNS TXT records for the given domain name.
@@ -557,7 +584,8 @@ func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error)
 // of names mapping to that address.
 //
 // The returned names are validated to be properly formatted presentation-format
-// domain names.
+// domain names. If the response contains invalid names, those records are filtered
+// out and an error will be returned alongside the the remaining results, if any.
 //
 // When using the host C library resolver, at most one result will be
 // returned. To bypass the host resolver, use a custom Resolver.
@@ -572,16 +600,26 @@ func LookupAddr(addr string) (names []string, err error) {
 // of names mapping to that address.
 //
 // The returned names are validated to be properly formatted presentation-format
-// domain names.
+// domain names. If the response contains invalid names, those records are filtered
+// out and an error will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
        names, err := r.lookupAddr(ctx, addr)
        if err != nil {
                return nil, err
        }
+       filteredNames := make([]string, 0, len(names))
        for _, name := range names {
-               if !isDomainName(name) {
-                       return nil, &DNSError{Err: "PTR target is invalid", Name: addr}
+               if isDomainName(name) {
+                       filteredNames = append(filteredNames, name)
                }
        }
-       return names, nil
+       if len(names) != len(filteredNames) {
+               return filteredNames, &DNSError{Err: errMalformedDNSRecordsDetail, Name: addr}
+       }
+       return filteredNames, nil
 }
+
+// errMalformedDNSRecordsDetail is the DNSError detail which is returned when a Resolver.Lookup...
+// method recieves DNS records which contain invalid DNS names. This may be returned alongside
+// results which have had the malformed records filtered out.
+var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"