]> Cypherpunks repositories - gostls13.git/commitdiff
internal/poll: simplify WriteMsg and ReadMsg on Windows
authorqmuntal <quimmuntal@gmail.com>
Tue, 16 Sep 2025 07:14:47 +0000 (09:14 +0200)
committerQuim Muntal <quimmuntal@gmail.com>
Fri, 26 Sep 2025 18:37:26 +0000 (11:37 -0700)
Let newWSAMsg retrieve the socket from the sync pool for unconnected
sockets instead of forcing the caller to do it.

Change-Id: I9b3d30bf3f5be187cbde5735d519b3b14f7b3008
Reviewed-on: https://go-review.googlesource.com/c/go/+/704116
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
src/internal/poll/fd_windows.go
src/internal/syscall/windows/syscall_windows.go

index 4d466e5b64c9d07e58bc81d12e62a3e5b7b7b2c7..dd9845d1b223c30dec3cde12eb88f0df84eaba47 100644 (file)
@@ -168,7 +168,7 @@ var wsaMsgPool = sync.Pool{
 
 // newWSAMsg creates a new WSAMsg with the provided parameters.
 // Use [freeWSAMsg] to free it.
-func newWSAMsg(p []byte, oob []byte, flags int, rsa *syscall.RawSockaddrAny) *windows.WSAMsg {
+func newWSAMsg(p []byte, oob []byte, flags int, unconnected bool) *windows.WSAMsg {
        // The returned object can't be allocated in the stack because it is accessed asynchronously
        // by Windows in between several system calls. If the stack frame is moved while that happens,
        // then Windows may access invalid memory.
@@ -183,11 +183,9 @@ func newWSAMsg(p []byte, oob []byte, flags int, rsa *syscall.RawSockaddrAny) *wi
                Buf: unsafe.SliceData(oob),
        }
        msg.Flags = uint32(flags)
-       msg.Name = syscall.Pointer(unsafe.Pointer(rsa))
-       if rsa != nil {
-               msg.Namelen = int32(unsafe.Sizeof(*rsa))
-       } else {
-               msg.Namelen = 0
+       if unconnected {
+               msg.Name = wsaRsaPool.Get().(*syscall.RawSockaddrAny)
+               msg.Namelen = int32(unsafe.Sizeof(syscall.RawSockaddrAny{}))
        }
        return msg
 }
@@ -198,6 +196,12 @@ func freeWSAMsg(msg *windows.WSAMsg) {
        msg.Buffers.Buf = nil
        msg.Control.Len = 0
        msg.Control.Buf = nil
+       if msg.Name != nil {
+               *msg.Name = syscall.RawSockaddrAny{}
+               wsaRsaPool.Put(msg.Name)
+               msg.Name = nil
+               msg.Namelen = 0
+       }
        wsaMsgPool.Put(msg)
 }
 
@@ -1355,9 +1359,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S
                p = p[:maxRW]
        }
 
-       rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny)
-       defer wsaRsaPool.Put(rsa)
-       msg := newWSAMsg(p, oob, flags, rsa)
+       msg := newWSAMsg(p, oob, flags, true)
        defer freeWSAMsg(msg)
        n, err := fd.execIO(&fd.rop, func(o *operation) (qty uint32, err error) {
                err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
@@ -1366,7 +1368,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S
        err = fd.eofError(n, err)
        var sa syscall.Sockaddr
        if err == nil {
-               sa, err = rsa.Sockaddr()
+               sa, err = msg.Name.Sockaddr()
        }
        return n, int(msg.Control.Len), int(msg.Flags), sa, err
 }
@@ -1382,9 +1384,7 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd
                p = p[:maxRW]
        }
 
