mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
test: add tests for validateBounds
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrInvalidBound = errors.New("range bound not valid")
|
||||
ErrEmptyRange = errors.New("range value cannot be empty")
|
||||
//ErrEmptyRange = errors.New("range value cannot be empty")
|
||||
ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound")
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ type (
|
||||
UserID uint64
|
||||
)
|
||||
|
||||
type validator struct {
|
||||
type Validator struct {
|
||||
// DeniedHostUids configures which host uids are disallowed
|
||||
deniedUIDs *idset.Set[UserID]
|
||||
|
||||
@@ -40,7 +40,7 @@ type validator struct {
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) {
|
||||
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*Validator, error) {
|
||||
valLogger := logger.Named("id_validator")
|
||||
|
||||
err := validateIDRange("deniedHostUIDs", deniedHostUIDs)
|
||||
@@ -55,7 +55,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
|
||||
}
|
||||
valLogger.Debug("group range configured", "denied range", deniedHostGIDs)
|
||||
|
||||
v := &validator{
|
||||
v := &Validator{
|
||||
deniedUIDs: idset.Parse[UserID](deniedHostUIDs),
|
||||
deniedGIDs: idset.Parse[GroupID](deniedHostGIDs),
|
||||
logger: valLogger,
|
||||
@@ -66,7 +66,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
|
||||
|
||||
// HasValidIDs is used when running a task to ensure the
|
||||
// given user is in the ID range defined in the task config
|
||||
func (v *validator) HasValidIDs(userName string) error {
|
||||
func (v *Validator) HasValidIDs(userName string) error {
|
||||
user, err := users.Lookup(userName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to identify user %q: %w", userName, err)
|
||||
@@ -82,7 +82,7 @@ func (v *validator) HasValidIDs(userName string) error {
|
||||
return fmt.Errorf("running as uid %d is disallowed", uid)
|
||||
}
|
||||
|
||||
gids, err := getGroupID(user)
|
||||
gids, err := getGroupsID(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validator: %w", err)
|
||||
}
|
||||
@@ -122,9 +122,6 @@ func validateBounds(boundsString string) error {
|
||||
uidDenyRangeParts := strings.Split(boundsString, "-")
|
||||
|
||||
switch len(uidDenyRangeParts) {
|
||||
case 0:
|
||||
return ErrEmptyRange
|
||||
|
||||
case 1:
|
||||
disallowedIdStr := uidDenyRangeParts[0]
|
||||
if _, err := strconv.ParseUint(disallowedIdStr, 10, 32); err != nil {
|
||||
|
||||
@@ -15,6 +15,6 @@ func getUserID(*user.User) (UserID, error) {
|
||||
}
|
||||
|
||||
// noop
|
||||
func getGroupID(*user.User) ([]GroupID, error) {
|
||||
func getGroupsID(*user.User) ([]GroupID, error) {
|
||||
return []GroupID{}, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package validators
|
||||
import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
@@ -50,47 +51,45 @@ func Test_IDRangeValid(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_HasValidIds(t *testing.T) {
|
||||
var validRange = "1-100"
|
||||
|
||||
var validRangeSingle = "1"
|
||||
user, err := user.Current()
|
||||
must.NoError(t, err)
|
||||
|
||||
userID, err := strconv.ParseUint(user.Uid, 10, 32)
|
||||
groupID, err := strconv.ParseUint(user.Gid, 10, 32)
|
||||
must.NoError(t, err)
|
||||
|
||||
userNotIncluded := fmt.Sprintf("%d-%d", userID+1, userID+11)
|
||||
userIncluded := fmt.Sprintf("%d-%d", userID, userID+11)
|
||||
userNotIncludedSingle := fmt.Sprintf("%d", userID+1)
|
||||
|
||||
groupNotIncluded := fmt.Sprintf("%d-%d", groupID+1, groupID+11)
|
||||
groupIncluded := fmt.Sprintf("%d-%d", groupID, groupID+11)
|
||||
groupNotIncludedSingle := fmt.Sprintf("%d", groupID+1)
|
||||
|
||||
emptyRanges := ""
|
||||
validRangesList := fmt.Sprintf("%s,%s", validRange, validRangeSingle)
|
||||
|
||||
userDeniedRangesList := fmt.Sprintf("%s,%s", userNotIncluded, userNotIncludedSingle)
|
||||
groupDeniedRangesList := fmt.Sprintf("%s,%s", groupNotIncluded, groupNotIncludedSingle)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
uidRanges string
|
||||
gidRanges string
|
||||
uid string
|
||||
gid string
|
||||
expectedErr string
|
||||
}{
|
||||
{name: "no-ranges-are-valid", uidRanges: validRangesList, gidRanges: emptyRanges},
|
||||
{name: "uid-and-gid-outside-of-ranges-valid", uidRanges: validRangesList, gidRanges: validRangesList},
|
||||
{name: "uid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, uid: "50", expectedErr: "running as uid 50 is disallowed"},
|
||||
{name: "gid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, gid: "50", expectedErr: "running as gid 50 is disallowed"},
|
||||
{name: "string-uid-throws-error", uid: "banana", expectedErr: "unable to convert userid banana to integer"},
|
||||
{name: "user_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: emptyRanges},
|
||||
{name: "user_and group_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: groupDeniedRangesList},
|
||||
{name: "uid_in_one_of_ranges_is_invalid", uidRanges: userIncluded, gidRanges: groupDeniedRangesList, expectedErr: fmt.Sprintf("running as uid %s is disallowed", user.Uid)},
|
||||
{name: "gid-in-one-of-ranges-is-invalid", uidRanges: userDeniedRangesList, gidRanges: groupIncluded, expectedErr: fmt.Sprintf("running as gid %s is disallowed", user.Gid)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user := &user.User{
|
||||
Uid: "200",
|
||||
Gid: "200",
|
||||
}
|
||||
|
||||
if tc.uid != "" {
|
||||
user.Uid = tc.uid
|
||||
}
|
||||
|
||||
if tc.gid != "" {
|
||||
user.Gid = tc.gid
|
||||
}
|
||||
|
||||
v, err := NewValidator(hclog.NewNullLogger(), tc.uidRanges, tc.gidRanges)
|
||||
must.NoError(t, err)
|
||||
|
||||
err = v.HasValidIDs(user)
|
||||
err = v.HasValidIDs(user.Username)
|
||||
|
||||
if tc.expectedErr == "" {
|
||||
must.NoError(t, err)
|
||||
@@ -101,3 +100,23 @@ func Test_HasValidIds(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ValidateBounds(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
bounds string
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "invalid_bound", bounds: "banana", expectedErr: ErrInvalidBound},
|
||||
{name: "invalid_lower_bound", bounds: "banana-10", expectedErr: ErrInvalidBound},
|
||||
{name: "invalid_upper_bound", bounds: "10-banana", expectedErr: ErrInvalidBound},
|
||||
{name: "lower_bigger_than_upper", bounds: "10-1", expectedErr: ErrInvalidRange},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateBounds(tc.bounds)
|
||||
must.ErrorIs(t, err, tc.expectedErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func getUserID(user *user.User) (UserID, error) {
|
||||
return UserID(id), nil
|
||||
}
|
||||
|
||||
func getGroupID(user *user.User) ([]GroupID, error) {
|
||||
func getGroupsID(user *user.User) ([]GroupID, error) {
|
||||
gidStrings, err := user.GroupIds()
|
||||
if err != nil {
|
||||
return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user