diff --git a/client/task_runner.go b/client/task_runner.go index ed9e825c6..55176c0e7 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -553,9 +553,12 @@ func (f *tokenFuture) Get() string { // allows setting the initial Vault token. This is useful when the Vault token // is recovered off disk. func (r *TaskRunner) vaultManager(token string) { - // Always stop renewing the token. If token is empty or untracked, it is a - // no-op so this is always safe. - defer r.vaultClient.StopRenewToken(r.vaultFuture.Get()) + // Helper for stopping token renewal + stopRenewal := func() { + if err := r.vaultClient.StopRenewToken(r.vaultFuture.Get()); err != nil { + r.logger.Printf("[WARN] client: failed to stop token renewal for task %v in alloc %q: %v", r.task.Name, r.alloc.ID, err) + } + } // updatedToken lets us store state between loops. If true, a new token // has been retrieved and we need to apply the Vault change mode @@ -566,6 +569,7 @@ OUTER: // Check if we should exit select { case <-r.waitCh: + stopRenewal() return default: } @@ -643,12 +647,14 @@ OUTER: // Clear the token token = "" r.logger.Printf("[ERR] client: failed to renew Vault token for task %v on alloc %q: %v", r.task.Name, r.alloc.ID, err) + stopRenewal() // Check if we have to do anything if r.task.Vault.ChangeMode != structs.VaultChangeModeNoop { updatedToken = true } case <-r.waitCh: + stopRenewal() return } } diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 8d2dab2bd..6503e88ba 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -876,6 +876,21 @@ func TestTaskRunner_BlockForVault(t *testing.T) { if act := string(data); act != token { t.Fatalf("Token didn't get written to disk properly, got %q; want %q", act, token) } + + // Check the token was revoked + m := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) + testutil.WaitForResult(func() (bool, error) { + if len(m.StoppedTokens) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", m.StoppedTokens) + } + + if a := m.StoppedTokens[0]; a != token { + return false, fmt.Errorf("got stopped token %q; want %q", a, token) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) } func TestTaskRunner_DeriveToken_Retry(t *testing.T) { @@ -946,6 +961,21 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { if act := string(data); act != token { t.Fatalf("Token didn't get written to disk properly, got %q; want %q", act, token) } + + // Check the token was revoked + m := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) + testutil.WaitForResult(func() (bool, error) { + if len(m.StoppedTokens) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", m.StoppedTokens) + } + + if a := m.StoppedTokens[0]; a != token { + return false, fmt.Errorf("got stopped token %q; want %q", a, token) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) } func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) { @@ -1215,6 +1245,21 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { }, func(err error) { t.Fatalf("err: %v", err) }) + + // Check the token was revoked + m := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) + testutil.WaitForResult(func() (bool, error) { + if len(m.StoppedTokens) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", m.StoppedTokens) + } + + if a := m.StoppedTokens[0]; a != token { + return false, fmt.Errorf("got stopped token %q; want %q", a, token) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) } func TestTaskRunner_VaultManager_Restart(t *testing.T) { diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 4cc5ce476..3b8273563 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1123,7 +1123,7 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, if rerr, ok := createErr.(*structs.RecoverableError); ok { reply.Error = rerr - } else if err != nil { + } else { reply.Error = structs.NewRecoverableError(createErr, false).(*structs.RecoverableError) }