diff --git a/client/allocrunner/interfaces/task_lifecycle.go b/client/allocrunner/interfaces/task_lifecycle.go index 680c0d9c2..ee99a507b 100644 --- a/client/allocrunner/interfaces/task_lifecycle.go +++ b/client/allocrunner/interfaces/task_lifecycle.go @@ -89,7 +89,7 @@ type TaskPrestartHook interface { // Prestart is called before the task is started including after every // restart. Prestart is not called if the allocation is terminal. // - // The context is cancelled if the task is killed. + // The context is cancelled if the task is killed or shutdown. Prestart(context.Context, *TaskPrestartRequest, *TaskPrestartResponse) error } diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 25d1b59bc..374f29f42 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/LK4D4/joincontext" multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" @@ -192,8 +193,11 @@ func (tr *TaskRunner) prestart() error { } // Run the prestart hook + // use a joint context to allow any blocking pre-start hooks + // to be canceled by either killCtx or shutdownCtx + joinedCtx, _ := joincontext.Join(tr.killCtx, tr.shutdownCtx) var resp interfaces.TaskPrestartResponse - if err := pre.Prestart(tr.killCtx, &req, &resp); err != nil { + if err := pre.Prestart(joinedCtx, &req, &resp); err != nil { tr.emitHookError(err, name) return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err) } diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 8a6ab8cf5..25124c3fd 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -1742,6 +1742,69 @@ func TestTaskRunner_Template_Artifact(t *testing.T) { require.NoErrorf(t, err, "%v not rendered", f2) } +// TestTaskRunner_Template_BlockingPreStart asserts that a template +// that fails to render in PreStart can gracefully be shutdown by +// either killCtx or shutdownCtx +func TestTaskRunner_Template_BlockingPreStart(t *testing.T) { + t.Parallel() + + alloc := mock.BatchAlloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Templates = []*structs.Template{ + { + EmbeddedTmpl: `{{ with secret "foo/secret" }}{{ .Data.certificate }}{{ end }}`, + DestPath: "local/test", + ChangeMode: structs.TemplateChangeModeNoop, + }, + } + + task.Vault = &structs.Vault{Policies: []string{"default"}} + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + defer cleanup() + + tr, err := NewTaskRunner(conf) + require.NoError(t, err) + go tr.Run() + defer tr.Shutdown() + + testutil.WaitForResult(func() (bool, error) { + ts := tr.TaskState() + + if len(ts.Events) == 0 { + return false, fmt.Errorf("no events yet") + } + + for _, e := range ts.Events { + if e.Type == "Template" && strings.Contains(e.DisplayMessage, "vault.read(foo/secret)") { + return true, nil + } + } + + return false, fmt.Errorf("no missing vault secret template event yet: %#v", ts.Events) + + }, func(err error) { + require.NoError(t, err) + }) + + shutdown := func() <-chan bool { + finished := make(chan bool) + go func() { + tr.Shutdown() + finished <- true + }() + + return finished + } + + select { + case <-shutdown(): + // it shut down like it should have + case <-time.After(10 * time.Second): + require.Fail(t, "timeout shutting down task") + } +} + // TestTaskRunner_Template_NewVaultToken asserts that a new vault token is // created when rendering template and that it is revoked on alloc completion func TestTaskRunner_Template_NewVaultToken(t *testing.T) {