]> Cypherpunks repositories - gostls13.git/commitdiff
os/user: add LookupGroup, LookupGroupId, and User.GroupIds functions
authorRoss Light <light@google.com>
Thu, 4 Feb 2016 23:39:00 +0000 (15:39 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 7 Mar 2016 20:11:19 +0000 (20:11 +0000)
As part of local testing with a large group member list, I discovered
that the lookup functions don't resize their buffer if they receive
ERANGE.  I fixed this as a side-effect of this CL.

Thanks to @andrenth for the original CL.

Fixes #2617

Change-Id: Ie6aae2fe0a89eae5cce85786869a8acaa665ffe9
Reviewed-on: https://go-review.googlesource.com/19235
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/os/user/lookup.go
src/os/user/lookup_android.go [new file with mode: 0644]
src/os/user/lookup_plan9.go
src/os/user/lookup_stubs.go
src/os/user/lookup_unix.go
src/os/user/lookup_windows.go
src/os/user/user.go
src/os/user/user_test.go

index 09f00c7bdb4ce534e5167df1c3eb1f446cfd03d8..3b4421badd5d336bfebe265f2a1607ddf4337c52 100644 (file)
@@ -12,11 +12,28 @@ func Current() (*User, error) {
 // Lookup looks up a user by username. If the user cannot be found, the
 // returned error is of type UnknownUserError.
 func Lookup(username string) (*User, error) {
-       return lookup(username)
+       return lookupUser(username)
 }
 
 // LookupId looks up a user by userid. If the user cannot be found, the
 // returned error is of type UnknownUserIdError.
 func LookupId(uid string) (*User, error) {
-       return lookupId(uid)
+       return lookupUserId(uid)
+}
+
+// LookupGroup looks up a group by name. If the group cannot be found, the
+// returned error is of type UnknownGroupError.
+func LookupGroup(name string) (*Group, error) {
+       return lookupGroup(name)
+}
+
+// LookupGroupId looks up a group by groupid. If the group cannot be found, the
+// returned error is of type UnknownGroupIdError.
+func LookupGroupId(gid string) (*Group, error) {
+       return lookupGroupId(gid)
+}
+
+// GroupIds returns the list of group IDs that the user is a member of.
+func (u *User) GroupIds() ([]string, error) {
+       return listGroups(u)
 }
diff --git a/src/os/user/lookup_android.go b/src/os/user/lookup_android.go
new file mode 100644 (file)
index 0000000..b1be3dc
--- /dev/null
@@ -0,0 +1,38 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build android
+
+package user
+
+import "errors"
+
+func init() {
+       userImplemented = false
+       groupImplemented = false
+}
+
+func current() (*User, error) {
+       return nil, errors.New("user: Current not implemented on android")
+}
+
+func lookupUser(string) (*User, error) {
+       return nil, errors.New("user: Lookup not implemented on android")
+}
+
+func lookupUserId(string) (*User, error) {
+       return nil, errors.New("user: LookupId not implemented on android")
+}
+
+func lookupGroup(string) (*Group, error) {
+       return nil, errors.New("user: LookupGroup not implemented on android")
+}
+
+func lookupGroupId(string) (*Group, error) {
+       return nil, errors.New("user: LookupGroupId not implemented on android")
+}
+
+func listGroups(*User) ([]string, error) {
+       return nil, errors.New("user: GroupIds not implemented on android")
+}
index f7ef3482b7eada90cb797d0bcc9c141302a9397f..ea3ce0bc7cfa52b40fd3b6b131ff7a242fe4d8be 100644 (file)
@@ -18,6 +18,10 @@ const (
        userFile = "/dev/user"
 )
 
+func init() {
+       groupImplemented = false
+}
+
 func current() (*User, error) {
        ubytes, err := ioutil.ReadFile(userFile)
        if err != nil {
@@ -37,10 +41,22 @@ func current() (*User, error) {
        return u, nil
 }
 
-func lookup(username string) (*User, error) {
+func lookupUser(username string) (*User, error) {
+       return nil, syscall.EPLAN9
+}
+
+func lookupUserId(uid string) (*User, error) {
+       return nil, syscall.EPLAN9
+}
+
+func lookupGroup(groupname string) (*Group, error) {
+       return nil, syscall.EPLAN9
+}
+
+func lookupGroupId(string) (*Group, error) {
        return nil, syscall.EPLAN9
 }
 
-func lookupId(uid string) (*User, error) {
+func listGroups(*User) ([]string, error) {
        return nil, syscall.EPLAN9
 }
index 4fb0e3c6edd865f90dc81d141480c960a2e7cfbe..92391d7074e3ebfc6c77dc618124fa9e886a0ed5 100644 (file)
@@ -2,27 +2,37 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// +build !cgo,!windows,!plan9 android
+// +build !cgo,!windows,!plan9,!android
 
 package user
 
-import (
-       "fmt"
-       "runtime"
-)
+import "errors"
 
 func init() {
-       implemented = false
+       userImplemented = false
+       groupImplemented = false
 }
 
 func current() (*User, error) {
-       return nil, fmt.Errorf("user: Current not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
+       return nil, errors.New("user: Current requires cgo")
 }
 
-func lookup(username string) (*User, error) {
-       return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
+func lookupUser(username string) (*User, error) {
+       return nil, errors.New("user: Lookup requires cgo")
 }
 
-func lookupId(uid string) (*User, error) {
-       return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
+func lookupUserId(uid string) (*User, error) {
+       return nil, errors.New("user: LookupId requires cgo")
+}
+
+func lookupGroup(groupname string) (*Group, error) {
+       return nil, errors.New("user: LookupGroup requires cgo")
+}
+
+func lookupGroupId(string) (*Group, error) {
+       return nil, errors.New("user: LookupGroupId requires cgo")
+}
+
+func listGroups(*User) ([]string, error) {
+       return nil, errors.New("user: GroupIds requires cgo")
 }
index 87ad1e7427f1780d8f9ba873e5fa16185d29bbff..52d57c3884dbb53d960afa8c68897b3a7121e7d7 100644 (file)
@@ -20,6 +20,7 @@ import (
 #include <unistd.h>
 #include <sys/types.h>
 #include <pwd.h>
+#include <grp.h>
 #include <stdlib.h>
 
 static int mygetpwuid_r(int uid, struct passwd *pwd,
@@ -31,76 +32,119 @@ static int mygetpwnam_r(const char *name, struct passwd *pwd,
        char *buf, size_t buflen, struct passwd **result) {
        return getpwnam_r(name, pwd, buf, buflen, result);
 }
+
+static int mygetgrgid_r(int gid, struct group *grp,
+       char *buf, size_t buflen, struct group **result) {
+ return getgrgid_r(gid, grp, buf, buflen, result);
+}
+
+static int mygetgrouplist(const char *user, gid_t group, gid_t *groups,
+       int *ngroups) {
+ return getgrouplist(user, group, groups, ngroups);
+}
 */
 import "C"
 
 func current() (*User, error) {
-       return lookupUnix(syscall.Getuid(), "", false)
+       return lookupUnixUid(syscall.Getuid())
 }
 
-func lookup(username string) (*User, error) {
-       return lookupUnix(-1, username, true)
+func lookupUser(username string) (*User, error) {
+       var pwd C.struct_passwd
+       var result *C.struct_passwd
+       nameC := C.CString(username)
+       defer C.free(unsafe.Pointer(nameC))
+
+       buf := alloc(userBuffer)
+       defer buf.free()
+
+       err := retryWithBuffer(buf, func() syscall.Errno {
+               // mygetpwnam_r is a wrapper around getpwnam_r to avoid
+               // passing a size_t to getpwnam_r, because for unknown
+               // reasons passing a size_t to getpwnam_r doesn't work on
+               // Solaris.
+               return syscall.Errno(C.mygetpwnam_r(nameC,
+                       &pwd,
+                       (*C.char)(buf.ptr),
+                       C.size_t(buf.size),
+                       &result))
+       })
+       if err != nil {
+               return nil, fmt.Errorf("user: lookup username %s: %v", username, err)
+       }
+       if result == nil {
+               return nil, UnknownUserError(username)
+       }
+       return buildUser(&pwd), err
 }
 
-func lookupId(uid string) (*User, error) {
+func lookupUserId(uid string) (*User, error) {
        i, e := strconv.Atoi(uid)
        if e != nil {
                return nil, e
        }
-       return lookupUnix(i, "", false)
+       return lookupUnixUid(i)
 }
 
-func lookupUnix(uid int, username string, lookupByName bool) (*User, error) {
+func lookupUnixUid(uid int) (*User, error) {
        var pwd C.struct_passwd
        var result *C.struct_passwd
 
-       bufSize := C.sysconf(C._SC_GETPW_R_SIZE_MAX)
-       if bufSize == -1 {
-               // DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
-               // Additionally, not all Linux systems have it, either. For
-               // example, the musl libc returns -1.
-               bufSize = 1024
-       }
-       if bufSize <= 0 || bufSize > 1<<20 {
-               return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize)
-       }
-       buf := C.malloc(C.size_t(bufSize))
-       defer C.free(buf)
-       var rv C.int
-       if lookupByName {
-               nameC := C.CString(username)
-               defer C.free(unsafe.Pointer(nameC))
-               // mygetpwnam_r is a wrapper around getpwnam_r to avoid
-               // passing a size_t to getpwnam_r, because for unknown
-               // reasons passing a size_t to getpwnam_r doesn't work on
-               // Solaris.
-               rv = C.mygetpwnam_r(nameC,
-                       &pwd,
-                       (*C.char)(buf),
-                       C.size_t(bufSize),
-                       &result)
-               if rv != 0 {
-                       return nil, fmt.Errorf("user: lookup username %s: %s", username, syscall.Errno(rv))
-               }
-               if result == nil {
-                       return nil, UnknownUserError(username)
-               }
-       } else {
+       buf := alloc(userBuffer)
+       defer buf.free()
+
+       err := retryWithBuffer(buf, func() syscall.Errno {
                // mygetpwuid_r is a wrapper around getpwuid_r to
                // to avoid using uid_t because C.uid_t(uid) for
                // unknown reasons doesn't work on linux.
-               rv = C.mygetpwuid_r(C.int(uid),
+               return syscall.Errno(C.mygetpwuid_r(C.int(uid),
                        &pwd,
-                       (*C.char)(buf),
-                       C.size_t(bufSize),
-                       &result)
-               if rv != 0 {
-                       return nil, fmt.Errorf("user: lookup userid %d: %s", uid, syscall.Errno(rv))
+                       (*C.char)(buf.ptr),
+                       C.size_t(buf.size),
+                       &result))
+       })
+       if err != nil {
+               return nil, fmt.Errorf("user: lookup userid %d: %v", uid, err)
+       }
+       if result == nil {
+               return nil, UnknownUserIdError(uid)
+       }
+       return buildUser(&pwd), nil
+}
+
+func listGroups(u *User) ([]string, error) {
+       ug, err := strconv.Atoi(u.Gid)
+       if err != nil {
+               return nil, fmt.Errorf("user: list groups for %s: invalid gid %q", u.Username, u.Gid)
+       }
+       userGID := C.gid_t(ug)
+       nameC := C.CString(u.Username)
+       defer C.free(unsafe.Pointer(nameC))
+
+       n := C.int(256)
+       gidsC := make([]C.gid_t, n)
+       rv := C.mygetgrouplist(nameC, userGID, &gidsC[0], &n)
+       if rv == -1 {
+               // More than initial buffer, but now n contains the correct size.
+               const maxGroups = 2048
+               if n > maxGroups {
+                       return nil, fmt.Errorf("user: list groups for %s: member of more than %d groups", u.Username, maxGroups)
                }
-               if result == nil {
-                       return nil, UnknownUserIdError(uid)
+               gidsC = make([]C.gid_t, n)
+               rv := C.mygetgrouplist(nameC, userGID, &gidsC[0], &n)
+               if rv == -1 {
+                       return nil, fmt.Errorf("user: list groups for %s failed (changed groups?)", u.Username)
                }
        }
+       gidsC = gidsC[:n]
+       gids := make([]string, 0, n)
+       for _, g := range gidsC[:n] {
+               gids = append(gids, strconv.Itoa(int(g)))
+       }
+       return gids, nil
+}
+
+func buildUser(pwd *C.struct_passwd) *User {
        u := &User{
                Uid:      strconv.Itoa(int(pwd.pw_uid)),
                Gid:      strconv.Itoa(int(pwd.pw_gid)),
@@ -115,5 +159,145 @@ func lookupUnix(uid int, username string, lookupByName bool) (*User, error) {
        if i := strings.Index(u.Name, ","); i >= 0 {
                u.Name = u.Name[:i]
        }
-       return u, nil
+       return u
+}
+
+func currentGroup() (*Group, error) {
+       return lookupUnixGid(syscall.Getgid())
+}
+
+func lookupGroup(groupname string) (*Group, error) {
+       var grp C.struct_group
+       var result *C.struct_group
+
+       buf := alloc(groupBuffer)
+       defer buf.free()
+       cname := C.CString(groupname)
+       defer C.free(unsafe.Pointer(cname))
+
+       err := retryWithBuffer(buf, func() syscall.Errno {
+               return syscall.Errno(C.getgrnam_r(cname,
+                       &grp,
+                       (*C.char)(buf.ptr),
+                       C.size_t(buf.size),
+                       &result))
+       })
+       if err != nil {
+               return nil, fmt.Errorf("user: lookup groupname %s: %v", groupname, err)
+       }
+       if result == nil {
+               return nil, UnknownGroupError(groupname)
+       }
+       return buildGroup(&grp), nil
+}
+
+func lookupGroupId(gid string) (*Group, error) {
+       i, e := strconv.Atoi(gid)
+       if e != nil {
+               return nil, e
+       }
+       return lookupUnixGid(i)
+}
+
+func lookupUnixGid(gid int) (*Group, error) {
+       var grp C.struct_group
+       var result *C.struct_group
+
+       buf := alloc(groupBuffer)
+       defer buf.free()
+
+       err := retryWithBuffer(buf, func() syscall.Errno {
+               // mygetgrgid_r is a wrapper around getgrgid_r to
+               // to avoid using gid_t because C.gid_t(gid) for
+               // unknown reasons doesn't work on linux.
+               return syscall.Errno(C.mygetgrgid_r(C.int(gid),
+                       &grp,
+                       (*C.char)(buf.ptr),
+                       C.size_t(buf.size),
+                       &result))
+       })
+       if err != nil {
+               return nil, fmt.Errorf("user: lookup groupid %d: %v", gid, err)
+       }
+       if result == nil {
+               return nil, UnknownGroupIdError(gid)
+       }
+       return buildGroup(&grp), nil
+}
+
+func buildGroup(grp *C.struct_group) *Group {
+       g := &Group{
+               Gid:  strconv.Itoa(int(grp.gr_gid)),
+               Name: C.GoString(grp.gr_name),
+       }
+       return g
+}
+
+type bufferKind C.int
+
+const (
+       userBuffer  = bufferKind(C._SC_GETPW_R_SIZE_MAX)
+       groupBuffer = bufferKind(C._SC_GETGR_R_SIZE_MAX)
+)
+
+func (k bufferKind) initialSize() C.size_t {
+       sz := C.sysconf(C.int(k))
+       if sz == -1 {
+               // DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
+               // Additionally, not all Linux systems have it, either. For
+               // example, the musl libc returns -1.
+               return 1024
+       }
+       if !isSizeReasonable(int64(sz)) {
+               // Truncate.  If this truly isn't enough, retryWithBuffer will error on the first run.
+               return maxBufferSize
+       }
+       return C.size_t(sz)
+}
+
+type memBuffer struct {
+       ptr  unsafe.Pointer
+       size C.size_t
+}
+
+func alloc(kind bufferKind) *memBuffer {
+       sz := kind.initialSize()
+       return &memBuffer{
+               ptr:  C.malloc(sz),
+               size: sz,
+       }
+}
+
+func (mb *memBuffer) resize(newSize C.size_t) {
+       mb.ptr = C.realloc(mb.ptr, newSize)
+       mb.size = newSize
+}
+
+func (mb *memBuffer) free() {
+       C.free(mb.ptr)
+}
+
+// retryWithBuffer repeatedly calls f(), increasing the size of the
+// buffer each time, until f succeeds, fails with a non-ERANGE error,
+// or the buffer exceeds a reasonable limit.
+func retryWithBuffer(buf *memBuffer, f func() syscall.Errno) error {
+       for {
+               errno := f()
+               if errno == 0 {
+                       return nil
+               } else if errno != syscall.ERANGE {
+                       return errno
+               }
+               newSize := buf.size * 2
+               if !isSizeReasonable(int64(newSize)) {
+                       return fmt.Errorf("internal buffer exceeds %d bytes", maxBufferSize)
+               }
+               buf.resize(newSize)
+       }
+}
+
+const maxBufferSize = 1 << 20
+
+func isSizeReasonable(sz int64) bool {
+       return sz > 0 && sz <= maxBufferSize
 }
index 9fb3c5546f27fe10e4fd17532c3548f55ec2e8b0..4e36a5c2bfabf0c271434e26b3518275b5eaf4de 100644 (file)
@@ -5,11 +5,16 @@
 package user
 
 import (
+       "errors"
        "fmt"
        "syscall"
        "unsafe"
 )
 
+func init() {
+       groupImplemented = false
+}
+
 func isDomainJoined() (bool, error) {
        var domain *uint16
        var status uint32
@@ -129,7 +134,7 @@ func newUserFromSid(usid *syscall.SID) (*User, error) {
        return newUser(usid, gid, dir)
 }
 
-func lookup(username string) (*User, error) {
+func lookupUser(username string) (*User, error) {
        sid, _, t, e := syscall.LookupSID("", username)
        if e != nil {
                return nil, e
@@ -140,10 +145,22 @@ func lookup(username string) (*User, error) {
        return newUserFromSid(sid)
 }
 
-func lookupId(uid string) (*User, error) {
+func lookupUserId(uid string) (*User, error) {
        sid, e := syscall.StringToSid(uid)
        if e != nil {
                return nil, e
        }
        return newUserFromSid(sid)
 }
+
+func lookupGroup(groupname string) (*Group, error) {
+       return nil, errors.New("user: LookupGroup not implemented on windows")
+}
+
+func lookupGroupId(string) (*Group, error) {
+       return nil, errors.New("user: LookupGroupId not implemented on windows")
+}
+
+func listGroups(*User) ([]string, error) {
+       return nil, errors.New("user: GroupIds not implemented on windows")
+}
index e8680fe5465541ab5ab205ac4db9ad57d419c056..7b44397afb324c54679150f2eca3f060655ff58a 100644 (file)
@@ -9,23 +9,35 @@ import (
        "strconv"
 )
 
-var implemented = true // set to false by lookup_stubs.go's init
+var (
+       userImplemented  = true // set to false by lookup_stubs.go's init
+       groupImplemented = true // set to false by lookup_stubs.go's init
+)
 
 // User represents a user account.
 //
-// On posix systems Uid and Gid contain a decimal number
+// On POSIX systems Uid and Gid contain a decimal number
 // representing uid and gid. On windows Uid and Gid
 // contain security identifier (SID) in a string format.
 // On Plan 9, Uid, Gid, Username, and Name will be the
 // contents of /dev/user.
 type User struct {
-       Uid      string // user id
-       Gid      string // primary group id
+       Uid      string // user ID
+       Gid      string // primary group ID
        Username string
        Name     string
        HomeDir  string
 }
 
+// Group represents a grouping of users.
+//
+// On POSIX systems Gid contains a decimal number
+// representing the group ID.
+type Group struct {
+       Gid  string // group ID
+       Name string // group name
+}
+
 // UnknownUserIdError is returned by LookupId when
 // a user cannot be found.
 type UnknownUserIdError int
@@ -41,3 +53,19 @@ type UnknownUserError string
 func (e UnknownUserError) Error() string {
        return "user: unknown user " + string(e)
 }
+
+// UnknownGroupIdError is returned by LookupGroupId when
+// a group cannot be found.
+type UnknownGroupIdError string
+
+func (e UnknownGroupIdError) Error() string {
+       return "group: unknown groupid " + string(e)
+}
+
+// UnknownGroupError is returned by LookupGroup when
+// a group cannot be found.
+type UnknownGroupError string
+
+func (e UnknownGroupError) Error() string {
+       return "group: unknown group " + string(e)
+}
index 9d9420e8090e3fba1020a81deca15b807af3af61..98f7e410a6180f4e322647344f3c90b5287e5354 100644 (file)
@@ -9,14 +9,14 @@ import (
        "testing"
 )
 
-func check(t *testing.T) {
-       if !implemented {
+func checkUser(t *testing.T) {
+       if !userImplemented {
                t.Skip("user: not implemented; skipping tests")
        }
 }
 
 func TestCurrent(t *testing.T) {
-       check(t)
+       checkUser(t)
 
        u, err := Current()
        if err != nil {
@@ -53,7 +53,7 @@ func compare(t *testing.T, want, got *User) {
 }
 
 func TestLookup(t *testing.T) {
-       check(t)
+       checkUser(t)
 
        if runtime.GOOS == "plan9" {
                t.Skipf("Lookup not implemented on %q", runtime.GOOS)
@@ -71,7 +71,7 @@ func TestLookup(t *testing.T) {
 }
 
 func TestLookupId(t *testing.T) {
-       check(t)
+       checkUser(t)
 
        if runtime.GOOS == "plan9" {
                t.Skipf("LookupId not implemented on %q", runtime.GOOS)
@@ -87,3 +87,57 @@ func TestLookupId(t *testing.T) {
        }
        compare(t, want, got)
 }
+
+func checkGroup(t *testing.T) {
+       if !groupImplemented {
+               t.Skip("user: group not implemented; skipping test")
+       }
+}
+
+func TestLookupGroup(t *testing.T) {
+       checkGroup(t)
+       user, err := Current()
+       if err != nil {
+               t.Fatalf("Current(): %v", err)
+       }
+
+       g1, err := LookupGroupId(user.Gid)
+       if err != nil {
+               t.Fatalf("LookupGroupId(%q): %v", user.Gid, err)
+       }
+       if g1.Gid != user.Gid {
+               t.Errorf("LookupGroupId(%q).Gid = %s; want %s", user.Gid, g1.Gid, user.Gid)
+       }
+
+       g2, err := LookupGroup(g1.Name)
+       if err != nil {
+               t.Fatalf("LookupGroup(%q): %v", g1.Name, err)
+       }
+       if g1.Gid != g2.Gid || g1.Name != g2.Name {
+               t.Errorf("LookupGroup(%q) = %+v; want %+v", g1.Name, g2, g1)
+       }
+}
+
+func TestGroupIds(t *testing.T) {
+       checkGroup(t)
+       user, err := Current()
+       if err != nil {
+               t.Fatalf("Current(): %v", err)
+       }
+       gids, err := user.GroupIds()
+       if err != nil {
+               t.Fatalf("%+v.GroupIds(): %v", user, err)
+       }
+       if !containsID(gids, user.Gid) {
+               t.Errorf("%+v.GroupIds() = %v; does not contain user GID %s", user, gids, user.Gid)
+       }
+}
+
+func containsID(ids []string, id string) bool {
+       for _, x := range ids {
+               if x == id {
+                       return true
+               }
+       }
+       return false
+}