]> Cypherpunks repositories - gostls13.git/commitdiff
net: implement ReadMsg/WriteMsg on windows
authorAman Gupta <aman@tmm1.net>
Tue, 7 Nov 2017 18:19:10 +0000 (10:19 -0800)
committerAlex Brainman <alex.brainman@gmail.com>
Fri, 10 Nov 2017 05:55:10 +0000 (05:55 +0000)
This means {Read,Write}Msg{UDP,IP} now work on windows.

Fixes #9252

Change-Id: Ifb105f9ad18d61289b22d7358a95faabe73d2d02
Reviewed-on: https://go-review.googlesource.com/76393
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
src/go/build/deps_test.go
src/internal/poll/fd_windows.go
src/internal/syscall/windows/syscall_windows.go
src/net/fd_windows.go
src/net/iprawsock.go
src/net/platform_test.go
src/net/protoconn_test.go
src/net/udpsock.go
src/net/udpsock_test.go

index 0048469ef4c85c591bc53148483d6652f318d2de..66d4157d636f30fc343a1899563f8f14a32842bc 100644 (file)
@@ -154,7 +154,7 @@ var pkgDeps = map[string][]string{
                "syscall",
        },
 
-       "internal/poll": {"L0", "internal/race", "syscall", "time", "unicode/utf16", "unicode/utf8"},
+       "internal/poll": {"L0", "internal/race", "syscall", "time", "unicode/utf16", "unicode/utf8", "internal/syscall/windows"},
        "os":            {"L1", "os", "syscall", "time", "internal/poll", "internal/syscall/windows"},
        "path/filepath": {"L2", "os", "syscall", "internal/syscall/windows"},
        "io/ioutil":     {"L2", "os", "path/filepath", "time"},
