]> Cypherpunks repositories - gostls13.git/commitdiff
net: move deadline logic into pollServer
authorDave Cheney <dave@cheney.net>
Wed, 28 Nov 2012 00:29:25 +0000 (11:29 +1100)
committerDave Cheney <dave@cheney.net>
Wed, 28 Nov 2012 00:29:25 +0000 (11:29 +1100)
Update #4434.

The proposal attempts to reduce the number of places where fd,{r,w}deadline is checked and updated in preparation for issue 4434. In doing so the deadline logic is simplified by letting the pollster return errTimeout from netFD.Wait{Read,Write} as part of the wakeup logic.

The behaviour of setting n = 0 has been restored to match rev 2a55e349097f, which was the previous change to fd_unix.go before CL 6851096.

R=jsing, bradfitz, mikioh.mikioh, rsc
CC=fullung, golang-dev
https://golang.org/cl/6850110

src/pkg/net/fd_unix.go
src/pkg/net/fd_unix_test.go

index 19d3ac9fe0272861d9218b338666259fad40adf0..9326b6278a0e1b2cb9f636946aba95a62ce5f6fb 100644 (file)
@@ -181,12 +181,10 @@ func (s *pollServer) CheckDeadlines() {
                                delete(s.pending, key)
                                if mode == 'r' {
                                        s.poll.DelFD(fd.sysfd, mode)
-                                       fd.rdeadline = -1
                                } else {
                                        s.poll.DelFD(fd.sysfd, mode)
-                                       fd.wdeadline = -1
                                }
-                               s.WakeFD(fd, mode, nil)
+                               s.WakeFD(fd, mode, errTimeout)
                        } else if nextDeadline == 0 || t < nextDeadline {
                                nextDeadline = t
                        }
