mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
func: move the user lookup into the validation, it's used everywhere the function is called
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"github.com/hashicorp/nomad/drivers/shared/validators"
|
||||
"github.com/hashicorp/nomad/helper/pluginutils/loader"
|
||||
"github.com/hashicorp/nomad/helper/pointer"
|
||||
"github.com/hashicorp/nomad/helper/users"
|
||||
"github.com/hashicorp/nomad/plugins/base"
|
||||
"github.com/hashicorp/nomad/plugins/drivers"
|
||||
"github.com/hashicorp/nomad/plugins/drivers/fsisolation"
|
||||
@@ -253,7 +251,7 @@ type TaskState struct {
|
||||
}
|
||||
|
||||
type UserIDValidator interface {
|
||||
HasValidIDs(user *user.User) error
|
||||
HasValidIDs(userName string) error
|
||||
}
|
||||
|
||||
// NewExecDriver returns a new DrivePlugin implementation
|
||||
@@ -450,15 +448,6 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Driver) validateUserIds(cfg *drivers.TaskConfig) error {
|
||||
user, err := users.Lookup(cfg.User)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to identify user %q: %w", cfg.User, err)
|
||||
}
|
||||
|
||||
return d.userIDValidator.HasValidIDs(user)
|
||||
}
|
||||
|
||||
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) {
|
||||
if _, ok := d.tasks.Get(cfg.ID); ok {
|
||||
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
|
||||
@@ -476,9 +465,10 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
|
||||
if cfg.User == "" {
|
||||
cfg.User = "nobody"
|
||||
}
|
||||
|
||||
d.logger.Debug("setting up user", "user", cfg.User)
|
||||
|
||||
if err := d.validateUserIds(cfg); err != nil {
|
||||
if err := d.userIDValidator.HasValidIDs(cfg.User); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed host user validation: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -39,7 +38,7 @@ import (
|
||||
|
||||
type mockIDValidator struct{}
|
||||
|
||||
func (mv *mockIDValidator) HasValidIDs(user *user.User) error {
|
||||
func (mv *mockIDValidator) HasValidIDs(userName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -112,7 +111,7 @@ var (
|
||||
)
|
||||
|
||||
type UserIDValidator interface {
|
||||
HasValidIDs(user *user.User) error
|
||||
HasValidIDs(userName string) error
|
||||
}
|
||||
|
||||
// Driver is a privileged version of the exec driver. It provides no
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -79,7 +78,7 @@ var (
|
||||
|
||||
type mockIDValidator struct{}
|
||||
|
||||
func (mv *mockIDValidator) HasValidIDs(user *user.User) error {
|
||||
func (mv *mockIDValidator) HasValidIDs(userName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ package rawexec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"github.com/hashicorp/nomad/helper/users"
|
||||
"github.com/hashicorp/nomad/plugins/drivers"
|
||||
@@ -15,22 +14,17 @@ import (
|
||||
|
||||
func (d *Driver) Validate(cfg drivers.TaskConfig) error {
|
||||
usernameToLookup := cfg.User
|
||||
var user *user.User
|
||||
var err error
|
||||
|
||||
// Uses the current user of the client agent process
|
||||
// if no override is given (differs from exec)
|
||||
if usernameToLookup == "" {
|
||||
user, err = users.Current()
|
||||
user, err := users.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current user: %w", err)
|
||||
}
|
||||
} else {
|
||||
user, err = users.Lookup(usernameToLookup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err)
|
||||
}
|
||||
|
||||
usernameToLookup = user.Username
|
||||
}
|
||||
|
||||
return d.userIDValidator.HasValidIDs(user)
|
||||
return d.userIDValidator.HasValidIDs(usernameToLookup)
|
||||
}
|
||||
|
||||
@@ -6,13 +6,12 @@ package validators
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/nomad/client/lib/idset"
|
||||
"github.com/hashicorp/nomad/client/lib/numalib/hw"
|
||||
"github.com/hashicorp/nomad/helper/users"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -21,12 +20,21 @@ var (
|
||||
ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound")
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
// A GroupID (GID) represents a unique numerical value assigned to each user group.
|
||||
GroupID uint64
|
||||
|
||||
// A UserID represents a unique numerical value assigned to each user account.
|
||||
UserID uint64
|
||||
)
|
||||
|
||||
type validator struct {
|
||||
// DeniedHostUids configures which host uids are disallowed
|
||||
deniedUIDs *idset.Set[hw.UserID]
|
||||
deniedUIDs *idset.Set[UserID]
|
||||
|
||||
// DeniedHostGids configures which host gids are disallowed
|
||||
deniedGIDs *idset.Set[hw.GroupID]
|
||||
deniedGIDs *idset.Set[GroupID]
|
||||
|
||||
// logger will log to the Nomad agent
|
||||
logger hclog.Logger
|
||||
@@ -48,8 +56,8 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
|
||||
valLogger.Debug("group range configured", "denied range", deniedHostGIDs)
|
||||
|
||||
v := &validator{
|
||||
deniedUIDs: idset.Parse[hw.UserID](deniedHostUIDs),
|
||||
deniedGIDs: idset.Parse[hw.GroupID](deniedHostGIDs),
|
||||
deniedUIDs: idset.Parse[UserID](deniedHostUIDs),
|
||||
deniedGIDs: idset.Parse[GroupID](deniedHostGIDs),
|
||||
logger: valLogger,
|
||||
}
|
||||
|
||||
@@ -58,7 +66,12 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
|
||||
|
||||
// HasValidIDs is used when running a task to ensure the
|
||||
// given user is in the ID range defined in the task config
|
||||
func (v *validator) HasValidIDs(user *user.User) error {
|
||||
func (v *validator) HasValidIDs(userName string) error {
|
||||
user, err := users.Lookup(userName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to identify user %q: %w", userName, err)
|
||||
}
|
||||
|
||||
uid, err := getUserID(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validator: %w", err)
|
||||
|
||||
@@ -7,16 +7,14 @@ package validators
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
|
||||
"github.com/hashicorp/nomad/client/lib/numalib/hw"
|
||||
)
|
||||
|
||||
// noop
|
||||
func getUserID(user *user.User) (hw.UserID, error) {
|
||||
func getUserID(*user.User) (UserID, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// noop
|
||||
func getGroupID(user *user.User) ([]hw.GroupID, error) {
|
||||
return []hw.GroupID{}, nil
|
||||
func getGroupID(*user.User) ([]GroupID, error) {
|
||||
return []GroupID{}, nil
|
||||
}
|
||||
|
||||
@@ -9,34 +9,32 @@ import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/nomad/client/lib/numalib/hw"
|
||||
)
|
||||
|
||||
func getUserID(user *user.User) (hw.UserID, error) {
|
||||
func getUserID(user *user.User) (UserID, error) {
|
||||
id, err := strconv.ParseUint(user.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to convert userid %s to integer", user.Uid)
|
||||
}
|
||||
|
||||
return hw.UserID(id), nil
|
||||
return UserID(id), nil
|
||||
}
|
||||
|
||||
func getGroupID(user *user.User) ([]hw.GroupID, error) {
|
||||
func getGroupID(user *user.User) ([]GroupID, error) {
|
||||
gidStrings, err := user.GroupIds()
|
||||
if err != nil {
|
||||
return []hw.GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
|
||||
return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
|
||||
}
|
||||
|
||||
gids := make([]hw.GroupID, len(gidStrings))
|
||||
gids := make([]GroupID, len(gidStrings))
|
||||
|
||||
for _, gidString := range gidStrings {
|
||||
u, err := strconv.ParseUint(gidString, 10, 32)
|
||||
if err != nil {
|
||||
return []hw.GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err)
|
||||
return []GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err)
|
||||
}
|
||||
|
||||
gids = append(gids, hw.GroupID(u))
|
||||
gids = append(gids, GroupID(u))
|
||||
}
|
||||
|
||||
return gids, nil
|
||||
|
||||
Reference in New Issue
Block a user