]> Cypherpunks repositories - gostls13.git/commitdiff
net: avoid race on test hooks with DNS goroutines
authorIan Lance Taylor <iant@golang.org>
Fri, 8 Dec 2017 04:30:28 +0000 (20:30 -0800)
committerIan Lance Taylor <iant@golang.org>
Fri, 8 Dec 2017 05:12:13 +0000 (05:12 +0000)
The DNS code can start goroutines and not wait for them to complete.
This does no harm, but in tests this can cause a race condition with
the test hooks that are installed and unintalled around the tests.
Add a WaitGroup that tests of DNS can use to avoid the race.

Fixes #21090

Change-Id: I6c1443a9c2378e8b89d0ab1d6390c0e3e726b0ce
Reviewed-on: https://go-review.googlesource.com/82795
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/internal/singleflight/singleflight.go
src/net/cgo_unix_test.go
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/lookup.go
src/net/lookup_test.go
src/net/netgo_unix_test.go

index de81ac87b9f6cc53f4a285f5b56a826cccaa7a48..1e9960d575d4c7219fcf5cbb5e20ab44b5bddb12 100644 (file)
@@ -65,8 +65,10 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e
 }
 
 // DoChan is like Do but returns a channel that will receive the
-// results when they are ready.
-func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
+// results when they are ready. The second result is true if the function
+// will eventually be called, false if it will not (because there is
+// a pending request with this key).
+func (g *Group) DoChan(key string, fn func() (interface{}, error)) (<-chan Result, bool) {
        ch := make(chan Result, 1)
        g.mu.Lock()
        if g.m == nil {
@@ -76,7 +78,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result
                c.dups++
                c.chans = append(c.chans, ch)
                g.mu.Unlock()
-               return ch
+               return ch, false
        }
        c := &call{chans: []chan<- Result{ch}}
        c.wg.Add(1)
@@ -85,7 +87,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result
 
        go g.doCall(c, key, fn)
 
-       return ch
+       return ch, true
 }
 
 // doCall handles the single call for a key.
