diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index 9f7da0bb7..3907c2156 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/nomad/client/allocwatcher" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/client/devicemanager" cinterfaces "github.com/hashicorp/nomad/client/interfaces" cstate "github.com/hashicorp/nomad/client/state" cstructs "github.com/hashicorp/nomad/client/structs" @@ -104,6 +105,10 @@ type allocRunner struct { // pluginSingletonLoader is a plugin loader that will returns singleton // instances of the plugins. pluginSingletonLoader loader.PluginCatalog + + // devicemanager is used to mount devices as well as lookup device + // statistics + devicemanager devicemanager.Manager } // NewAllocRunner returns a new allocation runner. @@ -130,6 +135,7 @@ func NewAllocRunner(config *Config) (*allocRunner, error) { deviceStatsReporter: config.DeviceStatsReporter, prevAllocWatcher: config.PrevAllocWatcher, pluginSingletonLoader: config.PluginSingletonLoader, + devicemanager: config.DeviceManager, } // Create the logger based on the allocation ID @@ -167,6 +173,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error { Vault: ar.vaultClient, PluginSingletonLoader: ar.pluginSingletonLoader, DeviceStatsReporter: ar.deviceStatsReporter, + DeviceManager: ar.devicemanager, } // Create, but do not Run, the task runner diff --git a/client/allocrunner/alloc_runner_test.go b/client/allocrunner/alloc_runner_test.go index 31d44c119..0f1c6897c 100644 --- a/client/allocrunner/alloc_runner_test.go +++ b/client/allocrunner/alloc_runner_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/nomad/client/allocwatcher" "github.com/hashicorp/nomad/client/config" consulapi "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/client/devicemanager" "github.com/hashicorp/nomad/client/state" "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/nomad/mock" @@ -69,6 +70,7 @@ func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*Config, fu StateUpdater: &MockStateUpdater{}, PrevAllocWatcher: allocwatcher.NoopPrevAlloc{}, PluginSingletonLoader: singleton.NewSingletonLoader(clientConf.Logger, pluginLoader), + DeviceManager: devicemanager.NoopMockManager(), } return conf, cleanup } diff --git a/client/allocrunner/config.go b/client/allocrunner/config.go index a63b85573..5912f6f80 100644 --- a/client/allocrunner/config.go +++ b/client/allocrunner/config.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/nomad/client/allocwatcher" clientconfig "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/client/devicemanager" "github.com/hashicorp/nomad/client/interfaces" cstate "github.com/hashicorp/nomad/client/state" "github.com/hashicorp/nomad/client/vaultclient" @@ -48,4 +49,8 @@ type Config struct { // PluginSingletonLoader is a plugin loader that will returns singleton // instances of the plugins. PluginSingletonLoader loader.PluginCatalog + + // DeviceManager is used to mount devices as well as lookup device + // statistics + DeviceManager devicemanager.Manager } diff --git a/client/allocrunner/interfaces/task_lifecycle.go b/client/allocrunner/interfaces/task_lifecycle.go index 808afd37a..c34b583a1 100644 --- a/client/allocrunner/interfaces/task_lifecycle.go +++ b/client/allocrunner/interfaces/task_lifecycle.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/nomad/client/driver/env" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" ) /* @@ -47,6 +48,9 @@ type TaskPrestartRequest struct { // Task is the task to run Task *structs.Task + // TaskResources is the resources assigned to the task + TaskResources *structs.AllocatedTaskResources + // Vault token may optionally be set if a Vault token is available VaultToken string @@ -61,6 +65,12 @@ type TaskPrestartResponse struct { // Env is the environment variables to set for the task Env map[string]string + // Mounts is the set of host volumes to mount into the task + Mounts []*drivers.MountConfig + + // Devices are the set of devices to mount into the task + Devices []*drivers.DeviceConfig + // HookData allows the hook to emit data to be passed in the next time it is // run HookData map[string]string diff --git a/client/allocrunner/taskrunner/device_hook.go b/client/allocrunner/taskrunner/device_hook.go new file mode 100644 index 000000000..3d55aba07 --- /dev/null +++ b/client/allocrunner/taskrunner/device_hook.go @@ -0,0 +1,89 @@ +package taskrunner + +import ( + "context" + "fmt" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/plugins/device" + "github.com/hashicorp/nomad/plugins/drivers" +) + +// deviceHook is used to retrieve device mounting information. +type deviceHook struct { + logger log.Logger + dm devicemanager.Manager +} + +func newDeviceHook(dm devicemanager.Manager, logger log.Logger) *deviceHook { + h := &deviceHook{ + dm: dm, + } + h.logger = logger.Named(h.Name()) + return h +} + +func (*deviceHook) Name() string { + return "devices" +} + +func (h *deviceHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + //TODO Can the nil check be removed once the TODO in NewTaskRunner + // where this is set is addressed? + if req.TaskResources == nil || len(req.TaskResources.Devices) == 0 { + resp.Done = true + return nil + } + + // Capture the responses + var reservations []*device.ContainerReservation + for _, req := range req.TaskResources.Devices { + // Ask the device manager for the reservation information + res, err := h.dm.Reserve(req) + if err != nil { + return fmt.Errorf("failed to reserve device %s: %v", req.ID(), err) + } + + reservations = append(reservations, res) + } + + // Build the response + for _, res := range reservations { + for k, v := range res.Envs { + if resp.Env == nil { + resp.Env = make(map[string]string) + } + + resp.Env[k] = v + } + + for _, m := range res.Mounts { + resp.Mounts = append(resp.Mounts, convertMount(m)) + } + + for _, d := range res.Devices { + resp.Devices = append(resp.Devices, convertDevice(d)) + } + } + + resp.Done = true + return nil +} + +func convertMount(in *device.Mount) *drivers.MountConfig { + return &drivers.MountConfig{ + TaskPath: in.TaskPath, + HostPath: in.HostPath, + Readonly: in.ReadOnly, + } +} + +func convertDevice(in *device.DeviceSpec) *drivers.DeviceConfig { + return &drivers.DeviceConfig{ + TaskPath: in.TaskPath, + HostPath: in.HostPath, + Permissions: in.CgroupPerms, + } +} diff --git a/client/allocrunner/taskrunner/device_hook_test.go b/client/allocrunner/taskrunner/device_hook_test.go new file mode 100644 index 000000000..9d9d6d7b3 --- /dev/null +++ b/client/allocrunner/taskrunner/device_hook_test.go @@ -0,0 +1,131 @@ +package taskrunner + +import ( + "context" + "fmt" + "testing" + + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/device" + "github.com/hashicorp/nomad/plugins/drivers" + "github.com/stretchr/testify/require" +) + +func TestDeviceHook_CorrectDevice(t *testing.T) { + t.Parallel() + require := require.New(t) + + dm := devicemanager.NoopMockManager() + l := testlog.HCLogger(t) + h := newDeviceHook(dm, l) + + reqDev := &structs.AllocatedDeviceResource{ + Vendor: "foo", + Type: "bar", + Name: "baz", + DeviceIDs: []string{"123"}, + } + + // Build the hook request + req := &interfaces.TaskPrestartRequest{ + TaskResources: &structs.AllocatedTaskResources{ + Devices: []*structs.AllocatedDeviceResource{ + reqDev, + }, + }, + } + + // Setup the device manager to return a response + dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) { + if d.Vendor != reqDev.Vendor || d.Type != reqDev.Type || + d.Name != reqDev.Name || len(d.DeviceIDs) != 1 || d.DeviceIDs[0] != reqDev.DeviceIDs[0] { + return nil, fmt.Errorf("unexpected request: %+v", d) + } + + res := &device.ContainerReservation{ + Envs: map[string]string{ + "123": "456", + }, + Mounts: []*device.Mount{ + { + ReadOnly: true, + TaskPath: "foo", + HostPath: "bar", + }, + }, + Devices: []*device.DeviceSpec{ + { + TaskPath: "foo", + HostPath: "bar", + CgroupPerms: "123", + }, + }, + } + return res, nil + } + + var resp interfaces.TaskPrestartResponse + err := h.Prestart(context.Background(), req, &resp) + require.NoError(err) + require.NotNil(resp) + + expEnv := map[string]string{ + "123": "456", + } + require.EqualValues(expEnv, resp.Env) + + expMounts := []*drivers.MountConfig{ + { + Readonly: true, + TaskPath: "foo", + HostPath: "bar", + }, + } + require.EqualValues(expMounts, resp.Mounts) + + expDevices := []*drivers.DeviceConfig{ + { + TaskPath: "foo", + HostPath: "bar", + Permissions: "123", + }, + } + require.EqualValues(expDevices, resp.Devices) +} + +func TestDeviceHook_IncorrectDevice(t *testing.T) { + t.Parallel() + require := require.New(t) + + dm := devicemanager.NoopMockManager() + l := testlog.HCLogger(t) + h := newDeviceHook(dm, l) + + reqDev := &structs.AllocatedDeviceResource{ + Vendor: "foo", + Type: "bar", + Name: "baz", + DeviceIDs: []string{"123"}, + } + + // Build the hook request + req := &interfaces.TaskPrestartRequest{ + TaskResources: &structs.AllocatedTaskResources{ + Devices: []*structs.AllocatedDeviceResource{ + reqDev, + }, + }, + } + + // Setup the device manager to return a response + dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) { + return nil, fmt.Errorf("bad request") + } + + var resp interfaces.TaskPrestartResponse + err := h.Prestart(context.Background(), req, &resp) + require.Error(err) +} diff --git a/client/allocrunner/taskrunner/state/state.go b/client/allocrunner/taskrunner/state/state.go index 83481c1ac..738080ba8 100644 --- a/client/allocrunner/taskrunner/state/state.go +++ b/client/allocrunner/taskrunner/state/state.go @@ -63,7 +63,13 @@ type HookState struct { // Prestart is true if the hook has run Prestart successfully and does // not need to run again PrestartDone bool - Data map[string]string + + // Data allows hooks to persist arbitrary state. + Data map[string]string + + // Environment variables set by the hook that will continue to be set + // even if PrestartDone=true. + Env map[string]string } func (h *HookState) Copy() *HookState { diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index 415c5df21..74a636eca 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/client/devicemanager" "github.com/hashicorp/nomad/client/driver/env" cinterfaces "github.com/hashicorp/nomad/client/interfaces" cstate "github.com/hashicorp/nomad/client/state" @@ -53,11 +54,12 @@ const ( ) type TaskRunner struct { - // allocID, taskName, and taskLeader are immutable so these fields may + // allocID, taskName, taskLeader, and taskResources are immutable so these fields may // be accessed without locks - allocID string - taskName string - taskLeader bool + allocID string + taskName string + taskLeader bool + taskResources *structs.AllocatedTaskResources alloc *structs.Allocation allocLock sync.Mutex @@ -140,6 +142,9 @@ type TaskRunner struct { // transistions. runnerHooks []interfaces.TaskHook + // hookResources captures the resources provided by hooks + hookResources *hookResources + // consulClient is the client used by the consul service hook for // registering services and checks consulClient consul.ConsulServiceAPI @@ -171,6 +176,10 @@ type TaskRunner struct { // PluginSingletonLoader is a plugin loader that will returns singleton // instances of the plugins. pluginSingletonLoader loader.PluginCatalog + + // devicemanager is used to mount devices as well as lookup device + // statistics + devicemanager devicemanager.Manager } type Config struct { @@ -196,6 +205,10 @@ type Config struct { // PluginSingletonLoader is a plugin loader that will returns singleton // instances of the plugins. PluginSingletonLoader loader.PluginCatalog + + // DeviceManager is used to mount devices as well as lookup device + // statistics + DeviceManager devicemanager.Manager } func NewTaskRunner(config *Config) (*TaskRunner, error) { @@ -242,11 +255,23 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) { triggerUpdateCh: make(chan struct{}, triggerUpdateChCap), waitCh: make(chan struct{}), pluginSingletonLoader: config.PluginSingletonLoader, + devicemanager: config.DeviceManager, } // Create the logger based on the allocation ID tr.logger = config.Logger.Named("task_runner").With("task", config.Task.Name) + // Pull out the task's resources + ares := tr.alloc.AllocatedResources + if ares != nil { + if tres, ok := ares.Tasks[tr.taskName]; ok { + tr.taskResources = tres + } + + // TODO in the else case should we do a migration from resources as an + // upgrade path + } + // Build the restart tracker. tg := tr.alloc.Job.LookupTaskGroup(tr.alloc.TaskGroup) if tg == nil { @@ -528,7 +553,6 @@ func (tr *TaskRunner) runDriver() error { return fmt.Errorf("failed to encode driver config: %v", err) } - //TODO mounts and devices //XXX Evaluate and encode driver config // If there's already a task handle (eg from a Restore) there's nothing @@ -688,6 +712,8 @@ func (tr *TaskRunner) buildTaskConfig() *drivers.TaskConfig { PercentTicks: float64(task.Resources.CPU) / float64(tr.clientConfig.Node.NodeResources.Cpu.CpuShares), }, }, + Devices: tr.hookResources.getDevices(), + Mounts: tr.hookResources.getMounts(), Env: tr.envBuilder.Build().Map(), User: task.User, AllocDir: tr.taskDir.AllocDir, diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 4850d5e4c..58ae5e6c6 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -3,14 +3,47 @@ package taskrunner import ( "context" "fmt" + "sync" "time" multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" ) +// hookResources captures the resources for the task provided by hooks. +type hookResources struct { + Devices []*drivers.DeviceConfig + Mounts []*drivers.MountConfig + sync.RWMutex +} + +func (h *hookResources) setDevices(d []*drivers.DeviceConfig) { + h.Lock() + h.Devices = d + h.Unlock() +} + +func (h *hookResources) getDevices() []*drivers.DeviceConfig { + h.RLock() + defer h.RUnlock() + return h.Devices +} + +func (h *hookResources) setMounts(m []*drivers.MountConfig) { + h.Lock() + h.Mounts = m + h.Unlock() +} + +func (h *hookResources) getMounts() []*drivers.MountConfig { + h.RLock() + defer h.RUnlock() + return h.Mounts +} + // initHooks intializes the tasks hooks. func (tr *TaskRunner) initHooks() { hookLogger := tr.logger.Named("task_hook") @@ -18,6 +51,9 @@ func (tr *TaskRunner) initHooks() { tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir) + // Add the hook resources + tr.hookResources = &hookResources{} + // Create the task directory hook. This is run first to ensure the // directory path exists for other hooks. tr.runnerHooks = []interfaces.TaskHook{ @@ -27,6 +63,7 @@ func (tr *TaskRunner) initHooks() { newDispatchHook(tr.Alloc(), hookLogger), newArtifactHook(tr, hookLogger), newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), + newDeviceHook(tr.devicemanager, hookLogger), } // If Vault is enabled, add the hook @@ -93,11 +130,13 @@ func (tr *TaskRunner) prestart() error { } name := pre.Name() + // Build the request req := interfaces.TaskPrestartRequest{ - Task: tr.Task(), - TaskDir: tr.taskDir, - TaskEnv: tr.envBuilder.Build(), + Task: tr.Task(), + TaskDir: tr.taskDir, + TaskEnv: tr.envBuilder.Build(), + TaskResources: tr.taskResources, } var origHookState *state.HookState @@ -106,9 +145,17 @@ func (tr *TaskRunner) prestart() error { origHookState = tr.localState.Hooks[name] } tr.stateLock.RUnlock() - if origHookState != nil && origHookState.PrestartDone { - tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) - continue + + if origHookState != nil { + if origHookState.PrestartDone { + tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) + // Always set env vars from hooks + tr.envBuilder.SetHookEnv(name, origHookState.Env) + continue + } + + // Give the hook it's old data + req.HookData = origHookState.Data } req.VaultToken = tr.getVaultToken() @@ -131,6 +178,7 @@ func (tr *TaskRunner) prestart() error { hookState := &state.HookState{ Data: resp.HookData, PrestartDone: resp.Done, + Env: resp.Env, } // Store and persist local state if the hook state has changed @@ -146,8 +194,14 @@ func (tr *TaskRunner) prestart() error { } // Store the environment variables returned by the hook - if len(resp.Env) != 0 { - tr.envBuilder.SetGenericEnv(resp.Env) + tr.envBuilder.SetHookEnv(name, resp.Env) + + // Store the resources + if len(resp.Devices) != 0 { + tr.hookResources.setDevices(resp.Devices) + } + if len(resp.Mounts) != 0 { + tr.hookResources.setMounts(resp.Mounts) } if tr.logger.IsTrace() { diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index df4581092..19a42aa71 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -8,14 +8,17 @@ import ( "time" "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/config" consulapi "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/client/devicemanager" cstate "github.com/hashicorp/nomad/client/state" "github.com/hashicorp/nomad/client/vaultclient" mockdriver "github.com/hashicorp/nomad/drivers/mock" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/device" "github.com/hashicorp/nomad/plugins/shared/catalog" "github.com/hashicorp/nomad/plugins/shared/singleton" "github.com/hashicorp/nomad/testutil" @@ -92,6 +95,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri StateDB: cstate.NoopDB{}, StateUpdater: NewMockTaskStateUpdater(), PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader), + DeviceManager: devicemanager.NoopMockManager(), } return conf, trCleanup } @@ -105,12 +109,11 @@ func TestTaskRunner_Restore_Running(t *testing.T) { alloc := mock.BatchAlloc() alloc.Job.TaskGroups[0].Count = 1 task := alloc.Job.TaskGroups[0].Tasks[0] - task.Name = "testtask" task.Driver = "mock_driver" task.Config = map[string]interface{}{ "run_for": "2s", } - conf, cleanup := testTaskRunnerConfig(t, alloc, "testtask") + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) conf.StateDB = cstate.NewMemDB() // "persist" state between task runners defer cleanup() @@ -166,7 +169,6 @@ func TestTaskRunner_TaskEnv(t *testing.T) { "common_user": "somebody", } task := alloc.Job.TaskGroups[0].Tasks[0] - task.Name = "testtask_taskenv" task.Driver = "mock_driver" task.Meta = map[string]string{ "foo": "bar", @@ -209,3 +211,141 @@ func TestTaskRunner_TaskEnv(t *testing.T) { require.NotNil(mockCfg) assert.Equal(t, "global bar somebody", mockCfg.StdoutString) } + +// Test that devices get sent to the driver +func TestTaskRunner_DevicePropogation(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Create a mock alloc that has a gpu + alloc := mock.BatchAlloc() + alloc.Job.TaskGroups[0].Count = 1 + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "run_for": "100ms", + } + tRes := alloc.AllocatedResources.Tasks[task.Name] + tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"}) + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf.StateDB = cstate.NewMemDB() // "persist" state between task runners + defer cleanup() + + // Setup the devicemanager + dm, ok := conf.DeviceManager.(*devicemanager.MockManager) + require.True(ok) + + dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) { + res := &device.ContainerReservation{ + Envs: map[string]string{ + "ABC": "123", + }, + Mounts: []*device.Mount{ + { + ReadOnly: true, + TaskPath: "foo", + HostPath: "bar", + }, + }, + Devices: []*device.DeviceSpec{ + { + TaskPath: "foo", + HostPath: "bar", + CgroupPerms: "123", + }, + }, + } + return res, nil + } + + // Run the TaskRunner + tr, err := NewTaskRunner(conf) + require.NoError(err) + go tr.Run() + defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup")) + + // Wait for task to complete + select { + case <-tr.WaitCh(): + case <-time.After(3 * time.Second): + } + + // Get the mock driver plugin + driverPlugin, err := conf.PluginSingletonLoader.Dispense( + mockdriver.PluginID.Name, + mockdriver.PluginID.PluginType, + nil, + conf.Logger, + ) + require.NoError(err) + mockDriver := driverPlugin.Plugin().(*mockdriver.Driver) + + // Assert its config has been properly interpolated + driverCfg, _ := mockDriver.GetTaskConfig() + require.NotNil(driverCfg) + require.Len(driverCfg.Devices, 1) + require.Equal(driverCfg.Devices[0].Permissions, "123") + require.Len(driverCfg.Mounts, 1) + require.Equal(driverCfg.Mounts[0].TaskPath, "foo") + require.Contains(driverCfg.Env, "ABC") +} + +// mockEnvHook is a test hook that sets an env var and done=true. It fails if +// it's called more than once. +type mockEnvHook struct { + called int +} + +func (*mockEnvHook) Name() string { + return "mock_env_hook" +} + +func (h *mockEnvHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + h.called++ + + resp.Done = true + resp.Env = map[string]string{ + "mock_hook": "1", + } + + return nil +} + +// TestTaskRunner_Restore_HookEnv asserts that re-running prestart hooks with +// hook environments set restores the environment without re-running done +// hooks. +func TestTaskRunner_Restore_HookEnv(t *testing.T) { + t.Parallel() + require := require.New(t) + + alloc := mock.BatchAlloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf.StateDB = cstate.NewMemDB() // "persist" state between prestart calls + defer cleanup() + + tr, err := NewTaskRunner(conf) + require.NoError(err) + + // Override the default hooks to only run the mock hook + mockHook := &mockEnvHook{} + tr.runnerHooks = []interfaces.TaskHook{mockHook} + + // Manually run prestart hooks + require.NoError(tr.prestart()) + + // Assert env was called + require.Equal(1, mockHook.called) + + // Re-running prestart hooks should *not* call done mock hook + require.NoError(tr.prestart()) + + // Assert env was called + require.Equal(1, mockHook.called) + + // Assert the env is still set + env := tr.envBuilder.Build().All() + require.Contains(env, "mock_hook") + require.Equal("1", env["mock_hook"]) +} diff --git a/client/allocrunner/taskrunner/validate_hook_test.go b/client/allocrunner/taskrunner/validate_hook_test.go index 823f6cdcb..c4850bd04 100644 --- a/client/allocrunner/taskrunner/validate_hook_test.go +++ b/client/allocrunner/taskrunner/validate_hook_test.go @@ -52,12 +52,12 @@ func TestTaskRunner_Validate_ServiceName(t *testing.T) { require.NoError(t, validateTask(task, builder.Build(), conf)) // Add an env var that should validate - builder.SetGenericEnv(map[string]string{"FOO": "bar"}) + builder.SetHookEnv("test", map[string]string{"FOO": "bar"}) task.Services[0].Name = "${FOO}" require.NoError(t, validateTask(task, builder.Build(), conf)) // Add an env var that should *not* validate - builder.SetGenericEnv(map[string]string{"BAD": "invalid/in/consul"}) + builder.SetHookEnv("test", map[string]string{"BAD": "invalid/in/consul"}) task.Services[0].Name = "${BAD}" require.Error(t, validateTask(task, builder.Build(), conf)) } diff --git a/client/client.go b/client/client.go index 07c6ee4b7..1f570a3c6 100644 --- a/client/client.go +++ b/client/client.go @@ -874,6 +874,7 @@ func (c *Client) restoreState() error { PrevAllocWatcher: prevAllocWatcher, PluginLoader: c.config.PluginLoader, PluginSingletonLoader: c.config.PluginSingletonLoader, + DeviceManager: c.devicemanager, } c.configLock.RUnlock() @@ -2054,6 +2055,7 @@ func (c *Client) addAlloc(alloc *structs.Allocation, migrateToken string) error PrevAllocWatcher: prevAllocWatcher, PluginLoader: c.config.PluginLoader, PluginSingletonLoader: c.config.PluginSingletonLoader, + DeviceManager: c.devicemanager, } c.configLock.RUnlock() diff --git a/client/devicemanager/testing.go b/client/devicemanager/testing.go new file mode 100644 index 000000000..3b9f9833b --- /dev/null +++ b/client/devicemanager/testing.go @@ -0,0 +1,48 @@ +package devicemanager + +import ( + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/device" +) + +type ReserveFn func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) +type AllStatsFn func() []*device.DeviceGroupStats +type DeviceStatsFn func(d *structs.AllocatedDeviceResource) (*device.DeviceGroupStats, error) + +func NoopReserve(*structs.AllocatedDeviceResource) (*device.ContainerReservation, error) { + return nil, nil +} + +func NoopAllStats() []*device.DeviceGroupStats { + return nil +} + +func NoopDeviceStats(*structs.AllocatedDeviceResource) (*device.DeviceGroupStats, error) { + return nil, nil +} + +func NoopMockManager() *MockManager { + return &MockManager{ + ReserveF: NoopReserve, + AllStatsF: NoopAllStats, + DeviceStatsF: NoopDeviceStats, + } +} + +type MockManager struct { + ReserveF ReserveFn + AllStatsF AllStatsFn + DeviceStatsF DeviceStatsFn +} + +func (m *MockManager) Run() {} +func (m *MockManager) Shutdown() {} +func (m *MockManager) AllStats() []*device.DeviceGroupStats { return m.AllStatsF() } + +func (m *MockManager) Reserve(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) { + return m.ReserveF(d) +} + +func (m *MockManager) DeviceStats(d *structs.AllocatedDeviceResource) (*device.DeviceGroupStats, error) { + return m.DeviceStatsF(d) +} diff --git a/client/driver/env/env.go b/client/driver/env/env.go index b9acc5d01..e29c3c8b3 100644 --- a/client/driver/env/env.go +++ b/client/driver/env/env.go @@ -297,23 +297,31 @@ type Builder struct { // and affect network env vars. networks []*structs.NetworkResource + // hookEnvs are env vars set by hooks and stored by hook name to + // support adding/removing vars from multiple hooks (eg HookA adds A:1, + // HookB adds A:2, HookA removes A, A should equal 2) + hookEnvs map[string]map[string]string + + // hookNames is a slice of hooks in hookEnvs to apply hookEnvs in the + // order the hooks are run. + hookNames []string + mu *sync.RWMutex } // NewBuilder creates a new task environment builder. func NewBuilder(node *structs.Node, alloc *structs.Allocation, task *structs.Task, region string) *Builder { - b := &Builder{ - region: region, - mu: &sync.RWMutex{}, - } + b := NewEmptyBuilder() + b.region = region return b.setTask(task).setAlloc(alloc).setNode(node) } // NewEmptyBuilder creates a new environment builder. func NewEmptyBuilder() *Builder { return &Builder{ - mu: &sync.RWMutex{}, - envvars: make(map[string]string), + mu: &sync.RWMutex{}, + hookEnvs: map[string]map[string]string{}, + envvars: make(map[string]string), } } @@ -406,7 +414,14 @@ func (b *Builder) Build() *TaskEnv { envMap[k] = hargs.ReplaceEnv(v, nodeAttrs, envMap) } - // Copy template env vars third as they override task env vars + // Copy hook env vars in the order the hooks were run + for _, h := range b.hookNames { + for k, v := range b.hookEnvs[h] { + envMap[k] = hargs.ReplaceEnv(v, nodeAttrs, envMap) + } + } + + // Copy template env vars as they override task env vars for k, v := range b.templateEnv { envMap[k] = v } @@ -428,12 +443,17 @@ func (b *Builder) UpdateTask(alloc *structs.Allocation, task *structs.Task) *Bui return b.setTask(task).setAlloc(alloc) } -func (b *Builder) SetGenericEnv(envs map[string]string) *Builder { +// SetHookEnv sets environment variables from a hook. Variables are +// Last-Write-Wins, so if a hook writes a variable that's also written by a +// later hook, the later hooks value always gets used. +func (b *Builder) SetHookEnv(hook string, envs map[string]string) *Builder { b.mu.Lock() defer b.mu.Unlock() - for k, v := range envs { - b.envvars[k] = v + + if _, exists := b.hookEnvs[hook]; !exists { + b.hookNames = append(b.hookNames, hook) } + b.hookEnvs[hook] = envs return b } diff --git a/client/driver/env/env_test.go b/client/driver/env/env_test.go index 9fb27a1c1..9d035d194 100644 --- a/client/driver/env/env_test.go +++ b/client/driver/env/env_test.go @@ -14,6 +14,7 @@ import ( cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -511,6 +512,42 @@ func TestEnvironment_Envvars(t *testing.T) { } } +// TestEnvironment_HookVars asserts hook env vars are LWW and deletes of later +// writes allow earlier hook's values to be visible. +func TestEnvironment_HookVars(t *testing.T) { + n := mock.Node() + a := mock.Alloc() + builder := NewBuilder(n, a, a.Job.TaskGroups[0].Tasks[0], "global") + + // Add vars from two hooks and assert the second one wins on + // conflicting keys. + builder.SetHookEnv("hookA", map[string]string{ + "foo": "bar", + "baz": "quux", + }) + builder.SetHookEnv("hookB", map[string]string{ + "foo": "123", + "hookB": "wins", + }) + + { + out := builder.Build().All() + assert.Equal(t, "123", out["foo"]) + assert.Equal(t, "quux", out["baz"]) + assert.Equal(t, "wins", out["hookB"]) + } + + // Asserting overwriting hook vars allows the first hooks original + // value to be used. + builder.SetHookEnv("hookB", nil) + { + out := builder.Build().All() + assert.Equal(t, "bar", out["foo"]) + assert.Equal(t, "quux", out["baz"]) + assert.NotContains(t, out, "hookB") + } +} + func TestEnvironment_Interpolate(t *testing.T) { n := mock.Node() n.Attributes["arch"] = "x86" diff --git a/command/agent/consul/int_test.go b/command/agent/consul/int_test.go index 65dab80fa..da6431ffc 100644 --- a/command/agent/consul/int_test.go +++ b/command/agent/consul/int_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/allocrunner/taskrunner" "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/devicemanager" "github.com/hashicorp/nomad/client/state" "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/command/agent/consul" @@ -157,6 +158,7 @@ func TestConsul_Integration(t *testing.T) { StateDB: state.NoopDB{}, StateUpdater: logUpdate, PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader), + DeviceManager: devicemanager.NoopMockManager(), } tr, err := taskrunner.NewTaskRunner(config) diff --git a/nomad/mock/mock.go b/nomad/mock/mock.go index 73880e782..017efb584 100644 --- a/nomad/mock/mock.go +++ b/nomad/mock/mock.go @@ -273,7 +273,7 @@ func BatchJob() *structs.Job { Datacenters: []string{"dc1"}, TaskGroups: []*structs.TaskGroup{ { - Name: "worker", + Name: "web", Count: 10, EphemeralDisk: &structs.EphemeralDisk{ SizeMB: 150, @@ -292,7 +292,7 @@ func BatchJob() *structs.Job { }, Tasks: []*structs.Task{ { - Name: "worker", + Name: "web", Driver: "mock_driver", Config: map[string]interface{}{ "run_for": "500ms", @@ -508,7 +508,7 @@ func BatchAlloc() *structs.Allocation { EvalID: uuid.Generate(), NodeID: "12345678-abcd-efab-cdef-123456789abc", Namespace: structs.DefaultNamespace, - TaskGroup: "worker", + TaskGroup: "web", // TODO Remove once clientv2 gets merged Resources: &structs.Resources{ diff --git a/nomad/structs/node_class.go b/nomad/structs/node_class.go index eef2db8f0..fbeb93966 100644 --- a/nomad/structs/node_class.go +++ b/nomad/structs/node_class.go @@ -42,7 +42,7 @@ func (n *Node) ComputeClass() error { // included in the computed node class. func (n Node) HashInclude(field string, v interface{}) (bool, error) { switch field { - case "Datacenter", "Attributes", "Meta", "NodeClass": + case "Datacenter", "Attributes", "Meta", "NodeClass", "NodeResources": return true, nil default: return false, nil @@ -65,6 +65,44 @@ func (n Node) HashIncludeMap(field string, k, v interface{}) (bool, error) { } } +// HashInclude is used to blacklist uniquely identifying node fields from being +// included in the computed node class. +func (n NodeResources) HashInclude(field string, v interface{}) (bool, error) { + switch field { + case "Devices": + return true, nil + default: + return false, nil + } +} + +// HashInclude is used to blacklist uniquely identifying node fields from being +// included in the computed node class. +func (n NodeDeviceResource) HashInclude(field string, v interface{}) (bool, error) { + switch field { + case "Vendor", "Type", "Name", "Attributes": + return true, nil + default: + return false, nil + } +} + +// HashIncludeMap is used to blacklist uniquely identifying node map keys from being +// included in the computed node class. +func (n NodeDeviceResource) HashIncludeMap(field string, k, v interface{}) (bool, error) { + key, ok := k.(string) + if !ok { + return false, fmt.Errorf("map key %v not a string", k) + } + + switch field { + case "Attributes": + return !IsUniqueNamespace(key), nil + default: + return false, fmt.Errorf("unexpected map field: %v", field) + } +} + // EscapedConstraints takes a set of constraints and returns the set that // escapes computed node classes. func EscapedConstraints(constraints []*Constraint) []*Constraint { diff --git a/nomad/structs/node_class_test.go b/nomad/structs/node_class_test.go index 131312c9b..1c7ffc3f6 100644 --- a/nomad/structs/node_class_test.go +++ b/nomad/structs/node_class_test.go @@ -5,8 +5,11 @@ import ( "testing" "github.com/hashicorp/nomad/helper/uuid" + psstructs "github.com/hashicorp/nomad/plugins/shared/structs" + "github.com/stretchr/testify/require" ) +// TODO Test func testNode() *Node { return &Node{ ID: uuid.Generate(), @@ -49,61 +52,73 @@ func testNode() *Node { } func TestNode_ComputedClass(t *testing.T) { + require := require.New(t) + // Create a node and gets it computed class n := testNode() - if err := n.ComputeClass(); err != nil { - t.Fatalf("ComputeClass() failed: %v", err) - } - if n.ComputedClass == "" { - t.Fatal("ComputeClass() didn't set computed class") - } + require.NoError(n.ComputeClass()) + require.NotEmpty(n.ComputedClass) old := n.ComputedClass // Compute again to ensure determinism - if err := n.ComputeClass(); err != nil { - t.Fatalf("ComputeClass() failed: %v", err) - } - if old != n.ComputedClass { - t.Fatalf("ComputeClass() should have returned same class; got %v; want %v", n.ComputedClass, old) - } + require.NoError(n.ComputeClass()) + require.Equal(n.ComputedClass, old) // Modify a field and compute the class again. n.Datacenter = "New DC" - if err := n.ComputeClass(); err != nil { - t.Fatalf("ComputeClass() failed: %v", err) - } - if n.ComputedClass == "" { - t.Fatal("ComputeClass() didn't set computed class") - } + require.NoError(n.ComputeClass()) + require.NotEqual(n.ComputedClass, old) + old = n.ComputedClass - if old == n.ComputedClass { - t.Fatal("ComputeClass() returned same computed class") - } + // Add a device + n.NodeResources.Devices = append(n.NodeResources.Devices, &NodeDeviceResource{ + Vendor: "foo", + Type: "gpu", + Name: "bam", + }) + require.NoError(n.ComputeClass()) + require.NotEqual(n.ComputedClass, old) } func TestNode_ComputedClass_Ignore(t *testing.T) { + require := require.New(t) + // Create a node and gets it computed class n := testNode() - if err := n.ComputeClass(); err != nil { - t.Fatalf("ComputeClass() failed: %v", err) - } - if n.ComputedClass == "" { - t.Fatal("ComputeClass() didn't set computed class") - } + require.NoError(n.ComputeClass()) + require.NotEmpty(n.ComputedClass) old := n.ComputedClass // Modify an ignored field and compute the class again. n.ID = "New ID" - if err := n.ComputeClass(); err != nil { - t.Fatalf("ComputeClass() failed: %v", err) - } - if n.ComputedClass == "" { - t.Fatal("ComputeClass() didn't set computed class") - } + require.NoError(n.ComputeClass()) + require.NotEmpty(n.ComputedClass) + require.Equal(n.ComputedClass, old) - if old != n.ComputedClass { - t.Fatal("ComputeClass() should have ignored field") +} + +func TestNode_ComputedClass_Device_Attr(t *testing.T) { + require := require.New(t) + + // Create a node and gets it computed class + n := testNode() + d := &NodeDeviceResource{ + Vendor: "foo", + Type: "gpu", + Name: "bam", + Attributes: map[string]*psstructs.Attribute{ + "foo": psstructs.NewBoolAttribute(true), + }, } + n.NodeResources.Devices = append(n.NodeResources.Devices, d) + require.NoError(n.ComputeClass()) + require.NotEmpty(n.ComputedClass) + old := n.ComputedClass + + // Update the attributes to be have a unique value + d.Attributes["unique.bar"] = psstructs.NewBoolAttribute(false) + require.NoError(n.ComputeClass()) + require.Equal(n.ComputedClass, old) } func TestNode_ComputedClass_Attr(t *testing.T) { diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 751e6b36b..09daf9895 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -2385,6 +2385,14 @@ type DeviceIdTuple struct { Name string } +func (d *DeviceIdTuple) String() string { + if d == nil { + return "" + } + + return fmt.Sprintf("%s/%s/%s", d.Vendor, d.Type, d.Name) +} + // Matches returns if this Device ID is a superset of the passed ID. func (id *DeviceIdTuple) Matches(other *DeviceIdTuple) bool { if other == nil { diff --git a/plugins/drivers/driver.go b/plugins/drivers/driver.go index a560a3344..7762c4039 100644 --- a/plugins/drivers/driver.go +++ b/plugins/drivers/driver.go @@ -108,8 +108,8 @@ type TaskConfig struct { Name string Env map[string]string Resources *Resources - Devices []DeviceConfig - Mounts []MountConfig + Devices []*DeviceConfig + Mounts []*MountConfig User string AllocDir string rawDriverConfig []byte diff --git a/plugins/drivers/utils.go b/plugins/drivers/utils.go index 2ab16674f..090ff014a 100644 --- a/plugins/drivers/utils.go +++ b/plugins/drivers/utils.go @@ -57,8 +57,8 @@ func taskConfigFromProto(pb *proto.TaskConfig) *TaskConfig { Env: pb.Env, rawDriverConfig: pb.MsgpackDriverConfig, Resources: resourcesFromProto(pb.Resources), - Devices: []DeviceConfig{}, //TODO - Mounts: []MountConfig{}, //TODO + Devices: devicesFromProto(pb.Devices), + Mounts: mountsFromProto(pb.Mounts), User: pb.User, AllocDir: pb.AllocDir, StdoutPath: pb.StdoutPath, @@ -78,8 +78,8 @@ func taskConfigToProto(cfg *TaskConfig) *proto.TaskConfig { Name: cfg.Name, Env: cfg.Env, Resources: resourcesToProto(cfg.Resources), - Mounts: []*proto.Mount{}, - Devices: []*proto.Device{}, + Devices: devicesToProto(cfg.Devices), + Mounts: mountsToProto(cfg.Mounts), User: cfg.User, AllocDir: cfg.AllocDir, MsgpackDriverConfig: cfg.rawDriverConfig, @@ -195,6 +195,106 @@ func resourcesToProto(r *Resources) *proto.Resources { return &pb } +func devicesFromProto(devices []*proto.Device) []*DeviceConfig { + if devices == nil { + return nil + } + + out := make([]*DeviceConfig, len(devices)) + for i, d := range devices { + out[i] = deviceFromProto(d) + } + + return out +} + +func deviceFromProto(device *proto.Device) *DeviceConfig { + if device == nil { + return nil + } + + return &DeviceConfig{ + TaskPath: device.TaskPath, + HostPath: device.HostPath, + Permissions: device.CgroupPermissions, + } +} + +func mountsFromProto(mounts []*proto.Mount) []*MountConfig { + if mounts == nil { + return nil + } + + out := make([]*MountConfig, len(mounts)) + for i, m := range mounts { + out[i] = mountFromProto(m) + } + + return out +} + +func mountFromProto(mount *proto.Mount) *MountConfig { + if mount == nil { + return nil + } + + return &MountConfig{ + TaskPath: mount.TaskPath, + HostPath: mount.HostPath, + Readonly: mount.Readonly, + } +} + +func devicesToProto(devices []*DeviceConfig) []*proto.Device { + if devices == nil { + return nil + } + + out := make([]*proto.Device, len(devices)) + for i, d := range devices { + out[i] = deviceToProto(d) + } + + return out +} + +func deviceToProto(device *DeviceConfig) *proto.Device { + if device == nil { + return nil + } + + return &proto.Device{ + TaskPath: device.TaskPath, + HostPath: device.HostPath, + CgroupPermissions: device.Permissions, + } +} + +func mountsToProto(mounts []*MountConfig) []*proto.Mount { + if mounts == nil { + return nil + } + + out := make([]*proto.Mount, len(mounts)) + for i, m := range mounts { + out[i] = mountToProto(m) + } + + return out +} + +func mountToProto(mount *MountConfig) *proto.Mount { + if mount == nil { + return nil + } + + return &proto.Mount{ + TaskPath: mount.TaskPath, + HostPath: mount.HostPath, + Readonly: mount.Readonly, + } +} + func taskHandleFromProto(pb *proto.TaskHandle) *TaskHandle { if pb == nil { return &TaskHandle{}