From 1596d71255f3c19a27600e751592079dae46bf40 Mon Sep 17 00:00:00 2001 From: Ananth Bhaskararaman Date: Wed, 22 Mar 2023 05:37:43 +0000 Subject: [PATCH] os/user: lookup Linux users and groups via systemd userdb Fetch usernames and groups via systemd userdb if available. Otherwise fall back to parsing /etc/passwd, etc. Fixes #38810 Co-authored-by: Michael Stapelberg Change-Id: Iff6ffc54feec6b6cec241b89e362c2285c8c0454 GitHub-Last-Rev: 1a627cc9a18063f5d274bb96113947cd4d952e5a GitHub-Pull-Request: golang/go#57458 Reviewed-on: https://go-review.googlesource.com/c/go/+/459455 TryBot-Result: Gopher Robot Run-TryBot: Ian Lance Taylor Reviewed-by: Ian Lance Taylor Reviewed-by: Heschi Kreinick Auto-Submit: Ian Lance Taylor --- src/os/user/listgroups_unix.go | 9 + src/os/user/lookup_unix.go | 30 + src/os/user/user.go | 4 + src/os/user/userdbclient.go | 22 + src/os/user/userdbclient_linux.go | 772 +++++++++++++++++++++++++ src/os/user/userdbclient_linux_test.go | 504 ++++++++++++++++ src/os/user/userdbclient_stub.go | 29 + 7 files changed, 1370 insertions(+) create mode 100644 src/os/user/userdbclient.go create mode 100644 src/os/user/userdbclient_linux.go create mode 100644 src/os/user/userdbclient_linux_test.go create mode 100644 src/os/user/userdbclient_stub.go diff --git a/src/os/user/listgroups_unix.go b/src/os/user/listgroups_unix.go index ef366fa280..b620ad3652 100644 --- a/src/os/user/listgroups_unix.go +++ b/src/os/user/listgroups_unix.go @@ -9,11 +9,13 @@ package user import ( "bufio" "bytes" + "context" "errors" "fmt" "io" "os" "strconv" + "time" ) func listGroupsFromReader(u *User, r io.Reader) ([]string, error) { @@ -99,6 +101,13 @@ func listGroupsFromReader(u *User, r io.Reader) ([]string, error) { } func listGroups(u *User) ([]string, error) { + if defaultUserdbClient.isUsable() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + if ids, ok, err := defaultUserdbClient.lookupGroupIds(ctx, u.Username); ok { + return ids, err + } + } f, err := os.Open(groupFile) if err != nil { return nil, err diff --git a/src/os/user/lookup_unix.go b/src/os/user/lookup_unix.go index 608d9b2140..0ee2ad35ef 100644 --- a/src/os/user/lookup_unix.go +++ b/src/os/user/lookup_unix.go @@ -9,11 +9,13 @@ package user import ( "bufio" "bytes" + "context" "errors" "io" "os" "strconv" "strings" + "time" ) // lineFunc returns a value, an error, or (nil, nil) to skip the row. @@ -198,6 +200,13 @@ func findUsername(name string, r io.Reader) (*User, error) { } func lookupGroup(groupname string) (*Group, error) { + if defaultUserdbClient.isUsable() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if g, ok, err := defaultUserdbClient.lookupGroup(ctx, groupname); ok { + return g, err + } + } f, err := os.Open(groupFile) if err != nil { return nil, err @@ -207,6 +216,13 @@ func lookupGroup(groupname string) (*Group, error) { } func lookupGroupId(id string) (*Group, error) { + if defaultUserdbClient.isUsable() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if g, ok, err := defaultUserdbClient.lookupGroupId(ctx, id); ok { + return g, err + } + } f, err := os.Open(groupFile) if err != nil { return nil, err @@ -216,6 +232,13 @@ func lookupGroupId(id string) (*Group, error) { } func lookupUser(username string) (*User, error) { + if defaultUserdbClient.isUsable() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if u, ok, err := defaultUserdbClient.lookupUser(ctx, username); ok { + return u, err + } + } f, err := os.Open(userFile) if err != nil { return nil, err @@ -225,6 +248,13 @@ func lookupUser(username string) (*User, error) { } func lookupUserId(uid string) (*User, error) { + if defaultUserdbClient.isUsable() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if u, ok, err := defaultUserdbClient.lookupUserId(ctx, uid); ok { + return u, err + } + } f, err := os.Open(userFile) if err != nil { return nil, err diff --git a/src/os/user/user.go b/src/os/user/user.go index 0307d2ad6a..4cf5b7c515 100644 --- a/src/os/user/user.go +++ b/src/os/user/user.go @@ -11,6 +11,10 @@ One is written in pure Go and parses /etc/passwd and /etc/group. The other is cgo-based and relies on the standard C library (libc) routines such as getpwuid_r, getgrnam_r, and getgrouplist. +For Linux, the pure Go implementation queries the systemd-userdb service first. +If the service is not available, it falls back to parsing /etc/passwd and +/etc/group. + When cgo is available, and the required routines are implemented in libc for a particular platform, cgo-based (libc-backed) code is used. This can be overridden by using osusergo build tag, which enforces diff --git a/src/os/user/userdbclient.go b/src/os/user/userdbclient.go new file mode 100644 index 0000000000..b0f3895ed4 --- /dev/null +++ b/src/os/user/userdbclient.go @@ -0,0 +1,22 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +// userdbClient queries the io.systemd.UserDatabase VARLINK interface provided by +// systemd-userdbd.service(8) on Linux for obtaining full user/group details +// even when cgo is not available. +// VARLINK protocol: https://varlink.org +// Systemd userdb VARLINK interface https://systemd.io/USER_GROUP_API +// dir contains multiple varlink service sockets implementing the userdb interface. +type userdbClient struct { + dir string +} + +// IsUsable checks if the client can be used to make queries. +func (cl userdbClient) isUsable() bool { + return len(cl.dir) != 0 +} + +var defaultUserdbClient userdbClient diff --git a/src/os/user/userdbclient_linux.go b/src/os/user/userdbclient_linux.go new file mode 100644 index 0000000000..e585b7f3c3 --- /dev/null +++ b/src/os/user/userdbclient_linux.go @@ -0,0 +1,772 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package user + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/fs" + "os" + "strconv" + "strings" + "sync" + "syscall" + "unicode/utf16" + "unicode/utf8" +) + +const ( + // Well known multiplexer service. + svcMultiplexer = "io.systemd.Multiplexer" + + userdbNamespace = "io.systemd.UserDatabase" + + // io.systemd.UserDatabase VARLINK interface methods. + mGetGroupRecord = userdbNamespace + ".GetGroupRecord" + mGetUserRecord = userdbNamespace + ".GetUserRecord" + mGetMemberships = userdbNamespace + ".GetMemberships" + + // io.systemd.UserDatabase VARLINK interface errors. + errNoRecordFound = userdbNamespace + ".NoRecordFound" + errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable" +) + +func init() { + defaultUserdbClient.dir = "/run/systemd/userdb" +} + +// userdbCall represents a VARLINK service call sent to systemd-userdb. +// method is the VARLINK method to call. +// parameters are the VARLINK parameters to pass. +// more indicates if more responses are expected. +// fastest indicates if only the fastest response should be returned. +type userdbCall struct { + method string + parameters callParameters + more bool + fastest bool +} + +func (u userdbCall) marshalJSON(service string) ([]byte, error) { + params, err := u.parameters.marshalJSON(service) + if err != nil { + return nil, err + } + var data bytes.Buffer + data.WriteString(`{"method":"`) + data.WriteString(u.method) + data.WriteString(`","parameters":`) + data.Write(params) + if u.more { + data.WriteString(`,"more":true`) + } + data.WriteString(`}`) + return data.Bytes(), nil +} + +type callParameters struct { + uid *int64 + userName string + gid *int64 + groupName string +} + +func (c callParameters) marshalJSON(service string) ([]byte, error) { + var data bytes.Buffer + data.WriteString(`{"service":"`) + data.WriteString(service) + data.WriteString(`"`) + if c.uid != nil { + data.WriteString(`,"uid":`) + data.WriteString(strconv.FormatInt(*c.uid, 10)) + } + if c.userName != "" { + data.WriteString(`,"userName":"`) + data.WriteString(c.userName) + data.WriteString(`"`) + } + if c.gid != nil { + data.WriteString(`,"gid":`) + data.WriteString(strconv.FormatInt(*c.gid, 10)) + } + if c.groupName != "" { + data.WriteString(`,"groupName":"`) + data.WriteString(c.groupName) + data.WriteString(`"`) + } + data.WriteString(`}`) + return data.Bytes(), nil +} + +type userdbReply struct { + continues bool + errorStr string +} + +func (u *userdbReply) unmarshalJSON(data []byte) error { + var ( + kContinues = []byte(`"continues"`) + kError = []byte(`"error"`) + ) + if i := bytes.Index(data, kContinues); i != -1 { + continues, err := parseJSONBoolean(data[i+len(kContinues):]) + if err != nil { + return err + } + u.continues = continues + } + if i := bytes.Index(data, kError); i != -1 { + errStr, err := parseJSONString(data[i+len(kError):]) + if err != nil { + return err + } + u.errorStr = errStr + } + return nil +} + +// response is the parsed reply from a method call to systemd-userdb. +// data is one or more VARLINK response parameters separated by 0. +// handled indicates if the call was handled by systemd-userdb. +// err is any error encountered. +type response struct { + data []byte + handled bool + err error +} + +// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request. +// Multiple replies can be read by setting more to true in the request. +// Reply parameters are accumulated separated by 0, if there are many. +// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped. +// Other UserDatabase errors are returned as is. +// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable +// error is seen in a response, the query is considered unhandled. +func querySocket(ctx context.Context, sock string, request []byte) response { + sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return response{err: err} + } + defer syscall.Close(sockFd) + if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil { + if errors.Is(err, os.ErrNotExist) { + return response{err: err} + } + return response{handled: true, err: err} + } + + // Null terminate request. + if request[len(request)-1] != 0 { + request = append(request, 0) + } + + // Write request to socket. + written := 0 + for written < len(request) { + if ctx.Err() != nil { + return response{handled: true, err: ctx.Err()} + } + if n, err := syscall.Write(sockFd, request[written:]); err != nil { + return response{handled: true, err: err} + } else { + written += n + } + } + + // Read response. + var resp bytes.Buffer + for { + if ctx.Err() != nil { + return response{handled: true, err: ctx.Err()} + } + buf := make([]byte, 4096) + if n, err := syscall.Read(sockFd, buf); err != nil { + return response{handled: true, err: err} + } else if n > 0 { + resp.Write(buf[:n]) + if buf[n-1] == 0 { + break + } + } else { + // EOF + break + } + } + + if resp.Len() == 0 { + return response{handled: true} + } + + buf := resp.Bytes() + // Remove trailing 0. + buf = buf[:len(buf)-1] + // Split into VARLINK messages. + msgs := bytes.Split(buf, []byte{0}) + + // Parse VARLINK messages. + for _, m := range msgs { + var resp userdbReply + if err := resp.unmarshalJSON(m); err != nil { + return response{handled: true, err: err} + } + // Handle VARLINK message errors. + switch e := resp.errorStr; e { + case "": + case errNoRecordFound: // Ignore not found error. + continue + case errServiceNotAvailable: + return response{} + default: + return response{handled: true, err: errors.New(e)} + } + if !resp.continues { + break + } + } + return response{data: buf, handled: true, err: ctx.Err()} +} + +// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once. +// ss is a slice of userdb services to call. Each service must have a socket in cl.dir. +// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read. +// Otherwise all replies are aggregated. um is called with aggregated reply parameters. +// queryMany returns the first error encountered. The first result is false if no userdb +// socket is available or if all requests time out. +func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) { + responseCh := make(chan response, len(ss)) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Query all services in parallel. + var workers sync.WaitGroup + for _, svc := range ss { + data, err := c.marshalJSON(svc) + if err != nil { + return true, err + } + // Spawn worker to query service. + workers.Add(1) + go func(sock string, data []byte) { + defer workers.Done() + responseCh <- querySocket(ctx, sock, data) + }(cl.dir+"/"+svc, data) + } + + go func() { + // Clean up workers. + workers.Wait() + close(responseCh) + }() + + var result bytes.Buffer + var notOk int +RecvResponses: + for { + select { + case resp, ok := <-responseCh: + if !ok { + // Responses channel is closed so stop reading. + break RecvResponses + } + if resp.err != nil { + // querySocket only returns unrecoverable errors, + // so return the first one received. + return true, resp.err + } + if !resp.handled { + notOk++ + continue + } + + first := result.Len() == 0 + result.Write(resp.data) + if first && c.fastest { + // Return the fastest response. + break RecvResponses + } + case <-ctx.Done(): + // If requests time out, userdb is unavailable. + return ctx.Err() != context.DeadlineExceeded, nil + } + } + // If all sockets are not ok, userdb is unavailable. + if notOk == len(ss) { + return false, nil + } + return true, um.unmarshalJSON(result.Bytes()) +} + +// services enumerates userdb service sockets in dir. +// If ok is false, io.systemd.UserDatabase service does not exist. +func (cl userdbClient) services() (s []string, ok bool, err error) { + var entries []fs.DirEntry + if entries, err = os.ReadDir(cl.dir); err != nil { + ok = !os.IsNotExist(err) + return + } + ok = true + for _, ent := range entries { + s = append(s, ent.Name()) + } + return +} + +// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface. +// If the multiplexer service is available, the call is sent only to it. +// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir. +// The fastest reply is read and parsed. All other requests are cancelled. +// If the service is unavailable, the first result is false. +// The service is considered unavailable if the requests time-out as well. +func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) { + services := []string{svcMultiplexer} + if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil { + // No mux service so call all available services. + var ok bool + if services, ok, err = cl.services(); !ok || err != nil { + return ok, err + } + } + call.fastest = true + if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil { + return ok, err + } + return true, nil +} + +type jsonUnmarshaler interface { + unmarshalJSON([]byte) error +} + +func isSpace(c byte) bool { + return c == ' ' || c == '\t' || c == '\r' || c == '\n' +} + +// findElementStart returns a slice of r that starts at the next JSON element. +// It skips over valid JSON space characters and checks for the colon separator. +func findElementStart(r []byte) ([]byte, error) { + var idx int + var b byte + colon := byte(':') + var seenColon bool + for idx, b = range r { + if isSpace(b) { + continue + } + if !seenColon && b == colon { + seenColon = true + continue + } + // Spotted colon and b is not a space, so value starts here. + if seenColon { + break + } + return nil, errors.New("expected colon, got invalid character: " + string(b)) + } + if !seenColon { + return nil, errors.New("expected colon, got end of input") + } + return r[idx:], nil +} + +// parseJSONString reads a JSON string from r. +func parseJSONString(r []byte) (string, error) { + r, err := findElementStart(r) + if err != nil { + return "", err + } + // Smallest valid string is `""`. + if l := len(r); l < 2 { + return "", errors.New("unexpected end of input") + } else if l == 2 { + if bytes.Equal(r, []byte(`""`)) { + return "", nil + } + return "", errors.New("invalid string") + } + + if c := r[0]; c != '"' { + return "", errors.New(`expected " got ` + string(c)) + } + // Advance over opening quote. + r = r[1:] + + var value strings.Builder + var inEsc bool + var inUEsc bool + var strEnds bool + reader := bytes.NewReader(r) + for { + if value.Len() > 4096 { + return "", errors.New("string too large") + } + + // Parse unicode escape sequences. + if inUEsc { + maybeRune := make([]byte, 4) + n, err := reader.Read(maybeRune) + if err != nil || n != 4 { + return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune)) + } + prn, err := strconv.ParseUint(string(maybeRune), 16, 32) + if err != nil { + return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune)) + } + rn := rune(prn) + if !utf16.IsSurrogate(rn) { + value.WriteRune(rn) + inUEsc = false + continue + } + // rn maybe a high surrogate; read the low surrogate. + maybeRune = make([]byte, 6) + n, err = reader.Read(maybeRune) + if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' { + // Not a valid UTF-16 surrogate pair. + if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil { + return "", err + } + // Invalid low surrogate; write the replacement character. + value.WriteRune(utf8.RuneError) + } else { + rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32) + if err != nil { + return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune)) + } + // Check if rn and rn1 are valid UTF-16 surrogate pairs. + if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError { + n = utf8.EncodeRune(maybeRune, dec) + // Write the decoded rune. + value.Write(maybeRune[:n]) + } + } + inUEsc = false + continue + } + + if inEsc { + b, err := reader.ReadByte() + if err != nil { + return "", err + } + switch b { + case 'b': + value.WriteByte('\b') + case 'f': + value.WriteByte('\f') + case 'n': + value.WriteByte('\n') + case 'r': + value.WriteByte('\r') + case 't': + value.WriteByte('\t') + case 'u': + inUEsc = true + case '/': + value.WriteByte('/') + case '\\': + value.WriteByte('\\') + case '"': + value.WriteByte('"') + default: + return "", errors.New("unexpected character in escape sequence " + string(b)) + } + inEsc = false + continue + } else { + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + break + } + return "", err + } + if rn == '\\' { + inEsc = true + continue + } + if rn == '"' { + // String ends on un-escaped quote. + strEnds = true + break + } + value.WriteRune(rn) + } + } + if !strEnds { + return "", errors.New("unexpected end of input") + } + return value.String(), nil +} + +// parseJSONInt64 reads a 64 bit integer from r. +func parseJSONInt64(r []byte) (int64, error) { + r, err := findElementStart(r) + if err != nil { + return 0, err + } + var num strings.Builder + for _, b := range r { + // int64 max is 19 digits long. + if num.Len() == 20 { + return 0, errors.New("number too large") + } + if strings.ContainsRune("0123456789", rune(b)) { + num.WriteByte(b) + } else { + break + } + } + n, err := strconv.ParseInt(num.String(), 10, 64) + return int64(n), err +} + +// parseJSONBoolean reads a boolean from r. +func parseJSONBoolean(r []byte) (bool, error) { + r, err := findElementStart(r) + if err != nil { + return false, err + } + if bytes.HasPrefix(r, []byte("true")) { + return true, nil + } + if bytes.HasPrefix(r, []byte("false")) { + return false, nil + } + return false, errors.New("unable to parse boolean value") +} + +type groupRecord struct { + groupName string + gid int64 +} + +func (g *groupRecord) unmarshalJSON(data []byte) error { + var ( + kGroupName = []byte(`"groupName"`) + kGid = []byte(`"gid"`) + ) + if i := bytes.Index(data, kGroupName); i != -1 { + groupname, err := parseJSONString(data[i+len(kGroupName):]) + if err != nil { + return err + } + g.groupName = groupname + } + if i := bytes.Index(data, kGid); i != -1 { + gid, err := parseJSONInt64(data[i+len(kGid):]) + if err != nil { + return err + } + g.gid = gid + } + return nil +} + +// queryGroupDb queries the userdb interface for a gid, groupname, or both. +func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) { + group := groupRecord{} + request := userdbCall{ + method: mGetGroupRecord, + parameters: callParameters{gid: gid, groupName: groupname}, + } + if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err) + } + return &Group{ + Name: group.groupName, + Gid: strconv.FormatInt(group.gid, 10), + }, true, nil +} + +type userRecord struct { + userName string + realName string + uid int64 + gid int64 + homeDirectory string +} + +func (u *userRecord) unmarshalJSON(data []byte) error { + var ( + kUserName = []byte(`"userName"`) + kRealName = []byte(`"realName"`) + kUid = []byte(`"uid"`) + kGid = []byte(`"gid"`) + kHomeDirectory = []byte(`"homeDirectory"`) + ) + if i := bytes.Index(data, kUserName); i != -1 { + username, err := parseJSONString(data[i+len(kUserName):]) + if err != nil { + return err + } + u.userName = username + } + if i := bytes.Index(data, kRealName); i != -1 { + realname, err := parseJSONString(data[i+len(kRealName):]) + if err != nil { + return err + } + u.realName = realname + } + if i := bytes.Index(data, kUid); i != -1 { + uid, err := parseJSONInt64(data[i+len(kUid):]) + if err != nil { + return err + } + u.uid = uid + } + if i := bytes.Index(data, kGid); i != -1 { + gid, err := parseJSONInt64(data[i+len(kGid):]) + if err != nil { + return err + } + u.gid = gid + } + if i := bytes.Index(data, kHomeDirectory); i != -1 { + homedir, err := parseJSONString(data[i+len(kHomeDirectory):]) + if err != nil { + return err + } + u.homeDirectory = homedir + } + return nil +} + +// queryUserDb queries the userdb interface for a uid, username, or both. +func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) { + user := userRecord{} + request := userdbCall{ + method: mGetUserRecord, + parameters: callParameters{ + uid: uid, + userName: username, + }, + } + if ok, err := cl.query(ctx, &request, &user); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err) + } + return &User{ + Uid: strconv.FormatInt(user.uid, 10), + Gid: strconv.FormatInt(user.gid, 10), + Username: user.userName, + Name: user.realName, + HomeDir: user.homeDirectory, + }, true, nil +} + +func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) { + return cl.queryGroupDb(ctx, nil, groupname) +} + +func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) { + gid, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, true, err + } + return cl.queryGroupDb(ctx, &gid, "") +} + +func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) { + return cl.queryUserDb(ctx, nil, username) +} + +func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) { + uid, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, true, err + } + return cl.queryUserDb(ctx, &uid, "") +} + +type memberships struct { + // Keys are groupNames and values are sets of userNames. + groupUsers map[string]map[string]struct{} +} + +// unmarshalJSON expects many (userName, groupName) records separated by a null byte. +// This is used to build a membership map. +func (m *memberships) unmarshalJSON(data []byte) error { + if m.groupUsers == nil { + m.groupUsers = make(map[string]map[string]struct{}) + } + var ( + kUserName = []byte(`"userName"`) + kGroupName = []byte(`"groupName"`) + ) + // Split records by null terminator. + records := bytes.Split(data, []byte{byte(0)}) + for _, rec := range records { + if len(rec) == 0 { + continue + } + var groupName string + var userName string + var err error + if i := bytes.Index(rec, kGroupName); i != -1 { + if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil { + return err + } + } + if i := bytes.Index(rec, kUserName); i != -1 { + if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil { + return err + } + } + // Associate userName with groupName. + if groupName != "" && userName != "" { + if _, ok := m.groupUsers[groupName]; ok { + m.groupUsers[groupName][userName] = struct{}{} + } else { + m.groupUsers[groupName] = map[string]struct{}{userName: {}} + } + } + } + return nil +} + +func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) { + services, ok, err := cl.services() + if !ok || err != nil { + return nil, ok, err + } + // Fetch group memberships for username. + var ms memberships + request := userdbCall{ + method: mGetMemberships, + parameters: callParameters{userName: username}, + more: true, + } + if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err) + } + // Fetch user group gid. + var group groupRecord + request = userdbCall{ + method: mGetGroupRecord, + parameters: callParameters{groupName: username}, + } + if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { + return nil, ok, err + } + gids := []string{strconv.FormatInt(group.gid, 10)} + + // Fetch group records for each group. + for g := range ms.groupUsers { + var group groupRecord + request.parameters.groupName = g + // Query group for gid. + if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err) + } + gids = append(gids, strconv.FormatInt(group.gid, 10)) + } + return gids, true, nil +} diff --git a/src/os/user/userdbclient_linux_test.go b/src/os/user/userdbclient_linux_test.go new file mode 100644 index 0000000000..1b9a336f72 --- /dev/null +++ b/src/os/user/userdbclient_linux_test.go @@ -0,0 +1,504 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package user + +import ( + "bytes" + "context" + "errors" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "syscall" + "testing" + "time" + "unicode/utf8" +) + +func TestQueryNoUserdb(t *testing.T) { + cl := &userdbClient{dir: "/non/existent"} + if _, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib"); ok { + t.Fatalf("should fail but lookup has been handled or error is nil: %v", err) + } +} + +type userdbTestData map[string]udbResponse + +type udbResponse struct { + data []byte + delay time.Duration +} + +func userdbServer(t *testing.T, sockFn string, data userdbTestData) { + ready := make(chan struct{}) + go func() { + if err := serveUserdb(ready, sockFn, data); err != nil { + t.Error(err) + } + }() + <-ready +} + +func (u userdbTestData) String() string { + var s strings.Builder + for k, v := range u { + s.WriteString("Request:\n") + s.WriteString(k) + s.WriteString("\nResponse:\n") + if v.delay > 0 { + s.WriteString("Delay: ") + s.WriteString(v.delay.String()) + s.WriteString("\n") + } + s.WriteString("Data:\n") + s.Write(v.data) + s.WriteString("\n") + } + return s.String() +} + +// serverUserdb is a simple userdb server that replies to VARLINK method calls. +// A message is sent on the ready channel when the server is ready to accept calls. +// The server will reply to each request in the data map. If a request is not +// found in the map, the server will return an error. +func serveUserdb(ready chan<- struct{}, sockFn string, data userdbTestData) error { + sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return err + } + defer syscall.Close(sockFd) + if err := syscall.Bind(sockFd, &syscall.SockaddrUnix{Name: sockFn}); err != nil { + return err + } + if err := syscall.Listen(sockFd, 1); err != nil { + return err + } + + // Send ready signal. + ready <- struct{}{} + + var srvGroup sync.WaitGroup + + srvErrs := make(chan error, len(data)) + for len(data) != 0 { + nfd, _, err := syscall.Accept(sockFd) + if err != nil { + syscall.Close(nfd) + return err + } + + // Read request. + buf := make([]byte, 4096) + n, err := syscall.Read(nfd, buf) + if err != nil { + syscall.Close(nfd) + return err + } + if n == 0 { + // Client went away. + continue + } + if buf[n-1] != 0 { + syscall.Close(nfd) + return errors.New("request not null terminated") + } + // Remove null terminator. + buf = buf[:n-1] + got := string(buf) + + // Fetch response for request. + response, ok := data[got] + if !ok { + syscall.Close(nfd) + msg := "unexpected request:\n" + got + "\n\ndata:\n" + data.String() + return errors.New(msg) + } + delete(data, got) + + srvGroup.Add(1) + go func() { + defer srvGroup.Done() + if err := serveClient(nfd, response); err != nil { + srvErrs <- err + } + }() + } + + srvGroup.Wait() + // Combine serve errors if any. + if len(srvErrs) > 0 { + var errs []error + for err := range srvErrs { + errs = append(errs, err) + } + return errors.Join(errs...) + } + + return nil +} + +func serveClient(fd int, response udbResponse) error { + defer syscall.Close(fd) + time.Sleep(response.delay) + data := response.data + if len(data) != 0 && data[len(data)-1] != 0 { + data = append(data, 0) + } + written := 0 + for written < len(data) { + if n, err := syscall.Write(fd, data[written:]); err != nil { + return err + } else { + written += n + } + } + return nil +} + +func TestSlowUserdbLookup(t *testing.T) { + tmpdir := t.TempDir() + data := userdbTestData{ + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{ + delay: time.Hour, + }, + } + userdbServer(t, tmpdir+"/"+svcMultiplexer, data) + cl := &userdbClient{dir: tmpdir} + // Lookup should timeout. + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + if _, ok, _ := cl.lookupGroup(ctx, "stdlibcontrib"); ok { + t.Fatalf("lookup should not be handled but was") + } +} + +func TestFastestUserdbLookup(t *testing.T) { + tmpdir := t.TempDir() + fastData := userdbTestData{ + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"fast","groupName":"stdlibcontrib"}}`: udbResponse{ + data: []byte( + `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + } + slowData := userdbTestData{ + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"slow","groupName":"stdlibcontrib"}}`: udbResponse{ + delay: 50 * time.Millisecond, + data: []byte( + `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":182,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + } + userdbServer(t, tmpdir+"/"+"fast", fastData) + userdbServer(t, tmpdir+"/"+"slow", slowData) + cl := &userdbClient{dir: tmpdir} + group, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib") + if !ok { + t.Fatalf("lookup should be handled but was not") + } + if err != nil { + t.Fatalf("lookup should not fail but did: %v", err) + } + if group.Gid != "181" { + t.Fatalf("lookup should return group 181 but returned %s", group.Gid) + } +} + +func TestUserdbLookupGroup(t *testing.T) { + tmpdir := t.TempDir() + data := userdbTestData{ + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{ + data: []byte( + `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + } + userdbServer(t, tmpdir+"/"+svcMultiplexer, data) + + groupname := "stdlibcontrib" + want := &Group{ + Name: "stdlibcontrib", + Gid: "181", + } + cl := &userdbClient{dir: tmpdir} + got, ok, err := cl.lookupGroup(context.Background(), groupname) + if !ok { + t.Fatal("lookup should have been handled") + } + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("lookupGroup(%s) = %v, want %v", groupname, got, want) + } +} + +func TestUserdbLookupUser(t *testing.T) { + tmpdir := t.TempDir() + data := userdbTestData{ + `{"method":"io.systemd.UserDatabase.GetUserRecord","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"}}`: udbResponse{ + data: []byte( + `{"parameters":{"record":{"userName":"stdlibcontrib","uid":181,"gid":181,"realName":"Stdlib Contrib","homeDirectory":"/home/stdlibcontrib","status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + } + userdbServer(t, tmpdir+"/"+svcMultiplexer, data) + + username := "stdlibcontrib" + want := &User{ + Uid: "181", + Gid: "181", + Username: "stdlibcontrib", + Name: "Stdlib Contrib", + HomeDir: "/home/stdlibcontrib", + } + cl := &userdbClient{dir: tmpdir} + got, ok, err := cl.lookupUser(context.Background(), username) + if !ok { + t.Fatal("lookup should have been handled") + } + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("lookupUser(%s) = %v, want %v", username, got, want) + } +} + +func TestUserdbLookupGroupIds(t *testing.T) { + tmpdir := t.TempDir() + data := userdbTestData{ + `{"method":"io.systemd.UserDatabase.GetMemberships","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"},"more":true}`: udbResponse{ + data: []byte( + `{"parameters":{"userName":"stdlibcontrib","groupName":"stdlib"},"continues":true}` + "\x00" + `{"parameters":{"userName":"stdlibcontrib","groupName":"contrib"}}`, + ), + }, + // group records + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{ + data: []byte( + `{"parameters":{"record":{"groupName":"stdlibcontrib","members":["stdlibcontrib"],"gid":181,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlib"}}`: udbResponse{ + data: []byte( + `{"parameters":{"record":{"groupName":"stdlib","members":["stdlibcontrib"],"gid":182,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"contrib"}}`: udbResponse{ + data: []byte( + `{"parameters":{"record":{"groupName":"contrib","members":["stdlibcontrib"],"gid":183,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, + ), + }, + } + userdbServer(t, tmpdir+"/"+svcMultiplexer, data) + + username := "stdlibcontrib" + want := []string{"181", "182", "183"} + cl := &userdbClient{dir: tmpdir} + got, ok, err := cl.lookupGroupIds(context.Background(), username) + if !ok { + t.Fatal("lookup should have been handled") + } + if err != nil { + t.Fatal(err) + } + // Result order is not specified so sort it. + sort.Strings(got) + if !reflect.DeepEqual(got, want) { + t.Fatalf("lookupGroupIds(%s) = %v, want %v", username, got, want) + } +} + +var findElementStartTestCases = []struct { + in []byte + want []byte + err bool +}{ + {in: []byte(`:`), want: []byte(``)}, + {in: []byte(`: `), want: []byte(``)}, + {in: []byte(`:"foo"`), want: []byte(`"foo"`)}, + {in: []byte(` :"foo"`), want: []byte(`"foo"`)}, + {in: []byte(` 1231 :"foo"`), err: true}, + {in: []byte(``), err: true}, + {in: []byte(`"foo"`), err: true}, + {in: []byte(`foo`), err: true}, +} + +func TestFindElementStart(t *testing.T) { + for i, tc := range findElementStartTestCases { + t.Run("#"+strconv.Itoa(i), func(t *testing.T) { + got, err := findElementStart(tc.in) + if tc.err && err == nil { + t.Errorf("want err for findElementStart(%s), got nil", tc.in) + } + if !tc.err { + if err != nil { + t.Errorf("findElementStart(%s) unexpected error: %s", tc.in, err.Error()) + } + if !bytes.Contains(tc.in, got) { + t.Errorf("%s should contain %s but does not", tc.in, got) + } + } + }) + } +} + +func FuzzFindElementStart(f *testing.F) { + for _, tc := range findElementStartTestCases { + if !tc.err { + f.Add(tc.in) + } + } + f.Fuzz(func(t *testing.T, b []byte) { + if out, err := findElementStart(b); err == nil && !bytes.Contains(b, out) { + t.Errorf("%s, %v", out, err) + } + }) +} + +var parseJSONStringTestCases = []struct { + in []byte + want string + err bool +}{ + {in: []byte(`:""`)}, + {in: []byte(`:"\n"`), want: "\n"}, + {in: []byte(`: "\""`), want: "\""}, + {in: []byte(`:"\t \\"`), want: "\t \\"}, + {in: []byte(`:"\\\\"`), want: `\\`}, + {in: []byte(`::`), err: true}, + {in: []byte(`""`), err: true}, + {in: []byte(`"`), err: true}, + {in: []byte(":\"0\xE5"), err: true}, + {in: []byte{':', '"', 0xFE, 0xFE, 0xFF, 0xFF, '"'}, want: "\uFFFD\uFFFD\uFFFD\uFFFD"}, + {in: []byte(`:"\u0061a"`), want: "aa"}, + {in: []byte(`:"\u0159\u0170"`), want: "řŰ"}, + {in: []byte(`:"\uD800\uDC00"`), want: "\U00010000"}, + {in: []byte(`:"\uD800"`), want: "\uFFFD"}, + {in: []byte(`:"\u000"`), err: true}, + {in: []byte(`:"\u00MF"`), err: true}, + {in: []byte(`:"\uD800\uDC0"`), err: true}, +} + +func TestParseJSONString(t *testing.T) { + for i, tc := range parseJSONStringTestCases { + t.Run("#"+strconv.Itoa(i), func(t *testing.T) { + got, err := parseJSONString(tc.in) + if tc.err && err == nil { + t.Errorf("want err for parseJSONString(%s), got nil", tc.in) + } + if !tc.err { + if err != nil { + t.Errorf("parseJSONString(%s) unexpected error: %s", tc.in, err.Error()) + } + if tc.want != got { + t.Errorf("parseJSONString(%s) = %s, want %s", tc.in, got, tc.want) + } + } + }) + } +} + +func FuzzParseJSONString(f *testing.F) { + for _, tc := range parseJSONStringTestCases { + f.Add(tc.in) + } + f.Fuzz(func(t *testing.T, b []byte) { + if out, err := parseJSONString(b); err == nil && !utf8.ValidString(out) { + t.Errorf("parseJSONString(%s) = %s, invalid string", b, out) + } + }) +} + +var parseJSONInt64TestCases = []struct { + in []byte + want int64 + err bool +}{ + {in: []byte(":1235"), want: 1235}, + {in: []byte(": 123"), want: 123}, + {in: []byte(":0")}, + {in: []byte(":5012313123131231"), want: 5012313123131231}, + {in: []byte("1231"), err: true}, +} + +func TestParseJSONInt64(t *testing.T) { + for i, tc := range parseJSONInt64TestCases { + t.Run("#"+strconv.Itoa(i), func(t *testing.T) { + got, err := parseJSONInt64(tc.in) + if tc.err && err == nil { + t.Errorf("want err for parseJSONInt64(%s), got nil", tc.in) + } + if !tc.err { + if err != nil { + t.Errorf("parseJSONInt64(%s) unexpected error: %s", tc.in, err.Error()) + } + if tc.want != got { + t.Errorf("parseJSONInt64(%s) = %d, want %d", tc.in, got, tc.want) + } + } + }) + } +} + +func FuzzParseJSONInt64(f *testing.F) { + for _, tc := range parseJSONInt64TestCases { + f.Add(tc.in) + } + f.Fuzz(func(t *testing.T, b []byte) { + if out, err := parseJSONInt64(b); err == nil && + !bytes.Contains(b, []byte(strconv.FormatInt(out, 10))) { + t.Errorf("parseJSONInt64(%s) = %d, %v", b, out, err) + } + }) +} + +var parseJSONBooleanTestCases = []struct { + in []byte + want bool + err bool +}{ + {in: []byte(": true "), want: true}, + {in: []byte(":true "), want: true}, + {in: []byte(": false "), want: false}, + {in: []byte(":false "), want: false}, + {in: []byte("true"), err: true}, + {in: []byte("false"), err: true}, + {in: []byte("foo"), err: true}, +} + +func TestParseJSONBoolean(t *testing.T) { + for i, tc := range parseJSONBooleanTestCases { + t.Run("#"+strconv.Itoa(i), func(t *testing.T) { + got, err := parseJSONBoolean(tc.in) + if tc.err && err == nil { + t.Errorf("want err for parseJSONBoolean(%s), got nil", tc.in) + } + if !tc.err { + if err != nil { + t.Errorf("parseJSONBoolean(%s) unexpected error: %s", tc.in, err.Error()) + } + if tc.want != got { + t.Errorf("parseJSONBoolean(%s) = %t, want %t", tc.in, got, tc.want) + } + } + }) + } +} + +func FuzzParseJSONBoolean(f *testing.F) { + for _, tc := range parseJSONBooleanTestCases { + f.Add(tc.in) + } + f.Fuzz(func(t *testing.T, b []byte) { + if out, err := parseJSONBoolean(b); err == nil && !bytes.Contains(b, []byte(strconv.FormatBool(out))) { + t.Errorf("parseJSONBoolean(%s) = %t, %v", b, out, err) + } + }) +} diff --git a/src/os/user/userdbclient_stub.go b/src/os/user/userdbclient_stub.go new file mode 100644 index 0000000000..d31f065c3a --- /dev/null +++ b/src/os/user/userdbclient_stub.go @@ -0,0 +1,29 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !linux + +package user + +import "context" + +func (cl userdbClient) lookupGroup(_ context.Context, _ string) (*Group, bool, error) { + return nil, false, nil +} + +func (cl userdbClient) lookupGroupId(_ context.Context, _ string) (*Group, bool, error) { + return nil, false, nil +} + +func (cl userdbClient) lookupUser(_ context.Context, _ string) (*User, bool, error) { + return nil, false, nil +} + +func (cl userdbClient) lookupUserId(_ context.Context, _ string) (*User, bool, error) { + return nil, false, nil +} + +func (cl userdbClient) lookupGroupIds(_ context.Context, _ string) ([]string, bool, error) { + return nil, false, nil +} -- 2.48.1