]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix netFD.Close races
authorDevon H. O'Dell <devon.odell@gmail.com>
Wed, 2 Dec 2009 07:28:57 +0000 (23:28 -0800)
committerRuss Cox <rsc@golang.org>
Wed, 2 Dec 2009 07:28:57 +0000 (23:28 -0800)
Fixes #271.
Fixes #321.

R=rsc, agl, cw
CC=golang-dev
https://golang.org/cl/163052

src/pkg/net/fd.go
src/pkg/net/sock.go
src/pkg/net/tcpsock.go
src/pkg/net/udpsock.go
src/pkg/net/unixsock.go

index f134b0f78ab46e5f04155fa4d7957e3dc6701464..e1592eb2691f3faed858ae00474ad667e1715423 100644 (file)
@@ -15,14 +15,18 @@ import (
 
 // Network file descriptor.
 type netFD struct {
+       // locking/lifetime of sysfd
+       sysmu   sync.Mutex;
+       sysref  int;
+       closing bool;
+
        // immutable until Close
-       fd      int;
+       sysfd   int;
        family  int;
        proto   int;
-       file    *os.File;
+       sysfile *os.File;
        cr      chan *netFD;
        cw      chan *netFD;
-       cc      chan *netFD;
        net     string;
        laddr   Addr;
        raddr   Addr;
@@ -68,13 +72,13 @@ type netFD struct {
 // channel will be empty for the next process's request.  A larger buffer
 // might help batch requests.
 //
-// In order to prevent race conditions, pollServer has an additional cc channel
-// that receives fds to be closed. pollServer doesn't make the close system
-// call, it just sets fd.file = nil and fd.fd = -1. Because of this, pollServer
-// is always in sync with the kernel's view of a given descriptor.
+// To avoid races in closing, all fd operations are locked and
+// refcounted. when netFD.Close() is called, it calls syscall.Shutdown
+// and sets a closing flag. Only when the last reference is removed
+// will the fd be closed.
 
 type pollServer struct {
-       cr, cw, cc      chan *netFD;    // buffered >= 1
+       cr, cw          chan *netFD;    // buffered >= 1
        pr, pw          *os.File;
        pending         map[int]*netFD;
        poll            *pollster;      // low-level OS hooks
@@ -85,7 +89,6 @@ func newPollServer() (s *pollServer, err os.Error) {
        s = new(pollServer);
        s.cr = make(chan *netFD, 1);
        s.cw = make(chan *netFD, 1);
-       s.cc = make(chan *netFD, 1);
        if s.pr, s.pw, err = os.Pipe(); err != nil {
                return nil, err
        }
@@ -114,16 +117,7 @@ func newPollServer() (s *pollServer, err os.Error) {
 }
 
 func (s *pollServer) AddFD(fd *netFD, mode int) {
-       // This check verifies that the underlying file descriptor hasn't been
-       // closed in the mean time. Any time a netFD is closed, the closing
-       // goroutine makes a round trip to the pollServer which sets file = nil
-       // and fd = -1. The goroutine then closes the actual file descriptor.
-       // Thus fd.fd mirrors the kernel's view of the file descriptor.
-
-       // TODO(rsc,agl): There is still a race in Read and Write,
-       // because they optimistically try to use the fd and don't
-       // call into the PollServer unless they get EAGAIN.
-       intfd := fd.fd;
+       intfd := fd.sysfd;
        if intfd < 0 {
                // fd closed underfoot
                if mode == 'r' {
@@ -213,10 +207,10 @@ func (s *pollServer) CheckDeadlines() {
                        if t <= now {
                                s.pending[key] = nil, false;
                                if mode == 'r' {
-                                       s.poll.DelFD(fd.fd, mode);
+                                       s.poll.DelFD(fd.sysfd, mode);
                                        fd.rdeadline = -1;
                                } else {
-                                       s.poll.DelFD(fd.fd, mode);
+                                       s.poll.DelFD(fd.sysfd, mode);
                                        fd.wdeadline = -1;
                                }
                                s.WakeFD(fd, mode);
@@ -254,7 +248,6 @@ func (s *pollServer) Run() {
                        for nn, _ := s.pr.Read(&scratch); nn > 0; {
                                nn, _ = s.pr.Read(&scratch)
                        }
-
                        // Read from channels
                        for fd, ok := <-s.cr; ok; fd, ok = <-s.cr {
                                s.AddFD(fd, 'r')
@@ -262,11 +255,6 @@ func (s *pollServer) Run() {
                        for fd, ok := <-s.cw; ok; fd, ok = <-s.cw {
                                s.AddFD(fd, 'w')
                        }
-                       for fd, ok := <-s.cc; ok; fd, ok = <-s.cc {
-                               fd.file = nil;
-                               fd.fd = -1;
-                               fd.cc <- fd;
-                       }
                } else {
                        netfd := s.LookupFD(fd, mode);
                        if netfd == nil {
@@ -294,12 +282,6 @@ func (s *pollServer) WaitWrite(fd *netFD) {
        <-fd.cw;
 }
 
-func (s *pollServer) WaitCloseAck(fd *netFD) {
-       s.cc <- fd;
-       s.Wakeup();
-       <-fd.cc;
-}
-
 // Network FD methods.
 // All the network FDs use a single pollServer.
 
@@ -319,7 +301,7 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err
                return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)}
        }
        f = &netFD{
-               fd: fd,
+               sysfd: fd,
                family: family,
                proto: proto,
                net: net,
@@ -333,13 +315,37 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err
        if raddr != nil {
                rs = raddr.String()
        }
-       f.file = os.NewFile(fd, net+":"+ls+"->"+rs);
+       f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs);
        f.cr = make(chan *netFD, 1);
        f.cw = make(chan *netFD, 1);
-       f.cc = make(chan *netFD, 1);
        return f, nil;
 }
 
+// Add a reference to this fd.
+func (fd *netFD) incref() {
+       fd.sysmu.Lock();
+       fd.sysref++;
+       fd.sysmu.Unlock();
+}
+
+// Remove a reference to this FD and close if we've been asked to do so (and
+// there are no references left.
+func (fd *netFD) decref() {
+       fd.sysmu.Lock();
+       fd.sysref--;
+       if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 {
+               // In case the user has set linger, switch to blocking mode so
+               // the close blocks.  As long as this doesn't happen often, we
+               // can handle the extra OS processes.  Otherwise we'll need to
+               // use the pollserver for Close too.  Sigh.
+               syscall.SetNonblock(fd.sysfd, false);
+               fd.sysfile.Close();
+               fd.sysfile = nil;
+               fd.sysfd = -1;
+       }
+       fd.sysmu.Unlock();
+}
+
 func isEAGAIN(e os.Error) bool {
        if e1, ok := e.(*os.PathError); ok {
                return e1.Error == os.EAGAIN
@@ -348,36 +354,32 @@ func isEAGAIN(e os.Error) bool {
 }
 
 func (fd *netFD) Close() os.Error {
-       if fd == nil || fd.file == nil {
+       if fd == nil || fd.sysfile == nil {
                return os.EINVAL
        }
 
-       // In case the user has set linger,
-       // switch to blocking mode so the close blocks.
-       // As long as this doesn't happen often,
-       // we can handle the extra OS processes.
-       // Otherwise we'll need to use the pollserver
-       // for Close too.  Sigh.
-       syscall.SetNonblock(fd.file.Fd(), false);
-
-       f := fd.file;
-       pollserver.WaitCloseAck(fd);
-       return f.Close();
+       fd.incref();
+       syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR);
+       fd.closing = true;
+       fd.decref();
+       return nil;
 }
 
 func (fd *netFD) Read(p []byte) (n int, err os.Error) {
-       if fd == nil || fd.file == nil {
+       if fd == nil || fd.sysfile == nil {
                return 0, os.EINVAL
        }
        fd.rio.Lock();
        defer fd.rio.Unlock();
+       fd.incref();
+       defer fd.decref();
        if fd.rdeadline_delta > 0 {
                fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
        } else {
                fd.rdeadline = 0
        }
        for {
-               n, err = fd.file.Read(p);
+               n, err = fd.sysfile.Read(p);
                if isEAGAIN(err) && fd.rdeadline >= 0 {
                        pollserver.WaitRead(fd);
                        continue;
@@ -388,11 +390,13 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) {
 }
 
 func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
-       if fd == nil || fd.file == nil {
+       if fd == nil || fd.sysfile == nil {
                return 0, nil, os.EINVAL
        }
        fd.rio.Lock();
        defer fd.rio.Unlock();
+       fd.incref();
+       defer fd.decref();
        if fd.rdeadline_delta > 0 {
                fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
        } else {
@@ -400,14 +404,14 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
        }
        for {
                var errno int;
-               n, sa, errno = syscall.Recvfrom(fd.fd, p, 0);
+               n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0);
                if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
                        pollserver.WaitRead(fd);
                        continue;
                }
                if errno != 0 {
                        n = 0;
-                       err = &os.PathError{"recvfrom", fd.file.Name(), os.Errno(errno)};
+                       err = &os.PathError{"recvfrom", fd.sysfile.Name(), os.Errno(errno)};
                }
                break;
        }
@@ -415,11 +419,13 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
 }
 
 func (fd *netFD) Write(p []byte) (n int, err os.Error) {
-       if fd == nil || fd.file == nil {
+       if fd == nil || fd.sysfile == nil {
                return 0, os.EINVAL
        }
        fd.wio.Lock();
        defer fd.wio.Unlock();
+       fd.incref();
+       defer fd.decref();
        if fd.wdeadline_delta > 0 {
                fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
        } else {
@@ -428,7 +434,7 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
        err = nil;
        nn := 0;
        for nn < len(p) {
-               n, err = fd.file.Write(p[nn:]);
+               n, err = fd.sysfile.Write(p[nn:]);
                if n > 0 {
                        nn += n
                }
@@ -447,11 +453,13 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
 }
 
 func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
-       if fd == nil || fd.file == nil {
+       if fd == nil || fd.sysfile == nil {
                return 0, os.EINVAL
        }
        fd.wio.Lock();
        defer fd.wio.Unlock();
+       fd.incref();
+       defer fd.decref();
        if fd.wdeadline_delta > 0 {
                fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
        } else {
@@ -459,13 +467,13 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
        }
        err = nil;
        for {
-               errno := syscall.Sendto(fd.fd, p, 0, sa);
+               errno := syscall.Sendto(fd.sysfd, p, 0, sa);
                if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
                        pollserver.WaitWrite(fd);
                        continue;
                }
                if errno != 0 {
-                       err = &os.PathError{"sendto", fd.file.Name(), os.Errno(errno)}
+                       err = &os.PathError{"sendto", fd.sysfile.Name(), os.Errno(errno)}
                }
                break;
        }
@@ -476,18 +484,21 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
 }
 
 func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
-       if fd == nil || fd.file == nil {
+       if fd == nil || fd.sysfile == nil {
                return nil, os.EINVAL
        }
 
+       fd.incref();
+       defer fd.decref();
+
        // See ../syscall/exec.go for description of ForkLock.
        // It is okay to hold the lock across syscall.Accept
-       // because we have put fd.fd into non-blocking mode.
+       // because we have put fd.sysfd into non-blocking mode.
        syscall.ForkLock.RLock();
        var s, e int;
        var sa syscall.Sockaddr;
        for {
-               s, sa, e = syscall.Accept(fd.fd);
+               s, sa, e = syscall.Accept(fd.sysfd);
                if e != syscall.EAGAIN {
                        break
                }
index c670aa21e7e1b53dd46f47c48a595b8dc40f33ba..336c968664297cc5b3efd8d9bd39b1aeae2649a8 100644 (file)
@@ -75,11 +75,15 @@ func setsockoptNsec(fd, level, opt int, nsec int64) os.Error {
 }
 
 func setReadBuffer(fd *netFD, bytes int) os.Error {
-       return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
+       fd.incref();
+       defer fd.decref();
+       return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes);
 }
 
 func setWriteBuffer(fd *netFD, bytes int) os.Error {
-       return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
+       fd.incref();
+       defer fd.decref();
+       return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes);
 }
 
 func setReadTimeout(fd *netFD, nsec int64) os.Error {
@@ -100,7 +104,9 @@ func setTimeout(fd *netFD, nsec int64) os.Error {
 }
 
 func setReuseAddr(fd *netFD, reuse bool) os.Error {
-       return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))
+       fd.incref();
+       defer fd.decref();
+       return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse));
 }
 
 func bindToDevice(fd *netFD, dev string) os.Error {
@@ -109,11 +115,15 @@ func bindToDevice(fd *netFD, dev string) os.Error {
 }
 
 func setDontRoute(fd *netFD, dontroute bool) os.Error {
-       return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))
+       fd.incref();
+       defer fd.decref();
+       return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute));
 }
 
 func setKeepAlive(fd *netFD, keepalive bool) os.Error {
-       return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))
+       fd.incref();
+       defer fd.decref();
+       return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive));
 }
 
 func setLinger(fd *netFD, sec int) os.Error {
@@ -125,7 +135,9 @@ func setLinger(fd *netFD, sec int) os.Error {
                l.Onoff = 0;
                l.Linger = 0;
        }
-       e := syscall.SetsockoptLinger(fd.fd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l);
+       fd.incref();
+       defer fd.decref();
+       e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l);
        return os.NewSyscallError("setsockopt", e);
 }
 
index 2633196266c34c6eb0eb7660ecea32201b2c2415..680ed30213a1549f3abd23da51c4d753661713bf 100644 (file)
@@ -73,7 +73,7 @@ type TCPConn struct {
 
 func newTCPConn(fd *netFD) *TCPConn {
        c := &TCPConn{fd};
-       setsockoptInt(fd.fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1);
+       setsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1);
        return c;
 }
 
@@ -234,9 +234,9 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
        if err != nil {
                return nil, err
        }
-       errno := syscall.Listen(fd.fd, listenBacklog());
+       errno := syscall.Listen(fd.sysfd, listenBacklog());
        if errno != 0 {
-               syscall.Close(fd.fd);
+               syscall.Close(fd.sysfd);
                return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)};
        }
        l = new(TCPListener);
@@ -247,7 +247,7 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
 // AcceptTCP accepts the next incoming call and returns the new connection
 // and the remote address.
 func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) {
-       if l == nil || l.fd == nil || l.fd.fd < 0 {
+       if l == nil || l.fd == nil || l.fd.sysfd < 0 {
                return nil, os.EINVAL
        }
        fd, err := l.fd.accept(sockaddrToTCP);
index a8b8ba3c922b2d0854f5e0ff76532c749f894e54..d74a380788379176124b11e1e52427d7e5a7a67e 100644 (file)
@@ -73,7 +73,7 @@ type UDPConn struct {
 
 func newUDPConn(fd *netFD) *UDPConn {
        c := &UDPConn{fd};
-       setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1);
+       setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1);
        return c;
 }
 
index 6572e85dc713d55ed7e55cc5910934ff3655c054..4ac3be54aa9b9cd7426899341b742bb0cbbe211d 100644 (file)
@@ -332,9 +332,9 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
                }
                return nil, e;
        }
-       e1 := syscall.Listen(fd.fd, 8); // listenBacklog());
+       e1 := syscall.Listen(fd.sysfd, 8);      // listenBacklog());
        if e1 != 0 {
-               syscall.Close(fd.fd);
+               syscall.Close(fd.sysfd);
                return nil, &OpError{"listen", "unix", laddr, os.Errno(e1)};
        }
        return &UnixListener{fd, laddr.Name}, nil;
@@ -343,7 +343,7 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
 // AcceptUnix accepts the next incoming call and returns the new connection
 // and the remote address.
 func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) {
-       if l == nil || l.fd == nil || l.fd.fd < 0 {
+       if l == nil || l.fd == nil {
                return nil, os.EINVAL
        }
        fd, e := l.fd.accept(sockaddrToUnix);