From c107b5fd2116a7c8048e14712a560424efbd951a Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Thu, 27 Apr 2023 11:54:10 -0400 Subject: [PATCH] testing: improve fidelity of mock driver task restore (#16990) While working on client status update improvements, I encountered problems getting tests with the mock driver to correctly restore. Unlike typical drivers the mock driver doesn't have an external source of truth for whether the task is running (ex. making API calls to `dockerd` or looking for a running PID), and so in order to make up that information, it re-parses the original task config. But the taskrunner doesn't call the encoding step for `RecoverTask`, only `StartTask`, so the task config the mock driver gets is missing data. Update the mock driver to stash the "external" state in the task state that we'll get from the task runner, so that we don't have to try to recover from the original `TaskConfig` anymore. This should bring the mock driver closer to the behavior of the other drivers. --- drivers/mock/driver.go | 85 +++++++++++------ drivers/mock/driver_test.go | 177 ++++++++++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+), 29 deletions(-) create mode 100644 drivers/mock/driver_test.go diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 390523dff..8acb0a1a8 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -266,6 +266,15 @@ type TaskConfig struct { type MockTaskState struct { StartedAt time.Time + + // these are not strictly "state" but because there's no external + // reattachment we need somewhere to stash this config so we can properly + // restore mock tasks + Command Command + ExecCommand *Command + PluginExitAfter time.Duration + KillAfter time.Duration + ProcState drivers.TaskState } func (d *Driver) PluginInfo() (*base.PluginInfoResponse, error) { @@ -358,21 +367,39 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { return fmt.Errorf("failed to decode task state from handle: %v", err) } - driverCfg, err := parseDriverConfig(handle.Config) - if err != nil { - d.logger.Error("failed to parse driver config from handle", "error", err, "task_id", handle.Config.ID, "config", hclog.Fmt("%+v", handle.Config)) - return fmt.Errorf("failed to parse driver config from handle: %v", err) + taskState.Command.parseDurations() + if taskState.ExecCommand != nil { + taskState.ExecCommand.parseDurations() } - // Remove the plugin exit time if set - driverCfg.pluginExitAfterDuration = 0 - // Correct the run_for time based on how long it has already been running now := time.Now() - driverCfg.runForDuration = driverCfg.runForDuration - now.Sub(taskState.StartedAt) + if !taskState.StartedAt.IsZero() { + taskState.Command.runForDuration = taskState.Command.runForDuration - now.Sub(taskState.StartedAt) + + if taskState.ExecCommand != nil { + taskState.ExecCommand.runForDuration = taskState.ExecCommand.runForDuration - now.Sub(taskState.StartedAt) + } + } + + // Recreate the taskHandle. Because there's no real running process, we'll + // assume we're still running if we've recovered it at all. + killCtx, killCancel := context.WithCancel(context.Background()) + h := &taskHandle{ + logger: d.logger.With("task_name", handle.Config.Name), + pluginExitAfter: taskState.PluginExitAfter, + killAfter: taskState.KillAfter, + waitCh: make(chan any), + taskConfig: handle.Config, + command: taskState.Command, + execCommand: taskState.ExecCommand, + procState: drivers.TaskStateRunning, + startedAt: taskState.StartedAt, + kill: killCancel, + killCh: killCtx.Done(), + Recovered: true, + } - h := newTaskHandle(handle.Config, driverCfg, d.logger) - h.Recovered = true d.tasks.Set(handle.Config.ID, h) go h.run() return nil @@ -423,23 +450,6 @@ func parseDriverConfig(cfg *drivers.TaskConfig) (*TaskConfig, error) { return &driverConfig, nil } -func newTaskHandle(cfg *drivers.TaskConfig, driverConfig *TaskConfig, logger hclog.Logger) *taskHandle { - killCtx, killCancel := context.WithCancel(context.Background()) - h := &taskHandle{ - taskConfig: cfg, - command: driverConfig.Command, - execCommand: driverConfig.ExecCommand, - pluginExitAfter: driverConfig.pluginExitAfterDuration, - killAfter: driverConfig.killAfterDuration, - logger: logger.With("task_name", cfg.Name), - waitCh: make(chan interface{}), - killCh: killCtx.Done(), - kill: killCancel, - startedAt: time.Now(), - } - return h -} - func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) { driverConfig, err := parseDriverConfig(cfg) if err != nil { @@ -477,9 +487,26 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive net.PortMap = map[string]int{parts[0]: port} } - h := newTaskHandle(cfg, driverConfig, d.logger) + killCtx, killCancel := context.WithCancel(context.Background()) + h := &taskHandle{ + taskConfig: cfg, + command: driverConfig.Command, + execCommand: driverConfig.ExecCommand, + pluginExitAfter: driverConfig.pluginExitAfterDuration, + killAfter: driverConfig.killAfterDuration, + logger: d.logger.With("task_name", cfg.Name), + waitCh: make(chan interface{}), + killCh: killCtx.Done(), + kill: killCancel, + startedAt: time.Now(), + } + driverState := MockTaskState{ - StartedAt: h.startedAt, + StartedAt: h.startedAt, + Command: driverConfig.Command, + ExecCommand: driverConfig.ExecCommand, + PluginExitAfter: driverConfig.pluginExitAfterDuration, + KillAfter: driverConfig.killAfterDuration, } handle := drivers.NewTaskHandle(taskHandleVersion) handle.Config = cfg diff --git a/drivers/mock/driver_test.go b/drivers/mock/driver_test.go new file mode 100644 index 000000000..f12155705 --- /dev/null +++ b/drivers/mock/driver_test.go @@ -0,0 +1,177 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package mock + +import ( + "context" + "os" + "sync" + "testing" + "time" + + hclog "github.com/hashicorp/go-hclog" + "github.com/shoenig/test/must" + "github.com/shoenig/test/wait" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/testtask" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + basePlug "github.com/hashicorp/nomad/plugins/base" + "github.com/hashicorp/nomad/plugins/drivers" + dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" +) + +func TestMockDriver_StartWaitRecoverWaitStop(t *testing.T) { + ci.Parallel(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + logger := testlog.HCLogger(t) + d := NewMockDriver(ctx, logger).(*Driver) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + var data []byte + must.NoError(t, basePlug.MsgPackEncode(&data, &Config{})) + bconfig := &basePlug.Config{PluginConfig: data} + must.NoError(t, harness.SetConfig(bconfig)) + + task := &drivers.TaskConfig{ + AllocID: uuid.Generate(), + ID: uuid.Generate(), + Name: "sleep", + Env: map[string]string{}, + } + tc := &TaskConfig{ + Command: Command{ + RunFor: "10s", + runForDuration: time.Second * 10, + }, + PluginExitAfter: "30s", + pluginExitAfterDuration: time.Second * 30, + } + must.NoError(t, task.EncodeConcreteDriverConfig(&tc)) + + testtask.SetTaskConfigEnv(task) + cleanup := mkTestAllocDir(t, harness, logger, task) + t.Cleanup(cleanup) + + handle, _, err := harness.StartTask(task) + must.NoError(t, err) + + ch, err := harness.WaitTask(context.Background(), task.ID) + must.NoError(t, err) + + var waitDone bool + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + result := <-ch + must.Error(t, result.Err) + waitDone = true + }() + + originalStatus, err := d.InspectTask(task.ID) + must.NoError(t, err) + + d.tasks.Delete(task.ID) + + wg.Wait() + must.True(t, waitDone) + _, err = d.InspectTask(task.ID) + must.Eq(t, drivers.ErrTaskNotFound, err) + + err = d.RecoverTask(handle) + must.NoError(t, err) + + // need to make sure the task is left running and doesn't just immediately + // exit after we recover it + must.Wait(t, wait.ContinualSuccess( + wait.BoolFunc(func() bool { + status, err := d.InspectTask(task.ID) + must.NoError(t, err) + return status.State == "running" + }), + wait.Timeout(1*time.Second), + wait.Gap(100*time.Millisecond), + )) + + status, err := d.InspectTask(task.ID) + must.NoError(t, err) + must.Eq(t, originalStatus, status) + + ch, err = harness.WaitTask(context.Background(), task.ID) + must.NoError(t, err) + + wg.Add(1) + waitDone = false + go func() { + defer wg.Done() + result := <-ch + must.NoError(t, result.Err) + must.Zero(t, result.ExitCode) + waitDone = true + }() + + time.Sleep(300 * time.Millisecond) + must.NoError(t, d.StopTask(task.ID, 0, "SIGKILL")) + wg.Wait() + must.NoError(t, d.DestroyTask(task.ID, false)) + must.True(t, waitDone) +} + +func mkTestAllocDir(t *testing.T, h *dtestutil.DriverHarness, logger hclog.Logger, tc *drivers.TaskConfig) func() { + dir, err := os.MkdirTemp("", "nomad_driver_harness-") + must.NoError(t, err) + + allocDir := allocdir.NewAllocDir(logger, dir, tc.AllocID) + must.NoError(t, allocDir.Build()) + + tc.AllocDir = allocDir.AllocDir + + taskDir := allocDir.NewTaskDir(tc.Name) + must.NoError(t, taskDir.Build(false, ci.TinyChroot)) + + task := &structs.Task{ + Name: tc.Name, + Env: tc.Env, + } + + // no logging + tc.StdoutPath = os.DevNull + tc.StderrPath = os.DevNull + + // Create the mock allocation + alloc := mock.Alloc() + alloc.ID = tc.AllocID + if tc.Resources != nil { + alloc.AllocatedResources.Tasks[task.Name] = tc.Resources.NomadResources + } + + taskBuilder := taskenv.NewBuilder(mock.Node(), alloc, task, "global") + dtestutil.SetEnvvars(taskBuilder, drivers.FSIsolationNone, taskDir, config.DefaultConfig()) + + taskEnv := taskBuilder.Build() + if tc.Env == nil { + tc.Env = taskEnv.Map() + } else { + for k, v := range taskEnv.Map() { + if _, ok := tc.Env[k]; !ok { + tc.Env[k] = v + } + } + } + + return func() { + allocDir.Destroy() + } +}