diff --git a/client/alloc_runner.go b/client/alloc_runner.go index 3129b0fb3..114d0c8d5 100644 --- a/client/alloc_runner.go +++ b/client/alloc_runner.go @@ -106,7 +106,8 @@ func (r *AllocRunner) RestoreState() error { var mErr multierror.Error for name := range r.taskStatus { task := &structs.Task{Name: name} - tr := NewTaskRunner(r.logger, r.config, r.setTaskStatus, r.ctx, r.alloc.ID, task, jobType, r.RestartPolicy) + restartTracker := newRestartTracker(jobType, r.RestartPolicy) + tr := NewTaskRunner(r.logger, r.config, r.setTaskStatus, r.ctx, r.alloc.ID, task, restartTracker) r.tasks[name] = tr if err := tr.RestoreState(); err != nil { r.logger.Printf("[ERR] client: failed to restore state for alloc %s task '%s': %v", r.alloc.ID, name, err) @@ -309,7 +310,8 @@ func (r *AllocRunner) Run() { // Merge in the task resources task.Resources = alloc.TaskResources[task.Name] jobType := r.alloc.Job.Type - tr := NewTaskRunner(r.logger, r.config, r.setTaskStatus, r.ctx, r.alloc.ID, task, jobType, r.RestartPolicy) + restartTracker := newRestartTracker(jobType, r.RestartPolicy) + tr := NewTaskRunner(r.logger, r.config, r.setTaskStatus, r.ctx, r.alloc.ID, task, restartTracker) r.tasks[task.Name] = tr go tr.Run() } diff --git a/client/restarts.go b/client/restarts.go index a18518473..4004f82f8 100644 --- a/client/restarts.go +++ b/client/restarts.go @@ -11,7 +11,6 @@ import ( // will be restarted only upto maxAttempts times type restartTracker interface { nextRestart() (bool, time.Duration) - increment() } func newRestartTracker(jobType string, restartPolicy *structs.RestartPolicy) restartTracker { @@ -38,11 +37,8 @@ type batchRestartTracker struct { count int } -func (b *batchRestartTracker) increment() { - b.count = b.count + 1 -} - func (b *batchRestartTracker) nextRestart() (bool, time.Duration) { + b.count += 1 if b.count < b.maxAttempts { return true, b.delay } @@ -58,24 +54,22 @@ type serviceRestartTracker struct { startTime time.Time } -func (c *serviceRestartTracker) increment() { - if c.count <= c.maxAttempts { - c.count = c.count + 1 - } -} - -func (c *serviceRestartTracker) nextRestart() (bool, time.Duration) { - windowEndTime := c.startTime.Add(c.interval) +func (s *serviceRestartTracker) nextRestart() (bool, time.Duration) { + s.count += 1 + windowEndTime := s.startTime.Add(s.interval) now := time.Now() + // If the window of restart is over we wait until the delay duration if now.After(windowEndTime) { - c.count = 0 - c.startTime = time.Now() - return true, c.delay + s.count = 0 + s.startTime = time.Now() + return true, s.delay } - if c.count < c.maxAttempts { - return true, c.delay + // If we are within the delay duration and didn't exhaust all retries + if s.count < s.maxAttempts { + return true, s.delay } + // If we exhausted all the retries and are withing the time window return true, windowEndTime.Sub(now) } diff --git a/client/restarts_test.go b/client/restarts_test.go index 8015afd6e..952d33649 100644 --- a/client/restarts_test.go +++ b/client/restarts_test.go @@ -8,11 +8,11 @@ import ( func TestTaskRunner_ServiceRestartCounter(t *testing.T) { rt := newRestartTracker(structs.JobTypeService, &structs.RestartPolicy{Attempts: 2, Interval: 2 * time.Minute, Delay: 1 * time.Second}) - rt.increment() - rt.increment() - rt.increment() - rt.increment() - rt.increment() + rt.nextRestart() + rt.nextRestart() + rt.nextRestart() + rt.nextRestart() + rt.nextRestart() actual, _ := rt.nextRestart() if !actual { t.Fatalf("Expect %v, Actual: %v", true, actual) @@ -21,11 +21,11 @@ func TestTaskRunner_ServiceRestartCounter(t *testing.T) { func TestTaskRunner_BatchRestartCounter(t *testing.T) { rt := newRestartTracker(structs.JobTypeBatch, &structs.RestartPolicy{Attempts: 2, Interval: 1 * time.Second, Delay: 1 * time.Second}) - rt.increment() - rt.increment() - rt.increment() - rt.increment() - rt.increment() + rt.nextRestart() + rt.nextRestart() + rt.nextRestart() + rt.nextRestart() + rt.nextRestart() actual, _ := rt.nextRestart() if actual { t.Fatalf("Expect %v, Actual: %v", false, actual) diff --git a/client/task_runner.go b/client/task_runner.go index 21649d6c1..ae97fb3c7 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -24,10 +24,9 @@ type TaskRunner struct { allocID string restartTracker restartTracker - task *structs.Task - restartPolicy *structs.RestartPolicy - updateCh chan *structs.Task - handle driver.DriverHandle + task *structs.Task + updateCh chan *structs.Task + handle driver.DriverHandle destroy bool destroyCh chan struct{} @@ -47,19 +46,16 @@ type TaskStateUpdater func(taskName, status, desc string) // NewTaskRunner is used to create a new task context func NewTaskRunner(logger *log.Logger, config *config.Config, updater TaskStateUpdater, ctx *driver.ExecContext, - allocID string, task *structs.Task, taskType string, - restartPolicy *structs.RestartPolicy) *TaskRunner { + allocID string, task *structs.Task, restartTracker restartTracker) *TaskRunner { - rt := newRestartTracker(taskType, restartPolicy) tc := &TaskRunner{ config: config, updater: updater, logger: logger, - restartTracker: rt, + restartTracker: restartTracker, ctx: ctx, allocID: allocID, task: task, - restartPolicy: restartPolicy, updateCh: make(chan *structs.Task, 8), destroyCh: make(chan struct{}), waitCh: make(chan struct{}), @@ -189,7 +185,6 @@ func (r *TaskRunner) Run() { for err != nil { r.logger.Printf("[ERR] client: failed to complete task '%s' for alloc '%s': %v", r.task.Name, r.allocID, err) - r.restartTracker.increment() shouldRestart, when := r.restartTracker.nextRestart() if !shouldRestart { r.logger.Printf("[INFO] Not restarting") @@ -198,6 +193,7 @@ func (r *TaskRunner) Run() { } r.logger.Printf("[INFO] Restarting Task: %v", r.task.Name) + r.setStatus(structs.AllocClientStatusPending, "Task Restarting") r.logger.Printf("[DEBUG] Sleeping for %v before restarting Task %v", when, r.task.Name) ch := time.After(when) L: diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 3d5199670..7a7242e7b 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -53,7 +53,8 @@ func testTaskRunner() (*MockTaskStateUpdater, *TaskRunner) { ctx := driver.NewExecContext(allocDir) rp := structs.NewRestartPolicy(structs.JobTypeService) - tr := NewTaskRunner(logger, conf, upd.Update, ctx, alloc.ID, task, structs.JobTypeService, rp) + restartTracker := newRestartTracker(structs.JobTypeService, rp) + tr := NewTaskRunner(logger, conf, upd.Update, ctx, alloc.ID, task, restartTracker) return upd, tr }