-       rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny)
-       defer wsaRsaPool.Put(rsa)
-       msg := newWSAMsg(p, oob, flags, rsa)
+       msg := newWSAMsg(p, oob, flags, true)
        defer freeWSAMsg(msg)
        n, err := fd.execIO(&fd.rop, func(o *operation) (qty uint32, err error) {
                err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
@@ -1392,7 +1392,7 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd
        })
        err = fd.eofError(n, err)
        if err == nil {
-               rawToSockaddrInet4(rsa, sa4)
+               rawToSockaddrInet4(msg.Name, sa4)
        }
        return n, int(msg.Control.Len), int(msg.Flags), err
 }
@@ -1408,9 +1408,7 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd
                p = p[:maxRW]
        }
 
-       rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny)
-       defer wsaRsaPool.Put(rsa)
-       msg := newWSAMsg(p, oob, flags, rsa)
+       msg := newWSAMsg(p, oob, flags, true)
        defer freeWSAMsg(msg)
        n, err := fd.execIO(&fd.rop, func(o *operation) (qty uint32, err error) {
                err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
@@ -1418,7 +1416,7 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd
        })
        err = fd.eofError(n, err)
        if err == nil {
-               rawToSockaddrInet6(rsa, sa6)
+               rawToSockaddrInet6(msg.Name, sa6)
        }
        return n, int(msg.Control.Len), int(msg.Flags), err
 }
@@ -1434,16 +1432,11 @@ func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, err
        }
        defer fd.writeUnlock()
 
-       var rsa *syscall.RawSockaddrAny
-       if sa != nil {
-               rsa = wsaRsaPool.Get().(*syscall.RawSockaddrAny)
-               defer wsaRsaPool.Put(rsa)
-       }
-       msg := newWSAMsg(p, oob, 0, rsa)
+       msg := newWSAMsg(p, oob, 0, sa != nil)
        defer freeWSAMsg(msg)
        if sa != nil {
                var err error
-               msg.Namelen, err = sockaddrToRaw(rsa, sa)
+               msg.Namelen, err = sockaddrToRaw(msg.Name, sa)
                if err != nil {
                        return 0, 0, err
                }
@@ -1466,15 +1459,10 @@ func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (in
        }
        defer fd.writeUnlock()
 
-       var rsa *syscall.RawSockaddrAny
-       if sa != nil {
-               rsa = wsaRsaPool.Get().(*syscall.RawSockaddrAny)
-               defer wsaRsaPool.Put(rsa)
-       }
-       msg := newWSAMsg(p, oob, 0, rsa)
+       msg := newWSAMsg(p, oob, 0, sa != nil)
        defer freeWSAMsg(msg)
        if sa != nil {
-               msg.Namelen = sockaddrInet4ToRaw(rsa, sa)
+               msg.Namelen = sockaddrInet4ToRaw(msg.Name, sa)
        }
        n, err := fd.execIO(&fd.wop, func(o *operation) (qty uint32, err error) {
                err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
@@ -1494,15 +1482,10 @@ func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (in
        }
        defer fd.writeUnlock()
 
-       var rsa *syscall.RawSockaddrAny
-       if sa != nil {
-               rsa = wsaRsaPool.Get().(*syscall.RawSockaddrAny)
-               defer wsaRsaPool.Put(rsa)
-       }
-       msg := newWSAMsg(p, oob, 0, rsa)
+       msg := newWSAMsg(p, oob, 0, sa != nil)
        defer freeWSAMsg(msg)
        if sa != nil {
-               msg.Namelen = sockaddrInet6ToRaw(rsa, sa)
+               msg.Namelen = sockaddrInet6ToRaw(msg.Name, sa)
        }
        n, err := fd.execIO(&fd.wop, func(o *operation) (qty uint32, err error) {
                err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
index b60648ea299126e3bec95c924c9bebefaf6b0d34..fb3b66540f1c9d0d2decedce0f70175002467bf4 100644 (file)
@@ -261,7 +261,7 @@ var sendRecvMsgFunc struct {
 }
 
 type WSAMsg struct {
-       Name        syscall.Pointer
+       Name        *syscall.RawSockaddrAny
        Namelen     int32
        Buffers     *syscall.WSABuf
        BufferCount uint32