From c89fd0eb089db13f82d3262edd2eedb10c85533e Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Wed, 14 Sep 2016 15:04:25 -0700 Subject: [PATCH] Clean up vault client --- client/alloc_runner.go | 82 +++++++++++----------- client/alloc_runner_test.go | 13 ++-- client/client.go | 10 --- client/task_runner.go | 2 +- client/vaultclient/vaultclient.go | 96 +++++++------------------- client/vaultclient/vaultclient_test.go | 43 ++---------- 6 files changed, 84 insertions(+), 162 deletions(-) diff --git a/client/alloc_runner.go b/client/alloc_runner.go index a657ef785..dda23cdef 100644 --- a/client/alloc_runner.go +++ b/client/alloc_runner.go @@ -69,6 +69,7 @@ type AllocRunner struct { updateCh chan *structs.Allocation vaultClient vaultclient.VaultClient + vaultTokens map[string]vaultToken destroy bool destroyCh chan struct{} @@ -141,7 +142,7 @@ func (r *AllocRunner) RestoreState() error { } // Recover the Vault tokens - tokens, vaultErr := r.recoverVaultTokens() + vaultErr := r.recoverVaultTokens() // Restore the task runners var mErr multierror.Error @@ -154,7 +155,7 @@ func (r *AllocRunner) RestoreState() error { task) r.tasks[name] = tr - if vt, ok := tokens[name]; ok { + if vt, ok := r.vaultTokens[name]; ok { tr.SetVaultToken(vt.token, vt.renewalCh) } @@ -357,17 +358,26 @@ func (r *AllocRunner) setTaskState(taskName, state string, event *structs.TaskEv taskState.State = state r.appendTaskEvent(taskState, event) - // If the task failed, we should kill all the other tasks in the task group. - if state == structs.TaskStateDead && taskState.Failed() { - var destroyingTasks []string - for task, tr := range r.tasks { - if task != taskName { - destroyingTasks = append(destroyingTasks, task) - tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName)) + if state == structs.TaskStateDead { + // If the task has a Vault token, stop renewing it + if vt, ok := r.vaultTokens[taskName]; ok { + if err := r.vaultClient.StopRenewToken(vt.token); err != nil { + r.logger.Printf("[ERR] client: stopping token renewal for task %q failed: %v", taskName, err) } } - if len(destroyingTasks) > 0 { - r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks) + + // If the task failed, we should kill all the other tasks in the task group. + if taskState.Failed() { + var destroyingTasks []string + for task, tr := range r.tasks { + if task != taskName { + destroyingTasks = append(destroyingTasks, task) + tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName)) + } + } + if len(destroyingTasks) > 0 { + r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks) + } } } @@ -433,7 +443,7 @@ func (r *AllocRunner) Run() { } // Request Vault tokens for the tasks that require them - tokens, err := r.deriveVaultTokens() + err := r.deriveVaultTokens() if err != nil { msg := fmt.Sprintf("failed to derive Vault token for allocation %q: %v", r.alloc.ID, err) r.logger.Printf("[ERR] client: %s", msg) @@ -454,7 +464,7 @@ func (r *AllocRunner) Run() { tr.MarkReceived() // If the task has a vault token set it before running - if vt, ok := tokens[task.Name]; ok { + if vt, ok := r.vaultTokens[task.Name]; ok { tr.SetVaultToken(vt.token, vt.renewalCh) } @@ -537,19 +547,14 @@ type vaultToken struct { // tasks to their respective vault token and renewal channel. This must be // called after the allocation directory is created as the vault tokens are // written to disk. -func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) { +func (r *AllocRunner) deriveVaultTokens() error { required, err := r.tasksRequiringVaultTokens() if err != nil { - return nil, err + return err } if len(required) == 0 { - return nil, nil - } - - // TODO Remove once the vault client isn't nil - if r.vaultClient == nil { - return nil, fmt.Errorf("Requesting Vault tokens when not enabled on the client") + return nil } renewingTokens := make(map[string]vaultToken, len(required)) @@ -557,7 +562,7 @@ func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) { // Get the tokens tokens, err := r.vaultClient.DeriveToken(r.Alloc(), required) if err != nil { - return nil, fmt.Errorf("failed to derive Vault tokens: %v", err) + return fmt.Errorf("failed to derive Vault tokens: %v", err) } // Persist the tokens to the appropriate secret directories @@ -565,17 +570,17 @@ func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) { for task, token := range tokens { secretDir, err := adir.GetSecretDir(task) if err != nil { - return nil, fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err) + return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err) } // Write the token to the file system tokenPath := filepath.Join(secretDir, vaultTokenFile) if err := ioutil.WriteFile(tokenPath, []byte(token), 0777); err != nil { - return nil, fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err) + return fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err) } // Start renewing the token - err, renewCh := r.vaultClient.RenewToken(token, 10) + renewCh, err := r.vaultClient.RenewToken(token, 10) if err != nil { var mErr multierror.Error errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err) @@ -588,12 +593,13 @@ func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) { } } - return nil, mErr.ErrorOrNil() + return mErr.ErrorOrNil() } renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh} } - return renewingTokens, nil + r.vaultTokens = renewingTokens + return nil } func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) { @@ -617,19 +623,14 @@ func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) { // recoverVaultTokens reads the Vault tokens for the tasks that have Vault // tokens off disk. If there is an error, it is returned, otherwise token // renewal is started. -func (r *AllocRunner) recoverVaultTokens() (map[string]vaultToken, error) { - // TODO remove once the vault client is never nil - if r.vaultClient == nil { - return nil, nil - } - +func (r *AllocRunner) recoverVaultTokens() error { required, err := r.tasksRequiringVaultTokens() if err != nil { - return nil, err + return err } if len(required) == 0 { - return nil, nil + return nil } // Read the tokens and start renewing them @@ -638,18 +639,18 @@ func (r *AllocRunner) recoverVaultTokens() (map[string]vaultToken, error) { for _, task := range required { secretDir, err := adir.GetSecretDir(task) if err != nil { - return nil, fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err) + return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err) } // Write the token to the file system tokenPath := filepath.Join(secretDir, vaultTokenFile) data, err := ioutil.ReadFile(tokenPath) if err != nil { - return nil, fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err) + return fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err) } token := string(data) - err, renewCh := r.vaultClient.RenewToken(token, 10) + renewCh, err := r.vaultClient.RenewToken(token, 10) if err != nil { var mErr multierror.Error errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err) @@ -662,13 +663,14 @@ func (r *AllocRunner) recoverVaultTokens() (map[string]vaultToken, error) { } } - return nil, mErr.ErrorOrNil() + return mErr.ErrorOrNil() } renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh} } - return renewingTokens, nil + r.vaultTokens = renewingTokens + return nil } // checkResources monitors and enforces alloc resource usage. It returns an diff --git a/client/alloc_runner_test.go b/client/alloc_runner_test.go index 7e72fae2a..201ac3892 100644 --- a/client/alloc_runner_test.go +++ b/client/alloc_runner_test.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/nomad/client/config" ctestutil "github.com/hashicorp/nomad/client/testutil" + "github.com/hashicorp/nomad/client/vaultclient" ) type MockAllocStateUpdater struct { @@ -35,7 +36,8 @@ func testAllocRunnerFromAlloc(alloc *structs.Allocation, restarts bool) (*MockAl *alloc.Job.LookupTaskGroup(alloc.TaskGroup).RestartPolicy = structs.RestartPolicy{Attempts: 0} alloc.Job.Type = structs.JobTypeBatch } - ar := NewAllocRunner(logger, conf, upd.Update, alloc) + vclient, _ := vaultclient.NewVaultClient(conf.VaultConfig, logger, nil) + ar := NewAllocRunner(logger, conf, upd.Update, alloc, vclient) return upd, ar } @@ -413,7 +415,7 @@ func TestAllocRunner_SaveRestoreState(t *testing.T) { // Create a new alloc runner ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, - &structs.Allocation{ID: ar.alloc.ID}) + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) err = ar2.RestoreState() if err != nil { t.Fatalf("err: %v", err) @@ -485,7 +487,7 @@ func TestAllocRunner_SaveRestoreState_TerminalAlloc(t *testing.T) { // Create a new alloc runner ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, - &structs.Allocation{ID: ar.alloc.ID}) + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) ar2.logger = prefixedTestLogger("ar2: ") err = ar2.RestoreState() if err != nil { @@ -576,7 +578,10 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) { if state1.State != structs.TaskStateDead { return false, fmt.Errorf("got state %v; want %v", state1.State, structs.TaskStateDead) } - if lastE := state1.Events[len(state1.Events)-1]; lastE.Type != structs.TaskSiblingFailed { + if len(state1.Events) < 3 { + return false, fmt.Errorf("Unexpected number of events") + } + if lastE := state1.Events[len(state1.Events)-3]; lastE.Type != structs.TaskSiblingFailed { return false, fmt.Errorf("got last event %v; want %v", lastE.Type, structs.TaskSiblingFailed) } diff --git a/client/client.go b/client/client.go index 202ca0ba0..c3d894db6 100644 --- a/client/client.go +++ b/client/client.go @@ -1293,16 +1293,6 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error { // setupVaultClient creates an object to periodically renew tokens and secrets // with vault. func (c *Client) setupVaultClient() error { - // TODO Want the vault client to always be valid. Should just return an - // error if it is not enabled - if c.config.VaultConfig == nil { - return fmt.Errorf("nil vault config") - } - - if !c.config.VaultConfig.Enabled { - return nil - } - var err error if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil { diff --git a/client/task_runner.go b/client/task_runner.go index 906beff18..a8938be2c 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -410,7 +410,7 @@ func (r *TaskRunner) run() { case <-r.destroyCh: // Store the task event that provides context on the task destroy. if r.destroyEvent.Type != structs.TaskKilled { - r.setState(structs.TaskStateDead, r.destroyEvent) + r.setState(structs.TaskStateRunning, r.destroyEvent) } // Mark that we received the kill event diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index fc6d89996..16e90722b 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -38,7 +38,7 @@ type VaultClient interface { // RenewToken renews a token with the given increment and adds it to // the min-heap for periodic renewal. - RenewToken(string, int) (error, <-chan error) + RenewToken(string, int) (<-chan error, error) // StopRenewToken removes the token from the min-heap, stopping its // renewal. @@ -46,7 +46,7 @@ type VaultClient interface { // RenewLease renews a vault secret's lease and adds the lease // identifier to the min-heap for periodic renewal. - RenewLease(string, int) (error, <-chan error) + RenewLease(string, int) (<-chan error, error) // StopRenewLease removes a secret's lease ID from the min-heap, // stopping its renewal. @@ -65,10 +65,6 @@ type vaultClient struct { // running indicates if the renewal loop is active or not running bool - // connEstablished marks whether the connection to vault was successful - // or not - connEstablished bool - // tokenData is the data of the passed VaultClient token token *tokenData @@ -145,10 +141,6 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver return nil, fmt.Errorf("nil vault config") } - if !config.Enabled { - return nil, nil - } - if logger == nil { return nil, fmt.Errorf("nil logger") } @@ -163,6 +155,10 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver tokenDeriver: tokenDeriver, } + if !config.Enabled { + return c, nil + } + // Get the Vault API configuration apiConf, err := config.ApiConfig() if err != nil { @@ -208,52 +204,7 @@ func (c *vaultClient) Start() { return } - c.logger.Printf("[DEBUG] client.vault: establishing connection to vault") - go c.establishConnection() -} - -// ConnectionEstablished indicates whether VaultClient successfully established -// connection to vault or not -func (c *vaultClient) ConnectionEstablished() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.connEstablished -} - -// establishConnection is used to make first contact with Vault. This should be -// called in a go-routine since the connection is retried till the Vault Client -// is stopped or the connection is successfully made at which point the renew -// loop is started. -func (c *vaultClient) establishConnection() { - // Create the retry timer and set initial duration to zero so it fires - // immediately - retryTimer := time.NewTimer(0) - -OUTER: - for { - select { - case <-c.stopCh: - return - case <-retryTimer.C: - // Ensure the API is reachable - if _, err := c.client.Sys().InitStatus(); err != nil { - c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v: %v", - c.config.ConnectionRetryIntv, err) - retryTimer.Reset(c.config.ConnectionRetryIntv) - continue OUTER - } - - break OUTER - } - } - - c.lock.Lock() - c.connEstablished = true - c.lock.Unlock() - - // Begin the renewal loop go c.run() - c.logger.Printf("[DEBUG] client.vault: started") } // Stops the renewal loop of vault client @@ -274,6 +225,9 @@ func (c *vaultClient) Stop() { // The return value is a map containing all the unwrapped tokens indexed by the // task name. func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + if !c.config.Enabled { + return nil, fmt.Errorf("vault client not enabled") + } if !c.running { return nil, fmt.Errorf("vault client is not running") } @@ -284,6 +238,9 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) // GetConsulACL creates a vault API client and reads from vault a consul ACL // token used by the task. func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) { + if !c.config.Enabled { + return nil, fmt.Errorf("vault client not enabled") + } if token == "" { return nil, fmt.Errorf("missing token") } @@ -291,10 +248,6 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) return nil, fmt.Errorf("missing consul ACL token vault path") } - if !c.ConnectionEstablished() { - return nil, fmt.Errorf("connection with vault is not yet established") - } - c.lock.Lock() defer c.lock.Unlock() @@ -315,14 +268,14 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) // the caller be notified of a renewal failure asynchronously for appropriate // actions to be taken. The caller of this function need not have to close the // error channel. -func (c *vaultClient) RenewToken(token string, increment int) (error, <-chan error) { +func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, error) { if token == "" { err := fmt.Errorf("missing token") - return err, nil + return nil, err } if increment < 1 { err := fmt.Errorf("increment cannot be less than 1") - return err, nil + return nil, err } // Create a buffered error channel @@ -341,10 +294,10 @@ func (c *vaultClient) RenewToken(token string, increment int) (error, <-chan err // error channel. if err := c.renew(renewalReq); err != nil { c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err) - return err, nil + return nil, err } - return nil, errCh + return errCh, nil } // RenewLease renews the supplied lease identifier for a supplied duration (in @@ -354,17 +307,15 @@ func (c *vaultClient) RenewToken(token string, increment int) (error, <-chan err // This helps the caller be notified of a renewal failure asynchronously for // appropriate actions to be taken. The caller of this function need not have // to close the error channel. -func (c *vaultClient) RenewLease(leaseId string, increment int) (error, <-chan error) { - c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId) - +func (c *vaultClient) RenewLease(leaseId string, increment int) (<-chan error, error) { if leaseId == "" { err := fmt.Errorf("missing lease ID") - return err, nil + return nil, err } if increment < 1 { err := fmt.Errorf("increment cannot be less than 1") - return err, nil + return nil, err } // Create a buffered error channel @@ -380,10 +331,10 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) (error, <-chan e // Renew the secret and send any error to the dedicated error channel if err := c.renew(renewalReq); err != nil { c.logger.Printf("[ERR] client.vault: renewal of lease failed: %v", err) - return err, nil + return nil, err } - return nil, errCh + return errCh, nil } // renew is a common method to handle renewal of both tokens and secret leases. @@ -395,6 +346,9 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { c.lock.Lock() defer c.lock.Unlock() + if !c.config.Enabled { + return fmt.Errorf("vault client not enabled") + } if !c.running { return fmt.Errorf("vault client is not running") } diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 3ff1b128b..e55fd35bf 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -11,38 +11,6 @@ import ( vaultapi "github.com/hashicorp/vault/api" ) -func TestVaultClient_EstablishConnection(t *testing.T) { - v := testutil.NewTestVault(t) - - logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) - v.Config.ConnectionRetryIntv = 100 * time.Millisecond - v.Config.TaskTokenTTL = "10s" - c, err := NewVaultClient(v.Config, logger, nil) - if err != nil { - t.Fatalf("failed to build vault client: %v", err) - } - - c.Start() - defer c.Stop() - - // Sleep a little while and check that no connection has been established. - time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) - - if c.ConnectionEstablished() { - t.Fatalf("ConnectionEstablished() returned true before Vault server started") - } - - // Start Vault - v.Start() - defer v.Stop() - - testutil.WaitForResult(func() (bool, error) { - return c.ConnectionEstablished(), nil - }, func(err error) { - t.Fatalf("Connection not established") - }) -} - func TestVaultClient_TokenRenewals(t *testing.T) { v := testutil.NewTestVault(t).Start() defer v.Stop() @@ -89,12 +57,15 @@ func TestVaultClient_TokenRenewals(t *testing.T) { tokens[i] = secret.Auth.ClientToken - errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + go func(errCh <-chan error) { - var err error for { select { - case err = <-errCh: + case err := <-errCh: t.Fatalf("error while renewing the token: %v", err) } } @@ -105,7 +76,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) } - time.Sleep(5 * time.Second) + time.Sleep(time.Duration(5*testutil.TestMultiplier()) * time.Second) for i := 0; i < num; i++ { if err := c.StopRenewToken(tokens[i]); err != nil {