index 67a4c506f52135ec2b8203c1eba848fca1de47ea..187908bc8341218fdb747912a243cc7732ba7ff7 100644 (file)
@@ -7,6 +7,7 @@ package poll
 import (
        "errors"
        "internal/race"
+       "internal/syscall/windows"
        "io"
        "runtime"
        "sync"
@@ -92,6 +93,7 @@ type operation struct {
        fd     *FD
        errc   chan error
        buf    syscall.WSABuf
+       msg    windows.WSAMsg
        sa     syscall.Sockaddr
        rsa    *syscall.RawSockaddrAny
        rsan   int32
@@ -132,6 +134,22 @@ func (o *operation) ClearBufs() {
        o.bufs = o.bufs[:0]
 }
 
+func (o *operation) InitMsg(p []byte, oob []byte) {
+       o.InitBuf(p)
+       o.msg.Buffers = &o.buf
+       o.msg.BufferCount = 1
+
+       o.msg.Name = nil
+       o.msg.Namelen = 0
+
+       o.msg.Flags = 0
+       o.msg.Control.Len = uint32(len(oob))
+       o.msg.Control.Buf = nil
+       if len(oob) != 0 {
+               o.msg.Control.Buf = &oob[0]
+       }
+}
+
 // ioSrv executes net IO requests.
 type ioSrv struct {
        req chan ioSrvReq
@@ -898,3 +916,77 @@ func (fd *FD) RawRead(f func(uintptr) bool) error {
 func (fd *FD) RawWrite(f func(uintptr) bool) error {
        return errors.New("not implemented")
 }
+
+func sockaddrToRaw(sa syscall.Sockaddr) (unsafe.Pointer, int32, error) {
+       switch sa := sa.(type) {
+       case *syscall.SockaddrInet4:
+               var raw syscall.RawSockaddrInet4
+               raw.Family = syscall.AF_INET
+               p := (*[2]byte)(unsafe.Pointer(&raw.Port))
+               p[0] = byte(sa.Port >> 8)
+               p[1] = byte(sa.Port)
+               for i := 0; i < len(sa.Addr); i++ {
+                       raw.Addr[i] = sa.Addr[i]
+               }
+               return unsafe.Pointer(&raw), int32(unsafe.Sizeof(raw)), nil
+       case *syscall.SockaddrInet6:
+               var raw syscall.RawSockaddrInet6
+               raw.Family = syscall.AF_INET6
+               p := (*[2]byte)(unsafe.Pointer(&raw.Port))
+               p[0] = byte(sa.Port >> 8)
+               p[1] = byte(sa.Port)
+               raw.Scope_id = sa.ZoneId
+               for i := 0; i < len(sa.Addr); i++ {
+                       raw.Addr[i] = sa.Addr[i]
+               }
+               return unsafe.Pointer(&raw), int32(unsafe.Sizeof(raw)), nil
+       default:
+               return nil, 0, syscall.EWINDOWS
+       }
+}
+
+// ReadMsg wraps the WSARecvMsg network call.
+func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
+       if err := fd.readLock(); err != nil {
+               return 0, 0, 0, nil, err
+       }
+       defer fd.readUnlock()
+
+       o := &fd.rop
+       o.InitMsg(p, oob)
+       o.rsa = new(syscall.RawSockaddrAny)
+       o.msg.Name = o.rsa
+       o.msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
+       n, err := rsrv.ExecIO(o, func(o *operation) error {
+               return windows.WSARecvMsg(o.fd.Sysfd, &o.msg, &o.qty, &o.o, nil)
+       })
+       err = fd.eofError(n, err)
+       var sa syscall.Sockaddr
+       if err == nil {
+               sa, err = o.rsa.Sockaddr()
+       }
+       return n, int(o.msg.Control.Len), int(o.msg.Flags), sa, err
+}
+
+// WriteMsg wraps the WSASendMsg network call.
+func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, error) {
+       if err := fd.writeLock(); err != nil {
+               return 0, 0, err
+       }
+       defer fd.writeUnlock()
+
+       o := &fd.wop
+       o.InitMsg(p, oob)
+       if sa != nil {
+               rsa, len, err := sockaddrToRaw(sa)
+               if err != nil {
+                       return 0, 0, err
+               }
+               o.msg.Name = (*syscall.RawSockaddrAny)(rsa)
+               o.msg.Namelen = len
+       }
+       n, err := wsrv.ExecIO(o, func(o *operation) error {
+               return windows.WSASendMsg(o.fd.Sysfd, &o.msg, 0, &o.qty, &o.o, nil)
+       })
+       return n, int(o.msg.Control.Len), err
+}
index 3c14691e1dd1adf8d734dd5f15cdc433e7242d5d..b531f89b62a285142cfda925d6cd7a1f3bf3ae63 100644 (file)
@@ -4,7 +4,11 @@
 
 package windows
 
-import "syscall"
+import (
+       "sync"
+       "syscall"
+       "unsafe"
+)
 
 const (
        ERROR_SHARING_VIOLATION      syscall.Errno = 32
@@ -115,10 +119,109 @@ const (
 const (
        WSA_FLAG_OVERLAPPED        = 0x01
        WSA_FLAG_NO_HANDLE_INHERIT = 0x80
+
+       WSAEMSGSIZE syscall.Errno = 10040
+
+       MSG_TRUNC  = 0x0100
+       MSG_CTRUNC = 0x0200
+
+       socket_error = uintptr(^uint32(0))
 )
 
+var WSAID_WSASENDMSG = syscall.GUID{
+       Data1: 0xa441e712,
+       Data2: 0x754f,
+       Data3: 0x43ca,
+       Data4: [8]byte{0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d},
+}
+
+var WSAID_WSARECVMSG = syscall.GUID{
+       Data1: 0xf689d7c8,
+       Data2: 0x6f1f,
+       Data3: 0x436b,
+       Data4: [8]byte{0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22},
+}
+
+var sendRecvMsgFunc struct {
+       once     sync.Once
+       sendAddr uintptr
+       recvAddr uintptr
+       err      error
+}
+
+type WSAMsg struct {
+       Name        *syscall.RawSockaddrAny
+       Namelen     int32
+       Buffers     *syscall.WSABuf
+       BufferCount uint32
+       Control     syscall.WSABuf
+       Flags       uint32
+}
+
 //sys  WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = ws2_32.WSASocketW
 
+func loadWSASendRecvMsg() error {
+       sendRecvMsgFunc.once.Do(func() {
+               var s syscall.Handle
+               s, sendRecvMsgFunc.err = syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
+               if sendRecvMsgFunc.err != nil {
+                       return
+               }
+               defer syscall.CloseHandle(s)
+               var n uint32
+               sendRecvMsgFunc.err = syscall.WSAIoctl(s,
+                       syscall.SIO_GET_EXTENSION_FUNCTION_POINTER,
+                       (*byte)(unsafe.Pointer(&WSAID_WSARECVMSG)),
+                       uint32(unsafe.Sizeof(WSAID_WSARECVMSG)),
+                       (*byte)(unsafe.Pointer(&sendRecvMsgFunc.recvAddr)),
+                       uint32(unsafe.Sizeof(sendRecvMsgFunc.recvAddr)),
+                       &n, nil, 0)
+               if sendRecvMsgFunc.err != nil {
+                       return
+               }
+               sendRecvMsgFunc.err = syscall.WSAIoctl(s,
+                       syscall.SIO_GET_EXTENSION_FUNCTION_POINTER,
+                       (*byte)(unsafe.Pointer(&WSAID_WSASENDMSG)),
+                       uint32(unsafe.Sizeof(WSAID_WSASENDMSG)),
+                       (*byte)(unsafe.Pointer(&sendRecvMsgFunc.sendAddr)),
+                       uint32(unsafe.Sizeof(sendRecvMsgFunc.sendAddr)),
+                       &n, nil, 0)
+       })
+       return sendRecvMsgFunc.err
+}
+
+func WSASendMsg(fd syscall.Handle, msg *WSAMsg, flags uint32, bytesSent *uint32, overlapped *syscall.Overlapped, croutine *byte) error {
+       err := loadWSASendRecvMsg()
+       if err != nil {
+               return err
+       }
+       r1, _, e1 := syscall.Syscall6(sendRecvMsgFunc.sendAddr, 6, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)))
+       if r1 == socket_error {
+               if e1 != 0 {
+                       err = errnoErr(e1)
+               } else {
+                       err = syscall.EINVAL
+               }
+       }
+       return err
+}
+
+func WSARecvMsg(fd syscall.Handle, msg *WSAMsg, bytesReceived *uint32, overlapped *syscall.Overlapped, croutine *byte) error {
+       err := loadWSASendRecvMsg()
+       if err != nil {
+               return err
+       }
+       r1, _, e1 := syscall.Syscall6(sendRecvMsgFunc.recvAddr, 5, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(unsafe.Pointer(bytesReceived)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)), 0)
+       if r1 == socket_error {
+               if e1 != 0 {
+                       err = errnoErr(e1)
+               } else {
+                       err = syscall.EINVAL
+               }
+       }
+       return err
+}
+
 const (
        ComputerNameNetBIOS                   = 0
        ComputerNameDnsHostname               = 1
index 563558dc52e9905702795f4ed6945954ae0727ba..e5f8da156a24c5bb8c79a1ab98eb42bc99330391 100644 (file)
@@ -223,17 +223,21 @@ func (fd *netFD) accept() (*netFD, error) {
        return netfd, nil
 }
 
+func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
+       n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
+       runtime.KeepAlive(fd)
+       return n, oobn, flags, sa, wrapSyscallError("wsarecvmsg", err)
+}
+
+func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
+       n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
+       runtime.KeepAlive(fd)
+       return n, oobn, wrapSyscallError("wsasendmsg", err)
+}
+
 // Unimplemented functions.
 
 func (fd *netFD) dup() (*os.File, error) {
        // TODO: Implement this
        return nil, syscall.EWINDOWS
 }
