From 15fbe3480b1c44113e9cdb26008da9f66d4e57b2 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Tue, 16 Sep 2025 09:14:47 +0200 Subject: [PATCH] internal/poll: simplify WriteMsg and ReadMsg on Windows Let newWSAMsg retrieve the socket from the sync pool for unconnected sockets instead of forcing the caller to do it. Change-Id: I9b3d30bf3f5be187cbde5735d519b3b14f7b3008 Reviewed-on: https://go-review.googlesource.com/c/go/+/704116 Reviewed-by: Damien Neil LUCI-TryBot-Result: Go LUCI Reviewed-by: Junyang Shao --- src/internal/poll/fd_windows.go | 61 +++++++------------ .../syscall/windows/syscall_windows.go | 2 +- 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/src/internal/poll/fd_windows.go b/src/internal/poll/fd_windows.go index 4d466e5b64..dd9845d1b2 100644 --- a/src/internal/poll/fd_windows.go +++ b/src/internal/poll/fd_windows.go @@ -168,7 +168,7 @@ var wsaMsgPool = sync.Pool{ // newWSAMsg creates a new WSAMsg with the provided parameters. // Use [freeWSAMsg] to free it. -func newWSAMsg(p []byte, oob []byte, flags int, rsa *syscall.RawSockaddrAny) *windows.WSAMsg { +func newWSAMsg(p []byte, oob []byte, flags int, unconnected bool) *windows.WSAMsg { // The returned object can't be allocated in the stack because it is accessed asynchronously // by Windows in between several system calls. If the stack frame is moved while that happens, // then Windows may access invalid memory. @@ -183,11 +183,9 @@ func newWSAMsg(p []byte, oob []byte, flags int, rsa *syscall.RawSockaddrAny) *wi Buf: unsafe.SliceData(oob), } msg.Flags = uint32(flags) - msg.Name = syscall.Pointer(unsafe.Pointer(rsa)) - if rsa != nil { - msg.Namelen = int32(unsafe.Sizeof(*rsa)) - } else { - msg.Namelen = 0 + if unconnected { + msg.Name = wsaRsaPool.Get().(*syscall.RawSockaddrAny) + msg.Namelen = int32(unsafe.Sizeof(syscall.RawSockaddrAny{})) } return msg } @@ -198,6 +196,12 @@ func freeWSAMsg(msg *windows.WSAMsg) { msg.Buffers.Buf = nil msg.Control.Len = 0 msg.Control.Buf = nil + if msg.Name != nil { + *msg.Name = syscall.RawSockaddrAny{} + wsaRsaPool.Put(msg.Name) + msg.Name = nil + msg.Namelen = 0 + } wsaMsgPool.Put(msg) } @@ -1355,9 +1359,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S p = p[:maxRW] } - rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny) - defer wsaRsaPool.Put(rsa) - msg := newWSAMsg(p, oob, flags, rsa) + msg := newWSAMsg(p, oob, flags, true) defer freeWSAMsg(msg) n, err := fd.execIO(&fd.rop, func(o *operation) (qty uint32, err error) { err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil) @@ -1366,7 +1368,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S err = fd.eofError(n, err) var sa syscall.Sockaddr if err == nil { - sa, err = rsa.Sockaddr() + sa, err = msg.Name.Sockaddr() } return n, int(msg.Control.Len), int(msg.Flags), sa, err } @@ -1382,9 +1384,7 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd p = p[:maxRW] } - rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny) - defer wsaRsaPool.Put(rsa) - msg := newWSAMsg(p, oob, flags, rsa) + msg := newWSAMsg(p, oob, flags, true) defer freeWSAMsg(msg) n, err := fd.execIO(&fd.rop, func(o *operation) (qty uint32, err error) { err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil) @@ -1392,7 +1392,7 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd }) err = fd.eofError(n, err) if err == nil { - rawToSockaddrInet4(rsa, sa4) + rawToSockaddrInet4(msg.Name, sa4) } return n, int(msg.Control.Len), int(msg.Flags), err } @@ -1408,9 +1408,7 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd p = p[:maxRW] } - rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny) - defer wsaRsaPool.Put(rsa) - msg := newWSAMsg(p, oob, flags, rsa) + msg := newWSAMsg(p, oob, flags, true) defer freeWSAMsg(msg) n, err := fd.execIO(&fd.rop, func(o *operation) (qty uint32, err error) { err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil) @@ -1418,7 +1416,7 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd }) err = fd.eofError(n, err) if err == nil { - rawToSockaddrInet6(rsa, sa6) + rawToSockaddrInet6(msg.Name, sa6) } return n, int(msg.Control.Len), int(msg.Flags), err } @@ -1434,16 +1432,11 @@ func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, err } defer fd.writeUnlock() - var rsa *syscall.RawSockaddrAny - if sa != nil { - rsa = wsaRsaPool.Get().(*syscall.RawSockaddrAny) - defer wsaRsaPool.Put(rsa) - } - msg := newWSAMsg(p, oob, 0, rsa) + msg := newWSAMsg(p, oob, 0, sa != nil) defer freeWSAMsg(msg) if sa != nil { var err error - msg.Namelen, err = sockaddrToRaw(rsa, sa) + msg.Namelen, err = sockaddrToRaw(msg.Name, sa) if err != nil { return 0, 0, err } @@ -1466,15 +1459,10 @@ func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (in } defer fd.writeUnlock() - var rsa *syscall.RawSockaddrAny - if sa != nil { - rsa = wsaRsaPool.Get().(*syscall.RawSockaddrAny) - defer wsaRsaPool.Put(rsa) - } - msg := newWSAMsg(p, oob, 0, rsa) + msg := newWSAMsg(p, oob, 0, sa != nil) defer freeWSAMsg(msg) if sa != nil { - msg.Namelen = sockaddrInet4ToRaw(rsa, sa) + msg.Namelen = sockaddrInet4ToRaw(msg.Name, sa) } n, err := fd.execIO(&fd.wop, func(o *operation) (qty uint32, err error) { err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil) @@ -1494,15 +1482,10 @@ func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (in } defer fd.writeUnlock() - var rsa *syscall.RawSockaddrAny - if sa != nil { - rsa = wsaRsaPool.Get().(*syscall.RawSockaddrAny) - defer wsaRsaPool.Put(rsa) - } - msg := newWSAMsg(p, oob, 0, rsa) + msg := newWSAMsg(p, oob, 0, sa != nil) defer freeWSAMsg(msg) if sa != nil { - msg.Namelen = sockaddrInet6ToRaw(rsa, sa) + msg.Namelen = sockaddrInet6ToRaw(msg.Name, sa) } n, err := fd.execIO(&fd.wop, func(o *operation) (qty uint32, err error) { err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil) diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go index b60648ea29..fb3b66540f 100644 --- a/src/internal/syscall/windows/syscall_windows.go +++ b/src/internal/syscall/windows/syscall_windows.go @@ -261,7 +261,7 @@ var sendRecvMsgFunc struct { } type WSAMsg struct { - Name syscall.Pointer + Name *syscall.RawSockaddrAny Namelen int32 Buffers *syscall.WSABuf BufferCount uint32 -- 2.52.0