From 432158b69a50e292b625d08dcfacd0604acbabd3 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 7 Jun 2022 16:53:53 -0700 Subject: [PATCH] net: fix testHookDialTCP race CL 410754 introduces a race accessing the global testHookDialTCP hook. Avoiding this race is difficult, since Dial can return while goroutines it starts are still running. Add a version of this hook to sysDialer, so it can be set on a per-test basis. (Perhaps other uses of this hook should be moved to use the sysDialer-local hook, but this change fixes the immediate data race.) For #52173. Change-Id: I8fb9be13957e91f92919cae7be213c38ad2af75a Reviewed-on: https://go-review.googlesource.com/c/go/+/410957 Run-TryBot: Damien Neil Reviewed-by: Ian Lance Taylor Reviewed-by: Cherry Mui TryBot-Result: Gopher Robot --- src/net/dial.go | 1 + src/net/dial_test.go | 11 +++++------ src/net/tcpsock_plan9.go | 7 +++++-- src/net/tcpsock_posix.go | 7 +++++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/net/dial.go b/src/net/dial.go index b24bd2f5f4..c538342566 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -341,6 +341,7 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { type sysDialer struct { Dialer network, address string + testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) } // Dial connects to the address on the named network. diff --git a/src/net/dial_test.go b/src/net/dial_test.go index 0550acb01d..e49b4a61d6 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -234,9 +234,7 @@ func TestDialParallel(t *testing.T) { for i, tt := range testCases { i, tt := i, tt t.Run(fmt.Sprint(i), func(t *testing.T) { - origTestHookDialTCP := testHookDialTCP - defer func() { testHookDialTCP = origTestHookDialTCP }() - testHookDialTCP = func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) { + dialTCP := func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) { n := "tcp6" if raddr.IP.To4() != nil { n = "tcp4" @@ -262,9 +260,10 @@ func TestDialParallel(t *testing.T) { } startTime := time.Now() sd := &sysDialer{ - Dialer: d, - network: "tcp", - address: "?", + Dialer: d, + network: "tcp", + address: "?", + testHookDialTCP: dialTCP, } c, err := sd.dialParallel(context.Background(), primaries, fallbacks) elapsed := time.Since(startTime) diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index 768d03b06c..435335e92e 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -15,8 +15,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { } func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { - if testHookDialTCP != nil { - return testHookDialTCP(ctx, sd.network, laddr, raddr) + if h := sd.testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) + } + if h := testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) } return sd.doDialTCP(ctx, laddr, raddr) } diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index bc3d324e6b..1c91170c50 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -55,8 +55,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { } func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { - if testHookDialTCP != nil { - return testHookDialTCP(ctx, sd.network, laddr, raddr) + if h := sd.testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) + } + if h := testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) } return sd.doDialTCP(ctx, laddr, raddr) } -- 2.48.1