index e861c7aa1f90d8a1372a08d858d778d16c36fff3..b476a6d62686ea5e15dae052871d4106de894a42 100644 (file)
@@ -13,6 +13,7 @@ import (
 )
 
 func TestCgoLookupIP(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        ctx := context.Background()
        _, err, ok := cgoLookupIP(ctx, "localhost")
        if !ok {
@@ -24,6 +25,7 @@ func TestCgoLookupIP(t *testing.T) {
 }
 
 func TestCgoLookupIPWithCancel(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel()
        _, err, ok := cgoLookupIP(ctx, "localhost")
@@ -36,6 +38,7 @@ func TestCgoLookupIPWithCancel(t *testing.T) {
 }
 
 func TestCgoLookupPort(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        ctx := context.Background()
        _, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
        if !ok {
@@ -47,6 +50,7 @@ func TestCgoLookupPort(t *testing.T) {
 }
 
 func TestCgoLookupPortWithCancel(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel()
        _, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
@@ -59,6 +63,7 @@ func TestCgoLookupPortWithCancel(t *testing.T) {
 }
 
 func TestCgoLookupPTR(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        ctx := context.Background()
        _, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
        if !ok {
@@ -70,6 +75,7 @@ func TestCgoLookupPTR(t *testing.T) {
 }
 
 func TestCgoLookupPTRWithCancel(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel()
        _, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
index acbf6c3b2ad99cec3d73377e904f83de1ab6ec76..9026fd8c74b1790fc4bbba636798fb41eab7d4d7 100644 (file)
@@ -479,7 +479,9 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
        var lastErr error
        for _, fqdn := range conf.nameList(name) {
                for _, qtype := range qtypes {
+                       dnsWaitGroup.Add(1)
                        go func(qtype uint16) {
+                               defer dnsWaitGroup.Done()
                                cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
                                lane <- racer{cname, rrs, err}
                        }(qtype)
index 73b628c1b5f328d50dfafa4bd6ecfb19432e62d1..295ed9770c0066c37ace6e97f5ddaac3d415bd7f 100644 (file)
@@ -203,6 +203,7 @@ var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.
 
 // Issue 13705: don't try to resolve onion addresses, etc
 func TestLookupTorOnion(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
        addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
        if err != nil {
@@ -300,6 +301,8 @@ var updateResolvConfTests = []struct {
 }
 
 func TestUpdateResolvConf(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
 
        conf, err := newResolvConfTest()
@@ -455,6 +458,8 @@ var goLookupIPWithResolverConfigTests = []struct {
 }
 
 func TestGoLookupIPWithResolverConfig(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                switch s {
                case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
@@ -547,6 +552,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
 
 // Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
 func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
@@ -603,6 +610,8 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
 // querying the original name instead of an error encountered
 // querying a generated name.
 func TestErrorForOriginalNameWhenSearching(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        const fqdn = "doesnotexist.domain"
 
        conf, err := newResolvConfTest()
@@ -657,6 +666,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
 
 // Issue 15434. If a name server gives a lame referral, continue to the next.
 func TestIgnoreLameReferrals(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -889,6 +900,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
 
 // Issue 16865. If a name server times out, continue to the next.
 func TestRetryTimeout(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -945,6 +958,8 @@ func TestRotate(t *testing.T) {
 }
 
 func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
+       defer dnsWaitGroup.Wait()
+
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -1008,6 +1023,8 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg {
 // Issue 17448. With StrictErrors enabled, temporary errors should make
 // LookupIP fail rather than return a partial result.
 func TestStrictErrorsLookupIP(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -1256,6 +1273,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
 // Issue 17448. With StrictErrors enabled, temporary errors should make
 // LookupTXT stop walking the search list.
 func TestStrictErrorsLookupTXT(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
        conf, err := newResolvConfTest()
        if err != nil {
                t.Fatal(err)
@@ -1312,3 +1331,25 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
                }
        }
 }
+
+// Test for a race between uninstalling the test hooks and closing a
+// socket connection. This used to fail when testing with -race.
+func TestDNSGoroutineRace(t *testing.T) {
+       defer dnsWaitGroup.Wait()
+
+       fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
+               time.Sleep(10 * time.Microsecond)
+               return nil, poll.ErrTimeout
+       }}
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+       // The timeout here is less than the timeout used by the server,
+       // so the goroutine started to query the (fake) server will hang
+       // around after this test is done if we don't call dnsWaitGroup.Wait.
+       ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
+       defer cancel()
+       _, err := r.LookupIPAddr(ctx, "where.are.they.now")
+       if err == nil {
+               t.Fatal("fake DNS lookup unexpectedly succeeded")
+       }
+}
index c9f327050afad8a7c6323d63985b5c3fd18d18f5..85e472932fcda7a8bb8d6bbfea488e202cb0636a 100644 (file)
@@ -8,6 +8,7 @@ import (
        "context"
        "internal/nettrace"
        "internal/singleflight"
+       "sync"
 )
 
 // protocols contains minimal mappings between internet protocol
@@ -53,6 +54,10 @@ var services = map[string]map[string]int{
        },
 }
 
+// dnsWaitGroup can be used by tests to wait for all DNS goroutines to
+// complete. This avoids races on the test hooks.
+var dnsWaitGroup sync.WaitGroup
+
 const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow
 
 func lookupProtocolMap(name string) (int, error) {
@@ -189,9 +194,14 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
                resolverFunc = alt
        }
 
-       ch := lookupGroup.DoChan(host, func() (interface{}, error) {
+       dnsWaitGroup.Add(1)
+       ch, called := lookupGroup.DoChan(host, func() (interface{}, error) {
+               defer dnsWaitGroup.Done()
                return testHookLookupIP(ctx, resolverFunc, host)
        })
+       if !called {
+               dnsWaitGroup.Done()
+       }
 
        select {
        case <-ctx.Done():
index e3bf114a8e2cc620b000da1c20f7dba81c6e28d4..bfb872551c04c77f8bf9afc91a5189b419aba84b 100644 (file)
@@ -105,6 +105,8 @@ func TestLookupGmailMX(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupGmailMXTests {
                mxs, err := LookupMX(tt.name)
                if err != nil {
@@ -137,6 +139,8 @@ func TestLookupGmailNS(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupGmailNSTests {
                nss, err := LookupNS(tt.name)
                if err != nil {
@@ -170,6 +174,8 @@ func TestLookupGmailTXT(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupGmailTXTTests {
                txts, err := LookupTXT(tt.name)
                if err != nil {
@@ -205,6 +211,8 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) {
                t.Skip("both IPv4 and IPv6 are required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupGooglePublicDNSAddrTests {
                names, err := LookupAddr(tt.addr)
                if err != nil {
@@ -226,6 +234,8 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) {
                t.Skip("IPv6 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        addrs, err := LookupHost("localhost")
        if err != nil {
                t.Fatal(err)
@@ -262,6 +272,8 @@ func TestLookupCNAME(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupCNAMETests {
                cname, err := LookupCNAME(tt.name)
                if err != nil {
@@ -289,6 +301,8 @@ func TestLookupGoogleHost(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupGoogleHostTests {
                addrs, err := LookupHost(tt.name)
                if err != nil {
@@ -313,6 +327,8 @@ func TestLookupLongTXT(t *testing.T) {
                testenv.MustHaveExternalNetwork(t)
        }
 
+       defer dnsWaitGroup.Wait()
+
        txts, err := LookupTXT("golang.rsc.io")
        if err != nil {
                t.Fatal(err)
@@ -343,6 +359,8 @@ func TestLookupGoogleIP(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        for _, tt := range lookupGoogleIPTests {
                ips, err := LookupIP(tt.name)
                if err != nil {
@@ -378,6 +396,7 @@ var revAddrTests = []struct {
 }
 
 func TestReverseAddress(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        for i, tt := range revAddrTests {
                a, err := reverseaddr(tt.Addr)
                if len(tt.ErrPrefix) > 0 && err == nil {
@@ -401,6 +420,8 @@ func TestDNSFlood(t *testing.T) {
                t.Skip("test disabled; use -dnsflood to enable")
        }
 
+       defer dnsWaitGroup.Wait()
+
        var N = 5000
        if runtime.GOOS == "darwin" {
                // On Darwin this test consumes kernel threads much
@@ -482,6 +503,8 @@ func TestLookupDotsWithLocalSource(t *testing.T) {
                testenv.MustHaveExternalNetwork(t)
        }
 
+       defer dnsWaitGroup.Wait()
+
        for i, fn := range []func() func(){forceGoDNS, forceCgoDNS} {
                fixup := fn()
                if fixup == nil {
@@ -527,6 +550,8 @@ func TestLookupDotsWithRemoteSource(t *testing.T) {
                t.Skip("IPv4 is required")
        }
 
+       defer dnsWaitGroup.Wait()
+
        if fixup := forceGoDNS(); fixup != nil {
                testDots(t, "go")
                fixup()
@@ -747,6 +772,9 @@ func TestLookupNonLDH(t *testing.T) {
        if runtime.GOOS == "nacl" {
                t.Skip("skip on nacl")
        }
+
+       defer dnsWaitGroup.Wait()
+
        if fixup := forceGoDNS(); fixup != nil {
                defer fixup()
        }
index 47901b03cf5ce1affd31d401eaacc047322e2db1..f2244ea69c40ba53822ecd7b7f49e77580368050 100644 (file)
@@ -13,6 +13,7 @@ import (
 )
 
 func TestGoLookupIP(t *testing.T) {
+       defer dnsWaitGroup.Wait()
        host := "localhost"
        ctx := context.Background()
        _, err, ok := cgoLookupIP(ctx, host)