]> Cypherpunks repositories - gostls13.git/commitdiff
internal/poll: don't use stack-allocated WSAMsg parameters
authorqmuntal <quimmuntal@gmail.com>
Mon, 25 Aug 2025 07:57:49 +0000 (09:57 +0200)
committerQuim Muntal <quimmuntal@gmail.com>
Tue, 26 Aug 2025 10:49:59 +0000 (03:49 -0700)
WSAMsg parameters should be passed to Windows as heap pointers instead
of stack pointers. This is because Windows might access the memory
after the syscall returned in case of a non-blocking operation (which
is the common case), and if the WSAMsg is on the stack, the Go
runtime might have moved it around.

Use a sync.Pool to cache WSAMsg structures to avoid a heap allocation
every time a WSAMsg is needed.

Fixes #74933

Cq-Include-Trybots: luci.golang.try:x_net-gotip-windows-amd64
Change-Id: I075e2ceb25cd545224ab3a10d404340faf19fc01
Reviewed-on: https://go-review.googlesource.com/c/go/+/698797
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/internal/poll/fd_windows.go

index 88d0785efd983a1fd3414398674e50b6c687f1c0..18ccee3cee724695c04d16d4bd3848bc319651ff 100644 (file)
@@ -144,19 +144,47 @@ func (o *operation) ClearBufs() {
        o.bufs = o.bufs[:0]
 }
 
-func newWSAMsg(p []byte, oob []byte, flags int) windows.WSAMsg {
-       return windows.WSAMsg{
-               Buffers: &syscall.WSABuf{
-                       Len: uint32(len(p)),
-                       Buf: unsafe.SliceData(p),
-               },
-               BufferCount: 1,
-               Control: syscall.WSABuf{
-                       Len: uint32(len(oob)),
-                       Buf: unsafe.SliceData(oob),
-               },
-               Flags: uint32(flags),
+// wsaMsgPool is a pool of WSAMsg structures that can only hold a single WSABuf.
+var wsaMsgPool = sync.Pool{
+       New: func() any {
+               return &windows.WSAMsg{
+                       Buffers:     &syscall.WSABuf{},
+                       BufferCount: 1,
+               }
+       },
+}
+
+// newWSAMsg creates a new WSAMsg with the provided parameters.
+// Use [freeWSAMsg] to free it.
+func newWSAMsg(p []byte, oob []byte, flags int, rsa *syscall.RawSockaddrAny) *windows.WSAMsg {
+       // The returned object can't be allocated in the stack because it is accessed asynchronously
+       // by Windows in between several system calls. If the stack frame is moved while that happens,
+       // then Windows may access invalid memory.
+       // TODO(qmuntal): investigate using runtime.Pinner keeping this path allocation-free.
+
+       // Use a pool to reuse allocations.
+       msg := wsaMsgPool.Get().(*windows.WSAMsg)
+       msg.Buffers.Len = uint32(len(p))
+       msg.Buffers.Buf = unsafe.SliceData(p)
+       msg.Control = syscall.WSABuf{
+               Len: uint32(len(oob)),
+               Buf: unsafe.SliceData(oob),
+       }
+       msg.Flags = uint32(flags)
+       msg.Name = syscall.Pointer(unsafe.Pointer(rsa))
+       if rsa != nil {
+               msg.Namelen = int32(unsafe.Sizeof(*rsa))
+       } else {
+               msg.Namelen = 0
        }
+       return msg
+}
+
+func freeWSAMsg(msg *windows.WSAMsg) {
+       // Clear pointers to buffers so they can be released by garbage collector.
+       msg.Buffers.Len = 0
+       msg.Buffers.Buf = nil
+       wsaMsgPool.Put(msg)
 }
 
 // waitIO waits for the IO operation o to complete.
@@ -1297,11 +1325,10 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S
        if o.rsa == nil {
                o.rsa = new(syscall.RawSockaddrAny)
        }
-       msg := newWSAMsg(p, oob, flags)
-       msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
-       msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
+       msg := newWSAMsg(p, oob, flags, o.rsa)
+       defer freeWSAMsg(msg)
        n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
-               err = windows.WSARecvMsg(fd.Sysfd, &msg, &qty, &o.o, nil)
+               err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
                return qty, err
        })
        err = fd.eofError(n, err)
