diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index 7d39c9fe4..7f17ff04f 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -265,7 +265,12 @@ func (tc *TaskConfig) validate() error { func (tc *TaskConfig) validateUserIds(cfg *drivers.TaskConfig, driverConfig *Config) error { usernameToLookup := getUsername(cfg) - return validators.HasValidIds(users.Lookup, usernameToLookup, driverConfig.DeniedHostUids, driverConfig.DeniedHostGids) + user, err := users.Lookup(usernameToLookup) + if err != nil { + return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err) + } + + return validators.HasValidIds(user, driverConfig.DeniedHostUids, driverConfig.DeniedHostGids) } // TaskState is the state which is encoded in the handle returned in diff --git a/drivers/rawexec/driver_unix.go b/drivers/rawexec/driver_unix.go index 406616165..c5d36f0d7 100644 --- a/drivers/rawexec/driver_unix.go +++ b/drivers/rawexec/driver_unix.go @@ -6,6 +6,8 @@ package rawexec import ( + "fmt" + "github.com/hashicorp/nomad/drivers/shared/validators" "github.com/hashicorp/nomad/helper/users" "github.com/hashicorp/nomad/plugins/drivers" @@ -25,5 +27,10 @@ func (tc *TaskConfig) Validate(driverCofig Config, cfg drivers.TaskConfig) error usernameToLookup = current.Name } - return validators.HasValidIds(users.Lookup, usernameToLookup, driverCofig.DeniedHostUids, driverCofig.DeniedHostGids) + user, err := users.Lookup(usernameToLookup) + if err != nil { + return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err) + } + + return validators.HasValidIds(user, driverCofig.DeniedHostUids, driverCofig.DeniedHostGids) } diff --git a/drivers/shared/validators/validators.go b/drivers/shared/validators/validators.go index e8842ecca..f48ab759c 100644 --- a/drivers/shared/validators/validators.go +++ b/drivers/shared/validators/validators.go @@ -1,8 +1,6 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 -//go:build !windows - package validators import ( @@ -36,21 +34,12 @@ func ParseIdRange(rangeType string, deniedRanges string) ([]structs.IDRange, err return idRanges, nil } -type userLookupFn func(string) (*user.User, error) - // HasValidIds is used when running a task to ensure the // given user is in the ID range defined in the task config -func HasValidIds(userLookupFn userLookupFn, usernameToLookup string, deniedHostUIDs, deniedHostGIDs []structs.IDRange) error { - - // look up user on host given username - - u, err := userLookupFn(usernameToLookup) +func HasValidIds(user *user.User, deniedHostUIDs, deniedHostGIDs []structs.IDRange) error { + uid, err := strconv.ParseUint(user.Uid, 10, 32) if err != nil { - return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err) - } - uid, err := strconv.ParseUint(u.Uid, 10, 32) - if err != nil { - return fmt.Errorf("unable to convert userid %s to integer", u.Uid) + return fmt.Errorf("unable to convert userid %s to integer", user.Uid) } // check uids @@ -63,7 +52,7 @@ func HasValidIds(userLookupFn userLookupFn, usernameToLookup string, deniedHostU // check gids - gidStrings, err := u.GroupIds() + gidStrings, err := user.GroupIds() if err != nil { return fmt.Errorf("unable to lookup user's group membership: %w", err) } diff --git a/drivers/shared/validators/validators_test.go b/drivers/shared/validators/validators_test.go index bd7b8b08c..2a5e4813b 100644 --- a/drivers/shared/validators/validators_test.go +++ b/drivers/shared/validators/validators_test.go @@ -1,8 +1,6 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 -//go:build !windows - package validators import ( @@ -63,13 +61,12 @@ func Test_HasValidIds(t *testing.T) { validRangesList := []structs.IDRange{validRange, validRangeSingle} testCases := []struct { - name string - uidRanges []structs.IDRange - gidRanges []structs.IDRange - uid string - gid string - expectedErr string - userLookupFunc userLookupFn + name string + uidRanges []structs.IDRange + gidRanges []structs.IDRange + uid string + gid string + expectedErr string }{ {name: "no-ranges-are-valid", uidRanges: validRangesList, gidRanges: emptyRanges}, {name: "uid-and-gid-outside-of-ranges-valid", uidRanges: validRangesList, gidRanges: validRangesList}, @@ -80,24 +77,20 @@ func Test_HasValidIds(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - defaultUserToReturn := &user.User{ + user := &user.User{ Uid: "200", Gid: "200", } if tc.uid != "" { - defaultUserToReturn.Uid = tc.uid + user.Uid = tc.uid } if tc.gid != "" { - defaultUserToReturn.Gid = tc.gid + user.Gid = tc.gid } - getUserFn := func(username string) (*user.User, error) { - return defaultUserToReturn, nil - } - - err := HasValidIds(getUserFn, "username", tc.uidRanges, tc.gidRanges) + err := HasValidIds(user, tc.uidRanges, tc.gidRanges) if tc.expectedErr == "" { must.NoError(t, err) } else {