]> Cypherpunks repositories - gostls13.git/commitdiff
net: prevent cancelation goroutine from adjusting fd timeout after connect
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 28 Jul 2016 11:42:11 +0000 (13:42 +0200)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 2 Aug 2016 00:55:45 +0000 (00:55 +0000)
This was previously fixed in https://golang.org/cl/21497 but not enough.

Fixes #16523

Change-Id: I678543a656304c82d654e25e12fb094cd6cc87e8
Reviewed-on: https://go-review.googlesource.com/25330
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/dial_unix_test.go [new file with mode: 0644]
src/net/fd_unix.go
src/net/hook_unix.go

diff --git a/src/net/dial_unix_test.go b/src/net/dial_unix_test.go
new file mode 100644 (file)
index 0000000..4705254
--- /dev/null
@@ -0,0 +1,108 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package net
+
+import (
+       "context"
+       "syscall"
+       "testing"
+       "time"
+)
+
+// Issue 16523
+func TestDialContextCancelRace(t *testing.T) {
+       oldConnectFunc := connectFunc
+       oldGetsockoptIntFunc := getsockoptIntFunc
+       oldTestHookCanceledDial := testHookCanceledDial
+       defer func() {
+               connectFunc = oldConnectFunc
+               getsockoptIntFunc = oldGetsockoptIntFunc
+               testHookCanceledDial = oldTestHookCanceledDial
+       }()
+
+       ln, err := newLocalListener("tcp")
+       if err != nil {
+               t.Fatal(err)
+       }
+       listenerDone := make(chan struct{})
+       go func() {
+               defer close(listenerDone)
+               c, err := ln.Accept()
+               if err == nil {
+                       c.Close()
+               }
+       }()
+       defer func() { <-listenerDone }()
+       defer ln.Close()
+
+       sawCancel := make(chan bool, 1)
+       testHookCanceledDial = func() {
+               sawCancel <- true
+       }
+
+       ctx, cancelCtx := context.WithCancel(context.Background())
+
+       connectFunc = func(fd int, addr syscall.Sockaddr) error {
+               err := oldConnectFunc(fd, addr)
+               t.Logf("connect(%d, addr) = %v", fd, err)
+               if err == nil {
+                       // On some operating systems, localhost
+                       // connects _sometimes_ succeed immediately.
+                       // Prevent that, so we exercise the code path
+                       // we're interested in testing. This seems
+                       // harmless. It makes FreeBSD 10.10 work when
+                       // run with many iterations. It failed about
+                       // half the time previously.
+                       return syscall.EINPROGRESS
+               }
+               return err
+       }
+
+       getsockoptIntFunc = func(fd, level, opt int) (val int, err error) {
+               val, err = oldGetsockoptIntFunc(fd, level, opt)
+               t.Logf("getsockoptIntFunc(%d, %d, %d) = (%v, %v)", fd, level, opt, val, err)
+               if level == syscall.SOL_SOCKET && opt == syscall.SO_ERROR && err == nil && val == 0 {
+                       t.Logf("canceling context")
+
+                       // Cancel the context at just the moment which
+                       // caused the race in issue 16523.
+                       cancelCtx()
+
+                       // And wait for the "interrupter" goroutine to
+                       // cancel the dial by messing with its write
+                       // timeout before returning.
+                       select {
+                       case <-sawCancel:
+                               t.Logf("saw cancel")
+                       case <-time.After(5 * time.Second):
+                               t.Errorf("didn't see cancel after 5 seconds")
+                       }
+               }
+               return
+       }
+
+       var d Dialer
+       c, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+       if err == nil {
+               c.Close()
+               t.Fatal("unexpected successful dial; want context canceled error")
+       }
+
+       select {
+       case <-ctx.Done():
+       case <-time.After(5 * time.Second):
+               t.Fatal("expected context to be canceled")
+       }
+
+       oe, ok := err.(*OpError)
+       if !ok || oe.Op != "dial" {
+               t.Fatalf("Dial error = %#v; want dial *OpError", err)
+       }
+       if oe.Err != ctx.Err() {
+               t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, ctx.Err())
+       }
+}
index 0f80bc79ac270eb0cddadefe1f570fe526c4dacc..11dde76977e40e2f52d476cd3ce19d9465b27224 100644 (file)
@@ -64,7 +64,7 @@ func (fd *netFD) name() string {
        return fd.net + ":" + ls + "->" + rs
 }
 
-func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret error) {
        // Do not need to call fd.writeLock here,
        // because fd is not yet accessible to user,
        // so no concurrent operations are possible.
@@ -101,21 +101,44 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
                defer fd.setWriteDeadline(noDeadline)
        }
 
-       // Wait for the goroutine converting context.Done into a write timeout
-       // to exist, otherwise our caller might cancel the context and
-       // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
-       done := make(chan bool) // must be unbuffered
-       defer func() { done <- true }()
-       go func() {
-               select {
-               case <-ctx.Done():
-                       // Force the runtime's poller to immediately give
-                       // up waiting for writability.
-                       fd.setWriteDeadline(aLongTimeAgo)
-                       <-done
-               case <-done:
-               }
-       }()
+       // Start the "interrupter" goroutine, if this context might be canceled.
+       // (The background context cannot)
+       //
+       // The interrupter goroutine waits for the context to be done and
+       // interrupts the dial (by altering the fd's write deadline, which
+       // wakes up waitWrite).
+       if ctx != context.Background() {
+               // Wait for the interrupter goroutine to exit before returning
+               // from connect.
+               done := make(chan struct{})
+               interruptRes := make(chan error)
+               defer func() {
+                       close(done)
+                       if ctxErr := <-interruptRes; ctxErr != nil && ret == nil {
+                               // The interrupter goroutine called setWriteDeadline,
+                               // but the connect code below had returned from
+                               // waitWrite already and did a successful connect (ret
+                               // == nil). Because we've now poisoned the connection
+                               // by making it unwritable, don't return a successful
+                               // dial. This was issue 16523.
+                               ret = ctxErr
+                               fd.Close() // prevent a leak
+                       }
+               }()
+               go func() {
+                       select {
+                       case <-ctx.Done():
+                               // Force the runtime's poller to immediately give up
+                               // waiting for writability, unblocking waitWrite
+                               // below.
+                               fd.setWriteDeadline(aLongTimeAgo)
+                               testHookCanceledDial()
+                               interruptRes <- ctx.Err()
+                       case <-done:
+                               interruptRes <- nil
+                       }
+               }()
+       }
 
        for {
                // Performing multiple connect system calls on a
index 361ca5980c38b24abc526e8a346fa6c2c0114efa..cf52567fcfdad88ce65d3f5d39650bdb4e7a6610 100644 (file)
@@ -9,7 +9,8 @@ package net
 import "syscall"
 
 var (
-       testHookDialChannel = func() {} // see golang.org/issue/5349
+       testHookDialChannel  = func() {} // for golang.org/issue/5349
+       testHookCanceledDial = func() {} // for golang.org/issue/16523
 
        // Placeholders for socket system calls.
        socketFunc        func(int, int, int) (int, error)         = syscall.Socket