]> Cypherpunks repositories - gostls13.git/commitdiff
os/user: fix Current().GroupIds() for AD joined users on Windows
authorqmuntal <quimmuntal@gmail.com>
Fri, 6 Sep 2024 13:51:44 +0000 (15:51 +0200)
committerQuim Muntal <quimmuntal@gmail.com>
Mon, 16 Sep 2024 20:48:56 +0000 (20:48 +0000)
This CL special-case User.GroupIds to get the group IDs from the user's
token when the user is the current user.

This approach is more efficient than calling NetUserGetLocalGroups.
It is also more reliable for users joined to an Active Directory domain,
where NetUserGetLocalGroups is likely to fail.

Updates #26041.
Fixes #62712.

Cq-Include-Trybots: luci.golang.try:gotip-windows-arm64
Change-Id: If7c30287192872077b98a514bd6346dbd1a64fb4
Reviewed-on: https://go-review.googlesource.com/c/go/+/611116
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
Reviewed-by: Tim King <taking@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/internal/syscall/windows/security_windows.go
src/os/user/lookup_windows.go

index e528744caad1bd85f33937d811b1542e6cca72d4..aed04c61c49269e8cc5710a245c7bac27211a52c 100644 (file)
@@ -175,3 +175,38 @@ func GetUserName(format uint32) (string, error) {
                }
        }
 }
+
+// getTokenInfo retrieves a specified type of information about an access token.
+func getTokenInfo(t syscall.Token, class uint32, initSize int) (unsafe.Pointer, error) {
+       n := uint32(initSize)
+       for {
+               b := make([]byte, n)
+               e := syscall.GetTokenInformation(t, class, &b[0], uint32(len(b)), &n)
+               if e == nil {
+                       return unsafe.Pointer(&b[0]), nil
+               }
+               if e != syscall.ERROR_INSUFFICIENT_BUFFER {
+                       return nil, e
+               }
+               if n <= uint32(len(b)) {
+                       return nil, e
+               }
+       }
+}
+
+type TOKEN_GROUPS struct {
+       GroupCount uint32
+       Groups     [1]SID_AND_ATTRIBUTES
+}
+
+func (g *TOKEN_GROUPS) AllGroups() []SID_AND_ATTRIBUTES {
+       return (*[(1 << 28) - 1]SID_AND_ATTRIBUTES)(unsafe.Pointer(&g.Groups[0]))[:g.GroupCount:g.GroupCount]
+}
+
+func GetTokenGroups(t syscall.Token) (*TOKEN_GROUPS, error) {
+       i, e := getTokenInfo(t, syscall.TokenGroups, 50)
+       if e != nil {
+               return nil, e
+       }
+       return (*TOKEN_GROUPS)(i), nil
+}
index 804fc64cc1bda9b749c20807248def6e71ad0a7d..5d990600657f43201d638daff46d08b50522233a 100644 (file)
@@ -434,17 +434,45 @@ func lookupGroupId(gid string) (*Group, error) {
 }
 
 func listGroups(user *User) ([]string, error) {
-       sid, err := syscall.StringToSid(user.Uid)
-       if err != nil {
-               return nil, err
-       }
-       username, domain, err := lookupUsernameAndDomain(sid)
-       if err != nil {
-               return nil, err
-       }
-       sids, err := listGroupsForUsernameAndDomain(username, domain)
-       if err != nil {
-               return nil, err
+       var sids []string
+       if u, err := Current(); err == nil && u.Uid == user.Uid {
+               // It is faster and more reliable to get the groups
+               // of the current user from the current process token.
+               err := runAsProcessOwner(func() error {
+                       t, err := syscall.OpenCurrentProcessToken()
+                       if err != nil {
+                               return err
+                       }
+                       defer t.Close()
+                       groups, err := windows.GetTokenGroups(t)
+                       if err != nil {
+                               return err
+                       }
+                       for _, g := range groups.AllGroups() {
+                               sid, err := g.Sid.String()
+                               if err != nil {
+                                       return err
+                               }
+                               sids = append(sids, sid)
+                       }
+                       return nil
+               })
+               if err != nil {
+                       return nil, err
+               }
+       } else {
+               sid, err := syscall.StringToSid(user.Uid)
+               if err != nil {
+                       return nil, err
+               }
+               username, domain, err := lookupUsernameAndDomain(sid)
+               if err != nil {
+                       return nil, err
+               }
+               sids, err = listGroupsForUsernameAndDomain(username, domain)
+               if err != nil {
+                       return nil, err
+               }
        }
        // Add the primary group of the user to the list if it is not already there.
        // This is done only to comply with the POSIX concept of a primary group.