]> Cypherpunks repositories - gostls13.git/commitdiff
net: pass MSG_CMSG_CLOEXEC flag in ReadMsgUnix
authorHowJMay <vulxj0j8j8@gmail.com>
Mon, 19 Apr 2021 18:06:54 +0000 (18:06 +0000)
committerIan Lance Taylor <iant@golang.org>
Mon, 19 Apr 2021 21:27:43 +0000 (21:27 +0000)
As mentioned in #42765, calling "recvmsg" syscall on Linux should come
with "MSG_CMSG_CLOEXEC" flag.

For other systems which not supports "MSG_CMSG_CLOEXEC". ReadMsgUnix()
would check the header. If the header type is "syscall.SCM_RIGHTS",
then ReadMsgUnix() would parse the SocketControlMessage and call each
fd with "syscall.CloseOnExec"

Fixes #42765

Change-Id: I74347db72b465685d7684bf0f32415d285845ebb
GitHub-Last-Rev: ca59e2c9e0e8de1ae590e9b6dc165cb768a574f5
GitHub-Pull-Request: golang/go#42768
Reviewed-on: https://go-review.googlesource.com/c/go/+/272226
Trust: Emmanuel Odeke <emmanuel@orijtech.com>
Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
12 files changed:
src/internal/poll/fd_unix.go
src/internal/poll/fd_windows.go
src/net/fd_posix.go
src/net/iprawsock_posix.go
src/net/net_fake.go
src/net/udpsock_posix.go
src/net/unixsock_posix.go
src/net/unixsock_readmsg_linux.go [new file with mode: 0644]
src/net/unixsock_readmsg_other.go [new file with mode: 0644]
src/net/unixsock_readmsg_posix.go [new file with mode: 0644]
src/net/unixsock_readmsg_test.go [new file with mode: 0644]
src/syscall/creds_test.go

index fe8a5c8ec008b494d0e0d765806351b9a3f0e81c..3b17cd22b03b8f85d0f9b6eb250d3379d59b9cfa 100644 (file)
@@ -231,7 +231,7 @@ func (fd *FD) ReadFrom(p []byte) (int, syscall.Sockaddr, error) {
 }
 
 // ReadMsg wraps the recvmsg network call.
-func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
+func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
        if err := fd.readLock(); err != nil {
                return 0, 0, 0, nil, err
        }
@@ -240,7 +240,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er
                return 0, 0, 0, nil, err
        }
        for {
-               n, oobn, flags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, 0)
+               n, oobn, sysflags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, flags)
                if err != nil {
                        if err == syscall.EINTR {
                                continue
@@ -253,7 +253,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er
                        }
                }
                err = fd.eofError(n, err)
-               return n, oobn, flags, sa, err
+               return n, oobn, sysflags, sa, err
        }
 }
 
index d8c834f92993fdfd070a92d497c1a5e44f5950a9..4a5169527c42b0592d13e6c75df3fc9773647347 100644 (file)
@@ -1013,7 +1013,7 @@ func sockaddrToRaw(sa syscall.Sockaddr) (unsafe.Pointer, int32, error) {
 }
 
 // ReadMsg wraps the WSARecvMsg network call.
-func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
+func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
        if err := fd.readLock(); err != nil {
                return 0, 0, 0, nil, err
        }
@@ -1028,6 +1028,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er
        o.rsa = new(syscall.RawSockaddrAny)
        o.msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
        o.msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
+       o.msg.Flags = uint32(flags)
        n, err := execIO(o, func(o *operation) error {
                return windows.WSARecvMsg(o.fd.Sysfd, &o.msg, &o.qty, &o.o, nil)
        })
index 2945e46a4890ab513ed9d27e43983da068dfbd6f..4703ff33a10c71ef22df5317f8b413cfe186e0d6 100644 (file)
@@ -64,10 +64,10 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
        return n, sa, wrapSyscallError(readFromSyscallName, err)
 }
 
-func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
-       n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
+func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
+       n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
        runtime.KeepAlive(fd)
