}
// ReadMsg wraps the recvmsg network call.
-func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
+func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, nil, err
}
return 0, 0, 0, nil, err
}
for {
- n, oobn, flags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, 0)
+ n, oobn, sysflags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, flags)
if err != nil {
if err == syscall.EINTR {
continue
}
}
err = fd.eofError(n, err)
- return n, oobn, flags, sa, err
+ return n, oobn, sysflags, sa, err
}
}
}
// ReadMsg wraps the WSARecvMsg network call.
-func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
+func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, nil, err
}
o.rsa = new(syscall.RawSockaddrAny)
o.msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
o.msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
+ o.msg.Flags = uint32(flags)
n, err := execIO(o, func(o *operation) error {
return windows.WSARecvMsg(o.fd.Sysfd, &o.msg, &o.qty, &o.o, nil)
})
return n, sa, wrapSyscallError(readFromSyscallName, err)
}
-func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
- n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
+func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
+ n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
runtime.KeepAlive(fd)
- return n, oobn, flags, sa, wrapSyscallError(readMsgSyscallName, err)
+ return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
}
func (fd *netFD) Write(p []byte) (nn int, err error) {
func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
var sa syscall.Sockaddr
- n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
+ n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
addr = &IPAddr{IP: sa.Addr[0:]}
return 0, nil, syscall.ENOSYS
}
-func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
+func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
return 0, 0, 0, nil, syscall.ENOSYS
}
func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
var sa syscall.Sockaddr
- n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
+ n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
var sa syscall.Sockaddr
- n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
+ n, oobn, flags, sa, err = c.fd.readMsg(b, oob, readMsgFlags)
+ if oobn > 0 {
+ setReadMsgCloseOnExec(oob[:oobn])
+ }
+
switch sa := sa.(type) {
case *syscall.SockaddrUnix:
if sa.Name != "" {
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+// +build linux
+
+package net
+
+import (
+ "syscall"
+)
+
+const readMsgFlags = syscall.MSG_CMSG_CLOEXEC
+
+func setReadMsgCloseOnExec(oob []byte) {
+}
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || windows
+// +build js,wasm windows
+
+package net
+
+const readMsgFlags = 0
+
+func setReadMsgCloseOnExec(oob []byte) {
+}
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
+// +build aix darwin dragonfly freebsd netbsd openbsd solaris
+
+package net
+
+import (
+ "syscall"
+)
+
+const readMsgFlags = 0
+
+func setReadMsgCloseOnExec(oob []byte) {
+ scms, err := syscall.ParseSocketControlMessage(oob)
+ if err != nil {
+ return
+ }
+
+ for _, scm := range scms {
+ if scm.Header.Level == syscall.SOL_SOCKET && scm.Header.Type == syscall.SCM_RIGHTS {
+ fds, err := syscall.ParseUnixRights(&scm)
+ if err != nil {
+ continue
+ }
+ for _, fd := range fds {
+ syscall.CloseOnExec(fd)
+ }
+ }
+ }
+}
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package net
+
+import (
+ "os"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestUnixConnReadMsgUnixSCMRightsCloseOnExec(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skip("not unix system")
+ }
+
+ scmFile, err := os.Open(os.DevNull)
+ if err != nil {
+ t.Fatalf("file open: %v", err)
+ }
+ defer scmFile.Close()
+
+ rights := syscall.UnixRights(int(scmFile.Fd()))
+ fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("Socketpair: %v", err)
+ }
+
+ writeFile := os.NewFile(uintptr(fds[0]), "write-socket")
+ defer writeFile.Close()
+ readFile := os.NewFile(uintptr(fds[1]), "read-socket")
+ defer readFile.Close()
+
+ cw, err := FileConn(writeFile)
+ if err != nil {
+ t.Fatalf("FileConn: %v", err)
+ }
+ defer cw.Close()
+ cr, err := FileConn(readFile)
+ if err != nil {
+ t.Fatalf("FileConn: %v", err)
+ }
+ defer cr.Close()
+
+ ucw, ok := cw.(*UnixConn)
+ if !ok {
+ t.Fatalf("got %T; want UnixConn", cw)
+ }
+ ucr, ok := cr.(*UnixConn)
+ if !ok {
+ t.Fatalf("got %T; want UnixConn", cr)
+ }
+
+ oob := make([]byte, syscall.CmsgSpace(4))
+ err = ucw.SetWriteDeadline(time.Now().Add(5 * time.Second))
+ if err != nil {
+ t.Fatalf("Can't set unix connection timeout: %v", err)
+ }
+ _, _, err = ucw.WriteMsgUnix(nil, rights, nil)
+ if err != nil {
+ t.Fatalf("UnixConn readMsg: %v", err)
+ }
+ err = ucr.SetReadDeadline(time.Now().Add(5 * time.Second))
+ if err != nil {
+ t.Fatalf("Can't set unix connection timeout: %v", err)
+ }
+ _, oobn, _, _, err := ucr.ReadMsgUnix(nil, oob)
+ if err != nil {
+ t.Fatalf("UnixConn readMsg: %v", err)
+ }
+
+ scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
+ if err != nil {
+ t.Fatalf("ParseSocketControlMessage: %v", err)
+ }
+ if len(scms) != 1 {
+ t.Fatalf("got scms = %#v; expected 1 SocketControlMessage", scms)
+ }
+ scm := scms[0]
+ gotFds, err := syscall.ParseUnixRights(&scm)
+ if err != nil {
+ t.Fatalf("syscall.ParseUnixRights: %v", err)
+ }
+ if len(gotFds) != 1 {
+ t.Fatalf("got FDs %#v: wanted only 1 fd", gotFds)
+ }
+ defer func() {
+ if err := syscall.Close(int(gotFds[0])); err != nil {
+ t.Fatalf("fail to close gotFds: %v", err)
+ }
+ }()
+
+ flags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(gotFds[0]), uintptr(syscall.F_GETFD), 0)
+ if errno != 0 {
+ t.Fatalf("Can't get flags of fd:%#v, with err:%v", gotFds[0], errno)
+ }
+ if flags&syscall.FD_CLOEXEC == 0 {
+ t.Fatalf("got flags %#x, want %#x (FD_CLOEXEC) set", flags, syscall.FD_CLOEXEC)
+ }
+}
if err != nil {
t.Fatalf("ReadMsgUnix: %v", err)
}
- if flags != 0 {
- t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
+ if flags != syscall.MSG_CMSG_CLOEXEC {
+ t.Fatalf("ReadMsgUnix flags = %#x, want %#x (MSG_CMSG_CLOEXEC)", flags, syscall.MSG_CMSG_CLOEXEC)
}
if n != tt.dataLen {
t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)