Merge pull request #4828 from hashicorp/b-restore

Implement client agent restarting
This commit is contained in:
Michael Schurter
2018-11-05 18:50:15 -06:00
committed by GitHub
20 changed files with 515 additions and 136 deletions

View File

@@ -157,7 +157,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error {
StateDB: ar.stateDB,
StateUpdater: ar,
Consul: ar.consulClient,
VaultClient: ar.vaultClient,
Vault: ar.vaultClient,
PluginSingletonLoader: ar.pluginSingletonLoader,
}
@@ -181,17 +181,61 @@ func (ar *allocRunner) Run() {
ar.destroyedLock.Lock()
defer ar.destroyedLock.Unlock()
// Run should not be called after Destroy is called. This is a
// programming error.
if ar.destroyed {
// Run should not be called after Destroy is called. This is a
// programming error.
ar.logger.Error("alloc destroyed; cannot run")
return
}
ar.runLaunched = true
// If an alloc should not be run, ensure any restored task handles are
// destroyed and exit to wait for the AR to be GC'd by the client.
if !ar.shouldRun() {
ar.logger.Debug("not running terminal alloc")
// Cleanup and sync state
states := ar.killTasks()
// Get the client allocation
calloc := ar.clientAlloc(states)
// Update the server
ar.stateUpdater.AllocStateUpdated(calloc)
// Broadcast client alloc to listeners
ar.allocBroadcaster.Send(calloc)
return
}
// Run! (and mark as having been run to ensure Destroy cleans up properly)
ar.runLaunched = true
go ar.runImpl()
}
// shouldRun returns true if the alloc is in a state that the alloc runner
// should run it.
func (ar *allocRunner) shouldRun() bool {
// Do not run allocs that are terminal
if ar.Alloc().TerminalStatus() {
ar.logger.Trace("alloc terminal; not running",
"desired_status", ar.Alloc().DesiredStatus,
"client_status", ar.Alloc().ClientStatus,
)
return false
}
// It's possible that the alloc local state was marked terminal before
// the server copy of the alloc (checked above) was marked as terminal,
// so check the local state as well.
switch clientStatus := ar.AllocState().ClientStatus; clientStatus {
case structs.AllocClientStatusComplete, structs.AllocClientStatusFailed, structs.AllocClientStatusLost:
ar.logger.Trace("alloc terminal; updating server and not running", "status", clientStatus)
return false
}
return true
}
func (ar *allocRunner) runImpl() {
// Close the wait channel on return
defer close(ar.waitCh)
@@ -354,7 +398,7 @@ func (ar *allocRunner) handleTaskStateUpdates() {
ar.logger.Debug("task failure, destroying all tasks", "failed_task", killTask)
}
ar.killTasks()
states = ar.killTasks()
}
// Get the client allocation
@@ -369,8 +413,12 @@ func (ar *allocRunner) handleTaskStateUpdates() {
}
// killTasks kills all task runners, leader (if there is one) first. Errors are
// logged except taskrunner.ErrTaskNotRunning which is ignored.
func (ar *allocRunner) killTasks() {
// logged except taskrunner.ErrTaskNotRunning which is ignored. Task states
// after Kill has been called are returned.
func (ar *allocRunner) killTasks() map[string]*structs.TaskState {
var mu sync.Mutex
states := make(map[string]*structs.TaskState, len(ar.tasks))
// Kill leader first, synchronously
for name, tr := range ar.tasks {
if !tr.IsLeader() {
@@ -381,6 +429,9 @@ func (ar *allocRunner) killTasks() {
if err != nil && err != taskrunner.ErrTaskNotRunning {
ar.logger.Warn("error stopping leader task", "error", err, "task_name", name)
}
state := tr.TaskState()
states[name] = state
break
}
@@ -398,9 +449,16 @@ func (ar *allocRunner) killTasks() {
if err != nil && err != taskrunner.ErrTaskNotRunning {
ar.logger.Warn("error stopping task", "error", err, "task_name", name)
}
state := tr.TaskState()
mu.Lock()
states[name] = state
mu.Unlock()
}(name, tr)
}
wg.Wait()
return states
}
// clientAlloc takes in the task states and returns an Allocation populated
@@ -510,6 +568,12 @@ func (ar *allocRunner) AllocState() *state.State {
}
}
// Generate alloc to get other state fields
alloc := ar.clientAlloc(state.TaskStates)
state.ClientStatus = alloc.ClientStatus
state.ClientDescription = alloc.ClientDescription
state.DeploymentStatus = alloc.DeploymentStatus
return state
}
@@ -563,8 +627,11 @@ func (ar *allocRunner) Destroy() {
}
defer ar.destroyedLock.Unlock()
// Stop any running tasks
ar.killTasks()
// Stop any running tasks and persist states in case the client is
// shutdown before Destroy finishes.
states := ar.killTasks()
calloc := ar.clientAlloc(states)
ar.stateUpdater.AllocStateUpdated(calloc)
// Wait for tasks to exit and postrun hooks to finish (if they ran at all)
if ar.runLaunched {

View File

@@ -11,7 +11,6 @@ import (
consulapi "github.com/hashicorp/nomad/client/consul"
"github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/shared/catalog"
@@ -57,20 +56,19 @@ func (m *MockStateUpdater) Reset() {
// testAllocRunnerConfig returns a new allocrunner.Config with mocks and noop
// versions of dependencies along with a cleanup func.
func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*Config, func()) {
logger := testlog.HCLogger(t)
pluginLoader := catalog.TestPluginLoader(t)
clientConf, cleanup := config.TestClientConfig(t)
conf := &Config{
// Copy the alloc in case the caller edits and reuses it
Alloc: alloc.Copy(),
Logger: logger,
Logger: clientConf.Logger,
ClientConfig: clientConf,
StateDB: state.NoopDB{},
Consul: consulapi.NewMockConsulServiceClient(t, logger),
Consul: consulapi.NewMockConsulServiceClient(t, clientConf.Logger),
Vault: vaultclient.NewMockVaultClient(),
StateUpdater: &MockStateUpdater{},
PrevAllocWatcher: allocwatcher.NoopPrevAlloc{},
PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader),
PluginSingletonLoader: singleton.NewSingletonLoader(clientConf.Logger, pluginLoader),
}
return conf, cleanup
}

View File

@@ -109,8 +109,8 @@ type TaskKillResponse struct{}
type TaskKillHook interface {
TaskHook
// Kill is called when a task is going to be killed.
Kill(context.Context, *TaskKillRequest, *TaskKillResponse) error
// Killing is called when a task is going to be Killed or Restarted.
Killing(context.Context, *TaskKillRequest, *TaskKillResponse) error
}
type TaskExitedRequest struct{}

View File

@@ -7,7 +7,13 @@ import (
)
type TaskLifecycle interface {
// Restart a task in place. If failure=false then the restart does not
// count as an attempt in the restart policy.
Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error
// Sends a signal to a task.
Signal(event *structs.TaskEvent, signal string) error
// Kill a task permanently.
Kill(ctx context.Context, event *structs.TaskEvent) error
}

View File

@@ -12,6 +12,7 @@ import (
func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error {
// Grab the handle
handle := tr.getDriverHandle()
// Check it is running
if handle == nil {
return ErrTaskNotRunning
@@ -20,12 +21,14 @@ func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, fai
// Emit the event since it may take a long time to kill
tr.EmitEvent(event)
// Run the hooks prior to restarting the task
tr.killing()
// Tell the restart tracker that a restart triggered the exit
tr.restartTracker.SetRestartTriggered(failure)
// Kill the task using an exponential backoff in-case of failures.
destroySuccess, err := tr.handleDestroy(handle)
if !destroySuccess {
if err := tr.killTask(handle); err != nil {
// We couldn't successfully destroy the resource created.
tr.logger.Error("failed to kill task. Resources may have been leaked", "error", err)
}
@@ -36,7 +39,10 @@ func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, fai
return err
}
<-waitCh
select {
case <-waitCh:
case <-ctx.Done():
}
return nil
}
@@ -61,7 +67,7 @@ func (tr *TaskRunner) Signal(event *structs.TaskEvent, s string) error {
func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error {
// Cancel the task runner to break out of restart delay or the main run
// loop.
tr.ctxCancel()
tr.killCtxCancel()
// Grab the handle
handle := tr.getDriverHandle()
@@ -75,16 +81,17 @@ func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error
tr.EmitEvent(event)
// Run the hooks prior to killing the task
tr.kill()
tr.killing()
// Tell the restart tracker that the task has been killed
// Tell the restart tracker that the task has been killed so it doesn't
// attempt to restart it.
tr.restartTracker.SetKilled()
// Kill the task using an exponential backoff in-case of failures.
destroySuccess, destroyErr := tr.handleDestroy(handle)
if !destroySuccess {
killErr := tr.killTask(handle)
if killErr != nil {
// We couldn't successfully destroy the resource created.
tr.logger.Error("failed to kill task. Resources may have been leaked", "error", destroyErr)
tr.logger.Error("failed to kill task. Resources may have been leaked", "error", killErr)
}
// Block until task has exited.
@@ -100,13 +107,16 @@ func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error
return err
}
<-waitCh
select {
case <-waitCh:
case <-ctx.Done():
}
// Store that the task has been destroyed and any associated error.
tr.UpdateState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(destroyErr))
tr.UpdateState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(killErr))
if destroyErr != nil {
return destroyErr
if killErr != nil {
return killErr
} else if err := ctx.Err(); err != nil {
return err
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sync"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
@@ -34,6 +35,7 @@ type serviceHook struct {
logger log.Logger
// The following fields may be updated
delay time.Duration
driverExec tinterfaces.ScriptExecutor
driverNet *cstructs.DriverNetwork
canary bool
@@ -53,6 +55,7 @@ func newServiceHook(c serviceHookConfig) *serviceHook {
taskName: c.task.Name,
services: c.task.Services,
restarter: c.restarter,
delay: c.task.ShutdownDelay,
}
if res := c.alloc.TaskResources[c.task.Name]; res != nil {
@@ -111,6 +114,7 @@ func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequ
}
// Update service hook fields
h.delay = task.ShutdownDelay
h.taskEnv = req.TaskEnv
h.services = task.Services
h.networks = networks
@@ -122,10 +126,35 @@ func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequ
return h.consul.UpdateTask(oldTaskServices, newTaskServices)
}
func (h *serviceHook) Exited(context.Context, *interfaces.TaskExitedRequest, *interfaces.TaskExitedResponse) error {
func (h *serviceHook) Killing(ctx context.Context, req *interfaces.TaskKillRequest, resp *interfaces.TaskKillResponse) error {
h.mu.Lock()
defer h.mu.Unlock()
// Deregister before killing task
h.deregister()
// If there's no shutdown delay, exit early
if h.delay == 0 {
return nil
}
h.logger.Debug("waiting before killing task", "shutdown_delay", h.delay)
select {
case <-ctx.Done():
case <-time.After(h.delay):
}
return nil
}
func (h *serviceHook) Exited(context.Context, *interfaces.TaskExitedRequest, *interfaces.TaskExitedResponse) error {
h.mu.Lock()
defer h.mu.Unlock()
h.deregister()
return nil
}
// deregister services from Consul.
func (h *serviceHook) deregister() {
taskServices := h.getTaskServices()
h.consul.RemoveTask(taskServices)
@@ -134,7 +163,6 @@ func (h *serviceHook) Exited(context.Context, *interfaces.TaskExitedRequest, *in
taskServices.Canary = !taskServices.Canary
h.consul.RemoveTask(taskServices)
return nil
}
func (h *serviceHook) getTaskServices() *agentconsul.TaskServices {

View File

@@ -1,36 +0,0 @@
package taskrunner
import (
"context"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
)
// shutdownDelayHook delays shutting down a task between deregistering it from
// Consul and actually killing it.
type shutdownDelayHook struct {
delay time.Duration
logger log.Logger
}
func newShutdownDelayHook(delay time.Duration, logger log.Logger) *shutdownDelayHook {
h := &shutdownDelayHook{
delay: delay,
}
h.logger = logger.Named(h.Name())
return h
}
func (*shutdownDelayHook) Name() string {
return "shutdown-delay"
}
func (h *shutdownDelayHook) Kill(ctx context.Context, req *interfaces.TaskKillRequest, resp *interfaces.TaskKillResponse) error {
select {
case <-ctx.Done():
case <-time.After(h.delay):
}
return nil
}

View File

@@ -78,12 +78,18 @@ type TaskRunner struct {
// stateDB is for persisting localState and taskState
stateDB cstate.StateDB
// ctx is the task runner's context representing the tasks's lifecycle.
// Canceling the context will cause the task to be destroyed.
// killCtx is the task runner's context representing the tasks's lifecycle.
// The context is canceled when the task is killed.
killCtx context.Context
// killCtxCancel is called when killing a task.
killCtxCancel context.CancelFunc
// ctx is used to exit the TaskRunner *without* affecting task state.
ctx context.Context
// ctxCancel is used to exit the task runner's Run loop without
// stopping the task. Shutdown hooks are run.
// ctxCancel causes the TaskRunner to exit immediately without
// affecting task state. Useful for testing or graceful agent shutdown.
ctxCancel context.CancelFunc
// Logger is the logger for the task runner.
@@ -168,8 +174,8 @@ type Config struct {
TaskDir *allocdir.TaskDir
Logger log.Logger
// VaultClient is the client to use to derive and renew Vault tokens
VaultClient vaultclient.VaultClient
// Vault is the client to use to derive and renew Vault tokens
Vault vaultclient.VaultClient
// StateDB is used to store and restore state.
StateDB cstate.StateDB
@@ -183,9 +189,12 @@ type Config struct {
}
func NewTaskRunner(config *Config) (*TaskRunner, error) {
// Create a context for the runner
// Create a context for causing the runner to exit
trCtx, trCancel := context.WithCancel(context.Background())
// Create a context for killing the runner
killCtx, killCancel := context.WithCancel(context.Background())
// Initialize the environment builder
envBuilder := env.NewBuilder(
config.ClientConfig.Node,
@@ -210,11 +219,13 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) {
taskLeader: config.Task.Leader,
envBuilder: envBuilder,
consulClient: config.Consul,
vaultClient: config.VaultClient,
vaultClient: config.Vault,
state: tstate,
localState: state.NewLocalState(),
stateDB: config.StateDB,
stateUpdater: config.StateUpdater,
killCtx: killCtx,
killCtxCancel: killCancel,
ctx: trCtx,
ctxCancel: trCancel,
triggerUpdateCh: make(chan struct{}, triggerUpdateChCap),
@@ -299,7 +310,16 @@ func (tr *TaskRunner) Run() {
go tr.handleUpdates()
MAIN:
for tr.ctx.Err() == nil {
for {
select {
case <-tr.killCtx.Done():
break MAIN
case <-tr.ctx.Done():
// TaskRunner was told to exit immediately
return
default:
}
// Run the prestart hooks
if err := tr.prestart(); err != nil {
tr.logger.Error("prestart failed", "error", err)
@@ -307,8 +327,13 @@ MAIN:
goto RESTART
}
if tr.ctx.Err() != nil {
select {
case <-tr.killCtx.Done():
break MAIN
case <-tr.ctx.Done():
// TaskRunner was told to exit immediately
return
default:
}
// Run the task
@@ -327,12 +352,19 @@ MAIN:
{
handle := tr.getDriverHandle()
// Do *not* use tr.ctx here as it would cause Wait() to
// unblock before the task exits when Kill() is called.
// Do *not* use tr.killCtx here as it would cause
// Wait() to unblock before the task exits when Kill()
// is called.
if resultCh, err := handle.WaitCh(context.Background()); err != nil {
tr.logger.Error("wait task failed", "error", err)
} else {
result = <-resultCh
select {
case result = <-resultCh:
// WaitCh returned a result
case <-tr.ctx.Done():
// TaskRunner was told to exit immediately
return
}
}
}
@@ -355,9 +387,12 @@ MAIN:
// Actually restart by sleeping and also watching for destroy events
select {
case <-time.After(restartDelay):
case <-tr.ctx.Done():
case <-tr.killCtx.Done():
tr.logger.Trace("task killed between restarts", "delay", restartDelay)
break MAIN
case <-tr.ctx.Done():
// TaskRunner was told to exit immediately
return
}
}
@@ -444,7 +479,20 @@ func (tr *TaskRunner) runDriver() error {
//TODO mounts and devices
//XXX Evaluate and encode driver config
// Start the job
// If there's already a task handle (eg from a Restore) there's nothing
// to do except update state.
if tr.getDriverHandle() != nil {
// Ensure running state is persisted but do *not* append a new
// task event as restoring is a client event and not relevant
// to a task's lifecycle.
if err := tr.updateStateImpl(structs.TaskStateRunning); err != nil {
//TODO return error and destroy task to avoid an orphaned task?
tr.logger.Warn("error persisting task state", "error", err)
}
return nil
}
// Start the job if there's no existing handle (or if RecoverTask failed)
handle, net, err := tr.driver.StartTask(taskConfig)
if err != nil {
return fmt.Errorf("driver start failed: %v", err)
@@ -452,9 +500,18 @@ func (tr *TaskRunner) runDriver() error {
tr.localStateLock.Lock()
tr.localState.TaskHandle = handle
tr.localState.DriverNetwork = net
if err := tr.stateDB.PutTaskRunnerLocalState(tr.allocID, tr.taskName, tr.localState); err != nil {
//TODO Nomad will be unable to restore this task; try to kill
// it now and fail? In general we prefer to leave running
// tasks running even if the agent encounters an error.
tr.logger.Warn("error persisting local task state; may be unable to restore after a Nomad restart",
"error", err, "task_id", handle.Config.ID)
}
tr.localStateLock.Unlock()
tr.setDriverHandle(NewDriverHandle(tr.driver, taskConfig.ID, tr.Task(), net))
// Emit an event that we started
tr.UpdateState(structs.TaskStateRunning, structs.NewTaskEvent(structs.TaskStarted))
return nil
@@ -525,17 +582,17 @@ func (tr *TaskRunner) initDriver() error {
return nil
}
// handleDestroy kills the task handle. In the case that killing fails,
// handleDestroy will retry with an exponential backoff and will give up at a
// given limit. It returns whether the task was destroyed and the error
// associated with the last kill attempt.
func (tr *TaskRunner) handleDestroy(handle *DriverHandle) (destroyed bool, err error) {
// killTask kills the task handle. In the case that killing fails,
// killTask will retry with an exponential backoff and will give up at a
// given limit. Returns an error if the task could not be killed.
func (tr *TaskRunner) killTask(handle *DriverHandle) error {
// Cap the number of times we attempt to kill the task.
var err error
for i := 0; i < killFailureLimit; i++ {
if err = handle.Kill(); err != nil {
if err == drivers.ErrTaskNotFound {
tr.logger.Warn("couldn't find task to kill", "task_id", handle.ID())
return true, nil
return nil
}
// Calculate the new backoff
backoff := (1 << (2 * uint64(i))) * killBackoffBaseline
@@ -547,10 +604,10 @@ func (tr *TaskRunner) handleDestroy(handle *DriverHandle) (destroyed bool, err e
time.Sleep(backoff)
} else {
// Kill was successful
return true, nil
return nil
}
}
return
return err
}
// persistLocalState persists local state to disk synchronously.
@@ -591,39 +648,84 @@ func (tr *TaskRunner) Restore() error {
ls.Canonicalize()
tr.localState = ls
}
if ts != nil {
ts.Canonicalize()
tr.state = ts
}
// If a TaskHandle was persisted, ensure it is valid or destroy it.
if taskHandle := tr.localState.TaskHandle; taskHandle != nil {
//TODO if RecoverTask returned the DriverNetwork we wouldn't
// have to persist it at all!
tr.restoreHandle(taskHandle, tr.localState.DriverNetwork)
}
return nil
}
// restoreHandle ensures a TaskHandle is valid by calling Driver.RecoverTask
// and sets the driver handle. If the TaskHandle is not valid, DestroyTask is
// called.
func (tr *TaskRunner) restoreHandle(taskHandle *drivers.TaskHandle, net *cstructs.DriverNetwork) {
// Ensure handle is well-formed
if taskHandle.Config == nil {
return
}
if err := tr.driver.RecoverTask(taskHandle); err != nil {
tr.logger.Error("error recovering task; destroying and restarting",
"error", err, "task_id", taskHandle.Config.ID)
// Try to cleanup any existing task state in the plugin before restarting
if err := tr.driver.DestroyTask(taskHandle.Config.ID, true); err != nil {
// Ignore ErrTaskNotFound errors as ideally
// this task has already been stopped and
// therefore doesn't exist.
if err != drivers.ErrTaskNotFound {
tr.logger.Warn("error destroying unrecoverable task",
"error", err, "task_id", taskHandle.Config.ID)
}
}
return
}
// Update driver handle on task runner
tr.setDriverHandle(NewDriverHandle(tr.driver, taskHandle.Config.ID, tr.Task(), net))
return
}
// UpdateState sets the task runners allocation state and triggers a server
// update.
func (tr *TaskRunner) UpdateState(state string, event *structs.TaskEvent) {
tr.stateLock.Lock()
defer tr.stateLock.Unlock()
tr.logger.Trace("setting task state", "state", state, "event", event.Type)
// Update the local state
tr.setStateLocal(state, event)
// Append the event
tr.appendEvent(event)
// Update the state
if err := tr.updateStateImpl(state); err != nil {
// Only log the error as we persistence errors should not
// affect task state.
tr.logger.Error("error persisting task state", "error", err, "event", event, "state", state)
}
// Notify the alloc runner of the transition
tr.stateUpdater.TaskStateUpdated()
}
// setStateLocal updates the local in-memory state, persists a copy to disk and returns a
// copy of the task's state.
func (tr *TaskRunner) setStateLocal(state string, event *structs.TaskEvent) {
tr.stateLock.Lock()
defer tr.stateLock.Unlock()
// updateStateImpl updates the in-memory task state and persists to disk.
func (tr *TaskRunner) updateStateImpl(state string) error {
// Update the task state
oldState := tr.state.State
taskState := tr.state
taskState.State = state
// Append the event
tr.appendEvent(event)
// Handle the state transition.
switch state {
case structs.TaskStateRunning:
@@ -662,11 +764,7 @@ func (tr *TaskRunner) setStateLocal(state string, event *structs.TaskEvent) {
}
// Persist the state and event
if err := tr.stateDB.PutTaskState(tr.allocID, tr.taskName, taskState); err != nil {
// Only a warning because the next event/state-transition will
// try to persist it again.
tr.logger.Error("error persisting task state", "error", err, "event", event, "state", state)
}
return tr.stateDB.PutTaskState(tr.allocID, tr.taskName, taskState)
}
// EmitEvent appends a new TaskEvent to this task's TaskState. The actual

View File

@@ -26,7 +26,6 @@ func (tr *TaskRunner) initHooks() {
newLogMonHook(tr.logmonHookConfig, hookLogger),
newDispatchHook(tr.Alloc(), hookLogger),
newArtifactHook(tr, hookLogger),
newShutdownDelayHook(task.ShutdownDelay, hookLogger),
newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
}
@@ -123,7 +122,7 @@ func (tr *TaskRunner) prestart() error {
// Run the prestart hook
var resp interfaces.TaskPrestartResponse
if err := pre.Prestart(tr.ctx, &req, &resp); err != nil {
if err := pre.Prestart(tr.killCtx, &req, &resp); err != nil {
return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err)
}
@@ -195,7 +194,7 @@ func (tr *TaskRunner) poststart() error {
TaskEnv: tr.envBuilder.Build(),
}
var resp interfaces.TaskPoststartResponse
if err := post.Poststart(tr.ctx, &req, &resp); err != nil {
if err := post.Poststart(tr.killCtx, &req, &resp); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err))
}
@@ -237,7 +236,7 @@ func (tr *TaskRunner) exited() error {
req := interfaces.TaskExitedRequest{}
var resp interfaces.TaskExitedResponse
if err := post.Exited(tr.ctx, &req, &resp); err != nil {
if err := post.Exited(tr.killCtx, &req, &resp); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err))
}
@@ -280,7 +279,7 @@ func (tr *TaskRunner) stop() error {
req := interfaces.TaskStopRequest{}
var resp interfaces.TaskStopResponse
if err := post.Stop(tr.ctx, &req, &resp); err != nil {
if err := post.Stop(tr.killCtx, &req, &resp); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err))
}
@@ -336,7 +335,7 @@ func (tr *TaskRunner) updateHooks() {
// Run the update hook
var resp interfaces.TaskUpdateResponse
if err := upd.Update(tr.ctx, &req, &resp); err != nil {
if err := upd.Update(tr.killCtx, &req, &resp); err != nil {
tr.logger.Error("update hook failed", "name", name, "error", err)
}
@@ -349,8 +348,8 @@ func (tr *TaskRunner) updateHooks() {
}
}
// kill is used to run the runners kill hooks.
func (tr *TaskRunner) kill() {
// killing is used to run the runners kill hooks.
func (tr *TaskRunner) killing() {
if tr.logger.IsTrace() {
start := time.Now()
tr.logger.Trace("running kill hooks", "start", start)
@@ -378,7 +377,7 @@ func (tr *TaskRunner) kill() {
// Run the update hook
req := interfaces.TaskKillRequest{}
var resp interfaces.TaskKillResponse
if err := upd.Kill(context.Background(), &req, &resp); err != nil {
if err := upd.Killing(context.Background(), &req, &resp); err != nil {
tr.logger.Error("kill hook failed", "name", name, "error", err)
}

View File

@@ -0,0 +1,157 @@
package taskrunner
import (
"context"
"fmt"
"path/filepath"
"testing"
"time"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/config"
consulapi "github.com/hashicorp/nomad/client/consul"
cstate "github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/shared/catalog"
"github.com/hashicorp/nomad/plugins/shared/singleton"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type MockTaskStateUpdater struct {
ch chan struct{}
}
func NewMockTaskStateUpdater() *MockTaskStateUpdater {
return &MockTaskStateUpdater{
ch: make(chan struct{}, 1),
}
}
func (m *MockTaskStateUpdater) TaskStateUpdated() {
select {
case m.ch <- struct{}{}:
default:
}
}
// testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
// plus a cleanup func.
func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
logger := testlog.HCLogger(t)
pluginLoader := catalog.TestPluginLoader(t)
clientConf, cleanup := config.TestClientConfig(t)
// Find the task
var thisTask *structs.Task
for _, tg := range alloc.Job.TaskGroups {
for _, task := range tg.Tasks {
if task.Name == taskName {
if thisTask != nil {
cleanup()
t.Fatalf("multiple tasks named %q; cannot use this helper", taskName)
}
thisTask = task
}
}
}
if thisTask == nil {
cleanup()
t.Fatalf("could not find task %q", taskName)
}
// Create the alloc dir + task dir
allocPath := filepath.Join(clientConf.AllocDir, alloc.ID)
allocDir := allocdir.NewAllocDir(logger, allocPath)
if err := allocDir.Build(); err != nil {
cleanup()
t.Fatalf("error building alloc dir: %v", err)
}
taskDir := allocDir.NewTaskDir(taskName)
trCleanup := func() {
if err := allocDir.Destroy(); err != nil {
t.Logf("error destroying alloc dir: %v", err)
}
cleanup()
}
conf := &Config{
Alloc: alloc,
ClientConfig: clientConf,
Consul: consulapi.NewMockConsulServiceClient(t, logger),
Task: thisTask,
TaskDir: taskDir,
Logger: clientConf.Logger,
Vault: vaultclient.NewMockVaultClient(),
StateDB: cstate.NoopDB{},
StateUpdater: NewMockTaskStateUpdater(),
PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader),
}
return conf, trCleanup
}
// TestTaskRunner_Restore asserts restoring a running task does not rerun the
// task.
func TestTaskRunner_Restore_Running(t *testing.T) {
t.Parallel()
require := require.New(t)
alloc := mock.BatchAlloc()
alloc.Job.TaskGroups[0].Count = 1
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Name = "testtask"
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"run_for": 2 * time.Second,
}
conf, cleanup := testTaskRunnerConfig(t, alloc, "testtask")
conf.StateDB = cstate.NewMemDB() // "persist" state between task runners
defer cleanup()
// Run the first TaskRunner
origTR, err := NewTaskRunner(conf)
require.NoError(err)
go origTR.Run()
defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
// Wait for it to be running
testutil.WaitForResult(func() (bool, error) {
ts := origTR.TaskState()
return ts.State == structs.TaskStateRunning, fmt.Errorf("%v", ts.State)
}, func(err error) {
t.Fatalf("expected running; got: %v", err)
})
// Cause TR to exit without shutting down task
origTR.ctxCancel()
<-origTR.WaitCh()
// Start a new TaskRunner and make sure it does not rerun the task
newTR, err := NewTaskRunner(conf)
require.NoError(err)
// Do the Restore
require.NoError(newTR.Restore())
go newTR.Run()
defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
// Wait for new task runner to exit when the process does
<-newTR.WaitCh()
// Assert that the process was only started once
started := 0
state := newTR.TaskState()
require.Equal(structs.TaskStateDead, state.State)
for _, ev := range state.Events {
if ev.Type == structs.TaskStarted {
started++
}
}
assert.Equal(t, 1, started)
}

View File

@@ -6,6 +6,7 @@ import (
"path/filepath"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/mitchellh/go-testing-interface"
)
@@ -14,6 +15,7 @@ import (
// a cleanup func to remove the state and alloc dirs when finished.
func TestClientConfig(t testing.T) (*Config, func()) {
conf := DefaultConfig()
conf.Logger = testlog.HCLogger(t)
// Create a tempdir to hold state and alloc subdirs
parent, err := ioutil.TempDir("", "nomadtest")

View File

@@ -20,6 +20,8 @@ import (
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/shared/catalog"
"github.com/hashicorp/nomad/plugins/shared/singleton"
"github.com/stretchr/testify/require"
)
@@ -143,16 +145,18 @@ func TestConsul_Integration(t *testing.T) {
}()
// Build the config
pluginLoader := catalog.TestPluginLoader(t)
config := &taskrunner.Config{
Alloc: alloc,
ClientConfig: conf,
Consul: serviceClient,
Task: task,
TaskDir: taskDir,
Logger: logger,
VaultClient: vclient,
StateDB: state.NoopDB{},
StateUpdater: logUpdate,
Alloc: alloc,
ClientConfig: conf,
Consul: serviceClient,
Task: task,
TaskDir: taskDir,
Logger: logger,
Vault: vclient,
StateDB: state.NoopDB{},
StateUpdater: logUpdate,
PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader),
}
tr, err := taskrunner.NewTaskRunner(config)

View File

@@ -273,9 +273,23 @@ func (d *Driver) buildFingerprint() *drivers.Fingerprint {
}
}
func (d *Driver) RecoverTask(*drivers.TaskHandle) error {
//TODO is there anything to do here?
return nil
func (d *Driver) RecoverTask(h *drivers.TaskHandle) error {
if h == nil {
return fmt.Errorf("handle cannot be nil")
}
if _, ok := d.tasks.Get(h.Config.ID); ok {
d.logger.Debug("nothing to recover; task already exists",
"task_id", h.Config.ID,
"task_name", h.Config.Name,
)
return nil
}
// Recovering a task requires the task to be running external to the
// plugin. Since the mock_driver runs all tasks in process it cannot
// recover tasks.
return fmt.Errorf("%s cannot recover tasks", pluginName)
}
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstructs.DriverNetwork, error) {

View File

@@ -244,6 +244,15 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return fmt.Errorf("error: handle cannot be nil")
}
// If already attached to handle there's nothing to recover.
if _, ok := d.tasks.Get(handle.Config.ID); ok {
d.logger.Trace("nothing to recover; task already exists",
"task_id", handle.Config.ID,
"task_name", handle.Config.Name,
)
return nil
}
var taskState TaskState
if err := handle.GetDriverState(&taskState); err != nil {
d.logger.Error("failed to decode taskConfig state from handle", "error", err, "task_id", handle.Config.ID)

View File

@@ -242,9 +242,19 @@ func (d *Driver) buildFingerprint() *drivers.Fingerprint {
func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
if handle == nil {
return fmt.Errorf("error: handle cannot be nil")
return fmt.Errorf("handle cannot be nil")
}
// If already attached to handle there's nothing to recover.
if _, ok := d.tasks.Get(handle.Config.ID); ok {
d.logger.Trace("nothing to recover; task already exists",
"task_id", handle.Config.ID,
"task_name", handle.Config.Name,
)
return nil
}
// Handle doesn't already exist, try to reattach
var taskState TaskState
if err := handle.GetDriverState(&taskState); err != nil {
d.logger.Error("failed to decode task state from handle", "error", err, "task_id", handle.Config.ID)
@@ -261,6 +271,7 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
Reattach: plugRC,
}
// Create client for reattached executor
exec, pluginClient, err := utils.CreateExecutorWithConfig(pluginConfig, os.Stderr)
if err != nil {
d.logger.Error("failed to reattach to executor", "error", err, "task_id", handle.Config.ID)

View File

@@ -317,6 +317,15 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return fmt.Errorf("error: handle cannot be nil")
}
// If already attached to handle there's nothing to recover.
if _, ok := d.tasks.Get(handle.Config.ID); ok {
d.logger.Trace("nothing to recover; task already exists",
"task_id", handle.Config.ID,
"task_name", handle.Config.Name,
)
return nil
}
var taskState TaskState
if err := handle.GetDriverState(&taskState); err != nil {
d.logger.Error("failed to decode taskConfig state from handle", "error", err, "task_id", handle.Config.ID)

View File

@@ -46,7 +46,7 @@ type DriverPlugin interface {
// DriverPlugin interface.
type DriverSignalTaskNotSupported struct{}
func (_ DriverSignalTaskNotSupported) SignalTask(taskID, signal string) error {
func (DriverSignalTaskNotSupported) SignalTask(taskID, signal string) error {
return fmt.Errorf("SignalTask is not supported by this driver")
}

View File

@@ -103,7 +103,7 @@ func TestBaseDriver_RecoverTask(t *testing.T) {
defer harness.Kill()
handle := &TaskHandle{
driverState: buf.Bytes(),
DriverState: buf.Bytes(),
}
err := harness.RecoverTask(handle)
require.NoError(err)

View File

@@ -11,7 +11,7 @@ type TaskHandle struct {
Driver string
Config *TaskConfig
State TaskState
driverState []byte
DriverState []byte
}
func NewTaskHandle(driver string) *TaskHandle {
@@ -19,12 +19,12 @@ func NewTaskHandle(driver string) *TaskHandle {
}
func (h *TaskHandle) SetDriverState(v interface{}) error {
h.driverState = []byte{}
return base.MsgPackEncode(&h.driverState, v)
h.DriverState = []byte{}
return base.MsgPackEncode(&h.DriverState, v)
}
func (h *TaskHandle) GetDriverState(v interface{}) error {
return base.MsgPackDecode(h.driverState, v)
return base.MsgPackDecode(h.DriverState, v)
}
@@ -34,7 +34,10 @@ func (h *TaskHandle) Copy() *TaskHandle {
}
handle := new(TaskHandle)
*handle = *h
handle.Driver = h.Driver
handle.Config = h.Config.Copy()
handle.State = h.State
handle.DriverState = make([]byte, len(h.DriverState))
copy(handle.DriverState, h.DriverState)
return handle
}

View File

@@ -194,7 +194,7 @@ func taskHandleFromProto(pb *proto.TaskHandle) *TaskHandle {
return &TaskHandle{
Config: taskConfigFromProto(pb.Config),
State: taskStateFromProtoMap[pb.State],
driverState: pb.DriverState,
DriverState: pb.DriverState,
}
}
@@ -202,7 +202,7 @@ func taskHandleToProto(handle *TaskHandle) *proto.TaskHandle {
return &proto.TaskHandle{
Config: taskConfigToProto(handle.Config),
State: taskStateToProtoMap[handle.State],
DriverState: handle.driverState,
DriverState: handle.DriverState,
}
}