@@ -329,14 +327,10 @@ func (fd *netFD) name() string {
 
 func (fd *netFD) connect(ra syscall.Sockaddr) error {
        err := syscall.Connect(fd.sysfd, ra)
-       hadTimeout := fd.wdeadline > 0
        if err == syscall.EINPROGRESS {
                if err = fd.pollServer.WaitWrite(fd); err != nil {
                        return err
                }
-               if hadTimeout && fd.wdeadline < 0 {
-                       return errTimeout
-               }
                var e int
                e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
                if err != nil {
@@ -430,20 +424,15 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
                        }
                }
                n, err = syscall.Read(int(fd.sysfd), p)
-               if err == syscall.EAGAIN {
+               if err != nil {
                        n = 0
-                       err = errTimeout
-                       if fd.rdeadline >= 0 {
+                       if err == syscall.EAGAIN {
                                if err = fd.pollServer.WaitRead(fd); err == nil {
                                        continue
                                }
                        }
                }
-               if err != nil {
-                       n = 0
-               } else if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM {
-                       err = io.EOF
-               }
+               err = chkReadErr(n, err, fd)
                break
        }
        if err != nil && err != io.EOF {
@@ -467,18 +456,15 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
                        }
                }
                n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
-               if err == syscall.EAGAIN {
+               if err != nil {
                        n = 0
-                       err = errTimeout
-                       if fd.rdeadline >= 0 {
+                       if err == syscall.EAGAIN {
                                if err = fd.pollServer.WaitRead(fd); err == nil {
                                        continue
                                }
                        }
                }
-               if err != nil {
-                       n = 0
-               }
+               err = chkReadErr(n, err, fd)
                break
        }
        if err != nil && err != io.EOF {
@@ -502,27 +488,30 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
                        }
                }
                n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
-               if err == syscall.EAGAIN {
-                       n = 0
-                       err = errTimeout
-                       if fd.rdeadline >= 0 {
+               if err != nil {
+                       // TODO(dfc) should n and oobn be set to nil
+                       if err == syscall.EAGAIN {
                                if err = fd.pollServer.WaitRead(fd); err == nil {
                                        continue
                                }
                        }
                }
-               if err == nil && n == 0 {
-                       err = io.EOF
-               }
+               err = chkReadErr(n, err, fd)
                break
        }
        if err != nil && err != io.EOF {
                err = &OpError{"read", fd.net, fd.laddr, err}
-               return
        }
        return
 }
 
+func chkReadErr(n int, err error, fd *netFD) error {
+       if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM && fd.sotype != syscall.SOCK_RAW {
+               return io.EOF
+       }
+       return err
+}
+
 func (fd *netFD) Write(p []byte) (int, error) {
        fd.wio.Lock()
        defer fd.wio.Unlock()
@@ -548,11 +537,8 @@ func (fd *netFD) Write(p []byte) (int, error) {
                        break
                }
                if err == syscall.EAGAIN {
-                       err = errTimeout
-                       if fd.wdeadline >= 0 {
-                               if err = fd.pollServer.WaitWrite(fd); err == nil {
-                                       continue
-                               }
+                       if err = fd.pollServer.WaitWrite(fd); err == nil {
+                               continue
                        }
                }
                if err != nil {
@@ -586,11 +572,8 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
                }
                err = syscall.Sendto(fd.sysfd, p, 0, sa)
                if err == syscall.EAGAIN {
-                       err = errTimeout
-                       if fd.wdeadline >= 0 {
-                               if err = fd.pollServer.WaitWrite(fd); err == nil {
-                                       continue
-                               }
+                       if err = fd.pollServer.WaitWrite(fd); err == nil {
+                               continue
                        }
                }
                break
@@ -619,11 +602,8 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
                }
                err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
                if err == syscall.EAGAIN {
-                       err = errTimeout
-                       if fd.wdeadline >= 0 {
-                               if err = fd.pollServer.WaitWrite(fd); err == nil {
-                                       continue
-                               }
+                       if err = fd.pollServer.WaitWrite(fd); err == nil {
+                               continue
                        }
                }
                break
@@ -654,11 +634,8 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
                if err != nil {
                        syscall.ForkLock.RUnlock()
                        if err == syscall.EAGAIN {
-                               err = errTimeout
-                               if fd.rdeadline >= 0 {
-                                       if err = fd.pollServer.WaitRead(fd); err == nil {
-                                               continue
-                                       }
+                               if err = fd.pollServer.WaitRead(fd); err == nil {
+                                       continue
                                }
                        } else if err == syscall.ECONNABORTED {
                                // This means that a socket on the listen queue was closed
index 5e1d2e05c81e4a43e50348d79e868a553a8d9c59..fd1385ef9342d65e3bbb1a71f04b10614cb636f2 100644 (file)
@@ -7,6 +7,8 @@
 package net
 
 import (
+       "io"
+       "syscall"
        "testing"
 )
 
@@ -57,3 +59,48 @@ func TestAddFDReturnsError(t *testing.T) {
        }
        t.Error("unexpected error:", err)
 }
+
+var chkReadErrTests = []struct {
+       n        int
+       err      error
+       fd       *netFD
+       expected error
+}{
+
+       {100, nil, &netFD{sotype: syscall.SOCK_STREAM}, nil},
+       {100, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF},
+       {100, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing},
+       {0, nil, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF},
+       {0, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF},
+       {0, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing},
+
+       {100, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil},
+       {100, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF},
+       {100, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing},
+       {0, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil},
+       {0, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF},
+       {0, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing},
+
+       {100, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, nil},
+       {100, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF},
+       {100, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing},
+       {0, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF},
+       {0, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF},
+       {0, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing},
+
+       {100, nil, &netFD{sotype: syscall.SOCK_RAW}, nil},
+       {100, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF},
+       {100, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing},
+       {0, nil, &netFD{sotype: syscall.SOCK_RAW}, nil},
+       {0, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF},
+       {0, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing},
+}
+
+func TestChkReadErr(t *testing.T) {
+       for _, tt := range chkReadErrTests {
+               actual := chkReadErr(tt.n, tt.err, tt.fd)
+               if actual != tt.expected {
+                       t.Errorf("chkReadError(%v, %v, %v): expected %v, actual %v", tt.n, tt.err, tt.fd.sotype, tt.expected, actual)
+               }
+       }
+}