-
-func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
-       return 0, 0, 0, nil, syscall.EWINDOWS
-}
-
-func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
-       return 0, 0, syscall.EWINDOWS
-}
index c4b54f00c4e6e39fb476c30e2f5775d9afd73c90..72cbc3943376f820b1055dcade5baef39280652d 100644 (file)
@@ -21,7 +21,7 @@ import (
 // change the behavior of these methods; use Read or ReadMsgIP
 // instead.
 
-// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgIP and
+// BUG(mikio): On NaCl and Plan 9, the ReadMsgIP and
 // WriteMsgIP methods of IPConn are not implemented.
 
 // BUG(mikio): On Windows, the File method of IPConn is not
index 5841ca35a00db4b85dcf01507f2da32a53272945..8b2b7c264bbc0f02023285338da6bea5549ccb87 100644 (file)
@@ -149,12 +149,18 @@ func testableListenArgs(network, address, client string) bool {
        return true
 }
 
-var condFatalf = func() func(*testing.T, string, ...interface{}) {
-       // A few APIs, File, Read/WriteMsg{UDP,IP}, are not
-       // implemented yet on both Plan 9 and Windows.
+func condFatalf(t *testing.T, api string, format string, args ...interface{}) {
+       // A few APIs like File and Read/WriteMsg{UDP,IP} are not
+       // fully implemented yet on Plan 9 and Windows.
        switch runtime.GOOS {
-       case "plan9", "windows":
-               return (*testing.T).Logf
+       case "windows":
+               if api == "file" {
+                       t.Logf(format, args...)
+                       return
+               }
+       case "plan9":
+               t.Logf(format, args...)
+               return
        }
-       return (*testing.T).Fatalf
-}()
+       t.Fatalf(format, args...)
+}
index 23589d3ca871b6a5001736471ef8b09068e7a861..d89c463011d6cc2262eb60734797c3464acf7166 100644 (file)
@@ -54,7 +54,7 @@ func TestTCPListenerSpecificMethods(t *testing.T) {
        }
 
        if f, err := ln.File(); err != nil {
-               condFatalf(t, "%v", err)
+               condFatalf(t, "file", "%v", err)
        } else {
                f.Close()
        }
@@ -139,14 +139,14 @@ func TestUDPConnSpecificMethods(t *testing.T) {
                t.Fatal(err)
        }
        if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil {
-               condFatalf(t, "%v", err)
+               condFatalf(t, "udp", "%v", err)
        }
        if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil {
-               condFatalf(t, "%v", err)
+               condFatalf(t, "udp", "%v", err)
        }
 
        if f, err := c.File(); err != nil {
-               condFatalf(t, "%v", err)
+               condFatalf(t, "file", "%v", err)
        } else {
                f.Close()
        }
@@ -184,7 +184,7 @@ func TestIPConnSpecificMethods(t *testing.T) {
        c.SetWriteBuffer(2048)
 
        if f, err := c.File(); err != nil {
-               condFatalf(t, "%v", err)
+               condFatalf(t, "file", "%v", err)
        } else {
                f.Close()
        }
index 2c0f74fdabdd293599463cc9c494d8e1bc586d7c..158265f06f850523a8741f84fd9717fc8af33543 100644 (file)
@@ -9,7 +9,7 @@ import (
        "syscall"
 )
 
-// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgUDP and
+// BUG(mikio): On NaCl and Plan 9, the ReadMsgUDP and
 // WriteMsgUDP methods of UDPConn are not implemented.
 
 // BUG(mikio): On Windows, the File method of UDPConn is not
index 6d4974e3e49b93a729df649fdff404b2f83bae29..4ae014c01d9b69539fd12fe1bdb7363384130130 100644 (file)
@@ -161,7 +161,7 @@ func testWriteToConn(t *testing.T, raddr string) {
        }
        _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil)
        switch runtime.GOOS {
-       case "nacl", "windows": // see golang.org/issue/9252
+       case "nacl": // see golang.org/issue/9252
                t.Skipf("not implemented yet on %s", runtime.GOOS)
        default:
                if err != nil {
@@ -204,7 +204,7 @@ func testWriteToPacketConn(t *testing.T, raddr string) {
        }
        _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra)
        switch runtime.GOOS {
-       case "nacl", "windows": // see golang.org/issue/9252
+       case "nacl": // see golang.org/issue/9252
                t.Skipf("not implemented yet on %s", runtime.GOOS)
        default:
                if err != nil {