]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix testHookDialTCP race
authorDamien Neil <dneil@google.com>
Tue, 7 Jun 2022 23:53:53 +0000 (16:53 -0700)
committerDamien Neil <dneil@google.com>
Wed, 8 Jun 2022 17:11:00 +0000 (17:11 +0000)
CL 410754 introduces a race accessing the global testHookDialTCP hook.
Avoiding this race is difficult, since Dial can return while
goroutines it starts are still running. Add a version of this
hook to sysDialer, so it can be set on a per-test basis.

(Perhaps other uses of this hook should be moved to use the
sysDialer-local hook, but this change fixes the immediate data race.)

For #52173.

Change-Id: I8fb9be13957e91f92919cae7be213c38ad2af75a
Reviewed-on: https://go-review.googlesource.com/c/go/+/410957
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/net/dial.go
src/net/dial_test.go
src/net/tcpsock_plan9.go
src/net/tcpsock_posix.go

index b24bd2f5f43e5c53474309ae8ee86118d64afd7f..c5383425666eb51b7911d9bc3f41d25ed414abbf 100644 (file)
@@ -341,6 +341,7 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
 type sysDialer struct {
        Dialer
        network, address string
+       testHookDialTCP  func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
 }
 
 // Dial connects to the address on the named network.
index 0550acb01dbed3c35204a806a9e0ee36e6758d1e..e49b4a61d611f74543e37842d38fc945cf4ef1bd 100644 (file)
@@ -234,9 +234,7 @@ func TestDialParallel(t *testing.T) {
        for i, tt := range testCases {
                i, tt := i, tt
                t.Run(fmt.Sprint(i), func(t *testing.T) {
-                       origTestHookDialTCP := testHookDialTCP
-                       defer func() { testHookDialTCP = origTestHookDialTCP }()
-                       testHookDialTCP = func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+                       dialTCP := func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
                                n := "tcp6"
                                if raddr.IP.To4() != nil {
                                        n = "tcp4"
@@ -262,9 +260,10 @@ func TestDialParallel(t *testing.T) {
                        }
                        startTime := time.Now()
                        sd := &sysDialer{
-                               Dialer:  d,
-                               network: "tcp",
-                               address: "?",
+                               Dialer:          d,
+                               network:         "tcp",
+                               address:         "?",
+                               testHookDialTCP: dialTCP,
                        }
                        c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
                        elapsed := time.Since(startTime)
index 768d03b06cefa0ab96bea21859226fae1023eef6..435335e92e8edb651618e8da043434d765142f1e 100644 (file)
@@ -15,8 +15,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
 }
 
 func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
-       if testHookDialTCP != nil {
-               return testHookDialTCP(ctx, sd.network, laddr, raddr)
+       if h := sd.testHookDialTCP; h != nil {
+               return h(ctx, sd.network, laddr, raddr)
+       }
+       if h := testHookDialTCP; h != nil {
+               return h(ctx, sd.network, laddr, raddr)
        }
        return sd.doDialTCP(ctx, laddr, raddr)
 }
index bc3d324e6ba2d7adff60c580f00c5ee45501354c..1c91170c50091569aa8d15545c75bf007a1b3995 100644 (file)
@@ -55,8 +55,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
 }
 
 func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
-       if testHookDialTCP != nil {
-               return testHookDialTCP(ctx, sd.network, laddr, raddr)
+       if h := sd.testHookDialTCP; h != nil {
+               return h(ctx, sd.network, laddr, raddr)
+       }
+       if h := testHookDialTCP; h != nil {
+               return h(ctx, sd.network, laddr, raddr)
        }
        return sd.doDialTCP(ctx, laddr, raddr)
 }