]> Cypherpunks repositories - gostls13.git/commitdiff
net: implement windows timeout
authorWei Guangjing <vcc.163@gmail.com>
Wed, 19 Jan 2011 19:49:25 +0000 (14:49 -0500)
committerRuss Cox <rsc@golang.org>
Wed, 19 Jan 2011 19:49:25 +0000 (14:49 -0500)
R=brainman, rsc
CC=golang-dev
https://golang.org/cl/1731047

src/pkg/net/fd_windows.go
src/pkg/net/timeout_test.go
src/pkg/syscall/syscall_windows.go
src/pkg/syscall/zsyscall_windows_386.go

index 72685d612a17e27d9bafbd3961f3072aa8c75662..f3e5761c874049e2dd054256765603f61fd01f10 100644 (file)
@@ -6,14 +6,14 @@ package net
 
 import (
        "os"
+       "runtime"
        "sync"
        "syscall"
+       "time"
        "unsafe"
        "runtime"
 )
 
-// BUG(brainman): The Windows implementation does not implement SetTimeout.
-
 // IO completion result parameters.
 type ioResult struct {
        key   uint32
@@ -79,6 +79,8 @@ type ioPacket struct {
 
        // Link to the io owner.
        c chan *ioResult
+
+       w *syscall.WSABuf
 }
 
 func (s *pollServer) getCompletedIO() (ov *syscall.Overlapped, result *ioResult, err os.Error) {
@@ -126,6 +128,8 @@ func startServer() {
                panic("Start pollServer: " + err.String() + "\n")
        }
        pollserver = p
+
+       go timeoutIO()
 }
 
 var initErr os.Error
@@ -143,8 +147,8 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err
                sysfd:  fd,
                family: family,
                proto:  proto,
-               cr:     make(chan *ioResult),
-               cw:     make(chan *ioResult),
+               cr:     make(chan *ioResult, 1),
+               cw:     make(chan *ioResult, 1),
                net:    net,
                laddr:  laddr,
                raddr:  raddr,
@@ -199,6 +203,80 @@ func newWSABuf(p []byte) *syscall.WSABuf {
        return &syscall.WSABuf{uint32(len(p)), p0}
 }
 
+func waitPacket(fd *netFD, pckt *ioPacket, mode int) (r *ioResult) {
+       var delta int64
+       if mode == 'r' {
+               delta = fd.rdeadline_delta
+       }
+       if mode == 'w' {
+               delta = fd.wdeadline_delta
+       }
+       if delta <= 0 {
+               return <-pckt.c
+       }
+
+       select {
+       case r = <-pckt.c:
+       case <-time.After(delta):
+               a := &arg{f: cancel, fd: fd, pckt: pckt, c: make(chan int)}
+               ioChan <- a
+               <-a.c
+               r = <-pckt.c
+               if r.errno == 995 { // IO Canceled
+                       r.errno = syscall.EWOULDBLOCK
+               }
+       }
+       return r
+}
+
+const (
+       read = iota
+       readfrom
+       write
+       writeto
+       cancel
+)
+
+type arg struct {
+       f     int
+       fd    *netFD
+       pckt  *ioPacket
+       done  *uint32
+       flags *uint32
+       rsa   *syscall.RawSockaddrAny
+       size  *int32
+       sa    *syscall.Sockaddr
+       c     chan int
+}
+
+var ioChan chan *arg = make(chan *arg)
+
+func timeoutIO() {
+       // CancelIO only cancels all pending input and output (I/O) operations that are
+       // issued by the calling thread for the specified file, does not cancel I/O
+       // operations that other threads issue for a file handle. So we need do all timeout
+       // I/O in single OS thread.
+       runtime.LockOSThread()
+       defer runtime.UnlockOSThread()
+       for {
+               o := <-ioChan
+               var e int
+               switch o.f {
+               case read:
+                       e = syscall.WSARecv(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, o.flags, &o.pckt.o, nil)
+               case readfrom:
+                       e = syscall.WSARecvFrom(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, o.flags, o.rsa, o.size, &o.pckt.o, nil)
+               case write:
+                       e = syscall.WSASend(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, uint32(0), &o.pckt.o, nil)
+               case writeto:
+                       e = syscall.WSASendto(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, 0, *o.sa, &o.pckt.o, nil)
+               case cancel:
+                       _, e = syscall.CancelIo(uint32(o.fd.sysfd))
+               }
+               o.c <- e
+       }
+}
+
 func (fd *netFD) Read(p []byte) (n int, err os.Error) {
        if fd == nil {
                return 0, os.EINVAL
@@ -213,9 +291,17 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) {
        // Submit receive request.
        var pckt ioPacket
        pckt.c = fd.cr
+       pckt.w = newWSABuf(p)
        var done uint32
        flags := uint32(0)
-       e := syscall.WSARecv(uint32(fd.sysfd), newWSABuf(p), 1, &done, &flags, &pckt.o, nil)
+       var e int
+       if fd.rdeadline_delta > 0 {
+               a := &arg{f: read, fd: fd, pckt: &pckt, done: &done, flags: &flags, c: make(chan int)}
+               ioChan <- a
+               e = <-a.c
+       } else {
+               e = syscall.WSARecv(uint32(fd.sysfd), pckt.w, 1, &done, &flags, &pckt.o, nil)
+       }
        switch e {
        case 0:
                // IO completed immediately, but we need to get our completion message anyway.
@@ -225,7 +311,7 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) {
                return 0, &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(e)}
        }
        // Wait for our request to complete.
-       r := <-pckt.c
+       r := waitPacket(fd, &pckt, 'r')
        if r.errno != 0 {
                err = &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(r.errno)}
        }
@@ -253,11 +339,19 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
        // Submit receive request.
        var pckt ioPacket
        pckt.c = fd.cr
+       pckt.w = newWSABuf(p)
        var done uint32
        flags := uint32(0)
        var rsa syscall.RawSockaddrAny
        l := int32(unsafe.Sizeof(rsa))
-       e := syscall.WSARecvFrom(uint32(fd.sysfd), newWSABuf(p), 1, &done, &flags, &rsa, &l, &pckt.o, nil)
+       var e int
+       if fd.rdeadline_delta > 0 {
+               a := &arg{f: readfrom, fd: fd, pckt: &pckt, done: &done, flags: &flags, rsa: &rsa, size: &l, c: make(chan int)}
+               ioChan <- a
+               e = <-a.c
+       } else {
+               e = syscall.WSARecvFrom(uint32(fd.sysfd), pckt.w, 1, &done, &flags, &rsa, &l, &pckt.o, nil)
+       }
        switch e {
        case 0:
                // IO completed immediately, but we need to get our completion message anyway.
@@ -267,7 +361,7 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
                return 0, nil, &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(e)}
        }
        // Wait for our request to complete.
