func: move the user lookup into the validation, it's used everywhere the function is called

This commit is contained in:
Juanadelacuesta
2024-10-31 10:33:24 +01:00
parent 3449056cd6
commit 8752bb0a65
8 changed files with 40 additions and 50 deletions

View File

@@ -7,7 +7,6 @@ import (
"context"
"fmt"
"os"
"os/user"
"path/filepath"
"runtime"
"sync"
@@ -24,7 +23,6 @@ import (
"github.com/hashicorp/nomad/drivers/shared/validators"
"github.com/hashicorp/nomad/helper/pluginutils/loader"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/plugins/drivers/fsisolation"
@@ -253,7 +251,7 @@ type TaskState struct {
}
type UserIDValidator interface {
HasValidIDs(user *user.User) error
HasValidIDs(userName string) error
}
// NewExecDriver returns a new DrivePlugin implementation
@@ -450,15 +448,6 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return nil
}
func (d *Driver) validateUserIds(cfg *drivers.TaskConfig) error {
user, err := users.Lookup(cfg.User)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", cfg.User, err)
}
return d.userIDValidator.HasValidIDs(user)
}
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) {
if _, ok := d.tasks.Get(cfg.ID); ok {
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
@@ -476,9 +465,10 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
if cfg.User == "" {
cfg.User = "nobody"
}
d.logger.Debug("setting up user", "user", cfg.User)
if err := d.validateUserIds(cfg); err != nil {
if err := d.userIDValidator.HasValidIDs(cfg.User); err != nil {
return nil, nil, fmt.Errorf("failed host user validation: %v", err)
}

View File

@@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
@@ -39,7 +38,7 @@ import (
type mockIDValidator struct{}
func (mv *mockIDValidator) HasValidIDs(user *user.User) error {
func (mv *mockIDValidator) HasValidIDs(userName string) error {
return nil
}

View File

@@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"strconv"
"time"
@@ -112,7 +111,7 @@ var (
)
type UserIDValidator interface {
HasValidIDs(user *user.User) error
HasValidIDs(userName string) error
}
// Driver is a privileged version of the exec driver. It provides no

View File

@@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
@@ -79,7 +78,7 @@ var (
type mockIDValidator struct{}
func (mv *mockIDValidator) HasValidIDs(user *user.User) error {
func (mv *mockIDValidator) HasValidIDs(userName string) error {
return nil
}

View File

@@ -7,7 +7,6 @@ package rawexec
import (
"fmt"
"os/user"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/plugins/drivers"
@@ -15,22 +14,17 @@ import (
func (d *Driver) Validate(cfg drivers.TaskConfig) error {
usernameToLookup := cfg.User
var user *user.User
var err error
// Uses the current user of the client agent process
// if no override is given (differs from exec)
if usernameToLookup == "" {
user, err = users.Current()
user, err := users.Current()
if err != nil {
return fmt.Errorf("failed to get current user: %w", err)
}
} else {
user, err = users.Lookup(usernameToLookup)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err)
}
usernameToLookup = user.Username
}
return d.userIDValidator.HasValidIDs(user)
return d.userIDValidator.HasValidIDs(usernameToLookup)
}

View File

@@ -6,13 +6,12 @@ package validators
import (
"errors"
"fmt"
"os/user"
"strconv"
"strings"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/lib/idset"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
"github.com/hashicorp/nomad/helper/users"
)
var (
@@ -21,12 +20,21 @@ var (
ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound")
)
type (
// A GroupID (GID) represents a unique numerical value assigned to each user group.
GroupID uint64
// A UserID represents a unique numerical value assigned to each user account.
UserID uint64
)
type validator struct {
// DeniedHostUids configures which host uids are disallowed
deniedUIDs *idset.Set[hw.UserID]
deniedUIDs *idset.Set[UserID]
// DeniedHostGids configures which host gids are disallowed
deniedGIDs *idset.Set[hw.GroupID]
deniedGIDs *idset.Set[GroupID]
// logger will log to the Nomad agent
logger hclog.Logger
@@ -48,8 +56,8 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
valLogger.Debug("group range configured", "denied range", deniedHostGIDs)
v := &validator{
deniedUIDs: idset.Parse[hw.UserID](deniedHostUIDs),
deniedGIDs: idset.Parse[hw.GroupID](deniedHostGIDs),
deniedUIDs: idset.Parse[UserID](deniedHostUIDs),
deniedGIDs: idset.Parse[GroupID](deniedHostGIDs),
logger: valLogger,
}
@@ -58,7 +66,12 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
// HasValidIDs is used when running a task to ensure the
// given user is in the ID range defined in the task config
func (v *validator) HasValidIDs(user *user.User) error {
func (v *validator) HasValidIDs(userName string) error {
user, err := users.Lookup(userName)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", userName, err)
}
uid, err := getUserID(user)
if err != nil {
return fmt.Errorf("validator: %w", err)

View File

@@ -7,16 +7,14 @@ package validators
import (
"os/user"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
)
// noop
func getUserID(user *user.User) (hw.UserID, error) {
func getUserID(*user.User) (UserID, error) {
return 0, nil
}
// noop
func getGroupID(user *user.User) ([]hw.GroupID, error) {
return []hw.GroupID{}, nil
func getGroupID(*user.User) ([]GroupID, error) {
return []GroupID{}, nil
}

View File

@@ -9,34 +9,32 @@ import (
"fmt"
"os/user"
"strconv"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
)
func getUserID(user *user.User) (hw.UserID, error) {
func getUserID(user *user.User) (UserID, error) {
id, err := strconv.ParseUint(user.Uid, 10, 32)
if err != nil {
return 0, fmt.Errorf("unable to convert userid %s to integer", user.Uid)
}
return hw.UserID(id), nil
return UserID(id), nil
}
func getGroupID(user *user.User) ([]hw.GroupID, error) {
func getGroupID(user *user.User) ([]GroupID, error) {
gidStrings, err := user.GroupIds()
if err != nil {
return []hw.GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
}
gids := make([]hw.GroupID, len(gidStrings))
gids := make([]GroupID, len(gidStrings))
for _, gidString := range gidStrings {
u, err := strconv.ParseUint(gidString, 10, 32)
if err != nil {
return []hw.GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err)
return []GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err)
}
gids = append(gids, hw.GroupID(u))
gids = append(gids, GroupID(u))
}
return gids, nil