diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index 00515439a..398832e31 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -146,6 +146,8 @@ type allocRunner struct { // servers have been contacted for the first time in case of a failed // restore. serversContactedCh chan struct{} + + taskHookCoordinator *taskHookCoordinator } // NewAllocRunner returns a new allocation runner. @@ -190,6 +192,8 @@ func NewAllocRunner(config *Config) (*allocRunner, error) { // Create alloc dir ar.allocDir = allocdir.NewAllocDir(ar.logger, filepath.Join(config.ClientConfig.AllocDir, alloc.ID)) + ar.taskHookCoordinator = newTaskHookCoordinator(ar.logger, tg.Tasks) + // Initialize the runners hooks. if err := ar.initRunnerHooks(config.ClientConfig); err != nil { return nil, err @@ -207,20 +211,21 @@ func NewAllocRunner(config *Config) (*allocRunner, error) { func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error { for _, task := range tasks { config := &taskrunner.Config{ - Alloc: ar.alloc, - ClientConfig: ar.clientConfig, - Task: task, - TaskDir: ar.allocDir.NewTaskDir(task.Name), - Logger: ar.logger, - StateDB: ar.stateDB, - StateUpdater: ar, - Consul: ar.consulClient, - ConsulSI: ar.sidsClient, - Vault: ar.vaultClient, - DeviceStatsReporter: ar.deviceStatsReporter, - DeviceManager: ar.devicemanager, - DriverManager: ar.driverManager, - ServersContactedCh: ar.serversContactedCh, + Alloc: ar.alloc, + ClientConfig: ar.clientConfig, + Task: task, + TaskDir: ar.allocDir.NewTaskDir(task.Name), + Logger: ar.logger, + StateDB: ar.stateDB, + StateUpdater: ar, + Consul: ar.consulClient, + ConsulSI: ar.sidsClient, + Vault: ar.vaultClient, + DeviceStatsReporter: ar.deviceStatsReporter, + DeviceManager: ar.devicemanager, + DriverManager: ar.driverManager, + ServersContactedCh: ar.serversContactedCh, + StartConditionMetCtx: ar.taskHookCoordinator.startConditionForTask(task), } // Create, but do not Run, the task runner @@ -488,6 +493,8 @@ func (ar *allocRunner) handleTaskStateUpdates() { } } + ar.taskHookCoordinator.taskStateUpdated(states) + // Get the client allocation calloc := ar.clientAlloc(states) diff --git a/client/allocrunner/task_hook_coordinator.go b/client/allocrunner/task_hook_coordinator.go new file mode 100644 index 000000000..8a7b17542 --- /dev/null +++ b/client/allocrunner/task_hook_coordinator.go @@ -0,0 +1,103 @@ +package allocrunner + +import ( + "context" + "fmt" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/nomad/structs" +) + +// TaskHookCoordinator helps coordinates when main start tasks can launch +// namely after all Prestart Tasks have run, and after all BlockUntilCompleted have completed +type taskHookCoordinator struct { + logger hclog.Logger + + closedCh chan struct{} + + mainTaskCtx context.Context + mainTaskCtxCancel func() + + prestartTasksUntilRunning map[string]struct{} + prestartTasksUntilCompleted map[string]struct{} +} + +func newTaskHookCoordinator(logger hclog.Logger, tasks []*structs.Task) *taskHookCoordinator { + closedCh := make(chan struct{}) + close(closedCh) + + mainTaskCtx, cancelFn := context.WithCancel(context.Background()) + + c := &taskHookCoordinator{ + logger: logger, + closedCh: closedCh, + mainTaskCtx: mainTaskCtx, + mainTaskCtxCancel: cancelFn, + prestartTasksUntilRunning: map[string]struct{}{}, + prestartTasksUntilCompleted: map[string]struct{}{}, + } + c.setTasks(tasks) + return c +} + +func (c *taskHookCoordinator) setTasks(tasks []*structs.Task) { + for _, task := range tasks { + if task.Lifecycle == nil || task.Lifecycle.Hook != structs.TaskLifecycleHookPrestart { + // move nothing + continue + } + + // only working with prestart hooks here + switch task.Lifecycle.BlockUntil { + case "", structs.TaskLifecycleBlockUntilRunning: + c.prestartTasksUntilRunning[task.Name] = struct{}{} + case structs.TaskLifecycleBlockUntilCompleted: + c.prestartTasksUntilCompleted[task.Name] = struct{}{} + default: + panic(fmt.Sprintf("unexpected block until value: %v", task.Lifecycle.BlockUntil)) + } + } + + if len(c.prestartTasksUntilRunning)+len(c.prestartTasksUntilCompleted) == 0 { + c.mainTaskCtxCancel() + } +} + +func (c *taskHookCoordinator) startConditionForTask(task *structs.Task) <-chan struct{} { + if task.Lifecycle != nil && task.Lifecycle.Hook == structs.TaskLifecycleHookPrestart { + return c.closedCh + } + + return c.mainTaskCtx.Done() + +} + +func (c *taskHookCoordinator) taskStateUpdated(states map[string]*structs.TaskState) { + if c.mainTaskCtx.Err() != nil { + // nothing to do here + return + } + + for task, _ := range c.prestartTasksUntilRunning { + st := states[task] + if st == nil || st.StartedAt.IsZero() { + continue + } + + delete(c.prestartTasksUntilRunning, task) + } + + for task, _ := range c.prestartTasksUntilCompleted { + st := states[task] + if st == nil || !st.Successful() { + continue + } + + delete(c.prestartTasksUntilCompleted, task) + } + + // everything well + if len(c.prestartTasksUntilRunning)+len(c.prestartTasksUntilCompleted) == 0 { + c.mainTaskCtxCancel() + } +} diff --git a/client/allocrunner/task_hook_coordinator_test.go b/client/allocrunner/task_hook_coordinator_test.go new file mode 100644 index 000000000..bcc68af2d --- /dev/null +++ b/client/allocrunner/task_hook_coordinator_test.go @@ -0,0 +1,26 @@ +package allocrunner + +import ( + "testing" + + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/stretchr/testify/require" +) + +func TestTaskHookCoordinator_OnlyMainApp(t *testing.T) { + alloc := mock.Alloc() + tasks := alloc.Job.TaskGroups[0].Tasks + logger := testlog.HCLogger(t) + + coord := newTaskHookCoordinator(logger, tasks) + + ch := coord.startConditionForTask(tasks[0]) + + select { + case _, ok := <-ch: + require.False(t, ok) + default: + require.Fail(t, "channel wasn't closed") + } +} diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index 2e4e09445..e98015b1e 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -202,6 +202,9 @@ type TaskRunner struct { // GetClientAllocs has been called in case of a failed restore. serversContactedCh <-chan struct{} + // startConditionMetCtx is done when TR should start the task + startConditionMetCtx <-chan struct{} + // waitOnServers defaults to false but will be set true if a restore // fails and the Run method should wait until serversContactedCh is // closed. @@ -247,6 +250,9 @@ type Config struct { // ServersContactedCh is closed when the first GetClientAllocs call to // servers succeeds and allocs are synced. ServersContactedCh chan struct{} + + // startConditionMetCtx is done when TR should start the task + StartConditionMetCtx <-chan struct{} } func NewTaskRunner(config *Config) (*TaskRunner, error) { @@ -271,32 +277,33 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) { } tr := &TaskRunner{ - alloc: config.Alloc, - allocID: config.Alloc.ID, - clientConfig: config.ClientConfig, - task: config.Task, - taskDir: config.TaskDir, - taskName: config.Task.Name, - taskLeader: config.Task.Leader, - envBuilder: envBuilder, - consulClient: config.Consul, - siClient: config.ConsulSI, - vaultClient: config.Vault, - state: tstate, - localState: state.NewLocalState(), - stateDB: config.StateDB, - stateUpdater: config.StateUpdater, - deviceStatsReporter: config.DeviceStatsReporter, - killCtx: killCtx, - killCtxCancel: killCancel, - shutdownCtx: trCtx, - shutdownCtxCancel: trCancel, - triggerUpdateCh: make(chan struct{}, triggerUpdateChCap), - waitCh: make(chan struct{}), - devicemanager: config.DeviceManager, - driverManager: config.DriverManager, - maxEvents: defaultMaxEvents, - serversContactedCh: config.ServersContactedCh, + alloc: config.Alloc, + allocID: config.Alloc.ID, + clientConfig: config.ClientConfig, + task: config.Task, + taskDir: config.TaskDir, + taskName: config.Task.Name, + taskLeader: config.Task.Leader, + envBuilder: envBuilder, + consulClient: config.Consul, + siClient: config.ConsulSI, + vaultClient: config.Vault, + state: tstate, + localState: state.NewLocalState(), + stateDB: config.StateDB, + stateUpdater: config.StateUpdater, + deviceStatsReporter: config.DeviceStatsReporter, + killCtx: killCtx, + killCtxCancel: killCancel, + shutdownCtx: trCtx, + shutdownCtxCancel: trCancel, + triggerUpdateCh: make(chan struct{}, triggerUpdateChCap), + waitCh: make(chan struct{}), + devicemanager: config.DeviceManager, + driverManager: config.DriverManager, + maxEvents: defaultMaxEvents, + serversContactedCh: config.ServersContactedCh, + startConditionMetCtx: config.StartConditionMetCtx, } // Create the logger based on the allocation ID @@ -454,6 +461,15 @@ func (tr *TaskRunner) Run() { } } + select { + case <-tr.startConditionMetCtx: + // yay proceed + case <-tr.killCtx.Done(): + return + case <-tr.shutdownCtx.Done(): + return + } + MAIN: for !tr.Alloc().TerminalStatus() { select {