-       r := <-pckt.c
+       r := waitPacket(fd, &pckt, 'r')
        if r.errno != 0 {
                err = &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(r.errno)}
        }
@@ -290,8 +384,16 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
        // Submit send request.
        var pckt ioPacket
        pckt.c = fd.cw
+       pckt.w = newWSABuf(p)
        var done uint32
-       e := syscall.WSASend(uint32(fd.sysfd), newWSABuf(p), 1, &done, uint32(0), &pckt.o, nil)
+       var e int
+       if fd.wdeadline_delta > 0 {
+               a := &arg{f: write, fd: fd, pckt: &pckt, done: &done, c: make(chan int)}
+               ioChan <- a
+               e = <-a.c
+       } else {
+               e = syscall.WSASend(uint32(fd.sysfd), pckt.w, 1, &done, uint32(0), &pckt.o, nil)
+       }
        switch e {
        case 0:
                // IO completed immediately, but we need to get our completion message anyway.
@@ -301,7 +403,7 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
                return 0, &OpError{"WSASend", fd.net, fd.laddr, os.Errno(e)}
        }
        // Wait for our request to complete.
-       r := <-pckt.c
+       r := waitPacket(fd, &pckt, 'w')
        if r.errno != 0 {
                err = &OpError{"WSASend", fd.net, fd.laddr, os.Errno(r.errno)}
        }
@@ -326,8 +428,16 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
        // Submit send request.
        var pckt ioPacket
        pckt.c = fd.cw
+       pckt.w = newWSABuf(p)
        var done uint32
