From cd8784894d8923527ff9020b8b1a422de8a6311d Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Thu, 15 Sep 2016 17:24:09 -0700 Subject: [PATCH] Alloc runner tests --- client/alloc_runner_test.go | 285 +++++++++++++++++++++- client/driver/mock_driver.go | 64 ++++- client/vaultclient/vaultclient_testing.go | 90 +++++++ 3 files changed, 415 insertions(+), 24 deletions(-) create mode 100644 client/vaultclient/vaultclient_testing.go diff --git a/client/alloc_runner_test.go b/client/alloc_runner_test.go index 2f21483fc..1d49e44de 100644 --- a/client/alloc_runner_test.go +++ b/client/alloc_runner_test.go @@ -3,7 +3,9 @@ package client import ( "bufio" "fmt" + "io/ioutil" "os" + "path/filepath" "testing" "time" @@ -36,7 +38,7 @@ func testAllocRunnerFromAlloc(alloc *structs.Allocation, restarts bool) (*MockAl *alloc.Job.LookupTaskGroup(alloc.TaskGroup).RestartPolicy = structs.RestartPolicy{Attempts: 0} alloc.Job.Type = structs.JobTypeBatch } - vclient, _ := vaultclient.NewVaultClient(conf.VaultConfig, logger, nil) + vclient := vaultclient.NewMockVaultClient() ar := NewAllocRunner(logger, conf, upd.Update, alloc, vclient) return upd, ar } @@ -392,13 +394,15 @@ func TestAllocRunner_Update(t *testing.T) { } func TestAllocRunner_SaveRestoreState(t *testing.T) { - ctestutil.ExecCompatible(t) - upd, ar := testAllocRunner(false) + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } - // Ensure task takes some time - task := ar.alloc.Job.TaskGroups[0].Tasks[0] - task.Config["command"] = "/bin/sleep" - task.Config["args"] = []string{"10"} + upd, ar := testAllocRunnerFromAlloc(alloc, false) go ar.Run() // Snapshot state @@ -422,21 +426,36 @@ func TestAllocRunner_SaveRestoreState(t *testing.T) { } go ar2.Run() + testutil.WaitForResult(func() (bool, error) { + if len(ar2.tasks) != 1 { + return false, fmt.Errorf("Incorrect number of tasks") + } + + if upd.Count == 0 { + return false, nil + } + + last := upd.Allocs[upd.Count-1] + return last.ClientStatus == structs.AllocClientStatusRunning, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + // Destroy and wait ar2.Destroy() start := time.Now() testutil.WaitForResult(func() (bool, error) { - if upd.Count == 0 { - return false, nil + alloc := ar2.Alloc() + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusComplete) } - last := upd.Allocs[upd.Count-1] - return last.ClientStatus != structs.AllocClientStatusPending, nil + return true, nil }, func(err error) { t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) }) - if time.Since(start) > time.Duration(testutil.TestMultiplier()*15)*time.Second { + if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second { t.Fatalf("took too long to terminate") } } @@ -599,3 +618,245 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) { t.Fatalf("err: %v", err) }) } + +func TestAllocRunner_SimpleRun_VaultToken(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{"exit_code": "0"} + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, ar := testAllocRunnerFromAlloc(alloc, false) + go ar.Run() + defer ar.Destroy() + + testutil.WaitForResult(func() (bool, error) { + if upd.Count == 0 { + return false, fmt.Errorf("No updates") + } + last := upd.Allocs[upd.Count-1] + if last.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("got status %v; want %v", last.ClientStatus, structs.AllocClientStatusComplete) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + + tr, ok := ar.tasks[task.Name] + if !ok { + t.Fatalf("No task runner made") + } + + // Check that the task runner was given the token + token := tr.vaultToken + if token == "" || tr.vaultRenewalCh == nil { + t.Fatalf("Vault token not set properly") + } + + // Check that it was written to disk + secretDir, err := ar.ctx.AllocDir.GetSecretDir(task.Name) + if err != nil { + t.Fatalf("bad: %v", err) + } + + tokenPath := filepath.Join(secretDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + if err != nil { + t.Fatalf("token not written to disk: %v", err) + } + + if string(data) != token { + t.Fatalf("Bad token written to disk") + } + + // Check that we stopped renewing the token + mockVC := ar.vaultClient.(*vaultclient.MockVaultClient) + if len(mockVC.StoppedTokens) != 1 || mockVC.StoppedTokens[0] != token { + t.Fatalf("We didn't stop renewing the token") + } +} + +func TestAllocRunner_SaveRestoreState_VaultTokens_Valid(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, ar := testAllocRunnerFromAlloc(alloc, false) + go ar.Run() + + // Snapshot state + var token string + testutil.WaitForResult(func() (bool, error) { + if len(ar.tasks) != 1 { + return false, fmt.Errorf("Task not started") + } + + tr, ok := ar.tasks[task.Name] + if !ok { + return false, fmt.Errorf("Incorrect task runner") + } + + if tr.vaultToken == "" { + return false, fmt.Errorf("Bad token") + } + + token = tr.vaultToken + return true, nil + }, func(err error) { + t.Fatalf("task never started: %v", err) + }) + + err := ar.SaveState() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create a new alloc runner + ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) + err = ar2.RestoreState() + if err != nil { + t.Fatalf("err: %v", err) + } + go ar2.Run() + + testutil.WaitForResult(func() (bool, error) { + if len(ar2.tasks) != 1 { + return false, fmt.Errorf("Incorrect number of tasks") + } + + tr, ok := ar2.tasks[task.Name] + if !ok { + return false, fmt.Errorf("Incorrect task runner") + } + + if tr.vaultToken != token { + return false, fmt.Errorf("Got token %q; want %q", tr.vaultToken, token) + } + + if upd.Count == 0 { + return false, nil + } + + last := upd.Allocs[upd.Count-1] + return last.ClientStatus == structs.AllocClientStatusRunning, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + // Destroy and wait + ar2.Destroy() + start := time.Now() + + testutil.WaitForResult(func() (bool, error) { + alloc := ar2.Alloc() + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusComplete) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second { + t.Fatalf("took too long to terminate") + } +} + +func TestAllocRunner_SaveRestoreState_VaultTokens_Invalid(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, ar := testAllocRunnerFromAlloc(alloc, false) + go ar.Run() + + // Snapshot state + var token string + testutil.WaitForResult(func() (bool, error) { + if len(ar.tasks) != 1 { + return false, fmt.Errorf("Task not started") + } + + tr, ok := ar.tasks[task.Name] + if !ok { + return false, fmt.Errorf("Incorrect task runner") + } + + if tr.vaultToken == "" { + return false, fmt.Errorf("Bad token") + } + + token = tr.vaultToken + return true, nil + }, func(err error) { + t.Fatalf("task never started: %v", err) + }) + + err := ar.SaveState() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create a new alloc runner + ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) + + // Invalidate the token + mockVC := ar2.vaultClient.(*vaultclient.MockVaultClient) + renewErr := fmt.Errorf("Test disallowing renewal") + mockVC.SetRenewTokenError(token, renewErr) + + // Restore and run + err = ar2.RestoreState() + if err != nil { + t.Fatalf("err: %v", err) + } + go ar2.Run() + + testutil.WaitForResult(func() (bool, error) { + if upd.Count == 0 { + return false, nil + } + + last := upd.Allocs[upd.Count-1] + return last.ClientStatus == structs.AllocClientStatusFailed, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + // Destroy and wait + ar2.Destroy() + start := time.Now() + + testutil.WaitForResult(func() (bool, error) { + alloc := ar2.Alloc() + if alloc.ClientStatus != structs.AllocClientStatusFailed { + return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusFailed) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second { + t.Fatalf("took too long to terminate") + } +} diff --git a/client/driver/mock_driver.go b/client/driver/mock_driver.go index 4e46a3997..5ff2219be 100644 --- a/client/driver/mock_driver.go +++ b/client/driver/mock_driver.go @@ -3,7 +3,9 @@ package driver import ( + "encoding/json" "errors" + "fmt" "log" "time" @@ -90,19 +92,11 @@ func (m *MockDriver) Start(ctx *ExecContext, task *structs.Task) (DriverHandle, return &h, nil } -// TODO implement Open when we need it. -// Open re-connects the driver to the running task -func (m *MockDriver) Open(ctx *ExecContext, handleID string) (DriverHandle, error) { - return nil, nil -} - -// TODO implement Open when we need it. // Validate validates the mock driver configuration func (m *MockDriver) Validate(map[string]interface{}) error { return nil } -// TODO implement Open when we need it. // Fingerprint fingerprints a node and returns if MockDriver is enabled func (m *MockDriver) Fingerprint(cfg *config.Config, node *structs.Node) (bool, error) { node.Attributes["driver.mock_driver"] = "1" @@ -123,12 +117,58 @@ type mockDriverHandle struct { doneCh chan struct{} } -// TODO Implement when we need it. -func (h *mockDriverHandle) ID() string { - return "" +type mockDriverID struct { + TaskName string + RunFor time.Duration + KillAfter time.Duration + KillTimeout time.Duration + ExitCode int + ExitSignal int + ExitErr error +} + +func (h *mockDriverHandle) ID() string { + id := mockDriverID{ + TaskName: h.taskName, + RunFor: h.runFor, + KillAfter: h.killAfter, + KillTimeout: h.killAfter, + ExitCode: h.exitCode, + ExitSignal: h.exitSignal, + ExitErr: h.exitErr, + } + + data, err := json.Marshal(id) + if err != nil { + h.logger.Printf("[ERR] driver.mock_driver: failed to marshal ID to JSON: %s", err) + } + return string(data) +} + +// Open re-connects the driver to the running task +func (m *MockDriver) Open(ctx *ExecContext, handleID string) (DriverHandle, error) { + id := &mockDriverID{} + if err := json.Unmarshal([]byte(handleID), id); err != nil { + return nil, fmt.Errorf("Failed to parse handle '%s': %v", handleID, err) + } + + h := mockDriverHandle{ + taskName: id.TaskName, + runFor: id.RunFor, + killAfter: id.KillAfter, + killTimeout: id.KillTimeout, + exitCode: id.ExitCode, + exitSignal: id.ExitSignal, + exitErr: id.ExitErr, + logger: m.logger, + doneCh: make(chan struct{}), + waitCh: make(chan *dstructs.WaitResult, 1), + } + + go h.run() + return &h, nil } -// TODO Implement when we need it. func (h *mockDriverHandle) WaitCh() chan *dstructs.WaitResult { return h.waitCh } diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go new file mode 100644 index 000000000..7f9310068 --- /dev/null +++ b/client/vaultclient/vaultclient_testing.go @@ -0,0 +1,90 @@ +package vaultclient + +import ( + "github.com/hashicorp/nomad/nomad/structs" + vaultapi "github.com/hashicorp/vault/api" +) + +// MockVaultClient is used for testing the vaultclient integration +type MockVaultClient struct { + // StoppedTokens tracks the tokens that have stopped renewing + StoppedTokens []string + + // RenewTokens are the tokens that have been renewed and their error + // channels + RenewTokens map[string]chan error + + // RenewTokenErrors is used to return an error when the RenewToken is called + // with the given token + RenewTokenErrors map[string]error + + // DeriveTokenErrors maps an allocation ID and tasks to an error when the + // token is derived + DeriveTokenErrors map[string]map[string]error +} + +// NewMockVaultClient returns a MockVaultClient for testing +func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } + +func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { + tokens := make(map[string]string, len(tasks)) + for _, task := range tasks { + if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok { + if err, ok := tasks[task]; ok { + return nil, err + } + } + + tokens[task] = structs.GenerateUUID() + } + + return tokens, nil +} + +func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { + if vc.DeriveTokenErrors == nil { + vc.DeriveTokenErrors = make(map[string]map[string]error, 10) + } + + if _, ok := vc.RenewTokenErrors[allocID]; !ok { + vc.DeriveTokenErrors[allocID] = make(map[string]error, 10) + } + + for _, task := range tasks { + vc.DeriveTokenErrors[allocID][task] = err + } +} + +func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { + if err, ok := vc.RenewTokenErrors[token]; ok { + return nil, err + } + + renewCh := make(chan error) + if vc.RenewTokens == nil { + vc.RenewTokens = make(map[string]chan error, 10) + } + vc.RenewTokens[token] = renewCh + return renewCh, nil +} + +func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { + if vc.RenewTokenErrors == nil { + vc.RenewTokenErrors = make(map[string]error, 10) + } + + vc.RenewTokenErrors[token] = err +} + +func (vc *MockVaultClient) StopRenewToken(token string) error { + vc.StoppedTokens = append(vc.StoppedTokens, token) + return nil +} + +func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) { + return nil, nil +} +func (vc *MockVaultClient) StopRenewLease(leaseId string) error { return nil } +func (vc *MockVaultClient) Start() {} +func (vc *MockVaultClient) Stop() {} +func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }