From 9f64add14c38e11f7452262ff032d6a40ea9e669 Mon Sep 17 00:00:00 2001 From: Michael Schurter Date: Thu, 20 Sep 2018 15:44:27 -0700 Subject: [PATCH] tr: fix shutdown/destroy/WaitResult handling Multiple receivers raced for the WaitResult when killing tasks which could lead to a deadlock if the "wrong" receiver won. Wrap handlers in an ugly little proxy to avoid this. At first I wanted to push this into drivers, but the result is tied to the TR's handle lifecycle -- not the lifecycle of an alloc or task. --- .../allocrunnerv2/taskrunner/handleproxy.go | 57 ++++++++++++++ .../taskrunner/handleproxy_test.go | 76 +++++++++++++++++++ client/allocrunnerv2/taskrunner/lifecycle.go | 28 ++++--- .../allocrunnerv2/taskrunner/task_runner.go | 69 ++++++++++------- .../taskrunner/task_runner_getters.go | 9 ++- .../taskrunner/task_runner_hooks.go | 2 +- 6 files changed, 194 insertions(+), 47 deletions(-) create mode 100644 client/allocrunnerv2/taskrunner/handleproxy.go create mode 100644 client/allocrunnerv2/taskrunner/handleproxy_test.go diff --git a/client/allocrunnerv2/taskrunner/handleproxy.go b/client/allocrunnerv2/taskrunner/handleproxy.go new file mode 100644 index 000000000..88e4ab7d2 --- /dev/null +++ b/client/allocrunnerv2/taskrunner/handleproxy.go @@ -0,0 +1,57 @@ +package taskrunner + +import ( + "context" + "sync" + + "github.com/hashicorp/nomad/client/driver/structs" +) + +// handleResult multiplexes a single WaitResult to multiple waiters. Useful +// because DriverHandle.WaitCh is closed after it returns a single WaitResult. +type handleResult struct { + doneCh <-chan struct{} + + result *structs.WaitResult + mu sync.RWMutex +} + +func newHandleResult(waitCh <-chan *structs.WaitResult) *handleResult { + doneCh := make(chan struct{}) + + h := &handleResult{ + doneCh: doneCh, + } + + go func() { + // Wait for result + res := <-waitCh + + // Set result + h.mu.Lock() + h.result = res + h.mu.Unlock() + + // Notify waiters + close(doneCh) + + }() + + return h +} + +// Wait blocks until a task's result is available or the passed-in context is +// canceled. Safe for concurrent callers. +func (h *handleResult) Wait(ctx context.Context) *structs.WaitResult { + // Block until done or canceled + select { + case <-h.doneCh: + case <-ctx.Done(): + } + + h.mu.RLock() + res := h.result + h.mu.RUnlock() + + return res +} diff --git a/client/allocrunnerv2/taskrunner/handleproxy_test.go b/client/allocrunnerv2/taskrunner/handleproxy_test.go new file mode 100644 index 000000000..5231d5297 --- /dev/null +++ b/client/allocrunnerv2/taskrunner/handleproxy_test.go @@ -0,0 +1,76 @@ +package taskrunner + +import ( + "context" + "testing" + "time" + + "github.com/hashicorp/nomad/client/driver/structs" + "github.com/stretchr/testify/require" +) + +// TestHandleResult_Wait_Result asserts multiple waiters on a handleResult all +// receive the wait result. +func TestHandleResult_Wait_Result(t *testing.T) { + t.Parallel() + + waitCh := make(chan *structs.WaitResult) + h := newHandleResult(waitCh) + + outCh1 := make(chan *structs.WaitResult) + outCh2 := make(chan *structs.WaitResult) + + // Create two recievers + go func() { + outCh1 <- h.Wait(context.Background()) + }() + go func() { + outCh2 <- h.Wait(context.Background()) + }() + + // Send a single result + go func() { + waitCh <- &structs.WaitResult{ExitCode: 1} + }() + + // Assert both receivers got the result + assert := func(outCh chan *structs.WaitResult) { + select { + case result := <-outCh: + require.NotNil(t, result) + require.Equal(t, 1, result.ExitCode) + case <-time.After(time.Second): + t.Fatalf("timeout waiting for result") + } + } + + assert(outCh1) + assert(outCh2) +} + +// TestHandleResult_Wait_Cancel asserts that canceling the context unblocks the +// waiter. +func TestHandleResult_Wait_Cancel(t *testing.T) { + t.Parallel() + + waitCh := make(chan *structs.WaitResult) + h := newHandleResult(waitCh) + + ctx, cancel := context.WithCancel(context.Background()) + outCh := make(chan *structs.WaitResult) + + go func() { + outCh <- h.Wait(ctx) + }() + + // Cancelling the context should unblock the Wait + cancel() + + // Assert the result is nil + select { + case result := <-outCh: + require.Nil(t, result) + case <-time.After(time.Second): + t.Fatalf("timeout waiting for result") + } +} diff --git a/client/allocrunnerv2/taskrunner/lifecycle.go b/client/allocrunnerv2/taskrunner/lifecycle.go index 848929636..4c9eabe10 100644 --- a/client/allocrunnerv2/taskrunner/lifecycle.go +++ b/client/allocrunnerv2/taskrunner/lifecycle.go @@ -7,9 +7,11 @@ import ( "github.com/hashicorp/nomad/nomad/structs" ) +// Restart a task. Returns immediately if no task is running. Blocks until +// existing task exits or passed-in context is canceled. func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error { // Grab the handle - handle := tr.getDriverHandle() + handle, result := tr.getDriverHandle() // Check it is running if handle == nil { @@ -29,19 +31,14 @@ func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, fai tr.logger.Error("failed to kill task. Resources may have been leaked", "error", err) } - // Drain the wait channel or wait for the request context to be cancelled - select { - case <-handle.WaitCh(): - case <-ctx.Done(): - return ctx.Err() - } - + // Drain the wait channel or wait for the request context to be canceled + result.Wait(ctx) return nil } func (tr *TaskRunner) Signal(event *structs.TaskEvent, s os.Signal) error { // Grab the handle - handle := tr.getDriverHandle() + handle, _ := tr.getDriverHandle() // Check it is running if handle == nil { @@ -58,8 +55,12 @@ func (tr *TaskRunner) Signal(event *structs.TaskEvent, s os.Signal) error { // Kill a task. Blocks until task exits or context is canceled. State is set to // dead. func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error { + // Cancel the task runner to break out of restart delay or the main run + // loop. + tr.ctxCancel() + // Grab the handle - handle := tr.getDriverHandle() + handle, result := tr.getDriverHandle() // Check if the handle is running if handle == nil { @@ -82,11 +83,8 @@ func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error tr.logger.Error("failed to kill task. Resources may have been leaked", "error", destroyErr) } - // Drain the wait channel or wait for the request context to be cancelled - select { - case <-handle.WaitCh(): - case <-ctx.Done(): - } + // Block until task has exited. + result.Wait(ctx) // Store that the task has been destroyed and any associated error. tr.UpdateState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(destroyErr)) diff --git a/client/allocrunnerv2/taskrunner/task_runner.go b/client/allocrunnerv2/taskrunner/task_runner.go index 4b5bed7be..84578621b 100644 --- a/client/allocrunnerv2/taskrunner/task_runner.go +++ b/client/allocrunnerv2/taskrunner/task_runner.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/driver" "github.com/hashicorp/nomad/client/driver/env" + dstructs "github.com/hashicorp/nomad/client/driver/structs" cstate "github.com/hashicorp/nomad/client/state" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/client/vaultclient" @@ -72,8 +73,8 @@ type TaskRunner struct { // unnecessary writes persistedHash []byte - // ctx is the task runner's context and is done whe the task runner - // should exit. Shutdown hooks are run. + // ctx is the task runner's context representing the tasks's lifecycle. + // Canceling the context will cause the task to be destroyed. ctx context.Context // ctxCancel is used to exit the task runner's Run loop without @@ -95,8 +96,9 @@ type TaskRunner struct { // driver is the driver for the task. driver driver.Driver - handle driver.DriverHandle // the handle to the running driver - handleLock sync.Mutex + handle driver.DriverHandle // the handle to the running driver + handleResult *handleResult // proxy for handle results + handleLock sync.Mutex // task is the task being run task *structs.Task @@ -259,7 +261,7 @@ func (tr *TaskRunner) initLabels() { func (tr *TaskRunner) Run() { defer close(tr.waitCh) - var handle driver.DriverHandle + var waitRes *dstructs.WaitResult // Updates are handled asynchronously with the other hooks but each // triggered update - whether due to alloc updates or a new vault token @@ -291,42 +293,48 @@ MAIN: tr.logger.Error("poststart failed", "error", err) } - // Grab the handle - handle = tr.getDriverHandle() + // Grab the result proxy and wait for task to exit + { + _, result := tr.getDriverHandle() - select { - case waitRes := <-handle.WaitCh(): - // Clear the handle - tr.clearDriverHandle() - - // Store the wait result on the restart tracker - tr.restartTracker.SetWaitResult(waitRes) - case <-tr.ctx.Done(): - tr.logger.Debug("task killed") + // Do *not* use tr.ctx here as it would cause Wait() to + // unblock before the task exits when Kill() is called. + waitRes = result.Wait(context.Background()) } + // Clear the handle + tr.clearDriverHandle() + + // Store the wait result on the restart tracker + tr.restartTracker.SetWaitResult(waitRes) + if err := tr.exited(); err != nil { tr.logger.Error("exited hooks failed", "error", err) } RESTART: - // Actually restart by sleeping and also watching for destroy events - restart, restartWait := tr.shouldRestart() + restart, restartDelay := tr.shouldRestart() if !restart { break MAIN } - deadline := time.Now().Add(restartWait) - timer := time.NewTimer(restartWait) - for time.Now().Before(deadline) { - select { - case <-timer.C: - case <-tr.ctx.Done(): - tr.logger.Debug("task runner cancelled") - break MAIN - } + // Actually restart by sleeping and also watching for destroy events + select { + case <-time.After(restartDelay): + case <-tr.ctx.Done(): + tr.logger.Trace("task killed between restarts", "delay", restartDelay) + break MAIN } - timer.Stop() + } + + // If task terminated, update server. All other exit conditions (eg + // killed or out of restarts) will perform their own server updates. + if waitRes != nil { + event := structs.NewTaskEvent(structs.TaskTerminated). + SetExitCode(waitRes.ExitCode). + SetSignal(waitRes.Signal). + SetExitMessage(waitRes.Err) + tr.UpdateState(structs.TaskStateDead, event) } // Run the stop hooks @@ -361,13 +369,16 @@ func (tr *TaskRunner) handleUpdates() { } } +// shouldRestart determines whether the task should be restarted and updates +// the task state unless the task is killed or terminated. func (tr *TaskRunner) shouldRestart() (bool, time.Duration) { // Determine if we should restart state, when := tr.restartTracker.GetState() reason := tr.restartTracker.GetReason() switch state { case structs.TaskKilled: - // The task was killed. Nothing to do + // Never restart an explicitly killed task. Kill method handles + // updating the server. return false, 0 case structs.TaskNotRestarting, structs.TaskTerminated: tr.logger.Info("not restarting task", "reason", reason) diff --git a/client/allocrunnerv2/taskrunner/task_runner_getters.go b/client/allocrunnerv2/taskrunner/task_runner_getters.go index 8a40c65bc..abce06593 100644 --- a/client/allocrunnerv2/taskrunner/task_runner_getters.go +++ b/client/allocrunnerv2/taskrunner/task_runner_getters.go @@ -49,20 +49,25 @@ func (tr *TaskRunner) setVaultToken(token string) { tr.envBuilder.SetVaultToken(token, tr.task.Vault.Env) } -func (tr *TaskRunner) getDriverHandle() driver.DriverHandle { +// getDriverHandle returns a driver handle and its result proxy. Use the +// result proxy instead of the handle's WaitCh. +func (tr *TaskRunner) getDriverHandle() (driver.DriverHandle, *handleResult) { tr.handleLock.Lock() defer tr.handleLock.Unlock() - return tr.handle + return tr.handle, tr.handleResult } +// setDriverHanlde sets the driver handle and creates a new result proxy. func (tr *TaskRunner) setDriverHandle(handle driver.DriverHandle) { tr.handleLock.Lock() defer tr.handleLock.Unlock() tr.handle = handle + tr.handleResult = newHandleResult(handle.WaitCh()) } func (tr *TaskRunner) clearDriverHandle() { tr.handleLock.Lock() defer tr.handleLock.Unlock() tr.handle = nil + tr.handleResult = nil } diff --git a/client/allocrunnerv2/taskrunner/task_runner_hooks.go b/client/allocrunnerv2/taskrunner/task_runner_hooks.go index 4e3d550f7..14caba832 100644 --- a/client/allocrunnerv2/taskrunner/task_runner_hooks.go +++ b/client/allocrunnerv2/taskrunner/task_runner_hooks.go @@ -168,7 +168,7 @@ func (tr *TaskRunner) poststart() error { }() } - handle := tr.getDriverHandle() + handle, _ := tr.getDriverHandle() net := handle.Network() var merr multierror.Error