moving user out of validators

This commit is contained in:
Mike Nomitch
2024-03-07 14:24:35 -08:00
committed by Juanadelacuesta
parent d8545fa262
commit 0fbf592131
4 changed files with 28 additions and 34 deletions

View File

@@ -265,7 +265,12 @@ func (tc *TaskConfig) validate() error {
func (tc *TaskConfig) validateUserIds(cfg *drivers.TaskConfig, driverConfig *Config) error {
usernameToLookup := getUsername(cfg)
return validators.HasValidIds(users.Lookup, usernameToLookup, driverConfig.DeniedHostUids, driverConfig.DeniedHostGids)
user, err := users.Lookup(usernameToLookup)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err)
}
return validators.HasValidIds(user, driverConfig.DeniedHostUids, driverConfig.DeniedHostGids)
}
// TaskState is the state which is encoded in the handle returned in

View File

@@ -6,6 +6,8 @@
package rawexec
import (
"fmt"
"github.com/hashicorp/nomad/drivers/shared/validators"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/plugins/drivers"
@@ -25,5 +27,10 @@ func (tc *TaskConfig) Validate(driverCofig Config, cfg drivers.TaskConfig) error
usernameToLookup = current.Name
}
return validators.HasValidIds(users.Lookup, usernameToLookup, driverCofig.DeniedHostUids, driverCofig.DeniedHostGids)
user, err := users.Lookup(usernameToLookup)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err)
}
return validators.HasValidIds(user, driverCofig.DeniedHostUids, driverCofig.DeniedHostGids)
}

View File

@@ -1,8 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !windows
package validators
import (
@@ -36,21 +34,12 @@ func ParseIdRange(rangeType string, deniedRanges string) ([]structs.IDRange, err
return idRanges, nil
}
type userLookupFn func(string) (*user.User, error)
// HasValidIds is used when running a task to ensure the
// given user is in the ID range defined in the task config
func HasValidIds(userLookupFn userLookupFn, usernameToLookup string, deniedHostUIDs, deniedHostGIDs []structs.IDRange) error {
// look up user on host given username
u, err := userLookupFn(usernameToLookup)
func HasValidIds(user *user.User, deniedHostUIDs, deniedHostGIDs []structs.IDRange) error {
uid, err := strconv.ParseUint(user.Uid, 10, 32)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", usernameToLookup, err)
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return fmt.Errorf("unable to convert userid %s to integer", u.Uid)
return fmt.Errorf("unable to convert userid %s to integer", user.Uid)
}
// check uids
@@ -63,7 +52,7 @@ func HasValidIds(userLookupFn userLookupFn, usernameToLookup string, deniedHostU
// check gids
gidStrings, err := u.GroupIds()
gidStrings, err := user.GroupIds()
if err != nil {
return fmt.Errorf("unable to lookup user's group membership: %w", err)
}

View File

@@ -1,8 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !windows
package validators
import (
@@ -63,13 +61,12 @@ func Test_HasValidIds(t *testing.T) {
validRangesList := []structs.IDRange{validRange, validRangeSingle}
testCases := []struct {
name string
uidRanges []structs.IDRange
gidRanges []structs.IDRange
uid string
gid string
expectedErr string
userLookupFunc userLookupFn
name string
uidRanges []structs.IDRange
gidRanges []structs.IDRange
uid string
gid string
expectedErr string
}{
{name: "no-ranges-are-valid", uidRanges: validRangesList, gidRanges: emptyRanges},
{name: "uid-and-gid-outside-of-ranges-valid", uidRanges: validRangesList, gidRanges: validRangesList},
@@ -80,24 +77,20 @@ func Test_HasValidIds(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defaultUserToReturn := &user.User{
user := &user.User{
Uid: "200",
Gid: "200",
}
if tc.uid != "" {
defaultUserToReturn.Uid = tc.uid
user.Uid = tc.uid
}
if tc.gid != "" {
defaultUserToReturn.Gid = tc.gid
user.Gid = tc.gid
}
getUserFn := func(username string) (*user.User, error) {
return defaultUserToReturn, nil
}
err := HasValidIds(getUserFn, "username", tc.uidRanges, tc.gidRanges)
err := HasValidIds(user, tc.uidRanges, tc.gidRanges)
if tc.expectedErr == "" {
must.NoError(t, err)
} else {