diff --git a/client/alloc_runner.go b/client/alloc_runner.go index 114d0c8d5..c434ef65f 100644 --- a/client/alloc_runner.go +++ b/client/alloc_runner.go @@ -102,11 +102,10 @@ func (r *AllocRunner) RestoreState() error { r.ctx = snap.Context // Restore the task runners - jobType := r.alloc.Job.Type var mErr multierror.Error for name := range r.taskStatus { task := &structs.Task{Name: name} - restartTracker := newRestartTracker(jobType, r.RestartPolicy) + restartTracker := newRestartTracker(r.alloc.Job.Type, r.RestartPolicy) tr := NewTaskRunner(r.logger, r.config, r.setTaskStatus, r.ctx, r.alloc.ID, task, restartTracker) r.tasks[name] = tr if err := tr.RestoreState(); err != nil { @@ -309,8 +308,7 @@ func (r *AllocRunner) Run() { // Merge in the task resources task.Resources = alloc.TaskResources[task.Name] - jobType := r.alloc.Job.Type - restartTracker := newRestartTracker(jobType, r.RestartPolicy) + restartTracker := newRestartTracker(r.alloc.Job.Type, r.RestartPolicy) tr := NewTaskRunner(r.logger, r.config, r.setTaskStatus, r.ctx, r.alloc.ID, task, restartTracker) r.tasks[task.Name] = tr go tr.Run() diff --git a/client/restarts.go b/client/restarts.go index b06b3f179..4141405f8 100644 --- a/client/restarts.go +++ b/client/restarts.go @@ -42,8 +42,8 @@ func (b *batchRestartTracker) increment() { } func (b *batchRestartTracker) nextRestart() (bool, time.Duration) { - defer b.increment() if b.count < b.maxAttempts { + b.increment() return true, b.delay } return false, 0 diff --git a/client/restarts_test.go b/client/restarts_test.go index a200d3beb..e27f10390 100644 --- a/client/restarts_test.go +++ b/client/restarts_test.go @@ -36,20 +36,21 @@ func TestTaskRunner_ServiceRestartCounter(t *testing.T) { } func TestTaskRunner_BatchRestartCounter(t *testing.T) { - rt := newRestartTracker(structs.JobTypeBatch, &structs.RestartPolicy{Attempts: 2, Interval: 1 * time.Second, Delay: 1 * time.Second}) - rt.nextRestart() - rt.nextRestart() - rt.nextRestart() - rt.nextRestart() - rt.nextRestart() + attempts := 2 + interval := 1 * time.Second + delay := 1 * time.Second + rt := newRestartTracker(structs.JobTypeBatch, &structs.RestartPolicy{Attempts: attempts, Interval: interval, Delay: delay}) + for i := 0; i < attempts; i++ { + shouldRestart, when := rt.nextRestart() + if !shouldRestart { + t.Fatalf("should restart returned %v, actual %v", shouldRestart, true) + } + if when != delay { + t.Fatalf("Delay should be %v, actual: %v", delay, when) + } + } actual, _ := rt.nextRestart() if actual { t.Fatalf("Expect %v, Actual: %v", false, actual) } - - time.Sleep(1 * time.Second) - actual, _ = rt.nextRestart() - if actual { - t.Fatalf("Expect %v, Actual: %v", false, actual) - } } diff --git a/client/task_runner.go b/client/task_runner.go index b868968df..a59c72fb8 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -195,15 +195,9 @@ func (r *TaskRunner) Run() { r.logger.Printf("[INFO] client: Restarting Task: %v", r.task.Name) r.setStatus(structs.AllocClientStatusPending, "Task Restarting") r.logger.Printf("[DEBUG] client: Sleeping for %v before restarting Task %v", when, r.task.Name) - ch := time.After(when) - L: - for { - select { - case <-ch: - break L - case <-r.destroyCh: - break L - } + select { + case <-time.After(when): + case <-r.destroyCh: } r.destroyLock.Lock() if r.destroy {