From 8752bb0a65a0e8be0701825396ede7ba0c874f13 Mon Sep 17 00:00:00 2001 From: Juanadelacuesta <8647634+Juanadelacuesta@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:33:24 +0100 Subject: [PATCH] func: move the user lookup into the validation, it's used everywhere the function is called --- drivers/exec/driver.go | 16 +++-------- drivers/exec/driver_test.go | 3 +-- drivers/rawexec/driver.go | 3 +-- drivers/rawexec/driver_test.go | 3 +-- drivers/rawexec/driver_unix.go | 14 +++------- drivers/shared/validators/validators.go | 27 ++++++++++++++----- .../shared/validators/validators_default.go | 8 +++--- drivers/shared/validators/validators_unix.go | 16 +++++------ 8 files changed, 40 insertions(+), 50 deletions(-) diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index 3ae3a1f64..caad707d9 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -7,7 +7,6 @@ import ( "context" "fmt" "os" - "os/user" "path/filepath" "runtime" "sync" @@ -24,7 +23,6 @@ import ( "github.com/hashicorp/nomad/drivers/shared/validators" "github.com/hashicorp/nomad/helper/pluginutils/loader" "github.com/hashicorp/nomad/helper/pointer" - "github.com/hashicorp/nomad/helper/users" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/plugins/drivers/fsisolation" @@ -253,7 +251,7 @@ type TaskState struct { } type UserIDValidator interface { - HasValidIDs(user *user.User) error + HasValidIDs(userName string) error } // NewExecDriver returns a new DrivePlugin implementation @@ -450,15 +448,6 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { return nil } -func (d *Driver) validateUserIds(cfg *drivers.TaskConfig) error { - user, err := users.Lookup(cfg.User) - if err != nil { - return fmt.Errorf("failed to identify user %q: %w", cfg.User, err) - } - - return d.userIDValidator.HasValidIDs(user) -} - func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) { if _, ok := d.tasks.Get(cfg.ID); ok { return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID) @@ -476,9 +465,10 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive if cfg.User == "" { cfg.User = "nobody" } + d.logger.Debug("setting up user", "user", cfg.User) - if err := d.validateUserIds(cfg); err != nil { + if err := d.userIDValidator.HasValidIDs(cfg.User); err != nil { return nil, nil, fmt.Errorf("failed host user validation: %v", err) } diff --git a/drivers/exec/driver_test.go b/drivers/exec/driver_test.go index 66af7965c..cd275bc87 100644 --- a/drivers/exec/driver_test.go +++ b/drivers/exec/driver_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "os" - "os/user" "path/filepath" "runtime" "strconv" @@ -39,7 +38,7 @@ import ( type mockIDValidator struct{} -func (mv *mockIDValidator) HasValidIDs(user *user.User) error { +func (mv *mockIDValidator) HasValidIDs(userName string) error { return nil } diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index 52a6108d7..39493e169 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "os" - "os/user" "path/filepath" "strconv" "time" @@ -112,7 +111,7 @@ var ( ) type UserIDValidator interface { - HasValidIDs(user *user.User) error + HasValidIDs(userName string) error } // Driver is a privileged version of the exec driver. It provides no diff --git a/drivers/rawexec/driver_test.go b/drivers/rawexec/driver_test.go index 506c0669d..693eac2d3 100644 --- a/drivers/rawexec/driver_test.go +++ b/drivers/rawexec/driver_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "os" - "os/user" "path/filepath" "runtime" "strconv" @@ -79,7 +78,7 @@ var ( type mockIDValidator struct{} -func (mv *mockIDValidator) HasValidIDs(user *user.User) error { +func (mv *mockIDValidator) HasValidIDs(userName string) error { return nil } diff --git a/drivers/rawexec/driver_unix.go b/drivers/rawexec/driver_unix.go index 7b3ebb2d4..04ef54173 100644 --- a/drivers/rawexec/driver_unix.go +++ b/drivers/rawexec/driver_unix.go @@ -7,7 +7,6 @@ package rawexec import ( "fmt" - "os/user" "github.com/hashicorp/nomad/helper/users" "github.com/hashicorp/nomad/plugins/drivers" @@ -15,22 +14,17 @@ import ( func (d *Driver) Validate(cfg drivers.TaskConfig) error { usernameToLookup := cfg.User - var user *user.User - var err error // Uses the current user of the client agent process // if no override is given (differs from exec) if usernameToLookup == "" { - user, err = users.Current() + user, err := users.Current() if err != nil { return fmt.Errorf("failed to get current user: %w", err) } - } else { - user, err = users.Lookup(usernameToLookup) - if err != nil { - return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err) - } + + usernameToLookup = user.Username } - return d.userIDValidator.HasValidIDs(user) + return d.userIDValidator.HasValidIDs(usernameToLookup) } diff --git a/drivers/shared/validators/validators.go b/drivers/shared/validators/validators.go index 2bbf4a8cb..edbbf04b7 100644 --- a/drivers/shared/validators/validators.go +++ b/drivers/shared/validators/validators.go @@ -6,13 +6,12 @@ package validators import ( "errors" "fmt" - "os/user" "strconv" "strings" "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/lib/idset" - "github.com/hashicorp/nomad/client/lib/numalib/hw" + "github.com/hashicorp/nomad/helper/users" ) var ( @@ -21,12 +20,21 @@ var ( ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound") ) +type ( + + // A GroupID (GID) represents a unique numerical value assigned to each user group. + GroupID uint64 + + // A UserID represents a unique numerical value assigned to each user account. + UserID uint64 +) + type validator struct { // DeniedHostUids configures which host uids are disallowed - deniedUIDs *idset.Set[hw.UserID] + deniedUIDs *idset.Set[UserID] // DeniedHostGids configures which host gids are disallowed - deniedGIDs *idset.Set[hw.GroupID] + deniedGIDs *idset.Set[GroupID] // logger will log to the Nomad agent logger hclog.Logger @@ -48,8 +56,8 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (* valLogger.Debug("group range configured", "denied range", deniedHostGIDs) v := &validator{ - deniedUIDs: idset.Parse[hw.UserID](deniedHostUIDs), - deniedGIDs: idset.Parse[hw.GroupID](deniedHostGIDs), + deniedUIDs: idset.Parse[UserID](deniedHostUIDs), + deniedGIDs: idset.Parse[GroupID](deniedHostGIDs), logger: valLogger, } @@ -58,7 +66,12 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (* // HasValidIDs is used when running a task to ensure the // given user is in the ID range defined in the task config -func (v *validator) HasValidIDs(user *user.User) error { +func (v *validator) HasValidIDs(userName string) error { + user, err := users.Lookup(userName) + if err != nil { + return fmt.Errorf("failed to identify user %q: %w", userName, err) + } + uid, err := getUserID(user) if err != nil { return fmt.Errorf("validator: %w", err) diff --git a/drivers/shared/validators/validators_default.go b/drivers/shared/validators/validators_default.go index 69dea4e9d..1f9adad46 100644 --- a/drivers/shared/validators/validators_default.go +++ b/drivers/shared/validators/validators_default.go @@ -7,16 +7,14 @@ package validators import ( "os/user" - - "github.com/hashicorp/nomad/client/lib/numalib/hw" ) // noop -func getUserID(user *user.User) (hw.UserID, error) { +func getUserID(*user.User) (UserID, error) { return 0, nil } // noop -func getGroupID(user *user.User) ([]hw.GroupID, error) { - return []hw.GroupID{}, nil +func getGroupID(*user.User) ([]GroupID, error) { + return []GroupID{}, nil } diff --git a/drivers/shared/validators/validators_unix.go b/drivers/shared/validators/validators_unix.go index 381857739..1d7aa597a 100644 --- a/drivers/shared/validators/validators_unix.go +++ b/drivers/shared/validators/validators_unix.go @@ -9,34 +9,32 @@ import ( "fmt" "os/user" "strconv" - - "github.com/hashicorp/nomad/client/lib/numalib/hw" ) -func getUserID(user *user.User) (hw.UserID, error) { +func getUserID(user *user.User) (UserID, error) { id, err := strconv.ParseUint(user.Uid, 10, 32) if err != nil { return 0, fmt.Errorf("unable to convert userid %s to integer", user.Uid) } - return hw.UserID(id), nil + return UserID(id), nil } -func getGroupID(user *user.User) ([]hw.GroupID, error) { +func getGroupID(user *user.User) ([]GroupID, error) { gidStrings, err := user.GroupIds() if err != nil { - return []hw.GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err) + return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err) } - gids := make([]hw.GroupID, len(gidStrings)) + gids := make([]GroupID, len(gidStrings)) for _, gidString := range gidStrings { u, err := strconv.ParseUint(gidString, 10, 32) if err != nil { - return []hw.GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err) + return []GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err) } - gids = append(gids, hw.GroupID(u)) + gids = append(gids, GroupID(u)) } return gids, nil