-       return n, oobn, flags, sa, wrapSyscallError(readMsgSyscallName, err)
+       return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
 }
 
 func (fd *netFD) Write(p []byte) (nn int, err error) {
index c1514f1698c5a8363423be4a7fa409c3d45ad5fc..b94eec0e182b3d006771e30238078fdb93dd9f44 100644 (file)
@@ -75,7 +75,7 @@ func stripIPv4Header(n int, b []byte) int {
 
 func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
        var sa syscall.Sockaddr
-       n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
+       n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
        switch sa := sa.(type) {
        case *syscall.SockaddrInet4:
                addr = &IPAddr{IP: sa.Addr[0:]}
index 49dc57c6ffeb55ea51576e7980499c2f4445cf6f..74fc1da6fd80af665696880153bfc49a508900d6 100644 (file)
@@ -268,7 +268,7 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
        return 0, nil, syscall.ENOSYS
 }
 
-func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
+func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
        return 0, 0, 0, nil, syscall.ENOSYS
 }
 
index 3b5346e5732cf87731d81eaea4e334de4c8ffd69..fcfb9c004cd7dce406c99d17fef0cbd0239d913d 100644 (file)
@@ -56,7 +56,7 @@ func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
 
 func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
        var sa syscall.Sockaddr
-       n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
+       n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
        switch sa := sa.(type) {
        case *syscall.SockaddrInet4:
                addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
index 1d1f27449f56cf8aab903df94f7259b0aaff7f7c..0306b5989beabb737264864fbd3a0029a4e68788 100644 (file)
@@ -113,7 +113,11 @@ func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
 
 func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
        var sa syscall.Sockaddr
-       n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
+       n, oobn, flags, sa, err = c.fd.readMsg(b, oob, readMsgFlags)
+       if oobn > 0 {
+               setReadMsgCloseOnExec(oob[:oobn])
+       }
+
        switch sa := sa.(type) {
        case *syscall.SockaddrUnix:
                if sa.Name != "" {
diff --git a/src/net/unixsock_readmsg_linux.go b/src/net/unixsock_readmsg_linux.go
new file mode 100644 (file)
index 0000000..3296681
--- /dev/null
@@ -0,0 +1,17 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+// +build linux
+
+package net
+
+import (
+       "syscall"
+)
+
+const readMsgFlags = syscall.MSG_CMSG_CLOEXEC
+
+func setReadMsgCloseOnExec(oob []byte) {
+}
diff --git a/src/net/unixsock_readmsg_other.go b/src/net/unixsock_readmsg_other.go
new file mode 100644 (file)
index 0000000..c8db657
--- /dev/null
@@ -0,0 +1,13 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || windows
+// +build js,wasm windows
+
+package net
+
+const readMsgFlags = 0
+
+func setReadMsgCloseOnExec(oob []byte) {
+}
diff --git a/src/net/unixsock_readmsg_posix.go b/src/net/unixsock_readmsg_posix.go
new file mode 100644 (file)
index 0000000..07d7df5
--- /dev/null
@@ -0,0 +1,33 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
+// +build aix darwin dragonfly freebsd netbsd openbsd solaris
+
+package net
+
+import (
+       "syscall"
+)
+
+const readMsgFlags = 0
+
+func setReadMsgCloseOnExec(oob []byte) {
+       scms, err := syscall.ParseSocketControlMessage(oob)
+       if err != nil {
+               return
+       }
+
+       for _, scm := range scms {
+               if scm.Header.Level == syscall.SOL_SOCKET && scm.Header.Type == syscall.SCM_RIGHTS {
+                       fds, err := syscall.ParseUnixRights(&scm)
+                       if err != nil {
+                               continue
+                       }
+                       for _, fd := range fds {
+                               syscall.CloseOnExec(fd)
+                       }
+               }
+       }
+}
diff --git a/src/net/unixsock_readmsg_test.go b/src/net/unixsock_readmsg_test.go
new file mode 100644 (file)
index 0000000..4961ecb
--- /dev/null
@@ -0,0 +1,105 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package net
+
+import (
+       "os"
+       "syscall"
+       "testing"
+       "time"
+)
+
+func TestUnixConnReadMsgUnixSCMRightsCloseOnExec(t *testing.T) {
+       if !testableNetwork("unix") {
+               t.Skip("not unix system")
+       }
+
+       scmFile, err := os.Open(os.DevNull)
+       if err != nil {
+               t.Fatalf("file open: %v", err)
+       }
+       defer scmFile.Close()
+
+       rights := syscall.UnixRights(int(scmFile.Fd()))
+       fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
+       if err != nil {
+               t.Fatalf("Socketpair: %v", err)
+       }
+
+       writeFile := os.NewFile(uintptr(fds[0]), "write-socket")
+       defer writeFile.Close()
+       readFile := os.NewFile(uintptr(fds[1]), "read-socket")
+       defer readFile.Close()
+
+       cw, err := FileConn(writeFile)
+       if err != nil {
+               t.Fatalf("FileConn: %v", err)
+       }
+       defer cw.Close()
+       cr, err := FileConn(readFile)
+       if err != nil {
+               t.Fatalf("FileConn: %v", err)
+       }
+       defer cr.Close()
+
+       ucw, ok := cw.(*UnixConn)
+       if !ok {
+               t.Fatalf("got %T; want UnixConn", cw)
+       }
+       ucr, ok := cr.(*UnixConn)
+       if !ok {
+               t.Fatalf("got %T; want UnixConn", cr)
+       }
+
+       oob := make([]byte, syscall.CmsgSpace(4))
+       err = ucw.SetWriteDeadline(time.Now().Add(5 * time.Second))
+       if err != nil {
+               t.Fatalf("Can't set unix connection timeout: %v", err)
+       }
+       _, _, err = ucw.WriteMsgUnix(nil, rights, nil)
+       if err != nil {
+               t.Fatalf("UnixConn readMsg: %v", err)
+       }
+       err = ucr.SetReadDeadline(time.Now().Add(5 * time.Second))
+       if err != nil {
+               t.Fatalf("Can't set unix connection timeout: %v", err)
+       }
+       _, oobn, _, _, err := ucr.ReadMsgUnix(nil, oob)
+       if err != nil {
+               t.Fatalf("UnixConn readMsg: %v", err)
+       }
+
+       scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
+       if err != nil {
+               t.Fatalf("ParseSocketControlMessage: %v", err)
+       }
+       if len(scms) != 1 {
+               t.Fatalf("got scms = %#v; expected 1 SocketControlMessage", scms)
+       }
+       scm := scms[0]
+       gotFds, err := syscall.ParseUnixRights(&scm)
+       if err != nil {
+               t.Fatalf("syscall.ParseUnixRights: %v", err)
+       }
+       if len(gotFds) != 1 {
+               t.Fatalf("got FDs %#v: wanted only 1 fd", gotFds)
+       }
+       defer func() {
+               if err := syscall.Close(int(gotFds[0])); err != nil {
+                       t.Fatalf("fail to close gotFds: %v", err)
+               }
+       }()
+
+       flags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(gotFds[0]), uintptr(syscall.F_GETFD), 0)
+       if errno != 0 {
+               t.Fatalf("Can't get flags of fd:%#v, with err:%v", gotFds[0], errno)
+       }
+       if flags&syscall.FD_CLOEXEC == 0 {
+               t.Fatalf("got flags %#x, want %#x (FD_CLOEXEC) set", flags, syscall.FD_CLOEXEC)
+       }
+}
index 736b497bc40302b9fab4155588a2aace177093b7..c1a8b516e81a408c6313afa7c7d47caf835a789c 100644 (file)
@@ -105,8 +105,8 @@ func TestSCMCredentials(t *testing.T) {
                if err != nil {
                        t.Fatalf("ReadMsgUnix: %v", err)
                }
-               if flags != 0 {
-                       t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
+               if flags != syscall.MSG_CMSG_CLOEXEC {
+                       t.Fatalf("ReadMsgUnix flags = %#x, want %#x (MSG_CMSG_CLOEXEC)", flags, syscall.MSG_CMSG_CLOEXEC)
                }
                if n != tt.dataLen {
                        t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)