]> Cypherpunks repositories - gostls13.git/commitdiff
os/user: faster user lookup on Windows
authorAlexey Borzenkov <snaury@gmail.com>
Wed, 15 May 2013 03:24:54 +0000 (13:24 +1000)
committerAlex Brainman <alex.brainman@gmail.com>
Wed, 15 May 2013 03:24:54 +0000 (13:24 +1000)
Trying to lookup user's display name with directory services can
take several seconds when user's computer is not in a domain.
As a workaround, check if computer is joined in a domain first,
and don't use directory services if it is not.
Additionally, don't leak tokens in user.Current().
Fixes #5298.

R=golang-dev, bradfitz, alex.brainman, lucio.dere
CC=golang-dev
https://golang.org/cl/8541047

src/pkg/os/user/lookup_windows.go
src/pkg/syscall/security_windows.go
src/pkg/syscall/zsyscall_windows_386.go
src/pkg/syscall/zsyscall_windows_amd64.go

index a0a8a4ec10f3fcf4c1c37b323a83cc45a58eaa43..99c325ff010c3dc18c7b828e88b094b308c4c40c 100644 (file)
@@ -10,37 +10,63 @@ import (
        "unsafe"
 )
 
-func lookupFullName(domain, username, domainAndUser string) (string, error) {
-       // try domain controller first
-       name, e := syscall.TranslateAccountName(domainAndUser,
+func isDomainJoined() (bool, error) {
+       var domain *uint16
+       var status uint32
+       err := syscall.NetGetJoinInformation(nil, &domain, &status)
+       if err != nil {
+               return false, err
+       }
+       syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
+       return status == syscall.NetSetupDomainName, nil
+}
+
+func lookupFullNameDomain(domainAndUser string) (string, error) {
+       return syscall.TranslateAccountName(domainAndUser,
                syscall.NameSamCompatible, syscall.NameDisplay, 50)
+}
+
+func lookupFullNameServer(servername, username string) (string, error) {
+       s, e := syscall.UTF16PtrFromString(servername)
        if e != nil {
-               // domain lookup failed, perhaps this pc is not part of domain
-               d, e := syscall.UTF16PtrFromString(domain)
-               if e != nil {
-                       return "", e
-               }
-               u, e := syscall.UTF16PtrFromString(username)
-               if e != nil {
-                       return "", e
-               }
-               var p *byte
-               e = syscall.NetUserGetInfo(d, u, 10, &p)
-               if e != nil {
-                       // path executed when a domain user is disconnected from the domain
-                       // pretend username is fullname
-                       return username, nil
-               }
-               defer syscall.NetApiBufferFree(p)
-               i := (*syscall.UserInfo10)(unsafe.Pointer(p))
-               if i.FullName == nil {
-                       return "", nil
-               }
-               name = syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:])
+               return "", e
        }
+       u, e := syscall.UTF16PtrFromString(username)
+       if e != nil {
+               return "", e
+       }
+       var p *byte
+       e = syscall.NetUserGetInfo(s, u, 10, &p)
+       if e != nil {
+               return "", e
+       }
+       defer syscall.NetApiBufferFree(p)
+       i := (*syscall.UserInfo10)(unsafe.Pointer(p))
+       if i.FullName == nil {
+               return "", nil
+       }
+       name := syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:])
        return name, nil
 }
 
+func lookupFullName(domain, username, domainAndUser string) (string, error) {
+       joined, err := isDomainJoined()
+       if err == nil && joined {
+               name, err := lookupFullNameDomain(domainAndUser)
+               if err == nil {
+                       return name, nil
+               }
+       }
+       name, err := lookupFullNameServer(domain, username)
+       if err == nil {
+               return name, nil
+       }
+       // domain worked neigher as a domain nor as a server
+       // could be domain server unavailable
+       // pretend username is fullname
+       return username, nil
+}
+
 func newUser(usid *syscall.SID, gid, dir string) (*User, error) {
        username, domain, t, e := usid.LookupAccount("")
        if e != nil {
@@ -73,6 +99,7 @@ func current() (*User, error) {
        if e != nil {
                return nil, e
        }
+       defer t.Close()
        u, e := t.GetTokenUser()
        if e != nil {
                return nil, e
index 017b270146f04aa9e592cb0ae3a15d45e6b51e6a..b22ecf578e0a08b779ea3e34a19a5c8002ffe055 100644 (file)
@@ -58,6 +58,14 @@ func TranslateAccountName(username string, from, to uint32, initSize int) (strin
        return UTF16ToString(b), nil
 }
 
+const (
+       // do not reorder
+       NetSetupUnknownStatus = iota
+       NetSetupUnjoined
+       NetSetupWorkgroupName
+       NetSetupDomainName
+)
+
 type UserInfo10 struct {
        Name       *uint16
        Comment    *uint16
@@ -66,6 +74,7 @@ type UserInfo10 struct {
 }
 
 //sys  NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) = netapi32.NetUserGetInfo
+//sys  NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) = netapi32.NetGetJoinInformation
 //sys  NetApiBufferFree(buf *byte) (neterr error) = netapi32.NetApiBufferFree
 
 const (
index e5c48488baf8ed92aa3e436ab2e8d3eefec39696..838812a62086926ff535ace8063314fb6b761df1 100644 (file)
@@ -140,6 +140,7 @@ var (
        procTranslateNameW                   = modsecur32.NewProc("TranslateNameW")
        procGetUserNameExW                   = modsecur32.NewProc("GetUserNameExW")
        procNetUserGetInfo                   = modnetapi32.NewProc("NetUserGetInfo")
+       procNetGetJoinInformation            = modnetapi32.NewProc("NetGetJoinInformation")
        procNetApiBufferFree                 = modnetapi32.NewProc("NetApiBufferFree")
        procLookupAccountSidW                = modadvapi32.NewProc("LookupAccountSidW")
        procLookupAccountNameW               = modadvapi32.NewProc("LookupAccountNameW")
@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by
        return
 }
 
+func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) {
+       r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType)))
+       if r0 != 0 {
+               neterr = Errno(r0)
+       }
+       return
+}
+
 func NetApiBufferFree(buf *byte) (neterr error) {
        r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0)
        if r0 != 0 {
index 465b509ae7ae9eae509f68be54ba4ae39e56dbb8..1b403be495591f347e751da2fead46301e22db95 100644 (file)
@@ -140,6 +140,7 @@ var (
        procTranslateNameW                   = modsecur32.NewProc("TranslateNameW")
        procGetUserNameExW                   = modsecur32.NewProc("GetUserNameExW")
        procNetUserGetInfo                   = modnetapi32.NewProc("NetUserGetInfo")
+       procNetGetJoinInformation            = modnetapi32.NewProc("NetGetJoinInformation")
        procNetApiBufferFree                 = modnetapi32.NewProc("NetApiBufferFree")
        procLookupAccountSidW                = modadvapi32.NewProc("LookupAccountSidW")
        procLookupAccountNameW               = modadvapi32.NewProc("LookupAccountNameW")
@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by
        return
 }
 
+func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) {
+       r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType)))
+       if r0 != 0 {
+               neterr = Errno(r0)
+       }
+       return
+}
+
 func NetApiBufferFree(buf *byte) (neterr error) {
        r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0)
        if r0 != 0 {