Clean up vault client

This commit is contained in:
Alex Dadgar
2016-09-14 15:04:25 -07:00
parent bec6adb2ee
commit c89fd0eb08
6 changed files with 84 additions and 162 deletions

View File

@@ -69,6 +69,7 @@ type AllocRunner struct {
updateCh chan *structs.Allocation
vaultClient vaultclient.VaultClient
vaultTokens map[string]vaultToken
destroy bool
destroyCh chan struct{}
@@ -141,7 +142,7 @@ func (r *AllocRunner) RestoreState() error {
}
// Recover the Vault tokens
tokens, vaultErr := r.recoverVaultTokens()
vaultErr := r.recoverVaultTokens()
// Restore the task runners
var mErr multierror.Error
@@ -154,7 +155,7 @@ func (r *AllocRunner) RestoreState() error {
task)
r.tasks[name] = tr
if vt, ok := tokens[name]; ok {
if vt, ok := r.vaultTokens[name]; ok {
tr.SetVaultToken(vt.token, vt.renewalCh)
}
@@ -357,17 +358,26 @@ func (r *AllocRunner) setTaskState(taskName, state string, event *structs.TaskEv
taskState.State = state
r.appendTaskEvent(taskState, event)
// If the task failed, we should kill all the other tasks in the task group.
if state == structs.TaskStateDead && taskState.Failed() {
var destroyingTasks []string
for task, tr := range r.tasks {
if task != taskName {
destroyingTasks = append(destroyingTasks, task)
tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName))
if state == structs.TaskStateDead {
// If the task has a Vault token, stop renewing it
if vt, ok := r.vaultTokens[taskName]; ok {
if err := r.vaultClient.StopRenewToken(vt.token); err != nil {
r.logger.Printf("[ERR] client: stopping token renewal for task %q failed: %v", taskName, err)
}
}
if len(destroyingTasks) > 0 {
r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks)
// If the task failed, we should kill all the other tasks in the task group.
if taskState.Failed() {
var destroyingTasks []string
for task, tr := range r.tasks {
if task != taskName {
destroyingTasks = append(destroyingTasks, task)
tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName))
}
}
if len(destroyingTasks) > 0 {
r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks)
}
}
}
@@ -433,7 +443,7 @@ func (r *AllocRunner) Run() {
}
// Request Vault tokens for the tasks that require them
tokens, err := r.deriveVaultTokens()
err := r.deriveVaultTokens()
if err != nil {
msg := fmt.Sprintf("failed to derive Vault token for allocation %q: %v", r.alloc.ID, err)
r.logger.Printf("[ERR] client: %s", msg)
@@ -454,7 +464,7 @@ func (r *AllocRunner) Run() {
tr.MarkReceived()
// If the task has a vault token set it before running
if vt, ok := tokens[task.Name]; ok {
if vt, ok := r.vaultTokens[task.Name]; ok {
tr.SetVaultToken(vt.token, vt.renewalCh)
}
@@ -537,19 +547,14 @@ type vaultToken struct {
// tasks to their respective vault token and renewal channel. This must be
// called after the allocation directory is created as the vault tokens are
// written to disk.
func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) {
func (r *AllocRunner) deriveVaultTokens() error {
required, err := r.tasksRequiringVaultTokens()
if err != nil {
return nil, err
return err
}
if len(required) == 0 {
return nil, nil
}
// TODO Remove once the vault client isn't nil
if r.vaultClient == nil {
return nil, fmt.Errorf("Requesting Vault tokens when not enabled on the client")
return nil
}
renewingTokens := make(map[string]vaultToken, len(required))
@@ -557,7 +562,7 @@ func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) {
// Get the tokens
tokens, err := r.vaultClient.DeriveToken(r.Alloc(), required)
if err != nil {
return nil, fmt.Errorf("failed to derive Vault tokens: %v", err)
return fmt.Errorf("failed to derive Vault tokens: %v", err)
}
// Persist the tokens to the appropriate secret directories
@@ -565,17 +570,17 @@ func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) {
for task, token := range tokens {
secretDir, err := adir.GetSecretDir(task)
if err != nil {
return nil, fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
}
// Write the token to the file system
tokenPath := filepath.Join(secretDir, vaultTokenFile)
if err := ioutil.WriteFile(tokenPath, []byte(token), 0777); err != nil {
return nil, fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err)
return fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err)
}
// Start renewing the token
err, renewCh := r.vaultClient.RenewToken(token, 10)
renewCh, err := r.vaultClient.RenewToken(token, 10)
if err != nil {
var mErr multierror.Error
errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err)
@@ -588,12 +593,13 @@ func (r *AllocRunner) deriveVaultTokens() (map[string]vaultToken, error) {
}
}
return nil, mErr.ErrorOrNil()
return mErr.ErrorOrNil()
}
renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh}
}
return renewingTokens, nil
r.vaultTokens = renewingTokens
return nil
}
func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) {
@@ -617,19 +623,14 @@ func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) {
// recoverVaultTokens reads the Vault tokens for the tasks that have Vault
// tokens off disk. If there is an error, it is returned, otherwise token
// renewal is started.
func (r *AllocRunner) recoverVaultTokens() (map[string]vaultToken, error) {
// TODO remove once the vault client is never nil
if r.vaultClient == nil {
return nil, nil
}
func (r *AllocRunner) recoverVaultTokens() error {
required, err := r.tasksRequiringVaultTokens()
if err != nil {
return nil, err
return err
}
if len(required) == 0 {
return nil, nil
return nil
}
// Read the tokens and start renewing them
@@ -638,18 +639,18 @@ func (r *AllocRunner) recoverVaultTokens() (map[string]vaultToken, error) {
for _, task := range required {
secretDir, err := adir.GetSecretDir(task)
if err != nil {
return nil, fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
}
// Write the token to the file system
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err)
return fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err)
}
token := string(data)
err, renewCh := r.vaultClient.RenewToken(token, 10)
renewCh, err := r.vaultClient.RenewToken(token, 10)
if err != nil {
var mErr multierror.Error
errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err)
@@ -662,13 +663,14 @@ func (r *AllocRunner) recoverVaultTokens() (map[string]vaultToken, error) {
}
}
return nil, mErr.ErrorOrNil()
return mErr.ErrorOrNil()
}
renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh}
}
return renewingTokens, nil
r.vaultTokens = renewingTokens
return nil
}
// checkResources monitors and enforces alloc resource usage. It returns an

