tr: fix shutdown/destroy/WaitResult handling

Multiple receivers raced for the WaitResult when killing tasks which
could lead to a deadlock if the "wrong" receiver won.

Wrap handlers in an ugly little proxy to avoid this. At first I wanted
to push this into drivers, but the result is tied to the TR's handle
lifecycle -- not the lifecycle of an alloc or task.
This commit is contained in:
Michael Schurter
2018-09-20 15:44:27 -07:00
parent 13f47aa521
commit 9f64add14c
6 changed files with 194 additions and 47 deletions

View File

@@ -0,0 +1,57 @@
package taskrunner
import (
"context"
"sync"
"github.com/hashicorp/nomad/client/driver/structs"
)
// handleResult multiplexes a single WaitResult to multiple waiters. Useful
// because DriverHandle.WaitCh is closed after it returns a single WaitResult.
type handleResult struct {
doneCh <-chan struct{}
result *structs.WaitResult
mu sync.RWMutex
}
func newHandleResult(waitCh <-chan *structs.WaitResult) *handleResult {
doneCh := make(chan struct{})
h := &handleResult{
doneCh: doneCh,
}
go func() {
// Wait for result
res := <-waitCh
// Set result
h.mu.Lock()
h.result = res
h.mu.Unlock()
// Notify waiters
close(doneCh)
}()
return h
}
// Wait blocks until a task's result is available or the passed-in context is
// canceled. Safe for concurrent callers.
func (h *handleResult) Wait(ctx context.Context) *structs.WaitResult {
// Block until done or canceled
select {
case <-h.doneCh:
case <-ctx.Done():
}
h.mu.RLock()
res := h.result
h.mu.RUnlock()
return res
}

View File

@@ -0,0 +1,76 @@
package taskrunner
import (
"context"
"testing"
"time"
"github.com/hashicorp/nomad/client/driver/structs"
"github.com/stretchr/testify/require"
)
// TestHandleResult_Wait_Result asserts multiple waiters on a handleResult all
// receive the wait result.
func TestHandleResult_Wait_Result(t *testing.T) {
t.Parallel()
waitCh := make(chan *structs.WaitResult)
h := newHandleResult(waitCh)
outCh1 := make(chan *structs.WaitResult)
outCh2 := make(chan *structs.WaitResult)
// Create two recievers
go func() {
outCh1 <- h.Wait(context.Background())
}()
go func() {
outCh2 <- h.Wait(context.Background())
}()
// Send a single result
go func() {
waitCh <- &structs.WaitResult{ExitCode: 1}
}()
// Assert both receivers got the result
assert := func(outCh chan *structs.WaitResult) {
select {
case result := <-outCh:
require.NotNil(t, result)
require.Equal(t, 1, result.ExitCode)
case <-time.After(time.Second):
t.Fatalf("timeout waiting for result")
}
}
assert(outCh1)
assert(outCh2)
}
// TestHandleResult_Wait_Cancel asserts that canceling the context unblocks the
// waiter.
func TestHandleResult_Wait_Cancel(t *testing.T) {
t.Parallel()
waitCh := make(chan *structs.WaitResult)
h := newHandleResult(waitCh)
ctx, cancel := context.WithCancel(context.Background())
outCh := make(chan *structs.WaitResult)
go func() {
outCh <- h.Wait(ctx)
}()
// Cancelling the context should unblock the Wait
cancel()
// Assert the result is nil
select {
case result := <-outCh:
require.Nil(t, result)
case <-time.After(time.Second):
t.Fatalf("timeout waiting for result")
}
}

View File

