func: move the validation to a dependency and use id sets

This commit is contained in:
Juanadelacuesta
2024-10-28 18:59:51 +01:00
parent 65be613be9
commit 0cd1b5ff13
8 changed files with 219 additions and 74 deletions

View File

@@ -14,9 +14,15 @@ type (
// Must be an alias because go-msgpack cannot handle the real type.
NodeID = uint8
// A SocketID represents a physicsl CPU socket.
// A SocketID represents a physical CPU socket.
SocketID uint8
// A CoreID represents one logical (vCPU) core.
CoreID uint16
// A GroupID (GID) represents a unique numerical value assigned to each user group.
GroupID uint64
// A UserID represents a unique numerical value assigned to each user account.
UserID uint64
)

View File

@@ -7,6 +7,7 @@ import (
"context"
"fmt"
"os"
"os/user"
"path/filepath"
"runtime"
"sync"
@@ -114,38 +115,6 @@ var (
}
)
// Driver fork/execs tasks using many of the underlying OS's isolation
// features where configured.
type Driver struct {
// eventer is used to handle multiplexing of TaskEvents calls such that an
// event can be broadcast to all callers
eventer *eventer.Eventer
// config is the driver configuration set by the SetConfig RPC
config Config
// nomadConfig is the client config from nomad
nomadConfig *base.ClientDriverConfig
// tasks is the in memory datastore mapping taskIDs to driverHandles
tasks *taskStore
// ctx is the context for the driver. It is passed to other subsystems to
// coordinate shutdown
ctx context.Context
// logger will log to the Nomad agent
logger hclog.Logger
// A tri-state boolean to know if the fingerprinting has happened and
// whether it has been successful
fingerprintSuccess *bool
fingerprintLock sync.Mutex
// compute contains cpu compute information
compute cpustats.Compute
}
// Config is the driver configuration set by the SetConfig RPC call
type Config struct {
// NoPivotRoot disables the use of pivot_root, useful when the root partition
@@ -166,12 +135,6 @@ type Config struct {
DeniedHostUidsStr string `codec:"denied_host_uids"`
DeniedHostGidsStr string `codec:"denied_host_gids"`
// DeniedHostUids configures which host uids are disallowed
DeniedHostUids []validators.IDRange `codec:"-"`
// DeniedHostGids configures which host gids are disallowed
DeniedHostGids []validators.IDRange `codec:"-"`
}
func (c *Config) validate() error {
@@ -195,23 +158,6 @@ func (c *Config) validate() error {
return nil
}
func (c *Config) setDeniedIds() error {
deniedUidRanges, err := validators.ParseIdRange("denied_host_uids", c.DeniedHostUidsStr)
if err != nil {
return err
}
deniedGidRanges, err := validators.ParseIdRange("denied_host_gids", c.DeniedHostGidsStr)
if err != nil {
return err
}
c.DeniedHostUids = deniedUidRanges
c.DeniedHostGids = deniedGidRanges
return nil
}
// TaskConfig is the driver configuration of a task within a job
type TaskConfig struct {
// Command is the thing to exec.
@@ -268,7 +214,9 @@ func (tc *TaskConfig) validateUserIds(cfg *drivers.TaskConfig, driverConfig *Con
return fmt.Errorf("failed to identify user %q: %w", cfg.User, err)
}
return validators.HasValidIds(user, driverConfig.DeniedHostUids, driverConfig.DeniedHostGids)
fmt.Println(user)
//return validators.HasValidIds(user, driverConfig.DeniedHostUids, driverConfig.DeniedHostGids)
return nil
}
// TaskState is the state which is encoded in the handle returned in
@@ -281,6 +229,44 @@ type TaskState struct {
StartedAt time.Time
}
type UserIDValidator interface {
HasValidIDs(user *user.User) error
}
// Driver fork/execs tasks using many of the underlying OS's isolation
// features where configured.
type Driver struct {
// eventer is used to handle multiplexing of TaskEvents calls such that an
// event can be broadcast to all callers
eventer *eventer.Eventer
// config is the driver configuration set by the SetConfig RPC
config *Config
// nomadConfig is the client config from nomad
nomadConfig *base.ClientDriverConfig
// tasks is the in memory datastore mapping taskIDs to driverHandles
tasks *taskStore
// ctx is the context for the driver. It is passed to other subsystems to
// coordinate shutdown
ctx context.Context
// logger will log to the Nomad agent
logger hclog.Logger
// A tri-state boolean to know if the fingerprinting has happened and
// whether it has been successful
fingerprintSuccess *bool
fingerprintLock sync.Mutex
// compute contains cpu compute information
compute cpustats.Compute
userIDValidator UserIDValidator
}
// NewExecDriver returns a new DrivePlugin implementation
func NewExecDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin {
logger = logger.Named(pluginName)
@@ -322,21 +308,42 @@ func (d *Driver) ConfigSchema() (*hclspec.Spec, error) {
return configSpec, nil
}
/* func (d *Driver) setDeniedIds(conf *Config) error {
deniedUidRanges, err := validators.ParseIdRange("denied_host_uids", conf.DeniedHostUidsStr)
if err != nil {
return err
}
deniedGidRanges, err := validators.ParseIdRange("denied_host_gids", conf.DeniedHostGidsStr)
if err != nil {
return err
}
d.DeniedHostUids = deniedUidRanges
d.DeniedHostGids = deniedGidRanges
return nil
} */
func (d *Driver) SetConfig(cfg *base.Config) error {
// unpack, validate, and set agent plugin config
var config Config
var config *Config
if len(cfg.PluginConfig) != 0 {
if err := base.MsgPackDecode(cfg.PluginConfig, &config); err != nil {
if err := base.MsgPackDecode(cfg.PluginConfig, config); err != nil {
return err
}
}
if err := config.validate(); err != nil {
return err
}
if err := config.setDeniedIds(); err != nil {
return err
idValidator, err := validators.NewValidator(d.logger, config.DeniedHostUidsStr, config.DeniedHostGidsStr)
if err != nil {
return fmt.Errorf("unable to start validator: %w", err)
}
d.userIDValidator = idValidator
d.config = config
if cfg != nil && cfg.AgentConfig != nil {
@@ -467,6 +474,15 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return nil
}
func (d *Driver) validateUserIds(cfg *drivers.TaskConfig) error {
user, err := users.Lookup(cfg.User)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", cfg.User, err)
}
return d.userIDValidator.HasValidIDs(user)
}
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) {
if _, ok := d.tasks.Get(cfg.ID); ok {
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
@@ -486,7 +502,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
}
d.logger.Debug("setting up user", "user", cfg.User)
if err := driverConfig.validateUserIds(cfg, &d.config); err != nil {
if err := d.validateUserIds(cfg); err != nil {
return nil, nil, fmt.Errorf("failed host user validation: %v", err)
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"strconv"
"time"
@@ -110,6 +111,10 @@ var (
}
)
type UserIDValidator interface {
HasValidIDs(user *user.User) error
}
// Driver is a privileged version of the exec driver. It provides no
// resource isolation and just fork/execs. The Exec driver should be preferred
// and this should only be used when explicitly needed.
@@ -136,6 +141,8 @@ type Driver struct {
// compute contains cpu compute information
compute cpustats.Compute
userIDValidator UserIDValidator
}
// Config is the driver configuration set by the SetConfig RPC call
@@ -213,7 +220,7 @@ func (d *Driver) SetConfig(cfg *base.Config) error {
}
}
deniedUidRanges, err := validators.ParseIdRange("denied_host_uids", config.DeniedHostUidsStr)
/* deniedUidRanges, err := validators.ParseIdRange("denied_host_uids", config.DeniedHostUidsStr)
if err != nil {
return err
}
@@ -221,12 +228,19 @@ func (d *Driver) SetConfig(cfg *base.Config) error {
deniedGidRanges, err := validators.ParseIdRange("denied_host_gids", config.DeniedHostGidsStr)
if err != nil {
return err
} */
idValidator, err := validators.NewValidator(d.logger, config.DeniedHostUidsStr, config.DeniedHostGidsStr)
if err != nil {
return fmt.Errorf("unable to start validator: %w", err)
}
d.config = &config
d.config.DeniedHostUids = deniedUidRanges
d.config.DeniedHostGids = deniedGidRanges
d.userIDValidator = idValidator
d.config = &config
/* d.config.DeniedHostUids = deniedUidRanges
d.config.DeniedHostGids = deniedGidRanges
*/
if cfg.AgentConfig != nil {
d.nomadConfig = cfg.AgentConfig.Driver
d.compute = cfg.AgentConfig.Compute()
@@ -359,7 +373,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
return nil, nil, fmt.Errorf("oom_score_adj must not be negative")
}
if err := driverConfig.Validate(*d.config, *cfg); err != nil {
if err := d.Validate(*d.config, *cfg); err != nil {
return nil, nil, fmt.Errorf("failed driver config validation: %v", err)
}

View File

@@ -9,12 +9,11 @@ import (
"fmt"
"os/user"
"github.com/hashicorp/nomad/drivers/shared/validators"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/plugins/drivers"
)
func (tc *TaskConfig) Validate(driverCofig Config, cfg drivers.TaskConfig) error {
func (d *Driver) Validate(driverCofig Config, cfg drivers.TaskConfig) error {
usernameToLookup := cfg.User
var user *user.User
var err error
@@ -33,5 +32,5 @@ func (tc *TaskConfig) Validate(driverCofig Config, cfg drivers.TaskConfig) error
}
}
return validators.HasValidIds(user, driverCofig.DeniedHostUids, driverCofig.DeniedHostGids)
return d.userIDValidator.HasValidIDs(user)
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
)
func (tc *TaskConfig) Validate(driverCofig Config, cfg drivers.TaskConfig) error {
func (d *Driver) Validate(driverCofig Config, cfg drivers.TaskConfig) error {
// This is a noop on windows since the uid and gid cannot be checked against a range easily
// We could eventually extend this functionality to check for individual users IDs strings
// but that is not currently supported. See driverValidators.HasValidIds for

View File

@@ -0,0 +1,22 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build windows
package validators
import (
"os/user"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
)
// noop
func getUserID(user *user.User) (hw.UserID, error) {
return 0, nil
}
// noop
func getGroupID(user *user.User) ([]hw.GroupID, error) {
return []hw.GroupID{}, nil
}

View File

@@ -5,17 +5,73 @@ package validators
import (
"fmt"
"strconv"
"strings"
"os/user"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/lib/idset"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
)
type validator struct {
// DeniedHostUids configures which host uids are disallowed
deniedUIDs *idset.Set[hw.UserID]
// DeniedHostGids configures which host gids are disallowed
deniedGIDs *idset.Set[hw.GroupID]
// logger will log to the Nomad agent
logger hclog.Logger
}
// IDRange defines a range of uids or gids (to eventually restrict)
type IDRange struct {
Lower uint64 `codec:"from"`
Upper uint64 `codec:"to"`
}
// ParseIdRange is used to ensure that the configuration for ID ranges is valid.
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) {
// TODO: Validate set, idset assumes its valid
dHostUID := idset.Parse[hw.UserID](deniedHostUIDs)
dHostGID := idset.Parse[hw.GroupID](deniedHostGIDs)
v := &validator{
deniedUIDs: dHostUID,
deniedGIDs: dHostGID,
logger: logger,
}
return v, nil
}
// HasValidIds is used when running a task to ensure the
// given user is in the ID range defined in the task config
func (v *validator) HasValidIDs(user *user.User) error {
uid, err := getUserID(user)
if err != nil {
return fmt.Errorf("validator: %w", err)
}
// check uids
if v.deniedUIDs.Contains(uid) {
return fmt.Errorf("running as uid %d is disallowed", uid)
}
gids, err := getGroupID(user)
if err != nil {
return fmt.Errorf("validator: %w", err)
}
// check gids
for _, gid := range gids {
if v.deniedGIDs.Contains(gid) {
return fmt.Errorf("running as gid %d is disallowed", gid)
}
}
return nil
}
/* // ParseIdRange is used to ensure that the configuration for ID ranges is valid.
func ParseIdRange(rangeType string, deniedRanges string) ([]IDRange, error) {
var idRanges []IDRange
parts := strings.Split(deniedRanges, ",")
@@ -78,3 +134,4 @@ func parseRangeString(boundsString string) (*IDRange, error) {
return &idRange, nil
}
*/

View File

@@ -9,8 +9,39 @@ import (
"fmt"
"os/user"
"strconv"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
)
func getUserID(user *user.User) (hw.UserID, error) {
id, err := strconv.ParseUint(user.Uid, 10, 32)
if err != nil {
return 0, fmt.Errorf("unable to convert userid %s to integer", user.Uid)
}
return hw.UserID(id), nil
}
func getGroupID(user *user.User) ([]hw.GroupID, error) {
gidStrings, err := user.GroupIds()
if err != nil {
return []hw.GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
}
gids := make([]hw.GroupID, len(gidStrings))
for _, gidString := range gidStrings {
u, err := strconv.ParseUint(gidString, 10, 32)
if err != nil {
return []hw.GroupID{}, fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err)
}
gids = append(gids, hw.GroupID(u))
}
return gids, nil
}
// HasValidIds is used when running a task to ensure the
// given user is in the ID range defined in the task config
func HasValidIds(user *user.User, deniedHostUIDs, deniedHostGIDs []IDRange) error {