View File

@@ -13,6 +13,7 @@ import (
"github.com/hashicorp/nomad/client/config"
ctestutil "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/client/vaultclient"
)
type MockAllocStateUpdater struct {
@@ -35,7 +36,8 @@ func testAllocRunnerFromAlloc(alloc *structs.Allocation, restarts bool) (*MockAl
*alloc.Job.LookupTaskGroup(alloc.TaskGroup).RestartPolicy = structs.RestartPolicy{Attempts: 0}
alloc.Job.Type = structs.JobTypeBatch
}
ar := NewAllocRunner(logger, conf, upd.Update, alloc)
vclient, _ := vaultclient.NewVaultClient(conf.VaultConfig, logger, nil)
ar := NewAllocRunner(logger, conf, upd.Update, alloc, vclient)
return upd, ar
}
@@ -413,7 +415,7 @@ func TestAllocRunner_SaveRestoreState(t *testing.T) {
// Create a new alloc runner
ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update,
&structs.Allocation{ID: ar.alloc.ID})
&structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient)
err = ar2.RestoreState()
if err != nil {
t.Fatalf("err: %v", err)
@@ -485,7 +487,7 @@ func TestAllocRunner_SaveRestoreState_TerminalAlloc(t *testing.T) {
// Create a new alloc runner
ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update,
&structs.Allocation{ID: ar.alloc.ID})
&structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient)
ar2.logger = prefixedTestLogger("ar2: ")
err = ar2.RestoreState()
if err != nil {
@@ -576,7 +578,10 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) {
if state1.State != structs.TaskStateDead {
return false, fmt.Errorf("got state %v; want %v", state1.State, structs.TaskStateDead)
}
if lastE := state1.Events[len(state1.Events)-1]; lastE.Type != structs.TaskSiblingFailed {
if len(state1.Events) < 3 {
return false, fmt.Errorf("Unexpected number of events")
}
if lastE := state1.Events[len(state1.Events)-3]; lastE.Type != structs.TaskSiblingFailed {
return false, fmt.Errorf("got last event %v; want %v", lastE.Type, structs.TaskSiblingFailed)
}

View File

@@ -1293,16 +1293,6 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error {
// setupVaultClient creates an object to periodically renew tokens and secrets
// with vault.
func (c *Client) setupVaultClient() error {
// TODO Want the vault client to always be valid. Should just return an
// error if it is not enabled
if c.config.VaultConfig == nil {
return fmt.Errorf("nil vault config")
}
if !c.config.VaultConfig.Enabled {
return nil
}
var err error
if c.vaultClient, err =
vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil {

View File

@@ -410,7 +410,7 @@ func (r *TaskRunner) run() {
case <-r.destroyCh:
// Store the task event that provides context on the task destroy.
if r.destroyEvent.Type != structs.TaskKilled {
r.setState(structs.TaskStateDead, r.destroyEvent)
r.setState(structs.TaskStateRunning, r.destroyEvent)
}
// Mark that we received the kill event

View File

@@ -38,7 +38,7 @@ type VaultClient interface {
// RenewToken renews a token with the given increment and adds it to
// the min-heap for periodic renewal.
RenewToken(string, int) (error, <-chan error)
RenewToken(string, int) (<-chan error, error)
// StopRenewToken removes the token from the min-heap, stopping its
// renewal.
@@ -46,7 +46,7 @@ type VaultClient interface {
// RenewLease renews a vault secret's lease and adds the lease
// identifier to the min-heap for periodic renewal.
RenewLease(string, int) (error, <-chan error)
RenewLease(string, int) (<-chan error, error)
// StopRenewLease removes a secret's lease ID from the min-heap,
// stopping its renewal.
@@ -65,10 +65,6 @@ type vaultClient struct {
// running indicates if the renewal loop is active or not
running bool
// connEstablished marks whether the connection to vault was successful
// or not
connEstablished bool
// tokenData is the data of the passed VaultClient token
token *tokenData
@@ -145,10 +141,6 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver
return nil, fmt.Errorf("nil vault config")
}
if !config.Enabled {
return nil, nil
}
if logger == nil {
return nil, fmt.Errorf("nil logger")
}
@@ -163,6 +155,10 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver
tokenDeriver: tokenDeriver,
}
if !config.Enabled {
return c, nil
}
// Get the Vault API configuration
apiConf, err := config.ApiConfig()
if err != nil {
@@ -208,52 +204,7 @@ func (c *vaultClient) Start() {
return
}
c.logger.Printf("[DEBUG] client.vault: establishing connection to vault")
go c.establishConnection()
}
// ConnectionEstablished indicates whether VaultClient successfully established
// connection to vault or not
func (c *vaultClient) ConnectionEstablished() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.connEstablished
}
// establishConnection is used to make first contact with Vault. This should be
// called in a go-routine since the connection is retried till the Vault Client
// is stopped or the connection is successfully made at which point the renew
// loop is started.
func (c *vaultClient) establishConnection() {
// Create the retry timer and set initial duration to zero so it fires
// immediately
retryTimer := time.NewTimer(0)
OUTER:
for {
select {
case <-c.stopCh:
return
case <-retryTimer.C:
// Ensure the API is reachable
if _, err := c.client.Sys().InitStatus(); err != nil {
c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v: %v",
c.config.ConnectionRetryIntv, err)
retryTimer.Reset(c.config.ConnectionRetryIntv)
continue OUTER
}
break OUTER
}
}
c.lock.Lock()
c.connEstablished = true
c.lock.Unlock()
// Begin the renewal loop
go c.run()
c.logger.Printf("[DEBUG] client.vault: started")
}
// Stops the renewal loop of vault client
@@ -274,6 +225,9 @@ func (c *vaultClient) Stop() {
// The return value is a map containing all the unwrapped tokens indexed by the
// task name.
func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) {
if !c.config.Enabled {
return nil, fmt.Errorf("vault client not enabled")
}
if !c.running {
return nil, fmt.Errorf("vault client is not running")
}
@@ -284,6 +238,9 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string)
// GetConsulACL creates a vault API client and reads from vault a consul ACL
// token used by the task.
func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) {
if !c.config.Enabled {
return nil, fmt.Errorf("vault client not enabled")
}
if token == "" {
return nil, fmt.Errorf("missing token")
}
@@ -291,10 +248,6 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error)
return nil, fmt.Errorf("missing consul ACL token vault path")
}
if !c.ConnectionEstablished() {
return nil, fmt.Errorf("connection with vault is not yet established")
}
c.lock.Lock()
defer c.lock.Unlock()
@@ -315,14 +268,14 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error)
// the caller be notified of a renewal failure asynchronously for appropriate
// actions to be taken. The caller of this function need not have to close the
// error channel.
func (c *vaultClient) RenewToken(token string, increment int) (error, <-chan error) {
func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, error) {
if token == "" {
err := fmt.Errorf("missing token")
return err, nil
return nil, err
}
if increment < 1 {
err := fmt.Errorf("increment cannot be less than 1")
return err, nil
return nil, err
}
// Create a buffered error channel
@@ -341,10 +294,10 @@ func (c *vaultClient) RenewToken(token string, increment int) (error, <-chan err
// error channel.
if err := c.renew(renewalReq); err != nil {
c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err)
return err, nil
return nil, err
}
return nil, errCh
return errCh, nil
}
// RenewLease renews the supplied lease identifier for a supplied duration (in
@@ -354,17 +307,15 @@ func (c *vaultClient) RenewToken(token string, increment int) (error, <-chan err
// This helps the caller be notified of a renewal failure asynchronously for
// appropriate actions to be taken. The caller of this function need not have
// to close the error channel.
func (c *vaultClient) RenewLease(leaseId string, increment int) (error, <-chan error) {
c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId)
func (c *vaultClient) RenewLease(leaseId string, increment int) (<-chan error, error) {
if leaseId == "" {
err := fmt.Errorf("missing lease ID")
return err, nil
return nil, err
}
if increment < 1 {
err := fmt.Errorf("increment cannot be less than 1")
return err, nil
return nil, err
}
// Create a buffered error channel
@@ -380,10 +331,10 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) (error, <-chan e
// Renew the secret and send any error to the dedicated error channel
if err := c.renew(renewalReq); err != nil {
c.logger.Printf("[ERR] client.vault: renewal of lease failed: %v", err)
return err, nil
return nil, err
}
return nil, errCh
return errCh, nil
}
// renew is a common method to handle renewal of both tokens and secret leases.
@@ -395,6 +346,9 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.config.Enabled {
return fmt.Errorf("vault client not enabled")
}
if !c.running {
return fmt.Errorf("vault client is not running")
}

View File

@@ -11,38 +11,6 @@ import (
vaultapi "github.com/hashicorp/vault/api"
)
func TestVaultClient_EstablishConnection(t *testing.T) {
v := testutil.NewTestVault(t)
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.TaskTokenTTL = "10s"
c, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
c.Start()
defer c.Stop()
// Sleep a little while and check that no connection has been established.
time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond)
if c.ConnectionEstablished() {
t.Fatalf("ConnectionEstablished() returned true before Vault server started")
}
// Start Vault
v.Start()
defer v.Stop()
testutil.WaitForResult(func() (bool, error) {
return c.ConnectionEstablished(), nil
}, func(err error) {
t.Fatalf("Connection not established")
})
}
func TestVaultClient_TokenRenewals(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
@@ -89,12 +57,15 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
tokens[i] = secret.Auth.ClientToken
errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
go func(errCh <-chan error) {
var err error
for {
select {
case err = <-errCh:
case err := <-errCh:
t.Fatalf("error while renewing the token: %v", err)
}
}
@@ -105,7 +76,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length())
}
time.Sleep(5 * time.Second)
time.Sleep(time.Duration(5*testutil.TestMultiplier()) * time.Second)
for i := 0; i < num; i++ {
if err := c.StopRenewToken(tokens[i]); err != nil {