]> Cypherpunks repositories - gostls13.git/commitdiff
os: make MkdirAll support volume names
authorqmuntal <quimmuntal@gmail.com>
Tue, 8 Aug 2023 13:31:43 +0000 (15:31 +0200)
committerQuim Muntal <quimmuntal@gmail.com>
Wed, 9 Aug 2023 15:15:57 +0000 (15:15 +0000)
MkdirAll fails to create directories under root paths using volume
names (e.g. //?/Volume{GUID}/foo). This is because fixRootDirectory
only handle extended length paths using drive letters (e.g. //?/C:/foo).

This CL fixes that issue by also detecting volume names without path
separator.

Updates #22230
Fixes #39785

Change-Id: I813fdc0b968ce71a4297f69245b935558e6cd789
Reviewed-on: https://go-review.googlesource.com/c/go/+/517015
Run-TryBot: Quim Muntal <quimmuntal@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/internal/syscall/windows/syscall_windows.go
src/internal/syscall/windows/zsyscall_windows.go
src/os/path.go
src/os/path_plan9.go
src/os/path_unix.go
src/os/path_windows.go
src/os/path_windows_test.go

index 924dd1e121241cd7de1a8aad39dbae30d00956c6..ad36bd48a688bda045aeba5271b76d29801e3b76 100644 (file)
@@ -398,6 +398,7 @@ type FILE_ID_BOTH_DIR_INFO struct {
 }
 
 //sys  GetVolumeInformationByHandle(file syscall.Handle, volumeNameBuffer *uint16, volumeNameSize uint32, volumeNameSerialNumber *uint32, maximumComponentLength *uint32, fileSystemFlags *uint32, fileSystemNameBuffer *uint16, fileSystemNameSize uint32) (err error) = GetVolumeInformationByHandleW
+//sys  GetVolumeNameForVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16, bufferlength uint32) (err error) = GetVolumeNameForVolumeMountPointW
 
 //sys  RtlLookupFunctionEntry(pc uintptr, baseAddress *uintptr, table *byte) (ret uintptr) = kernel32.RtlLookupFunctionEntry
 //sys  RtlVirtualUnwind(handlerType uint32, baseAddress uintptr, pc uintptr, entry uintptr, ctxt uintptr, data *uintptr, frame *uintptr, ctxptrs *byte) (ret uintptr) = kernel32.RtlVirtualUnwind
index fb87bd03a27bb1bdb995837c659f390e5f2b951e..e3f6d8d2a2208cd7a791fbb776914632fbf2103f 100644 (file)
@@ -45,46 +45,47 @@ var (
        moduserenv  = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
        modws2_32   = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
 
-       procAdjustTokenPrivileges         = modadvapi32.NewProc("AdjustTokenPrivileges")
-       procDuplicateTokenEx              = modadvapi32.NewProc("DuplicateTokenEx")
-       procImpersonateSelf               = modadvapi32.NewProc("ImpersonateSelf")
-       procLookupPrivilegeValueW         = modadvapi32.NewProc("LookupPrivilegeValueW")
-       procOpenSCManagerW                = modadvapi32.NewProc("OpenSCManagerW")
-       procOpenServiceW                  = modadvapi32.NewProc("OpenServiceW")
-       procOpenThreadToken               = modadvapi32.NewProc("OpenThreadToken")
-       procQueryServiceStatus            = modadvapi32.NewProc("QueryServiceStatus")
-       procRevertToSelf                  = modadvapi32.NewProc("RevertToSelf")
-       procSetTokenInformation           = modadvapi32.NewProc("SetTokenInformation")
-       procSystemFunction036             = modadvapi32.NewProc("SystemFunction036")
-       procGetAdaptersAddresses          = modiphlpapi.NewProc("GetAdaptersAddresses")
-       procCreateEventW                  = modkernel32.NewProc("CreateEventW")
-       procGetACP                        = modkernel32.NewProc("GetACP")
-       procGetComputerNameExW            = modkernel32.NewProc("GetComputerNameExW")
-       procGetConsoleCP                  = modkernel32.NewProc("GetConsoleCP")
-       procGetCurrentThread              = modkernel32.NewProc("GetCurrentThread")
-       procGetFileInformationByHandleEx  = modkernel32.NewProc("GetFileInformationByHandleEx")
-       procGetFinalPathNameByHandleW     = modkernel32.NewProc("GetFinalPathNameByHandleW")
-       procGetModuleFileNameW            = modkernel32.NewProc("GetModuleFileNameW")
-       procGetTempPath2W                 = modkernel32.NewProc("GetTempPath2W")
-       procGetVolumeInformationByHandleW = modkernel32.NewProc("GetVolumeInformationByHandleW")
-       procLockFileEx                    = modkernel32.NewProc("LockFileEx")
-       procModule32FirstW                = modkernel32.NewProc("Module32FirstW")
-       procModule32NextW                 = modkernel32.NewProc("Module32NextW")
-       procMoveFileExW                   = modkernel32.NewProc("MoveFileExW")
-       procMultiByteToWideChar           = modkernel32.NewProc("MultiByteToWideChar")
-       procRtlLookupFunctionEntry        = modkernel32.NewProc("RtlLookupFunctionEntry")
-       procRtlVirtualUnwind              = modkernel32.NewProc("RtlVirtualUnwind")
-       procSetFileInformationByHandle    = modkernel32.NewProc("SetFileInformationByHandle")
-       procUnlockFileEx                  = modkernel32.NewProc("UnlockFileEx")
-       procVirtualQuery                  = modkernel32.NewProc("VirtualQuery")
-       procNetShareAdd                   = modnetapi32.NewProc("NetShareAdd")
-       procNetShareDel                   = modnetapi32.NewProc("NetShareDel")
-       procNetUserGetLocalGroups         = modnetapi32.NewProc("NetUserGetLocalGroups")
-       procGetProcessMemoryInfo          = modpsapi.NewProc("GetProcessMemoryInfo")
-       procCreateEnvironmentBlock        = moduserenv.NewProc("CreateEnvironmentBlock")
-       procDestroyEnvironmentBlock       = moduserenv.NewProc("DestroyEnvironmentBlock")
-       procGetProfilesDirectoryW         = moduserenv.NewProc("GetProfilesDirectoryW")
-       procWSASocketW                    = modws2_32.NewProc("WSASocketW")
+       procAdjustTokenPrivileges             = modadvapi32.NewProc("AdjustTokenPrivileges")
+       procDuplicateTokenEx                  = modadvapi32.NewProc("DuplicateTokenEx")
+       procImpersonateSelf                   = modadvapi32.NewProc("ImpersonateSelf")
+       procLookupPrivilegeValueW             = modadvapi32.NewProc("LookupPrivilegeValueW")
+       procOpenSCManagerW                    = modadvapi32.NewProc("OpenSCManagerW")
+       procOpenServiceW                      = modadvapi32.NewProc("OpenServiceW")
+       procOpenThreadToken                   = modadvapi32.NewProc("OpenThreadToken")
+       procQueryServiceStatus                = modadvapi32.NewProc("QueryServiceStatus")
+       procRevertToSelf                      = modadvapi32.NewProc("RevertToSelf")
+       procSetTokenInformation               = modadvapi32.NewProc("SetTokenInformation")
+       procSystemFunction036                 = modadvapi32.NewProc("SystemFunction036")
+       procGetAdaptersAddresses              = modiphlpapi.NewProc("GetAdaptersAddresses")
+       procCreateEventW                      = modkernel32.NewProc("CreateEventW")
+       procGetACP                            = modkernel32.NewProc("GetACP")
+       procGetComputerNameExW                = modkernel32.NewProc("GetComputerNameExW")
+       procGetConsoleCP                      = modkernel32.NewProc("GetConsoleCP")
+       procGetCurrentThread                  = modkernel32.NewProc("GetCurrentThread")
+       procGetFileInformationByHandleEx      = modkernel32.NewProc("GetFileInformationByHandleEx")
+       procGetFinalPathNameByHandleW         = modkernel32.NewProc("GetFinalPathNameByHandleW")
+       procGetModuleFileNameW                = modkernel32.NewProc("GetModuleFileNameW")
+       procGetTempPath2W                     = modkernel32.NewProc("GetTempPath2W")
+       procGetVolumeInformationByHandleW     = modkernel32.NewProc("GetVolumeInformationByHandleW")
+       procGetVolumeNameForVolumeMountPointW = modkernel32.NewProc("GetVolumeNameForVolumeMountPointW")
+       procLockFileEx                        = modkernel32.NewProc("LockFileEx")
+       procModule32FirstW                    = modkernel32.NewProc("Module32FirstW")
+       procModule32NextW                     = modkernel32.NewProc("Module32NextW")
+       procMoveFileExW                       = modkernel32.NewProc("MoveFileExW")
+       procMultiByteToWideChar               = modkernel32.NewProc("MultiByteToWideChar")
+       procRtlLookupFunctionEntry            = modkernel32.NewProc("RtlLookupFunctionEntry")
+       procRtlVirtualUnwind                  = modkernel32.NewProc("RtlVirtualUnwind")
+       procSetFileInformationByHandle        = modkernel32.NewProc("SetFileInformationByHandle")
+       procUnlockFileEx                      = modkernel32.NewProc("UnlockFileEx")
+       procVirtualQuery                      = modkernel32.NewProc("VirtualQuery")
+       procNetShareAdd                       = modnetapi32.NewProc("NetShareAdd")
+       procNetShareDel                       = modnetapi32.NewProc("NetShareDel")
+       procNetUserGetLocalGroups             = modnetapi32.NewProc("NetUserGetLocalGroups")
+       procGetProcessMemoryInfo              = modpsapi.NewProc("GetProcessMemoryInfo")
+       procCreateEnvironmentBlock            = moduserenv.NewProc("CreateEnvironmentBlock")
+       procDestroyEnvironmentBlock           = moduserenv.NewProc("DestroyEnvironmentBlock")
+       procGetProfilesDirectoryW             = moduserenv.NewProc("GetProfilesDirectoryW")
+       procWSASocketW                        = modws2_32.NewProc("WSASocketW")
 )
 
 func adjustTokenPrivileges(token syscall.Token, disableAllPrivileges bool, newstate *TOKEN_PRIVILEGES, buflen uint32, prevstate *TOKEN_PRIVILEGES, returnlen *uint32) (ret uint32, err error) {
@@ -279,6 +280,14 @@ func GetVolumeInformationByHandle(file syscall.Handle, volumeNameBuffer *uint16,
        return
 }
 
+func GetVolumeNameForVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16, bufferlength uint32) (err error) {
+       r1, _, e1 := syscall.Syscall(procGetVolumeNameForVolumeMountPointW.Addr(), 3, uintptr(unsafe.Pointer(volumeMountPoint)), uintptr(unsafe.Pointer(volumeName)), uintptr(bufferlength))
+       if r1 == 0 {
+               err = errnoErr(e1)
+       }
+       return
+}
+
 func LockFileEx(file syscall.Handle, flags uint32, reserved uint32, bytesLow uint32, bytesHigh uint32, overlapped *syscall.Overlapped) (err error) {
        r1, _, e1 := syscall.Syscall6(procLockFileEx.Addr(), 6, uintptr(file), uintptr(flags), uintptr(reserved), uintptr(bytesLow), uintptr(bytesHigh), uintptr(unsafe.Pointer(overlapped)))
        if r1 == 0 {
index df87887b9be1aa1301cbc5765f5d2e91ae5337f9..6ac4cbe20f78d29eb1ce4d98e669303160921205 100644 (file)
@@ -26,19 +26,25 @@ func MkdirAll(path string, perm FileMode) error {
        }
 
        // Slow path: make sure parent exists and then call Mkdir for path.
-       i := len(path)
-       for i > 0 && IsPathSeparator(path[i-1]) { // Skip trailing path separator.
+
+       // Extract the parent folder from path by first removing any trailing
+       // path separator and then scanning backward until finding a path
+       // separator or reaching the beginning of the string.
+       i := len(path) - 1
+       for i >= 0 && IsPathSeparator(path[i]) {
                i--
        }
-
-       j := i
-       for j > 0 && !IsPathSeparator(path[j-1]) { // Scan backward over element.
-               j--
+       for i >= 0 && !IsPathSeparator(path[i]) {
+               i--
+       }
+       if i < 0 {
+               i = 0
        }
 
-       if j > 1 {
-               // Create parent.
-               err = MkdirAll(fixRootDirectory(path[:j-1]), perm)
+       // If there is a parent directory, and it is not the volume name,
+       // recurse to ensure parent directory exists.
+       if parent := path[:i]; len(parent) > len(volumeName(path)) {
+               err = MkdirAll(parent, perm)
                if err != nil {
                        return err
                }
index a54b4b98f1220a97cdb54c99515a8d5685aa4abb..f1c9dbc048c1f5ff4f64f708f2473b71452fc807 100644 (file)
@@ -14,6 +14,6 @@ func IsPathSeparator(c uint8) bool {
        return PathSeparator == c
 }
 
-func fixRootDirectory(p string) string {
-       return p
+func volumeName(p string) string {
+       return ""
 }
index c975cdb11e00ac8ca942ff023d28e9e647dc97f0..1c80fa91f8be4c973f797806308b29fb3cb157f9 100644 (file)
@@ -70,6 +70,6 @@ func splitPath(path string) (string, string) {
        return dirname, basename
 }
 
-func fixRootDirectory(p string) string {
-       return p
+func volumeName(p string) string {
+       return ""
 }
index 3356908a3609ec9591758a76e7afa36633bea26b..ec9a87274de6d882d84efd88e2b4397bdd9f8a53 100644 (file)
@@ -214,14 +214,3 @@ func fixLongPath(path string) string {
        }
        return string(pathbuf[:w])
 }
-
-// fixRootDirectory fixes a reference to a drive's root directory to
-// have the required trailing slash.
-func fixRootDirectory(p string) string {
-       if len(p) == len(`\\?\c:`) {
-               if IsPathSeparator(p[0]) && IsPathSeparator(p[1]) && p[2] == '?' && IsPathSeparator(p[3]) && p[5] == ':' {
-                       return p + `\`
-               }
-       }
-       return p
-}
index 2506b4f0d8dca760a0c9ee868b82dca49eb21fe2..4e5e501d1f124cfc06bd3bc39c546576e3a68ec6 100644 (file)
@@ -5,7 +5,11 @@
 package os_test
 
 import (
+       "fmt"
+       "internal/syscall/windows"
+       "internal/testenv"
        "os"
+       "path/filepath"
        "strings"
        "syscall"
        "testing"
@@ -106,3 +110,48 @@ func TestOpenRootSlash(t *testing.T) {
                dir.Close()
        }
 }
+
+func testMkdirAllAtRoot(t *testing.T, root string) {
+       // Create a unique-enough directory name in root.
+       base := fmt.Sprintf("%s-%d", t.Name(), os.Getpid())
+       path := filepath.Join(root, base)
+       if err := os.MkdirAll(path, 0777); err != nil {
+               t.Fatalf("MkdirAll(%q) failed: %v", path, err)
+       }
+       // Clean up
+       if err := os.RemoveAll(path); err != nil {
+               t.Fatal(err)
+       }
+}
+
+func TestMkdirAllExtendedLengthAtRoot(t *testing.T) {
+       if testenv.Builder() == "" {
+               t.Skipf("skipping non-hermetic test outside of Go builders")
+       }
+
+       const prefix = `\\?\`
+       vol := filepath.VolumeName(t.TempDir()) + `\`
+       if len(vol) < 4 || vol[:4] != prefix {
+               vol = prefix + vol
+       }
+       testMkdirAllAtRoot(t, vol)
+}
+
+func TestMkdirAllVolumeNameAtRoot(t *testing.T) {
+       if testenv.Builder() == "" {
+               t.Skipf("skipping non-hermetic test outside of Go builders")
+       }
+
+       vol, err := syscall.UTF16PtrFromString(filepath.VolumeName(t.TempDir()) + `\`)
+       if err != nil {
+               t.Fatal(err)
+       }
+       const maxVolNameLen = 50
+       var buf [maxVolNameLen]uint16
+       err = windows.GetVolumeNameForVolumeMountPoint(vol, &buf[0], maxVolNameLen)
+       if err != nil {
+               t.Fatal(err)
+       }
+       volName := syscall.UTF16ToString(buf[:])
+       testMkdirAllAtRoot(t, volName)
+}