Files
nomad/client/allocrunner/taskrunner/vault_hook_test.go
Tim Gross 26004c5407 vault: set renew increment to lease duration (#26041)
When we renew Vault tokens, we use the lease duration to determine how often to
renew. But we also set an `increment` value which is never updated from the
initial 30s. For periodic tokens this is not a problem because the `increment`
field is ignored on renewal. But for non-periodic tokens this prevents the token
TTL from being properly incremented. This behavior has been in place since the
initial Vault client implementation in #1606 but before the switch to workload
identity most (all?) tokens being created were periodic tokens so this was never
detected.

Fix this bug by updating the request's `increment` field to the lease duration
on each renewal.

Also switch out a `time.After` call in backoff of the derive token caller with a
safe timer so that we don't have to spawn a new goroutine per loop, and have
tighter control over when that's GC'd.

Ref: https://github.com/hashicorp/nomad/pull/1606
Ref: https://github.com/hashicorp/nomad/issues/25812
2025-06-13 13:50:54 -04:00

693 lines
18 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package taskrunner
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
trtesting "github.com/hashicorp/nomad/client/allocrunner/taskrunner/testing"
cstate "github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/client/widmgr"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
sconfig "github.com/hashicorp/nomad/nomad/structs/config"
"github.com/shoenig/test/must"
"github.com/shoenig/test/wait"
)
// Statically assert the stats hook implements the expected interfaces
var _ interfaces.TaskPrestartHook = (*vaultHook)(nil)
var _ interfaces.TaskStopHook = (*vaultHook)(nil)
var _ interfaces.ShutdownHook = (*vaultHook)(nil)
// vaultTokenUpdaterMock is a mock of the vaultTokenUpdateHandler interface.
type vaultTokenUpdaterMock struct {
currentToken string
}
func (v *vaultTokenUpdaterMock) updatedVaultToken(token string) {
v.currentToken = token
}
func setupTestVaultHook(t *testing.T, config *vaultHookConfig) *vaultHook {
t.Helper()
if config == nil {
config = &vaultHookConfig{}
}
job := mock.MinJob()
if config.alloc == nil {
config.alloc = mock.MinAlloc()
config.alloc.Job = job
}
if config.task == nil {
config.task = job.TaskGroups[0].Tasks[0]
config.task.Identities = []*structs.WorkloadIdentity{
{Name: "vault_default"},
}
config.task.Vault = &structs.Vault{
Cluster: structs.VaultDefaultCluster,
}
if config.vaultBlock != nil {
config.task.Identities[0].Name = config.vaultBlock.IdentityName()
config.task.Vault = config.vaultBlock
}
}
if config.vaultBlock == nil {
config.vaultBlock = config.task.Vault
}
if config.vaultConfigsFunc == nil {
config.vaultConfigsFunc = func(hclog.Logger) map[string]*sconfig.VaultConfig {
return map[string]*sconfig.VaultConfig{
"default": sconfig.DefaultVaultConfig(),
}
}
}
if config.clientFunc == nil {
config.clientFunc = func(cluster string) (vaultclient.VaultClient, error) {
return vaultclient.NewMockVaultClient(cluster)
}
}
if config.logger == nil {
config.logger = testlog.HCLogger(t)
}
if config.events == nil {
config.events = &trtesting.MockEmitter{}
}
if config.lifecycle == nil {
config.lifecycle = trtesting.NewMockTaskHooks()
}
if config.updater == nil {
config.updater = &vaultTokenUpdaterMock{}
}
if config.widmgr == nil {
db := cstate.NewMemDB(config.logger)
signer := widmgr.NewMockWIDSigner(config.task.Identities)
allocEnv := taskenv.NewBuilder(mock.Node(), config.alloc, nil, "global").Build()
config.widmgr = widmgr.NewWIDMgr(signer, config.alloc, db, config.logger, allocEnv)
err := config.widmgr.Run()
must.NoError(t, err)
}
return newVaultHook(config)
}
func TestTaskRunner_VaultHook(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
name string
task *structs.Task
configs map[string]*sconfig.VaultConfig
configNonrenewable bool
expectRole string
expectNoRenew bool
}{
{
name: "jwt flow",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
},
{
name: "jwt flow with role",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
Role: "task-role",
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
configs: map[string]*sconfig.VaultConfig{
"default": {
Role: "client-role",
},
},
expectRole: "task-role",
},
{
name: "jwt flow with role from client",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
configs: map[string]*sconfig.VaultConfig{
"default": {
Role: "client-role",
},
},
expectRole: "client-role",
},
{
name: "jwt flow with role from client and non-default cluster",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: "prod",
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_prod"},
},
},
configs: map[string]*sconfig.VaultConfig{
"default": {
Role: "client-role",
},
"prod": {
Role: "client-prod-role",
},
},
expectRole: "client-prod-role",
},
{
name: "disable file",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
DisableFile: true,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
},
{
name: "job requests no renewal",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
AllowTokenExpiration: true,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
expectNoRenew: true,
},
{
name: "tokens are not renewable",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
configNonrenewable: true,
expectNoRenew: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
alloc := mock.MinAlloc()
alloc.Job.TaskGroups[0].Tasks[0] = tc.task
hookConfig := &vaultHookConfig{
task: tc.task,
alloc: alloc,
vaultConfigsFunc: func(hclog.Logger) map[string]*sconfig.VaultConfig {
if tc.configs != nil {
return tc.configs
}
return map[string]*sconfig.VaultConfig{
"default": sconfig.DefaultVaultConfig(),
}
},
}
if tc.configNonrenewable {
hookConfig.clientFunc = func(cluster string) (vaultclient.VaultClient, error) {
client := &vaultclient.MockVaultClient{}
client.SetRenewable(false)
return client, nil
}
}
hook := setupTestVaultHook(t, hookConfig)
// Ensure Prestart() returns within a reasonable time.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: t.TempDir(),
PrivateDir: t.TempDir(),
},
Task: tc.task,
}
var resp interfaces.TaskPrestartResponse
err := hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
must.NoError(t, ctx.Err())
// Token must have been derived.
var token string
client := hook.client.(*vaultclient.MockVaultClient)
tokens := client.JWTTokens()
must.MapLen(t, 1, tokens)
swid, err := hook.widmgr.Get(structs.WIHandle{
IdentityName: tc.task.Vault.IdentityName(),
WorkloadIdentifier: tc.task.Name,
WorkloadType: structs.WorkloadTypeTask,
})
must.NoError(t, err)
token = tokens[swid.JWT]
must.NotEq(t, "", token)
// Token must be derived with correct role.
//
// MockVaultClient generates random UUIDv4 tokens, but append the
// role when requested.
if tc.expectRole != "" {
must.StrHasSuffix(t, tc.expectRole, token)
} else {
must.UUIDv4(t, token)
}
// Token must be set in token updater.
updater := (hook.updater).(*vaultTokenUpdaterMock)
must.Eq(t, token, updater.currentToken)
// Token must be written to disk.
tokenFile, err := os.ReadFile(hook.privateDirTokenPath)
must.NoError(t, err)
must.Eq(t, updater.currentToken, string(tokenFile))
if !tc.task.Vault.DisableFile {
tokenFile, err := os.ReadFile(hook.secretsDirTokenPath)
must.NoError(t, err)
must.Eq(t, updater.currentToken, string(tokenFile))
} else {
_, err = os.ReadFile(hook.secretsDirTokenPath)
must.ErrorIs(t, err, os.ErrNotExist)
}
// Token must be set for renewal.
if tc.expectNoRenew {
must.MapEmpty(t, client.RenewTokens())
} else {
must.MapLen(t, 1, client.RenewTokens())
must.NotNil(t, client.RenewTokens()[updater.currentToken])
}
// PrestartDone must be false so we can recover tokens.
// firstRun is used to prevent multiple executions.
must.False(t, resp.Done)
must.False(t, hook.firstRun)
// Stop renewal when hook stops.
err = hook.Stop(ctx, nil, nil)
must.NoError(t, err)
must.Wait(t, wait.InitialSuccess(
wait.ErrorFunc(func() error {
tokens := client.StoppedTokens()
if tc.expectNoRenew {
if len(tokens) != 0 {
return fmt.Errorf("expected no stopped tokens when renewal is disabled, got %d", len(tokens))
}
return nil
}
if len(tokens) != 1 {
return fmt.Errorf("expected stopped tokens to be %d, got %d", 1, len(tokens))
}
got := tokens[0]
expect := updater.currentToken
if got != expect {
return fmt.Errorf("expected stopped token to be %s, got %s", expect, got)
}
return nil
}),
wait.Timeout(5*time.Second),
wait.Gap(100*time.Millisecond),
))
})
}
}
func TestTaskRunner_VaultHook_recover(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
name string
setupReq func() (*interfaces.TaskPrestartRequest, error)
}{
{
name: "recover from secrets dir",
setupReq: func() (*interfaces.TaskPrestartRequest, error) {
// Write token to secrets dir.
secretsDirPath := t.TempDir()
err := os.WriteFile(filepath.Join(secretsDirPath, vaultTokenFile), []byte("much secret"), 0666)
if err != nil {
return nil, err
}
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: secretsDirPath,
PrivateDir: t.TempDir(),
},
}
return req, nil
},
},
{
name: "recover from private dir",
setupReq: func() (*interfaces.TaskPrestartRequest, error) {
// Write token to private dir.
privateDirPath := t.TempDir()
err := os.WriteFile(filepath.Join(privateDirPath, vaultTokenFile), []byte("much secret"), 0666)
if err != nil {
return nil, err
}
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: t.TempDir(),
PrivateDir: privateDirPath,
},
}
return req, nil
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hook := setupTestVaultHook(t, nil)
req, err := tc.setupReq()
must.NoError(t, err)
req.Task = hook.task
// Ensure Prestart() returns in a reasonable time.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
var resp interfaces.TaskPrestartResponse
err = hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
must.NoError(t, ctx.Err())
// Verify token was recovered and not derived.
client := hook.client.(*vaultclient.MockVaultClient)
must.MapLen(t, 0, client.JWTTokens())
})
}
}
func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
ci.Parallel(t)
t.Run("unrecoverable error", func(t *testing.T) {
vaultClient, _ := vaultclient.NewMockVaultClient("")
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
hook := setupTestVaultHook(t, &vaultHookConfig{
clientFunc: func(string) (vaultclient.VaultClient, error) {
return mockVaultClient, nil
},
})
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: t.TempDir(),
PrivateDir: t.TempDir(),
},
Task: hook.task,
}
var resp interfaces.TaskPrestartResponse
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
// Set unrecoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
// Cancel the context to simulate the task being killed.
cancel()
return "", false, 0, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
})
err := hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
// Verify task is killed because of unrecoverable error.
must.Wait(t, wait.InitialSuccess(
wait.ErrorFunc(func() error {
killEv := (hook.lifecycle.(*trtesting.MockTaskHooks)).KillEvent()
if killEv == nil {
return errors.New("missing kill event")
}
return nil
}),
wait.Timeout(5*time.Second),
wait.Gap(100*time.Millisecond),
))
killEv := (hook.lifecycle.(*trtesting.MockTaskHooks)).KillEvent()
must.StrContains(t, killEv.DisplayMessage, "unrecoverable test error")
})
t.Run("recoverable error", func(t *testing.T) {
vaultClient, _ := vaultclient.NewMockVaultClient("")
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
hook := setupTestVaultHook(t, &vaultHookConfig{
clientFunc: func(string) (vaultclient.VaultClient, error) {
return mockVaultClient, nil
},
})
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: t.TempDir(),
PrivateDir: t.TempDir(),
},
Task: hook.task,
}
var resp interfaces.TaskPrestartResponse
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
// Set recoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "", false, 0, structs.NewRecoverableError(errors.New("recoverable test error"), true)
})
go func() {
// Wait a bit for the first error then fix token renewal.
time.Sleep(time.Second)
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "secret", true, 30, nil
})
}()
err := hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
must.NoError(t, ctx.Err())
// Verify retry happened and token was derived.
updater := (hook.updater).(*vaultTokenUpdaterMock)
must.Eq(t, "secret", updater.currentToken)
})
t.Run("renew request failed", func(t *testing.T) {
vaultClient, _ := vaultclient.NewMockVaultClient("")
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
hook := setupTestVaultHook(t, &vaultHookConfig{
clientFunc: func(string) (vaultclient.VaultClient, error) {
return mockVaultClient, nil
},
})
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: t.TempDir(),
PrivateDir: t.TempDir(),
},
Task: hook.task,
}
var resp interfaces.TaskPrestartResponse
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
// Derive predictable token and fail renew request.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "secret", true, 30, nil
})
mockVaultClient.SetRenewTokenError("secret", errors.New("test error"))
go func() {
// Wait a bit for the renew error then fix token renewal.
time.Sleep(10 * time.Millisecond)
mockVaultClient.SetRenewTokenError("secret", nil)
}()
err := hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
must.NoError(t, ctx.Err())
// Verify retry happened and token was derived.
updater := (hook.updater).(*vaultTokenUpdaterMock)
must.Eq(t, "secret", updater.currentToken)
})
}
func TestTaskRunner_VaultHook_tokenRenewalFail(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
name string
vaultBlock *structs.Vault
verifyTaskLifecycle func(*trtesting.MockTaskHooks) error
}{
{
name: "change mode signal",
vaultBlock: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
ChangeMode: structs.VaultChangeModeSignal,
ChangeSignal: "SIGTERM",
},
verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error {
signals := h.Signals()
if len(signals) != 1 {
return fmt.Errorf("expected 1 signal, got %d", len(signals))
}
if signals[0] != "SIGTERM" {
return fmt.Errorf("expected signal to be SIGTERM, got %s", signals[0])
}
return nil
},
},
{
name: "change mode restart",
vaultBlock: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
ChangeMode: structs.VaultChangeModeRestart,
},
verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error {
restarts := h.Restarts()
if restarts != 1 {
return fmt.Errorf("expected 1 restart, got %d", restarts)
}
return nil
},
},
{
name: "change mode noop",
vaultBlock: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
ChangeMode: structs.VaultChangeModeNoop,
},
verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error {
restarts := h.Restarts()
if restarts != 0 {
return fmt.Errorf("expected 0 restarts, got %d", restarts)
}
signals := h.Signals()
if len(signals) != 0 {
return fmt.Errorf("expected 0 signals, got %d", len(signals))
}
return nil
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
vaultClient, _ := vaultclient.NewMockVaultClient("")
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
hook := setupTestVaultHook(t, &vaultHookConfig{
vaultBlock: tc.vaultBlock,
clientFunc: func(string) (vaultclient.VaultClient, error) {
return mockVaultClient, nil
},
})
req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewEmptyTaskEnv(),
TaskDir: &allocdir.TaskDir{
SecretsDir: t.TempDir(),
PrivateDir: t.TempDir(),
},
Task: hook.task,
}
var resp interfaces.TaskPrestartResponse
// Ensure Prestart() returns within a reasonable time.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
err := hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
// Fetch derived token.
updater := (hook.updater).(*vaultTokenUpdaterMock)
token := updater.currentToken
must.NotEq(t, "", token)
// Fetch renewal token error channel.
renewErrCh := mockVaultClient.RenewTokenErrCh(token)
must.NotNil(t, renewErrCh)
// Emit renewal error.
renewErrCh <- errors.New("renew error")
// Verify expected lifecycle events happen.
must.Wait(t, wait.InitialSuccess(
wait.ErrorFunc(func() error {
return tc.verifyTaskLifecycle((hook.lifecycle).(*trtesting.MockTaskHooks))
}),
wait.Timeout(3*time.Second),
wait.Gap(100*time.Millisecond),
))
})
}
}