From 80e398bbf7e360ef2d141429e48df6cef8996ca1 Mon Sep 17 00:00:00 2001 From: Juanadelacuesta <8647634+Juanadelacuesta@users.noreply.github.com> Date: Thu, 31 Oct 2024 14:51:41 +0100 Subject: [PATCH] test: add tests for validateBounds --- drivers/shared/validators/validators.go | 15 ++--- .../shared/validators/validators_default.go | 2 +- drivers/shared/validators/validators_test.go | 67 ++++++++++++------- drivers/shared/validators/validators_unix.go | 2 +- 4 files changed, 51 insertions(+), 35 deletions(-) diff --git a/drivers/shared/validators/validators.go b/drivers/shared/validators/validators.go index edbbf04b7..ca2518fd9 100644 --- a/drivers/shared/validators/validators.go +++ b/drivers/shared/validators/validators.go @@ -16,7 +16,7 @@ import ( var ( ErrInvalidBound = errors.New("range bound not valid") - ErrEmptyRange = errors.New("range value cannot be empty") + //ErrEmptyRange = errors.New("range value cannot be empty") ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound") ) @@ -29,7 +29,7 @@ type ( UserID uint64 ) -type validator struct { +type Validator struct { // DeniedHostUids configures which host uids are disallowed deniedUIDs *idset.Set[UserID] @@ -40,7 +40,7 @@ type validator struct { logger hclog.Logger } -func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) { +func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*Validator, error) { valLogger := logger.Named("id_validator") err := validateIDRange("deniedHostUIDs", deniedHostUIDs) @@ -55,7 +55,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (* } valLogger.Debug("group range configured", "denied range", deniedHostGIDs) - v := &validator{ + v := &Validator{ deniedUIDs: idset.Parse[UserID](deniedHostUIDs), deniedGIDs: idset.Parse[GroupID](deniedHostGIDs), logger: valLogger, @@ -66,7 +66,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (* // HasValidIDs is used when running a task to ensure the // given user is in the ID range defined in the task config -func (v *validator) HasValidIDs(userName string) error { +func (v *Validator) HasValidIDs(userName string) error { user, err := users.Lookup(userName) if err != nil { return fmt.Errorf("failed to identify user %q: %w", userName, err) @@ -82,7 +82,7 @@ func (v *validator) HasValidIDs(userName string) error { return fmt.Errorf("running as uid %d is disallowed", uid) } - gids, err := getGroupID(user) + gids, err := getGroupsID(user) if err != nil { return fmt.Errorf("validator: %w", err) } @@ -122,9 +122,6 @@ func validateBounds(boundsString string) error { uidDenyRangeParts := strings.Split(boundsString, "-") switch len(uidDenyRangeParts) { - case 0: - return ErrEmptyRange - case 1: disallowedIdStr := uidDenyRangeParts[0] if _, err := strconv.ParseUint(disallowedIdStr, 10, 32); err != nil { diff --git a/drivers/shared/validators/validators_default.go b/drivers/shared/validators/validators_default.go index 1f9adad46..b35cfeb07 100644 --- a/drivers/shared/validators/validators_default.go +++ b/drivers/shared/validators/validators_default.go @@ -15,6 +15,6 @@ func getUserID(*user.User) (UserID, error) { } // noop -func getGroupID(*user.User) ([]GroupID, error) { +func getGroupsID(*user.User) ([]GroupID, error) { return []GroupID{}, nil } diff --git a/drivers/shared/validators/validators_test.go b/drivers/shared/validators/validators_test.go index bff9b60ed..367358e2b 100644 --- a/drivers/shared/validators/validators_test.go +++ b/drivers/shared/validators/validators_test.go @@ -8,6 +8,7 @@ package validators import ( "fmt" "os/user" + "strconv" "testing" "github.com/hashicorp/go-hclog" @@ -50,47 +51,45 @@ func Test_IDRangeValid(t *testing.T) { } func Test_HasValidIds(t *testing.T) { - var validRange = "1-100" - var validRangeSingle = "1" + user, err := user.Current() + must.NoError(t, err) + + userID, err := strconv.ParseUint(user.Uid, 10, 32) + groupID, err := strconv.ParseUint(user.Gid, 10, 32) + must.NoError(t, err) + + userNotIncluded := fmt.Sprintf("%d-%d", userID+1, userID+11) + userIncluded := fmt.Sprintf("%d-%d", userID, userID+11) + userNotIncludedSingle := fmt.Sprintf("%d", userID+1) + + groupNotIncluded := fmt.Sprintf("%d-%d", groupID+1, groupID+11) + groupIncluded := fmt.Sprintf("%d-%d", groupID, groupID+11) + groupNotIncludedSingle := fmt.Sprintf("%d", groupID+1) emptyRanges := "" - validRangesList := fmt.Sprintf("%s,%s", validRange, validRangeSingle) + + userDeniedRangesList := fmt.Sprintf("%s,%s", userNotIncluded, userNotIncludedSingle) + groupDeniedRangesList := fmt.Sprintf("%s,%s", groupNotIncluded, groupNotIncludedSingle) testCases := []struct { name string uidRanges string gidRanges string - uid string - gid string expectedErr string }{ - {name: "no-ranges-are-valid", uidRanges: validRangesList, gidRanges: emptyRanges}, - {name: "uid-and-gid-outside-of-ranges-valid", uidRanges: validRangesList, gidRanges: validRangesList}, - {name: "uid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, uid: "50", expectedErr: "running as uid 50 is disallowed"}, - {name: "gid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, gid: "50", expectedErr: "running as gid 50 is disallowed"}, - {name: "string-uid-throws-error", uid: "banana", expectedErr: "unable to convert userid banana to integer"}, + {name: "user_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: emptyRanges}, + {name: "user_and group_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: groupDeniedRangesList}, + {name: "uid_in_one_of_ranges_is_invalid", uidRanges: userIncluded, gidRanges: groupDeniedRangesList, expectedErr: fmt.Sprintf("running as uid %s is disallowed", user.Uid)}, + {name: "gid-in-one-of-ranges-is-invalid", uidRanges: userDeniedRangesList, gidRanges: groupIncluded, expectedErr: fmt.Sprintf("running as gid %s is disallowed", user.Gid)}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - user := &user.User{ - Uid: "200", - Gid: "200", - } - - if tc.uid != "" { - user.Uid = tc.uid - } - - if tc.gid != "" { - user.Gid = tc.gid - } - v, err := NewValidator(hclog.NewNullLogger(), tc.uidRanges, tc.gidRanges) must.NoError(t, err) - err = v.HasValidIDs(user) + err = v.HasValidIDs(user.Username) if tc.expectedErr == "" { must.NoError(t, err) @@ -101,3 +100,23 @@ func Test_HasValidIds(t *testing.T) { }) } } + +func Test_ValidateBounds(t *testing.T) { + testCases := []struct { + name string + bounds string + expectedErr error + }{ + {name: "invalid_bound", bounds: "banana", expectedErr: ErrInvalidBound}, + {name: "invalid_lower_bound", bounds: "banana-10", expectedErr: ErrInvalidBound}, + {name: "invalid_upper_bound", bounds: "10-banana", expectedErr: ErrInvalidBound}, + {name: "lower_bigger_than_upper", bounds: "10-1", expectedErr: ErrInvalidRange}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateBounds(tc.bounds) + must.ErrorIs(t, err, tc.expectedErr) + }) + } +} diff --git a/drivers/shared/validators/validators_unix.go b/drivers/shared/validators/validators_unix.go index 1d7aa597a..4469205e4 100644 --- a/drivers/shared/validators/validators_unix.go +++ b/drivers/shared/validators/validators_unix.go @@ -20,7 +20,7 @@ func getUserID(user *user.User) (UserID, error) { return UserID(id), nil } -func getGroupID(user *user.User) ([]GroupID, error) { +func getGroupsID(user *user.User) ([]GroupID, error) { gidStrings, err := user.GroupIds() if err != nil { return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)