fix: update tests configuration

This commit is contained in:
Juanadelacuesta
2024-10-29 14:02:51 +01:00
parent 0cd1b5ff13
commit 0227788e22
11 changed files with 180 additions and 191 deletions

View File

@@ -327,9 +327,10 @@ func (d *Driver) ConfigSchema() (*hclspec.Spec, error) {
func (d *Driver) SetConfig(cfg *base.Config) error {
// unpack, validate, and set agent plugin config
var config *Config
var config Config
if len(cfg.PluginConfig) != 0 {
if err := base.MsgPackDecode(cfg.PluginConfig, config); err != nil {
if err := base.MsgPackDecode(cfg.PluginConfig, &config); err != nil {
return err
}
}
@@ -338,13 +339,16 @@ func (d *Driver) SetConfig(cfg *base.Config) error {
return err
}
idValidator, err := validators.NewValidator(d.logger, config.DeniedHostUidsStr, config.DeniedHostGidsStr)
if err != nil {
return fmt.Errorf("unable to start validator: %w", err)
if d.userIDValidator == nil {
idValidator, err := validators.NewValidator(d.logger, config.DeniedHostUidsStr, config.DeniedHostGidsStr)
if err != nil {
return fmt.Errorf("unable to start validator: %w", err)
}
d.userIDValidator = idValidator
}
d.userIDValidator = idValidator
d.config = config
d.config = &config
if cfg != nil && cfg.AgentConfig != nil {
d.nomadConfig = cfg.AgentConfig.Driver

View File

@@ -23,11 +23,9 @@ import (
"github.com/hashicorp/nomad/client/lib/numalib"
ctestutils "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/drivers/shared/executor"
"github.com/hashicorp/nomad/drivers/shared/validators"
"github.com/hashicorp/nomad/helper/pluginutils/hclutils"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/testtask"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/base"
@@ -834,37 +832,74 @@ func TestExecDriver_OOMKilled(t *testing.T) {
}
func TestDriver_Config_setDeniedIds(t *testing.T) {
ci.Parallel(t)
t.Run("denied_host_ids", func(t *testing.T) {
invalidUidRange := "invalid denied_host_uids"
invalidGidRange := "invalid denied_host_gids"
testCases := []struct {
name string
uidRanges string
gidRanges string
exError bool
}{
{
name: "empty_ranges",
uidRanges: "",
gidRanges: "",
exError: false,
},
{
name: "valid_ranges",
uidRanges: "1-10",
gidRanges: "1-10",
exError: false,
},
{
name: "empty_GID_invalid_UID_range",
uidRanges: "10-1",
gidRanges: "",
exError: true,
},
{
name: "empty_UID_invalid_GID_range",
uidRanges: "",
gidRanges: "10-1",
exError: true,
},
}
for _, tc := range []struct {
uidRanges string
gidRanges string
errorStr *string
}{
{uidRanges: "", gidRanges: "", errorStr: nil},
{uidRanges: "1-10", gidRanges: "1-10", errorStr: nil},
{uidRanges: "10-1", gidRanges: "", errorStr: &invalidUidRange},
{uidRanges: "", gidRanges: "10-1", errorStr: &invalidGidRange},
} {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := (&Config{
DefaultModePID: "private",
DefaultModeIPC: "private",
d := newExecDriverTest(t, ctx)
harness := dtestutil.NewDriverHarness(t, d)
defer harness.Kill()
config := &Config{
NoPivotRoot: false,
DefaultModePID: executor.IsolationModePrivate,
DefaultModeIPC: executor.IsolationModePrivate,
DeniedHostUidsStr: tc.uidRanges,
DeniedHostGidsStr: tc.gidRanges,
}).setDeniedIds()
if tc.errorStr == nil {
must.NoError(t, err)
} else {
must.ErrorContains(t, err, *tc.errorStr)
}
}
})
var data []byte
must.NoError(t, base.MsgPackEncode(&data, config))
baseConfig := &base.Config{
PluginConfig: data,
AgentConfig: &base.AgentConfig{
Driver: &base.ClientDriverConfig{
Topology: d.(*Driver).nomadConfig.Topology,
},
},
}
err := harness.SetConfig(baseConfig)
must.Eq(t, err != nil, tc.exError)
})
}
}
func TestDriver_Config_validate(t *testing.T) {
@@ -908,45 +943,6 @@ func TestDriver_Config_validate(t *testing.T) {
})
}
func TestDriver_TaskConfig_validateUserIds(t *testing.T) {
ci.Parallel(t)
current, err := users.Current()
require.NoError(t, err)
currentUid := os.Getuid()
nobodyUid, _, _, err := users.LookupUnix("nobody")
require.NoError(t, err)
allowAll := []validators.IDRange{}
denyCurrent := []validators.IDRange{{Lower: uint64(currentUid), Upper: uint64(currentUid)}}
denyNobody := []validators.IDRange{{Lower: uint64(nobodyUid), Upper: uint64(nobodyUid)}}
configAllowCurrent := Config{DeniedHostUids: allowAll}
configDenyCurrent := Config{DeniedHostUids: denyCurrent}
configDenyAnonymous := Config{DeniedHostUids: denyNobody}
driverConfigNoUserSpecified := drivers.TaskConfig{User: "nobody"}
driverConfigSpecifyCurrent := drivers.TaskConfig{User: current.Name}
currentUserErrStr := fmt.Sprintf("running as uid %d is disallowed", currentUid)
anonUserErrStr := fmt.Sprintf("running as uid %d is disallowed", nobodyUid)
for _, tc := range []struct {
config Config
driverConfig drivers.TaskConfig
expectedErr string
}{
{config: configAllowCurrent, driverConfig: driverConfigSpecifyCurrent, expectedErr: ""},
{config: configDenyCurrent, driverConfig: driverConfigNoUserSpecified, expectedErr: ""},
{config: configDenyCurrent, driverConfig: driverConfigSpecifyCurrent, expectedErr: currentUserErrStr},
{config: configDenyAnonymous, driverConfig: driverConfigNoUserSpecified, expectedErr: anonUserErrStr},
} {
err := (&TaskConfig{}).validateUserIds(&tc.driverConfig, &tc.config)
if tc.expectedErr == "" {
must.NoError(t, err)
} else {
must.ErrorContains(t, err, tc.expectedErr)
}
}
}
func TestDriver_TaskConfig_validate(t *testing.T) {
ci.Parallel(t)

View File

@@ -153,11 +153,13 @@ type Config struct {
DeniedHostUidsStr string `codec:"denied_host_uids"`
DeniedHostGidsStr string `codec:"denied_host_gids"`
// DeniedHostUids configures which host uids are disallowed
DeniedHostUids []validators.IDRange
/*
// DeniedHostUids configures which host uids are disallowed
DeniedHostUids []validators.IDRange
// DeniedHostGids configures which host gids are disallowed
DeniedHostGids []validators.IDRange
// DeniedHostGids configures which host gids are disallowed
DeniedHostGids []validators.IDRange
*/
}
// TaskConfig is the driver configuration of a task within a job
@@ -373,7 +375,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
return nil, nil, fmt.Errorf("oom_score_adj must not be negative")
}
if err := d.Validate(*d.config, *cfg); err != nil {
if err := d.Validate(*cfg); err != nil {
return nil, nil, fmt.Errorf("failed driver config validation: %v", err)
}

View File

@@ -130,7 +130,9 @@ func TestRawExecDriver_SetConfig(t *testing.T) {
bconfig.PluginConfig = data
err := harness.SetConfig(bconfig)
must.Error(t, err)
must.ErrorContains(t, err, "invalid range \"100-1\", lower bound cannot be greater than upper bound")
fmt.Println("el error ", err)
must.ErrorContains(t, err, "invalid range deniedHostUIDs \"100-1\": lower bound cannot be greater than upper bound")
}
func TestRawExecDriver_Fingerprint(t *testing.T) {

View File

@@ -13,12 +13,12 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
)
func (d *Driver) Validate(driverCofig Config, cfg drivers.TaskConfig) error {
func (d *Driver) Validate(cfg drivers.TaskConfig) error {
usernameToLookup := cfg.User
var user *user.User
var err error
// Uses the current user of the cleint agent process
// Uses the current user of the client agent process
// if no override is given (differs from exec)
if usernameToLookup == "" {
user, err = users.Current()

View File

@@ -22,7 +22,6 @@ import (
"github.com/hashicorp/nomad/ci"
clienttestutil "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/drivers/shared/validators"
"github.com/hashicorp/nomad/helper/testtask"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/helper/uuid"
@@ -544,36 +543,67 @@ func TestRawExecUnixDriver_StartWaitRecoverWaitStop(t *testing.T) {
wg.Wait()
require.NoError(d.DestroyTask(task.ID, false))
require.True(waitDone)
}
}
func TestRawExec_Validate(t *testing.T) {
ci.Parallel(t)
current, err := users.Current()
must.NoError(t, err)
currentUid, err := strconv.ParseUint(current.Uid, 10, 32)
must.NoError(t, err)
currentUserErrStr := fmt.Sprintf("running as uid %d is disallowed", currentUid)
currentUserErrStr := fmt.Sprintf("running as uid %s is disallowed", current.Uid)
allowAll := ""
denyCurrent := current.Uid
configAllowCurrent := Config{DeniedHostUidsStr: allowAll}
configDenyCurrent := Config{DeniedHostUidsStr: denyCurrent}
allowAll := []validators.IDRange{}
denyCurrent := []validators.IDRange{{Lower: currentUid, Upper: currentUid}}
configAllowCurrent := Config{DeniedHostUids: allowAll}
configDenyCurrent := Config{DeniedHostUids: denyCurrent}
driverConfigNoUserSpecified := drivers.TaskConfig{}
driverConfigSpecifyCurrent := drivers.TaskConfig{User: current.Name}
driverTaskConfig := drivers.TaskConfig{User: current.Name}
for _, tc := range []struct {
config Config
driverConfig drivers.TaskConfig
exp error
}{
{config: configAllowCurrent, driverConfig: driverConfigSpecifyCurrent, exp: nil},
{config: configDenyCurrent, driverConfig: driverConfigNoUserSpecified, exp: errors.New(currentUserErrStr)},
{config: configDenyCurrent, driverConfig: driverConfigSpecifyCurrent, exp: errors.New(currentUserErrStr)},
{
config: configAllowCurrent,
driverConfig: driverTaskConfig,
exp: nil,
},
{
config: configDenyCurrent,
driverConfig: driverConfigNoUserSpecified,
exp: errors.New(currentUserErrStr),
},
{
config: configDenyCurrent,
driverConfig: driverTaskConfig,
exp: errors.New(currentUserErrStr),
},
} {
must.Eq(t, tc.exp, (&TaskConfig{}).Validate(tc.config, tc.driverConfig))
d := newEnabledRawExecDriver(t)
harness := dtestutil.NewDriverHarness(t, d)
defer harness.Kill()
config := tc.config
var data []byte
must.NoError(t, base.MsgPackEncode(&data, config))
bconfig := &base.Config{
PluginConfig: data,
AgentConfig: &base.AgentConfig{
Driver: &base.ClientDriverConfig{
Topology: d.nomadConfig.Topology,
},
},
}
must.NoError(t, harness.SetConfig(bconfig))
must.Eq(t, tc.exp, d.Validate(tc.driverConfig))
}
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
)
func (d *Driver) Validate(driverCofig Config, cfg drivers.TaskConfig) error {
func (d *Driver) Validate(cfg drivers.TaskConfig) error {
// This is a noop on windows since the uid and gid cannot be checked against a range easily
// We could eventually extend this functionality to check for individual users IDs strings
// but that is not currently supported. See driverValidators.HasValidIds for

View File

@@ -4,14 +4,23 @@
package validators
import (
"errors"
"fmt"
"os/user"
"strconv"
"strings"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/lib/idset"
"github.com/hashicorp/nomad/client/lib/numalib/hw"
)
var (
ErrInvalidBound = errors.New("range bound not valid")
ErrEmptyRange = errors.New("range value cannot be empty")
ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound")
)
type validator struct {
// DeniedHostUids configures which host uids are disallowed
deniedUIDs *idset.Set[hw.UserID]
@@ -23,20 +32,20 @@ type validator struct {
logger hclog.Logger
}
// IDRange defines a range of uids or gids (to eventually restrict)
type IDRange struct {
Lower uint64 `codec:"from"`
Upper uint64 `codec:"to"`
}
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) {
// TODO: Validate set, idset assumes its valid
dHostUID := idset.Parse[hw.UserID](deniedHostUIDs)
dHostGID := idset.Parse[hw.GroupID](deniedHostGIDs)
err := validateIDRange("deniedHostUIDs", deniedHostUIDs)
if err != nil {
return nil, err
}
err = validateIDRange("deniedHostGIDs", deniedHostGIDs)
if err != nil {
return nil, err
}
v := &validator{
deniedUIDs: dHostUID,
deniedGIDs: dHostGID,
deniedUIDs: idset.Parse[hw.UserID](deniedHostUIDs),
deniedGIDs: idset.Parse[hw.GroupID](deniedHostGIDs),
logger: logger,
}
@@ -71,67 +80,57 @@ func (v *validator) HasValidIDs(user *user.User) error {
return nil
}
/* // ParseIdRange is used to ensure that the configuration for ID ranges is valid.
func ParseIdRange(rangeType string, deniedRanges string) ([]IDRange, error) {
var idRanges []IDRange
// ParseIdRange is used to ensure that the configuration for ID ranges is valid.
func validateIDRange(rangeType string, deniedRanges string) error {
parts := strings.Split(deniedRanges, ",")
// exit early if empty string
if len(parts) == 1 && parts[0] == "" {
return idRanges, nil
return nil
}
for _, rangeStr := range parts {
idRange, err := parseRangeString(rangeStr)
err := validateBounds(rangeStr)
if err != nil {
return nil, fmt.Errorf("invalid %s: %w", rangeType, err)
return fmt.Errorf("invalid range %s \"%s\": %w", rangeType, rangeStr, err)
}
idRanges = append(idRanges, *idRange)
}
return idRanges, nil
return nil
}
func parseRangeString(boundsString string) (*IDRange, error) {
func validateBounds(boundsString string) error {
uidDenyRangeParts := strings.Split(boundsString, "-")
var idRange IDRange
switch len(uidDenyRangeParts) {
case 0:
return nil, fmt.Errorf("range value cannot be empty")
return ErrEmptyRange
case 1:
disallowedIdStr := uidDenyRangeParts[0]
disallowedIdInt, err := strconv.ParseUint(disallowedIdStr, 10, 32)
if err != nil {
return nil, fmt.Errorf("range bound not valid, invalid bound: %q ", disallowedIdInt)
if _, err := strconv.ParseUint(disallowedIdStr, 10, 32); err != nil {
return ErrInvalidBound
}
idRange.Lower = disallowedIdInt
idRange.Upper = disallowedIdInt
case 2:
lowerBoundStr := uidDenyRangeParts[0]
upperBoundStr := uidDenyRangeParts[1]
lowerBoundInt, err := strconv.ParseUint(lowerBoundStr, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid bound: %q", lowerBoundStr)
return ErrInvalidBound
}
upperBoundInt, err := strconv.ParseUint(upperBoundStr, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid bound: %q", upperBoundStr)
return ErrInvalidBound
}
if lowerBoundInt > upperBoundInt {
return nil, fmt.Errorf("invalid range %q, lower bound cannot be greater than upper bound", boundsString)
return ErrInvalidRange
}
idRange.Lower = lowerBoundInt
idRange.Upper = upperBoundInt
}
return &idRange, nil
return nil
}
*/

View File

@@ -6,9 +6,11 @@
package validators
import (
"fmt"
"os/user"
"testing"
"github.com/hashicorp/go-hclog"
"github.com/shoenig/test/must"
)
@@ -36,7 +38,7 @@ func Test_IDRangeValid(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := ParseIdRange("uid", tc.idRange)
err := validateIDRange("uid", tc.idRange)
if tc.expectedErr == "" {
must.NoError(t, err)
} else {
@@ -48,23 +50,17 @@ func Test_IDRangeValid(t *testing.T) {
}
func Test_HasValidIds(t *testing.T) {
var validRange = IDRange{
Lower: 1,
Upper: 100,
}
var validRange = "1-100"
var validRangeSingle = IDRange{
Lower: 1,
Upper: 1,
}
var validRangeSingle = "1"
emptyRanges := []IDRange{}
validRangesList := []IDRange{validRange, validRangeSingle}
emptyRanges := ""
validRangesList := fmt.Sprintf("%s,%s", validRange, validRangeSingle)
testCases := []struct {
name string
uidRanges []IDRange
gidRanges []IDRange
uidRanges string
gidRanges string
uid string
gid string
expectedErr string
@@ -91,7 +87,11 @@ func Test_HasValidIds(t *testing.T) {
user.Gid = tc.gid
}
err := HasValidIds(user, tc.uidRanges, tc.gidRanges)
v, err := NewValidator(hclog.NewNullLogger(), tc.uidRanges, tc.gidRanges)
must.NoError(t, err)
err = v.HasValidIDs(user)
if tc.expectedErr == "" {
must.NoError(t, err)
} else {

View File

@@ -41,47 +41,3 @@ func getGroupID(user *user.User) ([]hw.GroupID, error) {
return gids, nil
}
// HasValidIds is used when running a task to ensure the
// given user is in the ID range defined in the task config
func HasValidIds(user *user.User, deniedHostUIDs, deniedHostGIDs []IDRange) error {
uid, err := strconv.ParseUint(user.Uid, 10, 32)
if err != nil {
return fmt.Errorf("unable to convert userid %s to integer", user.Uid)
}
// check uids
for _, uidRange := range deniedHostUIDs {
if uid >= uidRange.Lower && uid <= uidRange.Upper {
return fmt.Errorf("running as uid %d is disallowed", uid)
}
}
// check gids
gidStrings, err := user.GroupIds()
if err != nil {
return fmt.Errorf("unable to lookup user's group membership: %w", err)
}
gids := make([]uint64, len(gidStrings))
for _, gidString := range gidStrings {
u, err := strconv.ParseUint(gidString, 10, 32)
if err != nil {
return fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err)
}
gids = append(gids, u)
}
for _, gidRange := range deniedHostGIDs {
for _, gid := range gids {
if gid >= gidRange.Lower && gid <= gidRange.Upper {
return fmt.Errorf("running as gid %d is disallowed", gid)
}
}
}
return nil
}