@@ -1327,11 +1354,10 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd
        if o.rsa == nil {
                o.rsa = new(syscall.RawSockaddrAny)
        }
-       msg := newWSAMsg(p, oob, flags)
-       msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
-       msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
+       msg := newWSAMsg(p, oob, flags, o.rsa)
+       defer freeWSAMsg(msg)
        n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
-               err = windows.WSARecvMsg(fd.Sysfd, &msg, &qty, &o.o, nil)
+               err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
                return qty, err
        })
        err = fd.eofError(n, err)
@@ -1356,11 +1382,10 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd
        if o.rsa == nil {
                o.rsa = new(syscall.RawSockaddrAny)
        }
-       msg := newWSAMsg(p, oob, flags)
-       msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
-       msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
+       msg := newWSAMsg(p, oob, flags, o.rsa)
+       defer freeWSAMsg(msg)
        n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
-               err = windows.WSARecvMsg(fd.Sysfd, &msg, &qty, &o.o, nil)
+               err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
                return qty, err
        })
        err = fd.eofError(n, err)
@@ -1382,20 +1407,20 @@ func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, err
        defer fd.writeUnlock()
 
        o := &fd.wop
-       msg := newWSAMsg(p, oob, 0)
+       if sa != nil && o.rsa == nil {
+               o.rsa = new(syscall.RawSockaddrAny)
+       }
+       msg := newWSAMsg(p, oob, 0, o.rsa)
+       defer freeWSAMsg(msg)
        if sa != nil {
-               if o.rsa == nil {
-                       o.rsa = new(syscall.RawSockaddrAny)
-               }
-               len, err := sockaddrToRaw(o.rsa, sa)
+               var err error
+               msg.Namelen, err = sockaddrToRaw(o.rsa, sa)
                if err != nil {
                        return 0, 0, err
                }
-               msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
-               msg.Namelen = len
        }
        n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
-               err = windows.WSASendMsg(fd.Sysfd, &msg, 0, nil, &o.o, nil)
+               err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
                return qty, err
        })
        return n, int(msg.Control.Len), err
@@ -1413,16 +1438,16 @@ func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (in
        defer fd.writeUnlock()
 
        o := &fd.wop
-       msg := newWSAMsg(p, oob, 0)
+       if sa != nil && o.rsa == nil {
+               o.rsa = new(syscall.RawSockaddrAny)
+       }
+       msg := newWSAMsg(p, oob, 0, o.rsa)
+       defer freeWSAMsg(msg)
        if sa != nil {
-               if o.rsa == nil {
-                       o.rsa = new(syscall.RawSockaddrAny)
-               }
-               msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
                msg.Namelen = sockaddrInet4ToRaw(o.rsa, sa)
        }
        n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
-               err = windows.WSASendMsg(fd.Sysfd, &msg, 0, nil, &o.o, nil)
+               err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
                return qty, err
        })
        return n, int(msg.Control.Len), err
@@ -1440,16 +1465,16 @@ func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (in
        defer fd.writeUnlock()
 
        o := &fd.wop
-       msg := newWSAMsg(p, oob, 0)
+       if sa != nil && o.rsa == nil {
+               o.rsa = new(syscall.RawSockaddrAny)
+       }
+       msg := newWSAMsg(p, oob, 0, o.rsa)
+       defer freeWSAMsg(msg)
        if sa != nil {
-               if o.rsa == nil {
-                       o.rsa = new(syscall.RawSockaddrAny)
-               }
-               msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
                msg.Namelen = sockaddrInet6ToRaw(o.rsa, sa)
        }
        n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
-               err = windows.WSASendMsg(fd.Sysfd, &msg, 0, nil, &o.o, nil)
+               err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
                return qty, err
        })
        return n, int(msg.Control.Len), err