]> Cypherpunks repositories - gostls13.git/commitdiff
net: avoid infinite recursion in Windows Resolver.lookupTXT
authorIan Lance Taylor <iant@golang.org>
Tue, 14 Jun 2022 22:53:59 +0000 (15:53 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 15 Jun 2022 00:03:12 +0000 (00:03 +0000)
For #33097

Change-Id: I6138dc844f0b29b01c78a02efc1e1b1ad719b803
Reviewed-on: https://go-review.googlesource.com/c/go/+/412139
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>

src/net/lookup_windows.go
src/net/lookup_windows_test.go

index 051f47da392c3ecf426c27bd4c176ce737ec7c8c..9ff39c74a4e2083ca79adb81e136a52b87887910 100644 (file)
@@ -329,7 +329,7 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
 
 func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
        if r.preferGoOverWindows() {
-               return r.lookupTXT(ctx, name)
+               return r.goLookupTXT(ctx, name)
        }
        // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
        acquireThread()
index 9254733364595933d311ffd0493a7d576e6c3643..823ec088b839a93b91fd23aef53a8cbf8416484e 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "bytes"
+       "context"
        "encoding/json"
        "errors"
        "fmt"
@@ -26,104 +27,117 @@ func toJson(v any) string {
        return string(data)
 }
 
+func testLookup(t *testing.T, fn func(*testing.T, *Resolver, string)) {
+       for _, def := range []bool{true, false} {
+               def := def
+               for _, server := range nslookupTestServers {
+                       server := server
+                       var name string
+                       if def {
+                               name = "default/"
+                       } else {
+                               name = "go/"
+                       }
+                       t.Run(name+server, func(t *testing.T) {
+                               t.Parallel()
+                               r := DefaultResolver
+                               if !def {
+                                       r = &Resolver{PreferGo: true}
+                               }
+                               fn(t, r, server)
+                       })
+               }
+       }
+}
+
 func TestNSLookupMX(t *testing.T) {
        testenv.MustHaveExternalNetwork(t)
 
-       for _, server := range nslookupTestServers {
-               mx, err := LookupMX(server)
+       testLookup(t, func(t *testing.T, r *Resolver, server string) {
+               mx, err := r.LookupMX(context.Background(), server)
                if err != nil {
-                       t.Error(err)
-                       continue
+                       t.Fatal(err)
                }
                if len(mx) == 0 {
-                       t.Errorf("no results")
-                       continue
+                       t.Fatal("no results")
                }
                expected, err := nslookupMX(server)
                if err != nil {
-                       t.Logf("skipping failed nslookup %s test: %s", server, err)
+                       t.Skipf("skipping failed nslookup %s test: %s", server, err)
                }
                sort.Sort(byPrefAndHost(expected))
                sort.Sort(byPrefAndHost(mx))
                if !reflect.DeepEqual(expected, mx) {
                        t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx))
                }
-       }
+       })
 }
 
 func TestNSLookupCNAME(t *testing.T) {
        testenv.MustHaveExternalNetwork(t)
 
-       for _, server := range nslookupTestServers {
-               cname, err := LookupCNAME(server)
+       testLookup(t, func(t *testing.T, r *Resolver, server string) {
+               cname, err := r.LookupCNAME(context.Background(), server)
                if err != nil {
-                       t.Errorf("failed %s: %s", server, err)
-                       continue
+                       t.Fatalf("failed %s: %s", server, err)
                }
                if cname == "" {
-                       t.Errorf("no result %s", server)
+                       t.Fatalf("no result %s", server)
                }
                expected, err := nslookupCNAME(server)
                if err != nil {
-                       t.Logf("skipping failed nslookup %s test: %s", server, err)
-                       continue
+                       t.Skipf("skipping failed nslookup %s test: %s", server, err)
                }
                if expected != cname {
                        t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname)
                }
-       }
+       })
 }
 
 func TestNSLookupNS(t *testing.T) {
        testenv.MustHaveExternalNetwork(t)
 
-       for _, server := range nslookupTestServers {
-               ns, err := LookupNS(server)
+       testLookup(t, func(t *testing.T, r *Resolver, server string) {
+               ns, err := r.LookupNS(context.Background(), server)
                if err != nil {
-                       t.Errorf("failed %s: %s", server, err)
-                       continue
+                       t.Fatalf("failed %s: %s", server, err)
                }
                if len(ns) == 0 {
-                       t.Errorf("no results")
-                       continue
+                       t.Fatal("no results")
                }
                expected, err := nslookupNS(server)
                if err != nil {
-                       t.Logf("skipping failed nslookup %s test: %s", server, err)
-                       continue
+                       t.Skipf("skipping failed nslookup %s test: %s", server, err)
                }
                sort.Sort(byHost(expected))
                sort.Sort(byHost(ns))
                if !reflect.DeepEqual(expected, ns) {
                        t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns)
                }
-       }
+       })
 }
 
 func TestNSLookupTXT(t *testing.T) {
        testenv.MustHaveExternalNetwork(t)
 
-       for _, server := range nslookupTestServers {
-               txt, err := LookupTXT(server)
+       testLookup(t, func(t *testing.T, r *Resolver, server string) {
+               txt, err := r.LookupTXT(context.Background(), server)
                if err != nil {
-                       t.Errorf("failed %s: %s", server, err)
-                       continue
+                       t.Fatalf("failed %s: %s", server, err)
                }
                if len(txt) == 0 {
-                       t.Errorf("no results")
-                       continue
+                       t.Fatalf("no results")
                }
                expected, err := nslookupTXT(server)
                if err != nil {
-                       t.Logf("skipping failed nslookup %s test: %s", server, err)
-                       continue
+                       t.Skipf("skipping failed nslookup %s test: %s", server, err)
                }
                sort.Strings(expected)
                sort.Strings(txt)
                if !reflect.DeepEqual(expected, txt) {
                        t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt))
                }
-       }
+       })
 }
 
 func TestLookupLocalPTR(t *testing.T) {