test: add tests for validateBounds

This commit is contained in:
Juanadelacuesta
2024-10-31 14:51:41 +01:00
parent d0b015ec01
commit 80e398bbf7
4 changed files with 51 additions and 35 deletions

View File

@@ -16,7 +16,7 @@ import (
var (
ErrInvalidBound = errors.New("range bound not valid")
ErrEmptyRange = errors.New("range value cannot be empty")
//ErrEmptyRange = errors.New("range value cannot be empty")
ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound")
)
@@ -29,7 +29,7 @@ type (
UserID uint64
)
type validator struct {
type Validator struct {
// DeniedHostUids configures which host uids are disallowed
deniedUIDs *idset.Set[UserID]
@@ -40,7 +40,7 @@ type validator struct {
logger hclog.Logger
}
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) {
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*Validator, error) {
valLogger := logger.Named("id_validator")
err := validateIDRange("deniedHostUIDs", deniedHostUIDs)
@@ -55,7 +55,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
}
valLogger.Debug("group range configured", "denied range", deniedHostGIDs)
v := &validator{
v := &Validator{
deniedUIDs: idset.Parse[UserID](deniedHostUIDs),
deniedGIDs: idset.Parse[GroupID](deniedHostGIDs),
logger: valLogger,
@@ -66,7 +66,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
// HasValidIDs is used when running a task to ensure the
// given user is in the ID range defined in the task config
func (v *validator) HasValidIDs(userName string) error {
func (v *Validator) HasValidIDs(userName string) error {
user, err := users.Lookup(userName)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", userName, err)
@@ -82,7 +82,7 @@ func (v *validator) HasValidIDs(userName string) error {
return fmt.Errorf("running as uid %d is disallowed", uid)
}
gids, err := getGroupID(user)
gids, err := getGroupsID(user)
if err != nil {
return fmt.Errorf("validator: %w", err)
}
@@ -122,9 +122,6 @@ func validateBounds(boundsString string) error {
uidDenyRangeParts := strings.Split(boundsString, "-")
switch len(uidDenyRangeParts) {
case 0:
return ErrEmptyRange
case 1:
disallowedIdStr := uidDenyRangeParts[0]
if _, err := strconv.ParseUint(disallowedIdStr, 10, 32); err != nil {

View File

@@ -15,6 +15,6 @@ func getUserID(*user.User) (UserID, error) {
}
// noop
func getGroupID(*user.User) ([]GroupID, error) {
func getGroupsID(*user.User) ([]GroupID, error) {
return []GroupID{}, nil
}

View File

@@ -8,6 +8,7 @@ package validators
import (
"fmt"
"os/user"
"strconv"
"testing"
"github.com/hashicorp/go-hclog"
@@ -50,47 +51,45 @@ func Test_IDRangeValid(t *testing.T) {
}
func Test_HasValidIds(t *testing.T) {
var validRange = "1-100"
var validRangeSingle = "1"
user, err := user.Current()
must.NoError(t, err)
userID, err := strconv.ParseUint(user.Uid, 10, 32)
groupID, err := strconv.ParseUint(user.Gid, 10, 32)
must.NoError(t, err)
userNotIncluded := fmt.Sprintf("%d-%d", userID+1, userID+11)
userIncluded := fmt.Sprintf("%d-%d", userID, userID+11)
userNotIncludedSingle := fmt.Sprintf("%d", userID+1)
groupNotIncluded := fmt.Sprintf("%d-%d", groupID+1, groupID+11)
groupIncluded := fmt.Sprintf("%d-%d", groupID, groupID+11)
groupNotIncludedSingle := fmt.Sprintf("%d", groupID+1)
emptyRanges := ""
validRangesList := fmt.Sprintf("%s,%s", validRange, validRangeSingle)
userDeniedRangesList := fmt.Sprintf("%s,%s", userNotIncluded, userNotIncludedSingle)
groupDeniedRangesList := fmt.Sprintf("%s,%s", groupNotIncluded, groupNotIncludedSingle)
testCases := []struct {
name string
uidRanges string
gidRanges string
uid string
gid string
expectedErr string
}{
{name: "no-ranges-are-valid", uidRanges: validRangesList, gidRanges: emptyRanges},
{name: "uid-and-gid-outside-of-ranges-valid", uidRanges: validRangesList, gidRanges: validRangesList},
{name: "uid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, uid: "50", expectedErr: "running as uid 50 is disallowed"},
{name: "gid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, gid: "50", expectedErr: "running as gid 50 is disallowed"},
{name: "string-uid-throws-error", uid: "banana", expectedErr: "unable to convert userid banana to integer"},
{name: "user_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: emptyRanges},
{name: "user_and group_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: groupDeniedRangesList},
{name: "uid_in_one_of_ranges_is_invalid", uidRanges: userIncluded, gidRanges: groupDeniedRangesList, expectedErr: fmt.Sprintf("running as uid %s is disallowed", user.Uid)},
{name: "gid-in-one-of-ranges-is-invalid", uidRanges: userDeniedRangesList, gidRanges: groupIncluded, expectedErr: fmt.Sprintf("running as gid %s is disallowed", user.Gid)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user := &user.User{
Uid: "200",
Gid: "200",
}
if tc.uid != "" {
user.Uid = tc.uid
}
if tc.gid != "" {
user.Gid = tc.gid
}
v, err := NewValidator(hclog.NewNullLogger(), tc.uidRanges, tc.gidRanges)
must.NoError(t, err)
err = v.HasValidIDs(user)
err = v.HasValidIDs(user.Username)
if tc.expectedErr == "" {
must.NoError(t, err)
@@ -101,3 +100,23 @@ func Test_HasValidIds(t *testing.T) {
})
}
}
func Test_ValidateBounds(t *testing.T) {
testCases := []struct {
name string
bounds string
expectedErr error
}{
{name: "invalid_bound", bounds: "banana", expectedErr: ErrInvalidBound},
{name: "invalid_lower_bound", bounds: "banana-10", expectedErr: ErrInvalidBound},
{name: "invalid_upper_bound", bounds: "10-banana", expectedErr: ErrInvalidBound},
{name: "lower_bigger_than_upper", bounds: "10-1", expectedErr: ErrInvalidRange},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validateBounds(tc.bounds)
must.ErrorIs(t, err, tc.expectedErr)
})
}
}

View File

@@ -20,7 +20,7 @@ func getUserID(user *user.User) (UserID, error) {
return UserID(id), nil
}
func getGroupID(user *user.User) ([]GroupID, error) {
func getGroupsID(user *user.User) ([]GroupID, error) {
gidStrings, err := user.GroupIds()
if err != nil {
return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)