-       e := syscall.WSASendto(uint32(fd.sysfd), newWSABuf(p), 1, &done, 0, sa, &pckt.o, nil)
+       var e int
+       if fd.wdeadline_delta > 0 {
+               a := &arg{f: writeto, fd: fd, pckt: &pckt, done: &done, sa: &sa, c: make(chan int)}
+               ioChan <- a
+               e = <-a.c
+       } else {
+               e = syscall.WSASendto(uint32(fd.sysfd), pckt.w, 1, &done, 0, sa, &pckt.o, nil)
+       }
        switch e {
        case 0:
                // IO completed immediately, but we need to get our completion message anyway.
@@ -337,7 +447,7 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
                return 0, &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(e)}
        }
        // Wait for our request to complete.
-       r := <-pckt.c
+       r := waitPacket(fd, &pckt, 'w')
        if r.errno != 0 {
                err = &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(r.errno)}
        }
@@ -410,8 +520,8 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.
                sysfd:  s,
                family: fd.family,
                proto:  fd.proto,
-               cr:     make(chan *ioResult),
-               cw:     make(chan *ioResult),
+               cr:     make(chan *ioResult, 1),
+               cw:     make(chan *ioResult, 1),
                net:    fd.net,
                laddr:  laddr,
                raddr:  raddr,
index 092781685e1bb2706c099047d13f0240a0029b35..3594c0a350fff159f95fc9cb3d059a6064cd703d 100644 (file)
@@ -8,14 +8,9 @@ import (
        "os"
        "testing"
        "time"
-       "runtime"
 )
 
 func testTimeout(t *testing.T, network, addr string, readFrom bool) {
-       // Timeouts are not implemented on windows.
-       if runtime.GOOS == "windows" {
-               return
-       }
        fd, err := Dial(network, "", addr)
        if err != nil {
                t.Errorf("dial %s %s failed: %v", network, addr, err)
index 06dde518fdc5e6f9e91cb868967a64f976ac43ec..d3d22dba809ba1082fa70762f6cd0211ff51e4a4 100644 (file)
@@ -126,6 +126,7 @@ func getSysProcAddr(m uint32, pname string) uintptr {
 //sys  GetTimeZoneInformation(tzi *Timezoneinformation) (rc uint32, errno int) [failretval==0xffffffff]
 //sys  CreateIoCompletionPort(filehandle int32, cphandle int32, key uint32, threadcnt uint32) (handle int32, errno int)
 //sys  GetQueuedCompletionStatus(cphandle int32, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) (ok bool, errno int)
+//sys  CancelIo(s uint32) (ok bool, errno int)
 //sys  CreateProcess(appName *int16, commandLine *uint16, procSecurity *int16, threadSecurity *int16, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation)  (ok bool, errno int) = CreateProcessW
 //sys  GetStartupInfo(startupInfo *StartupInfo)  (ok bool, errno int) = GetStartupInfoW
 //sys  GetCurrentProcess() (pseudoHandle int32, errno int)
index 29880f2b28bb644143030d8fdb087493dafa14b3..18e36a0226d18cbfce80f347ee9b2413d460e1d2 100644 (file)
@@ -43,6 +43,7 @@ var (
        procGetTimeZoneInformation     = getSysProcAddr(modkernel32, "GetTimeZoneInformation")
        procCreateIoCompletionPort     = getSysProcAddr(modkernel32, "CreateIoCompletionPort")
        procGetQueuedCompletionStatus  = getSysProcAddr(modkernel32, "GetQueuedCompletionStatus")
+       procCancelIo                   = getSysProcAddr(modkernel32, "CancelIo")
        procCreateProcessW             = getSysProcAddr(modkernel32, "CreateProcessW")
        procGetStartupInfoW            = getSysProcAddr(modkernel32, "GetStartupInfoW")
        procGetCurrentProcess          = getSysProcAddr(modkernel32, "GetCurrentProcess")
@@ -512,6 +513,21 @@ func GetQueuedCompletionStatus(cphandle int32, qty *uint32, key *uint32, overlap
        return
 }
 
+func CancelIo(s uint32) (ok bool, errno int) {
+       r0, _, e1 := Syscall(procCancelIo, uintptr(s), 0, 0)
+       ok = bool(r0 != 0)
+       if !ok {
+               if e1 != 0 {
+                       errno = int(e1)
+               } else {
+                       errno = EINVAL
+               }
+       } else {
+               errno = 0
+       }
+       return
+}
+
 func CreateProcess(appName *int16, commandLine *uint16, procSecurity *int16, threadSecurity *int16, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (ok bool, errno int) {
        var _p0 uint32
        if inheritHandles {