From ae288a3ee657fa49652368bdd6bc270e790c82bb Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Tue, 18 Oct 2016 11:22:16 -0700 Subject: [PATCH] Tests --- client/restarts.go | 31 +- client/task_runner.go | 13 +- client/task_runner_test.go | 531 +++++++++++++++++++++- client/vaultclient/vaultclient_testing.go | 6 + 4 files changed, 549 insertions(+), 32 deletions(-) diff --git a/client/restarts.go b/client/restarts.go index 1c6a6b843..e80d80172 100644 --- a/client/restarts.go +++ b/client/restarts.go @@ -101,6 +101,19 @@ func (r *RestartTracker) GetState() (string, time.Duration) { r.lock.Lock() defer r.lock.Unlock() + // Clear out the existing state + defer func() { + r.startErr = nil + r.waitRes = nil + r.restartTriggered = false + }() + + // Hot path if a restart was triggered + if r.restartTriggered { + r.reason = "" + return structs.TaskRestarting, 0 + } + // Hot path if no attempts are expected if r.policy.Attempts == 0 { r.reason = ReasonNoRestartsAllowed @@ -121,25 +134,13 @@ func (r *RestartTracker) GetState() (string, time.Duration) { r.startTime = now } - var state string - var dur time.Duration if r.startErr != nil { - state, dur = r.handleStartError() + return r.handleStartError() } else if r.waitRes != nil { - state, dur = r.handleWaitResult() - } else if r.restartTriggered { - state, dur = structs.TaskRestarting, 0 - r.reason = "" - } else { - state, dur = "", 0 + return r.handleWaitResult() } - // Clear out the existing state - r.startErr = nil - r.waitRes = nil - r.restartTriggered = false - - return state, dur + return "", 0 } // handleStartError returns the new state and potential wait duration for diff --git a/client/task_runner.go b/client/task_runner.go index 8385bec02..62c5fbd51 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -378,6 +378,7 @@ func (r *TaskRunner) Run() { // NewTaskRunner if r.task.Vault != nil { // Start the go-routine to get a Vault token + r.vaultFuture.Clear() go r.vaultManager(r.recoveredVaultToken) } @@ -578,12 +579,11 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) { for { tokens, err := r.vaultClient.DeriveToken(r.alloc, []string{r.task.Name}) if err != nil { - r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v", r.task.Name, r.alloc.ID, err) - backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline if backoff > vaultBackoffLimit { backoff = vaultBackoffLimit } + r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", r.task.Name, r.alloc.ID, err, backoff) attempts++ @@ -591,11 +591,11 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) { select { case <-r.waitCh: return "", false - case <-time.After(backoff * time.Second): + case <-time.After(backoff): } + } else { + return tokens[r.task.Name], true } - - return tokens[r.task.Name], true } } @@ -646,7 +646,9 @@ func (r *TaskRunner) prestart(resultCh chan bool) { if r.task.Vault != nil { // Wait for the token + r.logger.Printf("[DEBUG] client: waiting for Vault token for task %v in alloc %q", r.task.Name, r.alloc.ID) tokenCh := r.vaultFuture.Wait() + r.logger.Printf("[DEBUG] client: retrieved Vault token for task %v in alloc %q", r.task.Name, r.alloc.ID) select { case <-tokenCh: @@ -1174,6 +1176,7 @@ func (r *TaskRunner) UnblockStart(source string) { } r.logger.Printf("[DEBUG] client: unblocking task %v for alloc %q: %v", r.task.Name, r.alloc.ID, source) + r.unblocked = true close(r.unblockCh) } diff --git a/client/task_runner_test.go b/client/task_runner_test.go index ea0265796..819c35799 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -2,6 +2,7 @@ package client import ( "fmt" + "io/ioutil" "log" "net/http" "net/http/httptest" @@ -221,24 +222,65 @@ func TestTaskRunner_Update(t *testing.T) { } func TestTaskRunner_SaveRestoreState(t *testing.T) { - ctestutil.ExecCompatible(t) - upd, tr := testTaskRunner(false) + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "5s", + } - // Change command to ensure we run for a bit - tr.task.Config["command"] = "/bin/sleep" - tr.task.Config["args"] = []string{"10"} + // Give it a Vault token + task.Vault = &structs.Vault{Policies: []string{"default"}} + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() go tr.Run() defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) - // Snapshot state - time.Sleep(2 * time.Second) + // Wait for the task to be running and then snapshot the state + testutil.WaitForResult(func() (bool, error) { + if l := len(upd.events); l != 2 { + return false, fmt.Errorf("Expect two events; got %v", l) + } + + if upd.events[0].Type != structs.TaskReceived { + return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + if err := tr.SaveState(); err != nil { t.Fatalf("err: %v", err) } + // Read the token from the file system + secretDir, err := tr.ctx.AllocDir.GetSecretDir(task.Name) + if err != nil { + t.Fatalf("failed to determine task %s secret dir: %v", err) + } + + tokenPath := filepath.Join(secretDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + token := string(data) + if len(token) == 0 { + t.Fatalf("Token not written to disk") + } + // Create a new task runner tr2 := NewTaskRunner(tr.logger, tr.config, upd.Update, tr.ctx, tr.alloc, &structs.Task{Name: tr.task.Name}, tr.vaultClient) + tr2.restartTracker = noRestartsTracker() if err := tr2.RestoreState(); err != nil { t.Fatalf("err: %v", err) } @@ -246,11 +288,16 @@ func TestTaskRunner_SaveRestoreState(t *testing.T) { defer tr2.Destroy(structs.NewTaskEvent(structs.TaskKilled)) // Destroy and wait - testutil.WaitForResult(func() (bool, error) { - return tr2.handle != nil, fmt.Errorf("RestoreState() didn't open handle") - }, func(err error) { - t.Fatalf("err: %v", err) - }) + select { + case <-tr2.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + // Check that we recovered the token + if act := tr2.vaultFuture.Get(); act != token { + t.Fatalf("Vault token not properly recovered") + } } func TestTaskRunner_Download_List(t *testing.T) { @@ -558,3 +605,463 @@ func TestTaskRunner_SignalFailure(t *testing.T) { t.Fatalf("Didn't receive error") } } + +func TestTaskRunner_BlockForVault(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "1s", + } + task.Vault = &structs.Vault{Policies: []string{"default"}} + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + + // Control when we get a Vault token + token := "1234" + waitCh := make(chan struct{}) + handler := func(*structs.Allocation, []string) (map[string]string, error) { + <-waitCh + return map[string]string{task.Name: token}, nil + } + tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler + + go tr.Run() + + select { + case <-tr.WaitCh(): + t.Fatalf("premature exit") + case <-time.After(1 * time.Second): + } + + if len(upd.events) != 1 { + t.Fatalf("should have 1 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStatePending { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStatePending) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + // Unblock + close(waitCh) + + select { + case <-tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + if len(upd.events) != 3 { + t.Fatalf("should have 3 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskTerminated { + t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated) + } + + // Check that the token is on disk + secretDir, err := tr.ctx.AllocDir.GetSecretDir(task.Name) + if err != nil { + t.Fatalf("failed to determine task %s secret dir: %v", err) + } + + // Read the token from the file system + tokenPath := filepath.Join(secretDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + + if act := string(data); act != token { + t.Fatalf("Token didn't get written to disk properly, got %q; want %q", act, token) + } +} + +func TestTaskRunner_DeriveToken_Retry(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "1s", + } + task.Vault = &structs.Vault{Policies: []string{"default"}} + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + + // Control when we get a Vault token + token := "1234" + count := 0 + handler := func(*structs.Allocation, []string) (map[string]string, error) { + if count > 0 { + return map[string]string{task.Name: token}, nil + } + + count++ + return nil, fmt.Errorf("Want a retry") + } + tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler + go tr.Run() + + select { + case <-tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + if len(upd.events) != 3 { + t.Fatalf("should have 3 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskTerminated { + t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated) + } + + // Check that the token is on disk + secretDir, err := tr.ctx.AllocDir.GetSecretDir(task.Name) + if err != nil { + t.Fatalf("failed to determine task %s secret dir: %v", err) + } + + // Read the token from the file system + tokenPath := filepath.Join(secretDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + + if act := string(data); act != token { + t.Fatalf("Token didn't get written to disk properly, got %q; want %q", act, token) + } +} + +func TestTaskRunner_Template_Block(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "1s", + } + task.Templates = []*structs.Template{ + { + EmbeddedTmpl: "{{key \"foo\"}}", + DestPath: "local/test", + ChangeMode: structs.TemplateChangeModeNoop, + }, + } + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + + go tr.Run() + + select { + case <-tr.WaitCh(): + t.Fatalf("premature exit") + case <-time.After(1 * time.Second): + } + + if len(upd.events) != 1 { + t.Fatalf("should have 1 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStatePending { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStatePending) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + // Unblock + tr.UnblockStart("test") + + select { + case <-tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + if len(upd.events) != 3 { + t.Fatalf("should have 3 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskTerminated { + t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated) + } +} + +func TestTaskRunner_Template_NewVaultToken(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "1s", + } + task.Templates = []*structs.Template{ + { + EmbeddedTmpl: "{{key \"foo\"}}", + DestPath: "local/test", + ChangeMode: structs.TemplateChangeModeNoop, + }, + } + task.Vault = &structs.Vault{Policies: []string{"default"}} + + _, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + go tr.Run() + + // Wait for a Vault token + var token string + testutil.WaitForResult(func() (bool, error) { + if token = tr.vaultFuture.Get(); token == "" { + return false, fmt.Errorf("No Vault token") + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + + // Error the token renewal + vc := tr.vaultClient.(*vaultclient.MockVaultClient) + renewalCh, ok := vc.RenewTokens[token] + if !ok { + t.Fatalf("no renewal channel") + } + + originalManager := tr.templateManager + + renewalCh <- fmt.Errorf("Test killing") + close(renewalCh) + + // Wait for a new Vault token + var token2 string + testutil.WaitForResult(func() (bool, error) { + if token2 = tr.vaultFuture.Get(); token2 == "" || token2 == token { + return false, fmt.Errorf("No new Vault token") + } + + if originalManager == tr.templateManager { + return false, fmt.Errorf("Template manager not updated") + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} + +func TestTaskRunner_VaultManager_Restart(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + ChangeMode: structs.VaultChangeModeRestart, + } + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + go tr.Run() + + // Wait for the task to start + testutil.WaitForResult(func() (bool, error) { + if l := len(upd.events); l != 2 { + return false, fmt.Errorf("Expect two events; got %v", l) + } + + if upd.events[0].Type != structs.TaskReceived { + return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + + // Error the token renewal + vc := tr.vaultClient.(*vaultclient.MockVaultClient) + renewalCh, ok := vc.RenewTokens[tr.vaultFuture.Get()] + if !ok { + t.Fatalf("no renewal channel") + } + + renewalCh <- fmt.Errorf("Test killing") + close(renewalCh) + + // Ensure a restart + testutil.WaitForResult(func() (bool, error) { + if l := len(upd.events); l != 7 { + return false, fmt.Errorf("Expect seven events; got %#v", upd.events) + } + + if upd.events[0].Type != structs.TaskReceived { + return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskRestartSignal { + return false, fmt.Errorf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskRestartSignal) + } + + if upd.events[3].Type != structs.TaskKilling { + return false, fmt.Errorf("Fourth Event was %v; want %v", upd.events[3].Type, structs.TaskKilling) + } + + if upd.events[4].Type != structs.TaskKilled { + return false, fmt.Errorf("Fifth Event was %v; want %v", upd.events[4].Type, structs.TaskKilled) + } + + if upd.events[5].Type != structs.TaskRestarting { + return false, fmt.Errorf("Sixth Event was %v; want %v", upd.events[5].Type, structs.TaskRestarting) + } + + if upd.events[6].Type != structs.TaskStarted { + return false, fmt.Errorf("Seventh Event was %v; want %v", upd.events[6].Type, structs.TaskStarted) + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} + +func TestTaskRunner_VaultManager_Signal(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + ChangeMode: structs.VaultChangeModeSignal, + ChangeSignal: "SIGUSR1", + } + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + go tr.Run() + + // Wait for the task to start + testutil.WaitForResult(func() (bool, error) { + if l := len(upd.events); l != 2 { + return false, fmt.Errorf("Expect two events; got %v", l) + } + + if upd.events[0].Type != structs.TaskReceived { + return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + + // Error the token renewal + vc := tr.vaultClient.(*vaultclient.MockVaultClient) + renewalCh, ok := vc.RenewTokens[tr.vaultFuture.Get()] + if !ok { + t.Fatalf("no renewal channel") + } + + renewalCh <- fmt.Errorf("Test killing") + close(renewalCh) + + // Ensure a restart + testutil.WaitForResult(func() (bool, error) { + if l := len(upd.events); l != 3 { + return false, fmt.Errorf("Expect three events; got %#v", upd.events) + } + + if upd.events[0].Type != structs.TaskReceived { + return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskSignaling { + return false, fmt.Errorf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskSignaling) + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index 7f9310068..fc935b184 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -21,12 +21,18 @@ type MockVaultClient struct { // DeriveTokenErrors maps an allocation ID and tasks to an error when the // token is derived DeriveTokenErrors map[string]map[string]error + + DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) } // NewMockVaultClient returns a MockVaultClient for testing func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { + if vc.DeriveTokenFn != nil { + return vc.DeriveTokenFn(a, tasks) + } + tokens := make(map[string]string, len(tasks)) for _, task := range tasks { if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {