diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 390523dff..8acb0a1a8 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -266,6 +266,15 @@ type TaskConfig struct { type MockTaskState struct { StartedAt time.Time + + // these are not strictly "state" but because there's no external + // reattachment we need somewhere to stash this config so we can properly + // restore mock tasks + Command Command + ExecCommand *Command + PluginExitAfter time.Duration + KillAfter time.Duration + ProcState drivers.TaskState } func (d *Driver) PluginInfo() (*base.PluginInfoResponse, error) { @@ -358,21 +367,39 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { return fmt.Errorf("failed to decode task state from handle: %v", err) } - driverCfg, err := parseDriverConfig(handle.Config) - if err != nil { - d.logger.Error("failed to parse driver config from handle", "error", err, "task_id", handle.Config.ID, "config", hclog.Fmt("%+v", handle.Config)) - return fmt.Errorf("failed to parse driver config from handle: %v", err) + taskState.Command.parseDurations() + if taskState.ExecCommand != nil { + taskState.ExecCommand.parseDurations() } - // Remove the plugin exit time if set - driverCfg.pluginExitAfterDuration = 0 - // Correct the run_for time based on how long it has already been running now := time.Now() - driverCfg.runForDuration = driverCfg.runForDuration - now.Sub(taskState.StartedAt) + if !taskState.StartedAt.IsZero() { + taskState.Command.runForDuration = taskState.Command.runForDuration - now.Sub(taskState.StartedAt) + + if taskState.ExecCommand != nil { + taskState.ExecCommand.runForDuration = taskState.ExecCommand.runForDuration - now.Sub(taskState.StartedAt) + } + } + + // Recreate the taskHandle. Because there's no real running process, we'll + // assume we're still running if we've recovered it at all. + killCtx, killCancel := context.WithCancel(context.Background()) + h := &taskHandle{ + logger: d.logger.With("task_name", handle.Config.Name), + pluginExitAfter: taskState.PluginExitAfter, + killAfter: taskState.KillAfter, + waitCh: make(chan any), + taskConfig: handle.Config, + command: taskState.Command, + execCommand: taskState.ExecCommand, + procState: drivers.TaskStateRunning, + startedAt: taskState.StartedAt, + kill: killCancel, + killCh: killCtx.Done(), + Recovered: true, + } - h := newTaskHandle(handle.Config, driverCfg, d.logger) - h.Recovered = true d.tasks.Set(handle.Config.ID, h) go h.run() return nil @@ -423,23 +450,6 @@ func parseDriverConfig(cfg *drivers.TaskConfig) (*TaskConfig, error) { return &driverConfig, nil } -func newTaskHandle(cfg *drivers.TaskConfig, driverConfig *TaskConfig, logger hclog.Logger) *taskHandle { - killCtx, killCancel := context.WithCancel(context.Background()) - h := &taskHandle{ - taskConfig: cfg, - command: driverConfig.Command, - execCommand: driverConfig.ExecCommand, - pluginExitAfter: driverConfig.pluginExitAfterDuration, - killAfter: driverConfig.killAfterDuration, - logger: logger.With("task_name", cfg.Name), - waitCh: make(chan interface{}), - killCh: killCtx.Done(), - kill: killCancel, - startedAt: time.Now(), - } - return h -} - func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) { driverConfig, err := parseDriverConfig(cfg) if err != nil { @@ -477,9 +487,26 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive net.PortMap = map[string]int{parts[0]: port} } - h := newTaskHandle(cfg, driverConfig, d.logger) + killCtx, killCancel := context.WithCancel(context.Background()) + h := &taskHandle{ + taskConfig: cfg, + command: driverConfig.Command, + execCommand: driverConfig.ExecCommand, + pluginExitAfter: driverConfig.pluginExitAfterDuration, + killAfter: driverConfig.killAfterDuration, + logger: d.logger.With("task_name", cfg.Name), + waitCh: make(chan interface{}), + killCh: killCtx.Done(), + kill: killCancel, + startedAt: time.Now(), + } + driverState := MockTaskState{ - StartedAt: h.startedAt, + StartedAt: h.startedAt, + Command: driverConfig.Command, + ExecCommand: driverConfig.ExecCommand, + PluginExitAfter: driverConfig.pluginExitAfterDuration, + KillAfter: driverConfig.killAfterDuration, } handle := drivers.NewTaskHandle(taskHandleVersion) handle.Config = cfg diff --git a/drivers/mock/driver_test.go b/drivers/mock/driver_test.go new file mode 100644 index 000000000..f12155705 --- /dev/null +++ b/drivers/mock/driver_test.go @@ -0,0 +1,177 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package mock + +import ( + "context" + "os" + "sync" + "testing" + "time" + + hclog "github.com/hashicorp/go-hclog" + "github.com/shoenig/test/must" + "github.com/shoenig/test/wait" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/testtask" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + basePlug "github.com/hashicorp/nomad/plugins/base" + "github.com/hashicorp/nomad/plugins/drivers" + dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" +) + +func TestMockDriver_StartWaitRecoverWaitStop(t *testing.T) { + ci.Parallel(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + logger := testlog.HCLogger(t) + d := NewMockDriver(ctx, logger).(*Driver) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + var data []byte + must.NoError(t, basePlug.MsgPackEncode(&data, &Config{})) + bconfig := &basePlug.Config{PluginConfig: data} + must.NoError(t, harness.SetConfig(bconfig)) + + task := &drivers.TaskConfig{ + AllocID: uuid.Generate(), + ID: uuid.Generate(), + Name: "sleep", + Env: map[string]string{}, + } + tc := &TaskConfig{ + Command: Command{ + RunFor: "10s", + runForDuration: time.Second * 10, + }, + PluginExitAfter: "30s", + pluginExitAfterDuration: time.Second * 30, + } + must.NoError(t, task.EncodeConcreteDriverConfig(&tc)) + + testtask.SetTaskConfigEnv(task) + cleanup := mkTestAllocDir(t, harness, logger, task) + t.Cleanup(cleanup) + + handle, _, err := harness.StartTask(task) + must.NoError(t, err) + + ch, err := harness.WaitTask(context.Background(), task.ID) + must.NoError(t, err) + + var waitDone bool + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + result := <-ch + must.Error(t, result.Err) + waitDone = true + }() + + originalStatus, err := d.InspectTask(task.ID) + must.NoError(t, err) + + d.tasks.Delete(task.ID) + + wg.Wait() + must.True(t, waitDone) + _, err = d.InspectTask(task.ID) + must.Eq(t, drivers.ErrTaskNotFound, err) + + err = d.RecoverTask(handle) + must.NoError(t, err) + + // need to make sure the task is left running and doesn't just immediately + // exit after we recover it + must.Wait(t, wait.ContinualSuccess( + wait.BoolFunc(func() bool { + status, err := d.InspectTask(task.ID) + must.NoError(t, err) + return status.State == "running" + }), + wait.Timeout(1*time.Second), + wait.Gap(100*time.Millisecond), + )) + + status, err := d.InspectTask(task.ID) + must.NoError(t, err) + must.Eq(t, originalStatus, status) + + ch, err = harness.WaitTask(context.Background(), task.ID) + must.NoError(t, err) + + wg.Add(1) + waitDone = false + go func() { + defer wg.Done() + result := <-ch + must.NoError(t, result.Err) + must.Zero(t, result.ExitCode) + waitDone = true + }() + + time.Sleep(300 * time.Millisecond) + must.NoError(t, d.StopTask(task.ID, 0, "SIGKILL")) + wg.Wait() + must.NoError(t, d.DestroyTask(task.ID, false)) + must.True(t, waitDone) +} + +func mkTestAllocDir(t *testing.T, h *dtestutil.DriverHarness, logger hclog.Logger, tc *drivers.TaskConfig) func() { + dir, err := os.MkdirTemp("", "nomad_driver_harness-") + must.NoError(t, err) + + allocDir := allocdir.NewAllocDir(logger, dir, tc.AllocID) + must.NoError(t, allocDir.Build()) + + tc.AllocDir = allocDir.AllocDir + + taskDir := allocDir.NewTaskDir(tc.Name) + must.NoError(t, taskDir.Build(false, ci.TinyChroot)) + + task := &structs.Task{ + Name: tc.Name, + Env: tc.Env, + } + + // no logging + tc.StdoutPath = os.DevNull + tc.StderrPath = os.DevNull + + // Create the mock allocation + alloc := mock.Alloc() + alloc.ID = tc.AllocID + if tc.Resources != nil { + alloc.AllocatedResources.Tasks[task.Name] = tc.Resources.NomadResources + } + + taskBuilder := taskenv.NewBuilder(mock.Node(), alloc, task, "global") + dtestutil.SetEnvvars(taskBuilder, drivers.FSIsolationNone, taskDir, config.DefaultConfig()) + + taskEnv := taskBuilder.Build() + if tc.Env == nil { + tc.Env = taskEnv.Map() + } else { + for k, v := range taskEnv.Map() { + if _, ok := tc.Env[k]; !ok { + tc.Env[k] = v + } + } + } + + return func() { + allocDir.Destroy() + } +}