]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix accept/connect deadline handling
authorDmitriy Vyukov <dvyukov@google.com>
Thu, 7 Mar 2013 13:03:40 +0000 (17:03 +0400)
committerDmitriy Vyukov <dvyukov@google.com>
Thu, 7 Mar 2013 13:03:40 +0000 (17:03 +0400)
Ensure that accept/connect respect deadline,
even if the operation can be executed w/o blocking.
Note this changes external behavior, but it makes
it consistent with read/write.
Factor out deadline check into pollServer.PrepareRead/Write,
in preparation for edge triggered pollServer.
Ensure that pollServer.WaitRead/Write are not called concurrently
by adding rio/wio locks around connect/accept.

R=golang-dev, mikioh.mikioh, bradfitz, iant
CC=golang-dev
https://golang.org/cl/7436048

src/pkg/net/fd_unix.go
src/pkg/net/timeout_test.go

index 0540df8255b3c3cfbc442f7c38d788d6209c988e..d7a83b6393e825b354d607137d5487a2a5e106d3 100644 (file)
@@ -234,6 +234,20 @@ func (s *pollServer) Run() {
        }
 }
 
+func (s *pollServer) PrepareRead(fd *netFD) error {
+       if fd.rdeadline.expired() {
+               return errTimeout
+       }
+       return nil
+}
+
+func (s *pollServer) PrepareWrite(fd *netFD) error {
+       if fd.wdeadline.expired() {
+               return errTimeout
+       }
+       return nil
+}
+
 func (s *pollServer) WaitRead(fd *netFD) error {
        err := s.AddFD(fd, 'r')
        if err == nil {
@@ -331,6 +345,11 @@ func (fd *netFD) name() string {
 }
 
 func (fd *netFD) connect(ra syscall.Sockaddr) error {
+       fd.wio.Lock()
+       defer fd.wio.Unlock()
+       if err := fd.pollServer.PrepareWrite(fd); err != nil {
+               return err
+       }
        err := syscall.Connect(fd.sysfd, ra)
        if err == syscall.EINPROGRESS {
                if err = fd.pollServer.WaitWrite(fd); err != nil {
@@ -425,11 +444,10 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
                return 0, err
        }
        defer fd.decref()
+       if err := fd.pollServer.PrepareRead(fd); err != nil {
+               return 0, &OpError{"read", fd.net, fd.raddr, err}
+       }
        for {
-               if fd.rdeadline.expired() {
-                       err = errTimeout
-                       break
-               }
                n, err = syscall.Read(int(fd.sysfd), p)
                if err != nil {
                        n = 0
@@ -455,11 +473,10 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
                return 0, nil, err
        }
        defer fd.decref()
+       if err := fd.pollServer.PrepareRead(fd); err != nil {
+               return 0, nil, &OpError{"read", fd.net, fd.laddr, err}
+       }
        for {
-               if fd.rdeadline.expired() {
-                       err = errTimeout
-                       break
-               }
                n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
                if err != nil {
                        n = 0
@@ -485,11 +502,10 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
                return 0, 0, 0, nil, err
        }
        defer fd.decref()
+       if err := fd.pollServer.PrepareRead(fd); err != nil {
+               return 0, 0, 0, nil, &OpError{"read", fd.net, fd.laddr, err}
+       }
        for {
-               if fd.rdeadline.expired() {
-                       err = errTimeout
-                       break
-               }
                n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
                if err != nil {
                        // TODO(dfc) should n and oobn be set to 0
@@ -522,11 +538,10 @@ func (fd *netFD) Write(p []byte) (nn int, err error) {
                return 0, err
        }
        defer fd.decref()
+       if err := fd.pollServer.PrepareWrite(fd); err != nil {
+               return 0, &OpError{"write", fd.net, fd.raddr, err}
+       }
        for {
-               if fd.wdeadline.expired() {
-                       err = errTimeout
-                       break
-               }
                var n int
                n, err = syscall.Write(int(fd.sysfd), p[nn:])
                if n > 0 {
@@ -562,11 +577,10 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
                return 0, err
        }
        defer fd.decref()
+       if err := fd.pollServer.PrepareWrite(fd); err != nil {
+               return 0, &OpError{"write", fd.net, fd.raddr, err}
+       }
        for {
-               if fd.wdeadline.expired() {
-                       err = errTimeout
-                       break
-               }
                err = syscall.Sendto(fd.sysfd, p, 0, sa)
                if err == syscall.EAGAIN {
                        if err = fd.pollServer.WaitWrite(fd); err == nil {
@@ -590,11 +604,10 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
                return 0, 0, err
        }
        defer fd.decref()
+       if err := fd.pollServer.PrepareWrite(fd); err != nil {
+               return 0, 0, &OpError{"write", fd.net, fd.raddr, err}
+       }
        for {
-               if fd.wdeadline.expired() {
-                       err = errTimeout
-                       break
-               }
                err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
                if err == syscall.EAGAIN {
                        if err = fd.pollServer.WaitWrite(fd); err == nil {
@@ -613,6 +626,8 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
 }
 
 func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) {
+       fd.rio.Lock()
+       defer fd.rio.Unlock()
        if err := fd.incref(false); err != nil {
                return nil, err
        }
@@ -620,6 +635,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
 
        var s int
        var rsa syscall.Sockaddr
+       if err = fd.pollServer.PrepareRead(fd); err != nil {
+               return nil, &OpError{"accept", fd.net, fd.laddr, err}
+       }
        for {
                s, rsa, err = accept(fd.sysfd)
                if err != nil {
index 0260efcc0b622df96e02a524e5468d0417ee3f20..2e92147b8e33e923390f92c9701a815d241b37dc 100644 (file)
@@ -532,7 +532,7 @@ func TestReadDeadlineDataAvailable(t *testing.T) {
        defer ln.Close()
 
        servec := make(chan copyRes)
-       const msg = "data client shouldn't read, even though it it'll be waiting"
+       const msg = "data client shouldn't read, even though it'll be waiting"
        go func() {
                c, err := ln.Accept()
                if err != nil {
@@ -596,6 +596,64 @@ func TestWriteDeadlineBufferAvailable(t *testing.T) {
        }
 }
 
+// TestAcceptDeadlineConnectionAvailable tests that accept deadlines work, even
+// if there's incoming connections available.
+func TestAcceptDeadlineConnectionAvailable(t *testing.T) {
+       switch runtime.GOOS {
+       case "plan9":
+               t.Skipf("skipping test on %q", runtime.GOOS)
+       }
+
+       ln := newLocalListener(t).(*TCPListener)
+       defer ln.Close()
+
+       go func() {
+               c, err := Dial("tcp", ln.Addr().String())
+               if err != nil {
+                       t.Fatalf("Dial: %v", err)
+               }
+               defer c.Close()
+               var buf [1]byte
+               c.Read(buf[:]) // block until the connection or listener is closed
+       }()
+       time.Sleep(10 * time.Millisecond)
+       ln.SetDeadline(time.Now().Add(-5 * time.Second)) // in the past
+       c, err := ln.Accept()
+       if err == nil {
+               defer c.Close()
+       }
+       if !isTimeout(err) {
+               t.Fatalf("Accept: got %v; want timeout", err)
+       }
+}
+
+// TestConnectDeadlineInThePast tests that connect deadlines work, even
+// if the connection can be established w/o blocking.
+func TestConnectDeadlineInThePast(t *testing.T) {
+       switch runtime.GOOS {
+       case "plan9":
+               t.Skipf("skipping test on %q", runtime.GOOS)
+       }
+
+       ln := newLocalListener(t).(*TCPListener)
+       defer ln.Close()
+
+       go func() {
+               c, err := ln.Accept()
+               if err == nil {
+                       defer c.Close()
+               }
+       }()
+       time.Sleep(10 * time.Millisecond)
+       c, err := DialTimeout("tcp", ln.Addr().String(), -5*time.Second) // in the past
+       if err == nil {
+               defer c.Close()
+       }
+       if !isTimeout(err) {
+               t.Fatalf("DialTimeout: got %v; want timeout", err)
+       }
+}
+
 // TestProlongTimeout tests concurrent deadline modification.
 // Known to cause data races in the past.
 func TestProlongTimeout(t *testing.T) {