From 0227788e220e4a9f881340ffc7dee1fd3f22203a Mon Sep 17 00:00:00 2001 From: Juanadelacuesta <8647634+Juanadelacuesta@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:02:51 +0100 Subject: [PATCH] fix: update tests configuration --- drivers/exec/driver.go | 18 ++- drivers/exec/driver_test.go | 126 +++++++++--------- drivers/rawexec/driver.go | 12 +- drivers/rawexec/driver_test.go | 4 +- drivers/rawexec/driver_unix.go | 4 +- drivers/rawexec/driver_unix_test.go | 60 ++++++--- drivers/rawexec/driver_windows.go | 2 +- drivers/shared/validators/validators.go | 73 +++++----- ...dator_default.go => validators_default.go} | 0 ...dators_unix_test.go => validators_test.go} | 28 ++-- drivers/shared/validators/validators_unix.go | 44 ------ 11 files changed, 180 insertions(+), 191 deletions(-) rename drivers/shared/validators/{validator_default.go => validators_default.go} (100%) rename drivers/shared/validators/{validators_unix_test.go => validators_test.go} (85%) diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index b17a19c25..96c5b90a0 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -327,9 +327,10 @@ func (d *Driver) ConfigSchema() (*hclspec.Spec, error) { func (d *Driver) SetConfig(cfg *base.Config) error { // unpack, validate, and set agent plugin config - var config *Config + var config Config + if len(cfg.PluginConfig) != 0 { - if err := base.MsgPackDecode(cfg.PluginConfig, config); err != nil { + if err := base.MsgPackDecode(cfg.PluginConfig, &config); err != nil { return err } } @@ -338,13 +339,16 @@ func (d *Driver) SetConfig(cfg *base.Config) error { return err } - idValidator, err := validators.NewValidator(d.logger, config.DeniedHostUidsStr, config.DeniedHostGidsStr) - if err != nil { - return fmt.Errorf("unable to start validator: %w", err) + if d.userIDValidator == nil { + idValidator, err := validators.NewValidator(d.logger, config.DeniedHostUidsStr, config.DeniedHostGidsStr) + if err != nil { + return fmt.Errorf("unable to start validator: %w", err) + } + + d.userIDValidator = idValidator } - d.userIDValidator = idValidator - d.config = config + d.config = &config if cfg != nil && cfg.AgentConfig != nil { d.nomadConfig = cfg.AgentConfig.Driver diff --git a/drivers/exec/driver_test.go b/drivers/exec/driver_test.go index 6fe7c7bce..8cdab3a5d 100644 --- a/drivers/exec/driver_test.go +++ b/drivers/exec/driver_test.go @@ -23,11 +23,9 @@ import ( "github.com/hashicorp/nomad/client/lib/numalib" ctestutils "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/drivers/shared/executor" - "github.com/hashicorp/nomad/drivers/shared/validators" "github.com/hashicorp/nomad/helper/pluginutils/hclutils" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/testtask" - "github.com/hashicorp/nomad/helper/users" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/base" @@ -834,37 +832,74 @@ func TestExecDriver_OOMKilled(t *testing.T) { } func TestDriver_Config_setDeniedIds(t *testing.T) { + ci.Parallel(t) - t.Run("denied_host_ids", func(t *testing.T) { - invalidUidRange := "invalid denied_host_uids" - invalidGidRange := "invalid denied_host_gids" + testCases := []struct { + name string + uidRanges string + gidRanges string + exError bool + }{ + { + name: "empty_ranges", + uidRanges: "", + gidRanges: "", + exError: false, + }, + { + name: "valid_ranges", + uidRanges: "1-10", + gidRanges: "1-10", + exError: false, + }, + { + name: "empty_GID_invalid_UID_range", + uidRanges: "10-1", + gidRanges: "", + exError: true, + }, + { + name: "empty_UID_invalid_GID_range", + uidRanges: "", + gidRanges: "10-1", + exError: true, + }, + } - for _, tc := range []struct { - uidRanges string - gidRanges string - errorStr *string - }{ - {uidRanges: "", gidRanges: "", errorStr: nil}, - {uidRanges: "1-10", gidRanges: "1-10", errorStr: nil}, - {uidRanges: "10-1", gidRanges: "", errorStr: &invalidUidRange}, - {uidRanges: "", gidRanges: "10-1", errorStr: &invalidGidRange}, - } { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - err := (&Config{ - DefaultModePID: "private", - DefaultModeIPC: "private", + d := newExecDriverTest(t, ctx) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + config := &Config{ + NoPivotRoot: false, + DefaultModePID: executor.IsolationModePrivate, + DefaultModeIPC: executor.IsolationModePrivate, DeniedHostUidsStr: tc.uidRanges, DeniedHostGidsStr: tc.gidRanges, - }).setDeniedIds() - - if tc.errorStr == nil { - must.NoError(t, err) - } else { - must.ErrorContains(t, err, *tc.errorStr) } - } - }) + + var data []byte + must.NoError(t, base.MsgPackEncode(&data, config)) + + baseConfig := &base.Config{ + PluginConfig: data, + AgentConfig: &base.AgentConfig{ + Driver: &base.ClientDriverConfig{ + Topology: d.(*Driver).nomadConfig.Topology, + }, + }, + } + + err := harness.SetConfig(baseConfig) + must.Eq(t, err != nil, tc.exError) + }) + } } func TestDriver_Config_validate(t *testing.T) { @@ -908,45 +943,6 @@ func TestDriver_Config_validate(t *testing.T) { }) } -func TestDriver_TaskConfig_validateUserIds(t *testing.T) { - ci.Parallel(t) - - current, err := users.Current() - require.NoError(t, err) - currentUid := os.Getuid() - nobodyUid, _, _, err := users.LookupUnix("nobody") - require.NoError(t, err) - - allowAll := []validators.IDRange{} - denyCurrent := []validators.IDRange{{Lower: uint64(currentUid), Upper: uint64(currentUid)}} - denyNobody := []validators.IDRange{{Lower: uint64(nobodyUid), Upper: uint64(nobodyUid)}} - configAllowCurrent := Config{DeniedHostUids: allowAll} - configDenyCurrent := Config{DeniedHostUids: denyCurrent} - configDenyAnonymous := Config{DeniedHostUids: denyNobody} - driverConfigNoUserSpecified := drivers.TaskConfig{User: "nobody"} - driverConfigSpecifyCurrent := drivers.TaskConfig{User: current.Name} - currentUserErrStr := fmt.Sprintf("running as uid %d is disallowed", currentUid) - anonUserErrStr := fmt.Sprintf("running as uid %d is disallowed", nobodyUid) - - for _, tc := range []struct { - config Config - driverConfig drivers.TaskConfig - expectedErr string - }{ - {config: configAllowCurrent, driverConfig: driverConfigSpecifyCurrent, expectedErr: ""}, - {config: configDenyCurrent, driverConfig: driverConfigNoUserSpecified, expectedErr: ""}, - {config: configDenyCurrent, driverConfig: driverConfigSpecifyCurrent, expectedErr: currentUserErrStr}, - {config: configDenyAnonymous, driverConfig: driverConfigNoUserSpecified, expectedErr: anonUserErrStr}, - } { - err := (&TaskConfig{}).validateUserIds(&tc.driverConfig, &tc.config) - if tc.expectedErr == "" { - must.NoError(t, err) - } else { - must.ErrorContains(t, err, tc.expectedErr) - } - } -} - func TestDriver_TaskConfig_validate(t *testing.T) { ci.Parallel(t) diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index 48bff2de2..fb3dee8f1 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -153,11 +153,13 @@ type Config struct { DeniedHostUidsStr string `codec:"denied_host_uids"` DeniedHostGidsStr string `codec:"denied_host_gids"` - // DeniedHostUids configures which host uids are disallowed - DeniedHostUids []validators.IDRange + /* + // DeniedHostUids configures which host uids are disallowed + DeniedHostUids []validators.IDRange - // DeniedHostGids configures which host gids are disallowed - DeniedHostGids []validators.IDRange + // DeniedHostGids configures which host gids are disallowed + DeniedHostGids []validators.IDRange + */ } // TaskConfig is the driver configuration of a task within a job @@ -373,7 +375,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive return nil, nil, fmt.Errorf("oom_score_adj must not be negative") } - if err := d.Validate(*d.config, *cfg); err != nil { + if err := d.Validate(*cfg); err != nil { return nil, nil, fmt.Errorf("failed driver config validation: %v", err) } diff --git a/drivers/rawexec/driver_test.go b/drivers/rawexec/driver_test.go index 267827f1d..0186bcc1f 100644 --- a/drivers/rawexec/driver_test.go +++ b/drivers/rawexec/driver_test.go @@ -130,7 +130,9 @@ func TestRawExecDriver_SetConfig(t *testing.T) { bconfig.PluginConfig = data err := harness.SetConfig(bconfig) must.Error(t, err) - must.ErrorContains(t, err, "invalid range \"100-1\", lower bound cannot be greater than upper bound") + + fmt.Println("el error ", err) + must.ErrorContains(t, err, "invalid range deniedHostUIDs \"100-1\": lower bound cannot be greater than upper bound") } func TestRawExecDriver_Fingerprint(t *testing.T) { diff --git a/drivers/rawexec/driver_unix.go b/drivers/rawexec/driver_unix.go index abb77a279..7b3ebb2d4 100644 --- a/drivers/rawexec/driver_unix.go +++ b/drivers/rawexec/driver_unix.go @@ -13,12 +13,12 @@ import ( "github.com/hashicorp/nomad/plugins/drivers" ) -func (d *Driver) Validate(driverCofig Config, cfg drivers.TaskConfig) error { +func (d *Driver) Validate(cfg drivers.TaskConfig) error { usernameToLookup := cfg.User var user *user.User var err error - // Uses the current user of the cleint agent process + // Uses the current user of the client agent process // if no override is given (differs from exec) if usernameToLookup == "" { user, err = users.Current() diff --git a/drivers/rawexec/driver_unix_test.go b/drivers/rawexec/driver_unix_test.go index a5a178003..123083366 100644 --- a/drivers/rawexec/driver_unix_test.go +++ b/drivers/rawexec/driver_unix_test.go @@ -22,7 +22,6 @@ import ( "github.com/hashicorp/nomad/ci" clienttestutil "github.com/hashicorp/nomad/client/testutil" - "github.com/hashicorp/nomad/drivers/shared/validators" "github.com/hashicorp/nomad/helper/testtask" "github.com/hashicorp/nomad/helper/users" "github.com/hashicorp/nomad/helper/uuid" @@ -544,36 +543,67 @@ func TestRawExecUnixDriver_StartWaitRecoverWaitStop(t *testing.T) { wg.Wait() require.NoError(d.DestroyTask(task.ID, false)) require.True(waitDone) - -} +} func TestRawExec_Validate(t *testing.T) { ci.Parallel(t) current, err := users.Current() must.NoError(t, err) - currentUid, err := strconv.ParseUint(current.Uid, 10, 32) - must.NoError(t, err) - currentUserErrStr := fmt.Sprintf("running as uid %d is disallowed", currentUid) + currentUserErrStr := fmt.Sprintf("running as uid %s is disallowed", current.Uid) + + allowAll := "" + denyCurrent := current.Uid + + configAllowCurrent := Config{DeniedHostUidsStr: allowAll} + configDenyCurrent := Config{DeniedHostUidsStr: denyCurrent} - allowAll := []validators.IDRange{} - denyCurrent := []validators.IDRange{{Lower: currentUid, Upper: currentUid}} - configAllowCurrent := Config{DeniedHostUids: allowAll} - configDenyCurrent := Config{DeniedHostUids: denyCurrent} driverConfigNoUserSpecified := drivers.TaskConfig{} - driverConfigSpecifyCurrent := drivers.TaskConfig{User: current.Name} + driverTaskConfig := drivers.TaskConfig{User: current.Name} for _, tc := range []struct { config Config driverConfig drivers.TaskConfig exp error }{ - {config: configAllowCurrent, driverConfig: driverConfigSpecifyCurrent, exp: nil}, - {config: configDenyCurrent, driverConfig: driverConfigNoUserSpecified, exp: errors.New(currentUserErrStr)}, - {config: configDenyCurrent, driverConfig: driverConfigSpecifyCurrent, exp: errors.New(currentUserErrStr)}, + { + config: configAllowCurrent, + driverConfig: driverTaskConfig, + exp: nil, + }, + { + config: configDenyCurrent, + driverConfig: driverConfigNoUserSpecified, + exp: errors.New(currentUserErrStr), + }, + { + config: configDenyCurrent, + driverConfig: driverTaskConfig, + exp: errors.New(currentUserErrStr), + }, } { - must.Eq(t, tc.exp, (&TaskConfig{}).Validate(tc.config, tc.driverConfig)) + + d := newEnabledRawExecDriver(t) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + config := tc.config + + var data []byte + + must.NoError(t, base.MsgPackEncode(&data, config)) + bconfig := &base.Config{ + PluginConfig: data, + AgentConfig: &base.AgentConfig{ + Driver: &base.ClientDriverConfig{ + Topology: d.nomadConfig.Topology, + }, + }, + } + + must.NoError(t, harness.SetConfig(bconfig)) + must.Eq(t, tc.exp, d.Validate(tc.driverConfig)) } } diff --git a/drivers/rawexec/driver_windows.go b/drivers/rawexec/driver_windows.go index 43a7ac48b..f64cbb8f3 100644 --- a/drivers/rawexec/driver_windows.go +++ b/drivers/rawexec/driver_windows.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/nomad/plugins/drivers" ) -func (d *Driver) Validate(driverCofig Config, cfg drivers.TaskConfig) error { +func (d *Driver) Validate(cfg drivers.TaskConfig) error { // This is a noop on windows since the uid and gid cannot be checked against a range easily // We could eventually extend this functionality to check for individual users IDs strings // but that is not currently supported. See driverValidators.HasValidIds for diff --git a/drivers/shared/validators/validators.go b/drivers/shared/validators/validators.go index d3e2e4887..b55690393 100644 --- a/drivers/shared/validators/validators.go +++ b/drivers/shared/validators/validators.go @@ -4,14 +4,23 @@ package validators import ( + "errors" "fmt" "os/user" + "strconv" + "strings" "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/lib/idset" "github.com/hashicorp/nomad/client/lib/numalib/hw" ) +var ( + ErrInvalidBound = errors.New("range bound not valid") + ErrEmptyRange = errors.New("range value cannot be empty") + ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound") +) + type validator struct { // DeniedHostUids configures which host uids are disallowed deniedUIDs *idset.Set[hw.UserID] @@ -23,20 +32,20 @@ type validator struct { logger hclog.Logger } -// IDRange defines a range of uids or gids (to eventually restrict) -type IDRange struct { - Lower uint64 `codec:"from"` - Upper uint64 `codec:"to"` -} - func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) { - // TODO: Validate set, idset assumes its valid - dHostUID := idset.Parse[hw.UserID](deniedHostUIDs) - dHostGID := idset.Parse[hw.GroupID](deniedHostGIDs) + err := validateIDRange("deniedHostUIDs", deniedHostUIDs) + if err != nil { + return nil, err + } + + err = validateIDRange("deniedHostGIDs", deniedHostGIDs) + if err != nil { + return nil, err + } v := &validator{ - deniedUIDs: dHostUID, - deniedGIDs: dHostGID, + deniedUIDs: idset.Parse[hw.UserID](deniedHostUIDs), + deniedGIDs: idset.Parse[hw.GroupID](deniedHostGIDs), logger: logger, } @@ -71,67 +80,57 @@ func (v *validator) HasValidIDs(user *user.User) error { return nil } -/* // ParseIdRange is used to ensure that the configuration for ID ranges is valid. -func ParseIdRange(rangeType string, deniedRanges string) ([]IDRange, error) { - var idRanges []IDRange +// ParseIdRange is used to ensure that the configuration for ID ranges is valid. +func validateIDRange(rangeType string, deniedRanges string) error { + parts := strings.Split(deniedRanges, ",") // exit early if empty string if len(parts) == 1 && parts[0] == "" { - return idRanges, nil + return nil } for _, rangeStr := range parts { - idRange, err := parseRangeString(rangeStr) + err := validateBounds(rangeStr) if err != nil { - return nil, fmt.Errorf("invalid %s: %w", rangeType, err) + return fmt.Errorf("invalid range %s \"%s\": %w", rangeType, rangeStr, err) } - - idRanges = append(idRanges, *idRange) } - return idRanges, nil + return nil } -func parseRangeString(boundsString string) (*IDRange, error) { +func validateBounds(boundsString string) error { uidDenyRangeParts := strings.Split(boundsString, "-") - var idRange IDRange - switch len(uidDenyRangeParts) { case 0: - return nil, fmt.Errorf("range value cannot be empty") + return ErrEmptyRange + case 1: disallowedIdStr := uidDenyRangeParts[0] - disallowedIdInt, err := strconv.ParseUint(disallowedIdStr, 10, 32) - if err != nil { - return nil, fmt.Errorf("range bound not valid, invalid bound: %q ", disallowedIdInt) + if _, err := strconv.ParseUint(disallowedIdStr, 10, 32); err != nil { + return ErrInvalidBound } - idRange.Lower = disallowedIdInt - idRange.Upper = disallowedIdInt case 2: lowerBoundStr := uidDenyRangeParts[0] upperBoundStr := uidDenyRangeParts[1] lowerBoundInt, err := strconv.ParseUint(lowerBoundStr, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid bound: %q", lowerBoundStr) + return ErrInvalidBound } upperBoundInt, err := strconv.ParseUint(upperBoundStr, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid bound: %q", upperBoundStr) + return ErrInvalidBound } if lowerBoundInt > upperBoundInt { - return nil, fmt.Errorf("invalid range %q, lower bound cannot be greater than upper bound", boundsString) + return ErrInvalidRange } - - idRange.Lower = lowerBoundInt - idRange.Upper = upperBoundInt } - return &idRange, nil + return nil } -*/ diff --git a/drivers/shared/validators/validator_default.go b/drivers/shared/validators/validators_default.go similarity index 100% rename from drivers/shared/validators/validator_default.go rename to drivers/shared/validators/validators_default.go diff --git a/drivers/shared/validators/validators_unix_test.go b/drivers/shared/validators/validators_test.go similarity index 85% rename from drivers/shared/validators/validators_unix_test.go rename to drivers/shared/validators/validators_test.go index 0dbfd8e72..bff9b60ed 100644 --- a/drivers/shared/validators/validators_unix_test.go +++ b/drivers/shared/validators/validators_test.go @@ -6,9 +6,11 @@ package validators import ( + "fmt" "os/user" "testing" + "github.com/hashicorp/go-hclog" "github.com/shoenig/test/must" ) @@ -36,7 +38,7 @@ func Test_IDRangeValid(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, err := ParseIdRange("uid", tc.idRange) + err := validateIDRange("uid", tc.idRange) if tc.expectedErr == "" { must.NoError(t, err) } else { @@ -48,23 +50,17 @@ func Test_IDRangeValid(t *testing.T) { } func Test_HasValidIds(t *testing.T) { - var validRange = IDRange{ - Lower: 1, - Upper: 100, - } + var validRange = "1-100" - var validRangeSingle = IDRange{ - Lower: 1, - Upper: 1, - } + var validRangeSingle = "1" - emptyRanges := []IDRange{} - validRangesList := []IDRange{validRange, validRangeSingle} + emptyRanges := "" + validRangesList := fmt.Sprintf("%s,%s", validRange, validRangeSingle) testCases := []struct { name string - uidRanges []IDRange - gidRanges []IDRange + uidRanges string + gidRanges string uid string gid string expectedErr string @@ -91,7 +87,11 @@ func Test_HasValidIds(t *testing.T) { user.Gid = tc.gid } - err := HasValidIds(user, tc.uidRanges, tc.gidRanges) + v, err := NewValidator(hclog.NewNullLogger(), tc.uidRanges, tc.gidRanges) + must.NoError(t, err) + + err = v.HasValidIDs(user) + if tc.expectedErr == "" { must.NoError(t, err) } else { diff --git a/drivers/shared/validators/validators_unix.go b/drivers/shared/validators/validators_unix.go index dc647568d..381857739 100644 --- a/drivers/shared/validators/validators_unix.go +++ b/drivers/shared/validators/validators_unix.go @@ -41,47 +41,3 @@ func getGroupID(user *user.User) ([]hw.GroupID, error) { return gids, nil } - -// HasValidIds is used when running a task to ensure the -// given user is in the ID range defined in the task config -func HasValidIds(user *user.User, deniedHostUIDs, deniedHostGIDs []IDRange) error { - uid, err := strconv.ParseUint(user.Uid, 10, 32) - if err != nil { - return fmt.Errorf("unable to convert userid %s to integer", user.Uid) - } - - // check uids - - for _, uidRange := range deniedHostUIDs { - if uid >= uidRange.Lower && uid <= uidRange.Upper { - return fmt.Errorf("running as uid %d is disallowed", uid) - } - } - - // check gids - - gidStrings, err := user.GroupIds() - if err != nil { - return fmt.Errorf("unable to lookup user's group membership: %w", err) - } - gids := make([]uint64, len(gidStrings)) - - for _, gidString := range gidStrings { - u, err := strconv.ParseUint(gidString, 10, 32) - if err != nil { - return fmt.Errorf("unable to convert user's group %q to integer: %w", gidString, err) - } - - gids = append(gids, u) - } - - for _, gidRange := range deniedHostGIDs { - for _, gid := range gids { - if gid >= gidRange.Lower && gid <= gidRange.Upper { - return fmt.Errorf("running as gid %d is disallowed", gid) - } - } - } - - return nil -}