]> Cypherpunks repositories - gostls13.git/commitdiff
runtime,internal/poll: move websocket handling out of the runtime on Windows
authorqmuntal <quimmuntal@gmail.com>
Fri, 19 Jan 2024 13:14:17 +0000 (14:14 +0100)
committerQuim Muntal <quimmuntal@gmail.com>
Tue, 23 Jan 2024 15:36:19 +0000 (15:36 +0000)
On Windows, the netpoll is currently coupled with the websocket usage
in the internal/poll package.

This CL moves the websocket handling out of the runtime and puts it into
the internal/poll package, which already contains most of the async I/O
logic for websockets.

This is a good refactor per se, as the Go runtime shouldn't know about
websockets. In addition, it will make it easier (in a future CL) to only
load ws2_32.dll when the Go program actually uses websockets.

Change-Id: Ic820872cf9bdbbf092505ed7f7504edb6687735e
Reviewed-on: https://go-review.googlesource.com/c/go/+/556936
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
src/internal/poll/fd_windows.go
src/internal/syscall/windows/syscall_windows.go
src/internal/syscall/windows/zsyscall_windows.go
src/runtime/netpoll_windows.go
src/runtime/os_windows.go

index 2095a6aa292700db883e7d973e76b9b330baf28d..fe23d4b3d74e3fe321184d7adce6a3c498ea8479 100644 (file)
@@ -71,8 +71,6 @@ type operation struct {
        // fields used by runtime.netpoll
        runtimeCtx uintptr
        mode       int32
-       errno      int32
-       qty        uint32
 
        // fields used only by net package
        fd     *FD
@@ -83,6 +81,7 @@ type operation struct {
        rsan   int32
        handle syscall.Handle
        flags  uint32
+       qty    uint32
        bufs   []syscall.WSABuf
 }
 
@@ -174,9 +173,9 @@ func execIO(o *operation, submit func(o *operation) error) (int, error) {
        // Wait for our request to complete.
        err = fd.pd.wait(int(o.mode), fd.isFile)
        if err == nil {
+               err = windows.WSAGetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false, &o.flags)
                // All is good. Extract our IO results and return.
-               if o.errno != 0 {
-                       err = syscall.Errno(o.errno)
+               if err != nil {
                        // More data available. Return back the size of received data.
                        if err == syscall.ERROR_MORE_DATA || err == windows.WSAEMSGSIZE {
                                return int(o.qty), err
@@ -202,8 +201,8 @@ func execIO(o *operation, submit func(o *operation) error) (int, error) {
        }
        // Wait for cancellation to complete.
        fd.pd.waitCanceled(int(o.mode))
-       if o.errno != 0 {
-               err = syscall.Errno(o.errno)
+       err = windows.WSAGetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false, &o.flags)
+       if err != nil {
                if err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
                        err = netpollErr
                }
index d10e30cb6825fb9d29d47f40cb0e6f1a1ae33046..a02c96c8f03cff269810874f697e5556700c4b7e 100644 (file)
@@ -235,6 +235,7 @@ type WSAMsg struct {
 }
 
 //sys  WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = ws2_32.WSASocketW
+//sys  WSAGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
 
 func loadWSASendRecvMsg() error {
        sendRecvMsgFunc.once.Do(func() {
index 931f157cf166c267573aa8834bbcc72157da7352..7d3cd37b92bed1d707132fd12d8241df35af8d27 100644 (file)
@@ -86,6 +86,7 @@ var (
        procCreateEnvironmentBlock            = moduserenv.NewProc("CreateEnvironmentBlock")
        procDestroyEnvironmentBlock           = moduserenv.NewProc("DestroyEnvironmentBlock")
        procGetProfilesDirectoryW             = moduserenv.NewProc("GetProfilesDirectoryW")
+       procWSAGetOverlappedResult            = modws2_32.NewProc("WSAGetOverlappedResult")
        procWSASocketW                        = modws2_32.NewProc("WSASocketW")
 )
 
@@ -426,6 +427,18 @@ func GetProfilesDirectory(dir *uint16, dirLen *uint32) (err error) {
        return
 }
 
+func WSAGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
+       var _p0 uint32
+       if wait {
+               _p0 = 1
+       }
+       r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
+       if r1 == 0 {
+               err = errnoErr(e1)
+       }
+       return
+}
+
 func WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) {
        r0, _, e1 := syscall.Syscall6(procWSASocketW.Addr(), 6, uintptr(af), uintptr(typ), uintptr(protocol), uintptr(unsafe.Pointer(protinfo)), uintptr(group), uintptr(flags))
        handle = syscall.Handle(r0)
index 484a9e85b2d61113dc45247424ef3e149a69f777..59377bc5885d5678acc4f6e9eec4b8fb051e5529 100644 (file)
@@ -19,10 +19,8 @@ type net_op struct {
        // used by windows
        o overlapped
        // used by netpoll
-       pd    *pollDesc
-       mode  int32
-       errno int32
-       qty   uint32
+       pd   *pollDesc
+       mode int32
 }
 
 type overlappedEntry struct {
@@ -86,7 +84,7 @@ func netpollBreak() {
 // delay > 0: block for up to that many nanoseconds
 func netpoll(delay int64) (gList, int32) {
        var entries [64]overlappedEntry
-       var wait, qty, flags, n, i uint32
+       var wait, n, i uint32
        var errno int32
        var op *net_op
        var toRun gList
@@ -131,12 +129,12 @@ func netpoll(delay int64) (gList, int32) {
        for i = 0; i < n; i++ {
                op = entries[i].op
                if op != nil && op.pd == entries[i].key {
-                       errno = 0
-                       qty = 0
-                       if stdcall5(_WSAGetOverlappedResult, op.pd.fd, uintptr(unsafe.Pointer(op)), uintptr(unsafe.Pointer(&qty)), 0, uintptr(unsafe.Pointer(&flags))) == 0 {
-                               errno = int32(getlasterror())
+                       mode := op.mode
+                       if mode != 'r' && mode != 'w' {
+                               println("runtime: GetQueuedCompletionStatusEx returned net_op with invalid mode=", mode)
+                               throw("runtime: netpoll failed")
                        }
-                       delta += handlecompletion(&toRun, op, errno, qty)
+                       delta += netpollready(&toRun, op.pd, mode)
                } else {
                        netpollWakeSig.Store(0)
                        if delay == 0 {
@@ -148,14 +146,3 @@ func netpoll(delay int64) (gList, int32) {
        }
        return toRun, delta
 }
-
-func handlecompletion(toRun *gList, op *net_op, errno int32, qty uint32) int32 {
-       mode := op.mode
-       if mode != 'r' && mode != 'w' {
-               println("runtime: GetQueuedCompletionStatusEx returned invalid mode=", mode)
-               throw("runtime: netpoll failed")
-       }
-       op.errno = errno
-       op.qty = qty
-       return netpollready(toRun, op.pd, mode)
-}
index cd0a3c260ec7f91303785743db0eeef436251d2b..7e9bbd04f2758807ac59a4e9537ea4654cab1a5c 100644 (file)
@@ -139,7 +139,6 @@ var (
        // These are from non-kernel32.dll, so we prefer to LoadLibraryEx them.
        _timeBeginPeriod,
        _timeEndPeriod,
-       _WSAGetOverlappedResult,
        _ stdFunction
 )
 
@@ -148,7 +147,6 @@ var (
        ntdlldll            = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
        powrprofdll         = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
        winmmdll            = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-       ws2_32dll           = [...]uint16{'w', 's', '2', '_', '3', '2', '.', 'd', 'l', 'l', 0}
 )
 
 // Function to be called by windows CreateThread
@@ -256,15 +254,6 @@ func loadOptionalSyscalls() {
        }
        _RtlGetCurrentPeb = windowsFindfunc(n32, []byte("RtlGetCurrentPeb\000"))
        _RtlGetNtVersionNumbers = windowsFindfunc(n32, []byte("RtlGetNtVersionNumbers\000"))
-
-       ws232 := windowsLoadSystemLib(ws2_32dll[:])
-       if ws232 == 0 {
-               throw("ws2_32.dll not found")
-       }
-       _WSAGetOverlappedResult = windowsFindfunc(ws232, []byte("WSAGetOverlappedResult\000"))
-       if _WSAGetOverlappedResult == nil {
-               throw("WSAGetOverlappedResult not found")
-       }
 }
 
 func monitorSuspendResume() {