@@ -7,9 +7,11 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
)
// Restart a task. Returns immediately if no task is running. Blocks until
// existing task exits or passed-in context is canceled.
func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error {
// Grab the handle
handle := tr.getDriverHandle()
handle, result := tr.getDriverHandle()
// Check it is running
if handle == nil {
@@ -29,19 +31,14 @@ func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, fai
tr.logger.Error("failed to kill task. Resources may have been leaked", "error", err)
}
// Drain the wait channel or wait for the request context to be cancelled
select {
case <-handle.WaitCh():
case <-ctx.Done():
return ctx.Err()
}
// Drain the wait channel or wait for the request context to be canceled
result.Wait(ctx)
return nil
}
func (tr *TaskRunner) Signal(event *structs.TaskEvent, s os.Signal) error {
// Grab the handle
handle := tr.getDriverHandle()
handle, _ := tr.getDriverHandle()
// Check it is running
if handle == nil {
@@ -58,8 +55,12 @@ func (tr *TaskRunner) Signal(event *structs.TaskEvent, s os.Signal) error {
// Kill a task. Blocks until task exits or context is canceled. State is set to
// dead.
func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error {
// Cancel the task runner to break out of restart delay or the main run
// loop.
tr.ctxCancel()
// Grab the handle
handle := tr.getDriverHandle()
handle, result := tr.getDriverHandle()
// Check if the handle is running
if handle == nil {
@@ -82,11 +83,8 @@ func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error
tr.logger.Error("failed to kill task. Resources may have been leaked", "error", destroyErr)
}
// Drain the wait channel or wait for the request context to be cancelled
select {
case <-handle.WaitCh():
case <-ctx.Done():
}
// Block until task has exited.
result.Wait(ctx)
// Store that the task has been destroyed and any associated error.
tr.UpdateState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(destroyErr))

View File

@@ -16,6 +16,7 @@ import (
"github.com/hashicorp/nomad/client/consul"
"github.com/hashicorp/nomad/client/driver"
"github.com/hashicorp/nomad/client/driver/env"
dstructs "github.com/hashicorp/nomad/client/driver/structs"
cstate "github.com/hashicorp/nomad/client/state"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/client/vaultclient"
@@ -72,8 +73,8 @@ type TaskRunner struct {
// unnecessary writes
persistedHash []byte
// ctx is the task runner's context and is done whe the task runner
// should exit. Shutdown hooks are run.
// ctx is the task runner's context representing the tasks's lifecycle.
// Canceling the context will cause the task to be destroyed.
ctx context.Context
// ctxCancel is used to exit the task runner's Run loop without
@@ -95,8 +96,9 @@ type TaskRunner struct {
// driver is the driver for the task.
driver driver.Driver
handle driver.DriverHandle // the handle to the running driver
handleLock sync.Mutex
handle driver.DriverHandle // the handle to the running driver
handleResult *handleResult // proxy for handle results
handleLock sync.Mutex
// task is the task being run
task *structs.Task
@@ -259,7 +261,7 @@ func (tr *TaskRunner) initLabels() {
func (tr *TaskRunner) Run() {
defer close(tr.waitCh)
var handle driver.DriverHandle
var waitRes *dstructs.WaitResult
// Updates are handled asynchronously with the other hooks but each
// triggered update - whether due to alloc updates or a new vault token
@@ -291,42 +293,48 @@ MAIN:
tr.logger.Error("poststart failed", "error", err)
}
// Grab the handle
handle = tr.getDriverHandle()
// Grab the result proxy and wait for task to exit
{
_, result := tr.getDriverHandle()
select {
case waitRes := <-handle.WaitCh():
// Clear the handle
tr.clearDriverHandle()
// Store the wait result on the restart tracker
tr.restartTracker.SetWaitResult(waitRes)
case <-tr.ctx.Done():
tr.logger.Debug("task killed")
// Do *not* use tr.ctx here as it would cause Wait() to
// unblock before the task exits when Kill() is called.
waitRes = result.Wait(context.Background())
}
// Clear the handle
tr.clearDriverHandle()
// Store the wait result on the restart tracker
tr.restartTracker.SetWaitResult(waitRes)
if err := tr.exited(); err != nil {
tr.logger.Error("exited hooks failed", "error", err)
}
RESTART:
// Actually restart by sleeping and also watching for destroy events
restart, restartWait := tr.shouldRestart()
restart, restartDelay := tr.shouldRestart()
if !restart {
break MAIN
}
deadline := time.Now().Add(restartWait)
timer := time.NewTimer(restartWait)
for time.Now().Before(deadline) {
select {
case <-timer.C:
case <-tr.ctx.Done():
tr.logger.Debug("task runner cancelled")
break MAIN
}
// Actually restart by sleeping and also watching for destroy events
select {
case <-time.After(restartDelay):
case <-tr.ctx.Done():
tr.logger.Trace("task killed between restarts", "delay", restartDelay)
break MAIN
}
timer.Stop()
}
// If task terminated, update server. All other exit conditions (eg
// killed or out of restarts) will perform their own server updates.
if waitRes != nil {
event := structs.NewTaskEvent(structs.TaskTerminated).
SetExitCode(waitRes.ExitCode).
SetSignal(waitRes.Signal).
SetExitMessage(waitRes.Err)
tr.UpdateState(structs.TaskStateDead, event)
}
// Run the stop hooks
@@ -361,13 +369,16 @@ func (tr *TaskRunner) handleUpdates() {
}
}
// shouldRestart determines whether the task should be restarted and updates
// the task state unless the task is killed or terminated.
func (tr *TaskRunner) shouldRestart() (bool, time.Duration) {
// Determine if we should restart
state, when := tr.restartTracker.GetState()
reason := tr.restartTracker.GetReason()
switch state {
case structs.TaskKilled:
// The task was killed. Nothing to do
// Never restart an explicitly killed task. Kill method handles
// updating the server.
return false, 0
case structs.TaskNotRestarting, structs.TaskTerminated:
tr.logger.Info("not restarting task", "reason", reason)

View File

@@ -49,20 +49,25 @@ func (tr *TaskRunner) setVaultToken(token string) {
tr.envBuilder.SetVaultToken(token, tr.task.Vault.Env)
}
func (tr *TaskRunner) getDriverHandle() driver.DriverHandle {
// getDriverHandle returns a driver handle and its result proxy. Use the
// result proxy instead of the handle's WaitCh.
func (tr *TaskRunner) getDriverHandle() (driver.DriverHandle, *handleResult) {
tr.handleLock.Lock()
defer tr.handleLock.Unlock()
return tr.handle
return tr.handle, tr.handleResult
}
// setDriverHanlde sets the driver handle and creates a new result proxy.
func (tr *TaskRunner) setDriverHandle(handle driver.DriverHandle) {
tr.handleLock.Lock()
defer tr.handleLock.Unlock()
tr.handle = handle
tr.handleResult = newHandleResult(handle.WaitCh())
}
func (tr *TaskRunner) clearDriverHandle() {
tr.handleLock.Lock()
defer tr.handleLock.Unlock()
tr.handle = nil
tr.handleResult = nil
}

View File

@@ -168,7 +168,7 @@ func (tr *TaskRunner) poststart() error {
}()
}
handle := tr.getDriverHandle()
handle, _ := tr.getDriverHandle()
net := handle.Network()
var merr multierror.Error