mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
When we renew Vault tokens, we use the lease duration to determine how often to renew. But we also set an `increment` value which is never updated from the initial 30s. For periodic tokens this is not a problem because the `increment` field is ignored on renewal. But for non-periodic tokens this prevents the token TTL from being properly incremented. This behavior has been in place since the initial Vault client implementation in #1606 but before the switch to workload identity most (all?) tokens being created were periodic tokens so this was never detected. Fix this bug by updating the request's `increment` field to the lease duration on each renewal. Also switch out a `time.After` call in backoff of the derive token caller with a safe timer so that we don't have to spawn a new goroutine per loop, and have tighter control over when that's GC'd. Ref: https://github.com/hashicorp/nomad/pull/1606 Ref: https://github.com/hashicorp/nomad/issues/25812
693 lines
18 KiB
Go
693 lines
18 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package taskrunner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/nomad/ci"
|
|
"github.com/hashicorp/nomad/client/allocdir"
|
|
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
|
trtesting "github.com/hashicorp/nomad/client/allocrunner/taskrunner/testing"
|
|
cstate "github.com/hashicorp/nomad/client/state"
|
|
"github.com/hashicorp/nomad/client/taskenv"
|
|
"github.com/hashicorp/nomad/client/vaultclient"
|
|
"github.com/hashicorp/nomad/client/widmgr"
|
|
"github.com/hashicorp/nomad/helper/testlog"
|
|
"github.com/hashicorp/nomad/nomad/mock"
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
sconfig "github.com/hashicorp/nomad/nomad/structs/config"
|
|
"github.com/shoenig/test/must"
|
|
"github.com/shoenig/test/wait"
|
|
)
|
|
|
|
// Statically assert the stats hook implements the expected interfaces
|
|
var _ interfaces.TaskPrestartHook = (*vaultHook)(nil)
|
|
var _ interfaces.TaskStopHook = (*vaultHook)(nil)
|
|
var _ interfaces.ShutdownHook = (*vaultHook)(nil)
|
|
|
|
// vaultTokenUpdaterMock is a mock of the vaultTokenUpdateHandler interface.
|
|
type vaultTokenUpdaterMock struct {
|
|
currentToken string
|
|
}
|
|
|
|
func (v *vaultTokenUpdaterMock) updatedVaultToken(token string) {
|
|
v.currentToken = token
|
|
}
|
|
|
|
func setupTestVaultHook(t *testing.T, config *vaultHookConfig) *vaultHook {
|
|
t.Helper()
|
|
|
|
if config == nil {
|
|
config = &vaultHookConfig{}
|
|
}
|
|
|
|
job := mock.MinJob()
|
|
if config.alloc == nil {
|
|
config.alloc = mock.MinAlloc()
|
|
config.alloc.Job = job
|
|
}
|
|
if config.task == nil {
|
|
config.task = job.TaskGroups[0].Tasks[0]
|
|
config.task.Identities = []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
}
|
|
config.task.Vault = &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
}
|
|
|
|
if config.vaultBlock != nil {
|
|
config.task.Identities[0].Name = config.vaultBlock.IdentityName()
|
|
config.task.Vault = config.vaultBlock
|
|
}
|
|
}
|
|
if config.vaultBlock == nil {
|
|
config.vaultBlock = config.task.Vault
|
|
}
|
|
if config.vaultConfigsFunc == nil {
|
|
config.vaultConfigsFunc = func(hclog.Logger) map[string]*sconfig.VaultConfig {
|
|
return map[string]*sconfig.VaultConfig{
|
|
"default": sconfig.DefaultVaultConfig(),
|
|
}
|
|
}
|
|
}
|
|
if config.clientFunc == nil {
|
|
config.clientFunc = func(cluster string) (vaultclient.VaultClient, error) {
|
|
return vaultclient.NewMockVaultClient(cluster)
|
|
}
|
|
}
|
|
if config.logger == nil {
|
|
config.logger = testlog.HCLogger(t)
|
|
}
|
|
if config.events == nil {
|
|
config.events = &trtesting.MockEmitter{}
|
|
}
|
|
if config.lifecycle == nil {
|
|
config.lifecycle = trtesting.NewMockTaskHooks()
|
|
}
|
|
if config.updater == nil {
|
|
config.updater = &vaultTokenUpdaterMock{}
|
|
}
|
|
if config.widmgr == nil {
|
|
db := cstate.NewMemDB(config.logger)
|
|
signer := widmgr.NewMockWIDSigner(config.task.Identities)
|
|
allocEnv := taskenv.NewBuilder(mock.Node(), config.alloc, nil, "global").Build()
|
|
config.widmgr = widmgr.NewWIDMgr(signer, config.alloc, db, config.logger, allocEnv)
|
|
err := config.widmgr.Run()
|
|
must.NoError(t, err)
|
|
}
|
|
|
|
return newVaultHook(config)
|
|
}
|
|
|
|
func TestTaskRunner_VaultHook(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
task *structs.Task
|
|
configs map[string]*sconfig.VaultConfig
|
|
configNonrenewable bool
|
|
expectRole string
|
|
expectNoRenew bool
|
|
}{
|
|
{
|
|
name: "jwt flow",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "jwt flow with role",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
Role: "task-role",
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
},
|
|
},
|
|
configs: map[string]*sconfig.VaultConfig{
|
|
"default": {
|
|
Role: "client-role",
|
|
},
|
|
},
|
|
expectRole: "task-role",
|
|
},
|
|
{
|
|
name: "jwt flow with role from client",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
},
|
|
},
|
|
configs: map[string]*sconfig.VaultConfig{
|
|
"default": {
|
|
Role: "client-role",
|
|
},
|
|
},
|
|
expectRole: "client-role",
|
|
},
|
|
{
|
|
name: "jwt flow with role from client and non-default cluster",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: "prod",
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_prod"},
|
|
},
|
|
},
|
|
configs: map[string]*sconfig.VaultConfig{
|
|
"default": {
|
|
Role: "client-role",
|
|
},
|
|
"prod": {
|
|
Role: "client-prod-role",
|
|
},
|
|
},
|
|
expectRole: "client-prod-role",
|
|
},
|
|
{
|
|
name: "disable file",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
DisableFile: true,
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "job requests no renewal",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
AllowTokenExpiration: true,
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
},
|
|
},
|
|
expectNoRenew: true,
|
|
},
|
|
{
|
|
name: "tokens are not renewable",
|
|
task: &structs.Task{
|
|
Vault: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
},
|
|
Identities: []*structs.WorkloadIdentity{
|
|
{Name: "vault_default"},
|
|
},
|
|
},
|
|
configNonrenewable: true,
|
|
expectNoRenew: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
alloc := mock.MinAlloc()
|
|
alloc.Job.TaskGroups[0].Tasks[0] = tc.task
|
|
|
|
hookConfig := &vaultHookConfig{
|
|
task: tc.task,
|
|
alloc: alloc,
|
|
vaultConfigsFunc: func(hclog.Logger) map[string]*sconfig.VaultConfig {
|
|
if tc.configs != nil {
|
|
return tc.configs
|
|
}
|
|
return map[string]*sconfig.VaultConfig{
|
|
"default": sconfig.DefaultVaultConfig(),
|
|
}
|
|
},
|
|
}
|
|
|
|
if tc.configNonrenewable {
|
|
hookConfig.clientFunc = func(cluster string) (vaultclient.VaultClient, error) {
|
|
client := &vaultclient.MockVaultClient{}
|
|
client.SetRenewable(false)
|
|
return client, nil
|
|
}
|
|
}
|
|
|
|
hook := setupTestVaultHook(t, hookConfig)
|
|
|
|
// Ensure Prestart() returns within a reasonable time.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: t.TempDir(),
|
|
PrivateDir: t.TempDir(),
|
|
},
|
|
Task: tc.task,
|
|
}
|
|
var resp interfaces.TaskPrestartResponse
|
|
|
|
err := hook.Prestart(ctx, req, &resp)
|
|
must.NoError(t, err)
|
|
must.NoError(t, ctx.Err())
|
|
|
|
// Token must have been derived.
|
|
var token string
|
|
client := hook.client.(*vaultclient.MockVaultClient)
|
|
|
|
tokens := client.JWTTokens()
|
|
must.MapLen(t, 1, tokens)
|
|
|
|
swid, err := hook.widmgr.Get(structs.WIHandle{
|
|
IdentityName: tc.task.Vault.IdentityName(),
|
|
WorkloadIdentifier: tc.task.Name,
|
|
WorkloadType: structs.WorkloadTypeTask,
|
|
})
|
|
must.NoError(t, err)
|
|
token = tokens[swid.JWT]
|
|
|
|
must.NotEq(t, "", token)
|
|
|
|
// Token must be derived with correct role.
|
|
//
|
|
// MockVaultClient generates random UUIDv4 tokens, but append the
|
|
// role when requested.
|
|
if tc.expectRole != "" {
|
|
must.StrHasSuffix(t, tc.expectRole, token)
|
|
} else {
|
|
must.UUIDv4(t, token)
|
|
}
|
|
|
|
// Token must be set in token updater.
|
|
updater := (hook.updater).(*vaultTokenUpdaterMock)
|
|
must.Eq(t, token, updater.currentToken)
|
|
|
|
// Token must be written to disk.
|
|
tokenFile, err := os.ReadFile(hook.privateDirTokenPath)
|
|
must.NoError(t, err)
|
|
must.Eq(t, updater.currentToken, string(tokenFile))
|
|
|
|
if !tc.task.Vault.DisableFile {
|
|
tokenFile, err := os.ReadFile(hook.secretsDirTokenPath)
|
|
must.NoError(t, err)
|
|
must.Eq(t, updater.currentToken, string(tokenFile))
|
|
} else {
|
|
_, err = os.ReadFile(hook.secretsDirTokenPath)
|
|
must.ErrorIs(t, err, os.ErrNotExist)
|
|
}
|
|
|
|
// Token must be set for renewal.
|
|
if tc.expectNoRenew {
|
|
must.MapEmpty(t, client.RenewTokens())
|
|
} else {
|
|
must.MapLen(t, 1, client.RenewTokens())
|
|
must.NotNil(t, client.RenewTokens()[updater.currentToken])
|
|
}
|
|
|
|
// PrestartDone must be false so we can recover tokens.
|
|
// firstRun is used to prevent multiple executions.
|
|
must.False(t, resp.Done)
|
|
must.False(t, hook.firstRun)
|
|
|
|
// Stop renewal when hook stops.
|
|
err = hook.Stop(ctx, nil, nil)
|
|
must.NoError(t, err)
|
|
must.Wait(t, wait.InitialSuccess(
|
|
wait.ErrorFunc(func() error {
|
|
tokens := client.StoppedTokens()
|
|
|
|
if tc.expectNoRenew {
|
|
if len(tokens) != 0 {
|
|
return fmt.Errorf("expected no stopped tokens when renewal is disabled, got %d", len(tokens))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if len(tokens) != 1 {
|
|
return fmt.Errorf("expected stopped tokens to be %d, got %d", 1, len(tokens))
|
|
}
|
|
got := tokens[0]
|
|
expect := updater.currentToken
|
|
if got != expect {
|
|
return fmt.Errorf("expected stopped token to be %s, got %s", expect, got)
|
|
}
|
|
return nil
|
|
}),
|
|
wait.Timeout(5*time.Second),
|
|
wait.Gap(100*time.Millisecond),
|
|
))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTaskRunner_VaultHook_recover(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
setupReq func() (*interfaces.TaskPrestartRequest, error)
|
|
}{
|
|
{
|
|
name: "recover from secrets dir",
|
|
setupReq: func() (*interfaces.TaskPrestartRequest, error) {
|
|
// Write token to secrets dir.
|
|
secretsDirPath := t.TempDir()
|
|
err := os.WriteFile(filepath.Join(secretsDirPath, vaultTokenFile), []byte("much secret"), 0666)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: secretsDirPath,
|
|
PrivateDir: t.TempDir(),
|
|
},
|
|
}
|
|
return req, nil
|
|
},
|
|
},
|
|
{
|
|
name: "recover from private dir",
|
|
setupReq: func() (*interfaces.TaskPrestartRequest, error) {
|
|
// Write token to private dir.
|
|
privateDirPath := t.TempDir()
|
|
err := os.WriteFile(filepath.Join(privateDirPath, vaultTokenFile), []byte("much secret"), 0666)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: t.TempDir(),
|
|
PrivateDir: privateDirPath,
|
|
},
|
|
}
|
|
return req, nil
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
hook := setupTestVaultHook(t, nil)
|
|
|
|
req, err := tc.setupReq()
|
|
must.NoError(t, err)
|
|
req.Task = hook.task
|
|
|
|
// Ensure Prestart() returns in a reasonable time.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
var resp interfaces.TaskPrestartResponse
|
|
err = hook.Prestart(ctx, req, &resp)
|
|
must.NoError(t, err)
|
|
must.NoError(t, ctx.Err())
|
|
|
|
// Verify token was recovered and not derived.
|
|
client := hook.client.(*vaultclient.MockVaultClient)
|
|
must.MapLen(t, 0, client.JWTTokens())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
t.Run("unrecoverable error", func(t *testing.T) {
|
|
vaultClient, _ := vaultclient.NewMockVaultClient("")
|
|
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
|
|
|
|
hook := setupTestVaultHook(t, &vaultHookConfig{
|
|
clientFunc: func(string) (vaultclient.VaultClient, error) {
|
|
return mockVaultClient, nil
|
|
},
|
|
})
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: t.TempDir(),
|
|
PrivateDir: t.TempDir(),
|
|
},
|
|
Task: hook.task,
|
|
}
|
|
var resp interfaces.TaskPrestartResponse
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
// Set unrecoverable error.
|
|
mockVaultClient.SetDeriveTokenWithJWTFn(
|
|
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
|
|
// Cancel the context to simulate the task being killed.
|
|
cancel()
|
|
return "", false, 0, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
|
|
})
|
|
|
|
err := hook.Prestart(ctx, req, &resp)
|
|
must.NoError(t, err)
|
|
|
|
// Verify task is killed because of unrecoverable error.
|
|
must.Wait(t, wait.InitialSuccess(
|
|
wait.ErrorFunc(func() error {
|
|
killEv := (hook.lifecycle.(*trtesting.MockTaskHooks)).KillEvent()
|
|
if killEv == nil {
|
|
return errors.New("missing kill event")
|
|
}
|
|
return nil
|
|
}),
|
|
wait.Timeout(5*time.Second),
|
|
wait.Gap(100*time.Millisecond),
|
|
))
|
|
killEv := (hook.lifecycle.(*trtesting.MockTaskHooks)).KillEvent()
|
|
must.StrContains(t, killEv.DisplayMessage, "unrecoverable test error")
|
|
})
|
|
|
|
t.Run("recoverable error", func(t *testing.T) {
|
|
vaultClient, _ := vaultclient.NewMockVaultClient("")
|
|
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
|
|
|
|
hook := setupTestVaultHook(t, &vaultHookConfig{
|
|
clientFunc: func(string) (vaultclient.VaultClient, error) {
|
|
return mockVaultClient, nil
|
|
},
|
|
})
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: t.TempDir(),
|
|
PrivateDir: t.TempDir(),
|
|
},
|
|
Task: hook.task,
|
|
}
|
|
var resp interfaces.TaskPrestartResponse
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
// Set recoverable error.
|
|
mockVaultClient.SetDeriveTokenWithJWTFn(
|
|
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
|
|
return "", false, 0, structs.NewRecoverableError(errors.New("recoverable test error"), true)
|
|
})
|
|
|
|
go func() {
|
|
// Wait a bit for the first error then fix token renewal.
|
|
time.Sleep(time.Second)
|
|
mockVaultClient.SetDeriveTokenWithJWTFn(
|
|
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
|
|
return "secret", true, 30, nil
|
|
})
|
|
|
|
}()
|
|
err := hook.Prestart(ctx, req, &resp)
|
|
must.NoError(t, err)
|
|
must.NoError(t, ctx.Err())
|
|
|
|
// Verify retry happened and token was derived.
|
|
updater := (hook.updater).(*vaultTokenUpdaterMock)
|
|
must.Eq(t, "secret", updater.currentToken)
|
|
})
|
|
|
|
t.Run("renew request failed", func(t *testing.T) {
|
|
vaultClient, _ := vaultclient.NewMockVaultClient("")
|
|
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
|
|
|
|
hook := setupTestVaultHook(t, &vaultHookConfig{
|
|
clientFunc: func(string) (vaultclient.VaultClient, error) {
|
|
return mockVaultClient, nil
|
|
},
|
|
})
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: t.TempDir(),
|
|
PrivateDir: t.TempDir(),
|
|
},
|
|
Task: hook.task,
|
|
}
|
|
var resp interfaces.TaskPrestartResponse
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
// Derive predictable token and fail renew request.
|
|
mockVaultClient.SetDeriveTokenWithJWTFn(
|
|
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
|
|
return "secret", true, 30, nil
|
|
})
|
|
mockVaultClient.SetRenewTokenError("secret", errors.New("test error"))
|
|
|
|
go func() {
|
|
// Wait a bit for the renew error then fix token renewal.
|
|
time.Sleep(10 * time.Millisecond)
|
|
mockVaultClient.SetRenewTokenError("secret", nil)
|
|
|
|
}()
|
|
err := hook.Prestart(ctx, req, &resp)
|
|
must.NoError(t, err)
|
|
must.NoError(t, ctx.Err())
|
|
|
|
// Verify retry happened and token was derived.
|
|
updater := (hook.updater).(*vaultTokenUpdaterMock)
|
|
must.Eq(t, "secret", updater.currentToken)
|
|
})
|
|
}
|
|
|
|
func TestTaskRunner_VaultHook_tokenRenewalFail(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
vaultBlock *structs.Vault
|
|
verifyTaskLifecycle func(*trtesting.MockTaskHooks) error
|
|
}{
|
|
{
|
|
name: "change mode signal",
|
|
vaultBlock: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
ChangeMode: structs.VaultChangeModeSignal,
|
|
ChangeSignal: "SIGTERM",
|
|
},
|
|
verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error {
|
|
signals := h.Signals()
|
|
if len(signals) != 1 {
|
|
return fmt.Errorf("expected 1 signal, got %d", len(signals))
|
|
}
|
|
if signals[0] != "SIGTERM" {
|
|
return fmt.Errorf("expected signal to be SIGTERM, got %s", signals[0])
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
name: "change mode restart",
|
|
vaultBlock: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
ChangeMode: structs.VaultChangeModeRestart,
|
|
},
|
|
verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error {
|
|
restarts := h.Restarts()
|
|
if restarts != 1 {
|
|
return fmt.Errorf("expected 1 restart, got %d", restarts)
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
name: "change mode noop",
|
|
vaultBlock: &structs.Vault{
|
|
Cluster: structs.VaultDefaultCluster,
|
|
ChangeMode: structs.VaultChangeModeNoop,
|
|
},
|
|
verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error {
|
|
restarts := h.Restarts()
|
|
if restarts != 0 {
|
|
return fmt.Errorf("expected 0 restarts, got %d", restarts)
|
|
}
|
|
|
|
signals := h.Signals()
|
|
if len(signals) != 0 {
|
|
return fmt.Errorf("expected 0 signals, got %d", len(signals))
|
|
}
|
|
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
vaultClient, _ := vaultclient.NewMockVaultClient("")
|
|
mockVaultClient := vaultClient.(*vaultclient.MockVaultClient)
|
|
|
|
hook := setupTestVaultHook(t, &vaultHookConfig{
|
|
vaultBlock: tc.vaultBlock,
|
|
clientFunc: func(string) (vaultclient.VaultClient, error) {
|
|
return mockVaultClient, nil
|
|
},
|
|
})
|
|
|
|
req := &interfaces.TaskPrestartRequest{
|
|
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
|
TaskDir: &allocdir.TaskDir{
|
|
SecretsDir: t.TempDir(),
|
|
PrivateDir: t.TempDir(),
|
|
},
|
|
Task: hook.task,
|
|
}
|
|
var resp interfaces.TaskPrestartResponse
|
|
|
|
// Ensure Prestart() returns within a reasonable time.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
err := hook.Prestart(ctx, req, &resp)
|
|
must.NoError(t, err)
|
|
|
|
// Fetch derived token.
|
|
updater := (hook.updater).(*vaultTokenUpdaterMock)
|
|
token := updater.currentToken
|
|
must.NotEq(t, "", token)
|
|
|
|
// Fetch renewal token error channel.
|
|
renewErrCh := mockVaultClient.RenewTokenErrCh(token)
|
|
must.NotNil(t, renewErrCh)
|
|
|
|
// Emit renewal error.
|
|
renewErrCh <- errors.New("renew error")
|
|
|
|
// Verify expected lifecycle events happen.
|
|
must.Wait(t, wait.InitialSuccess(
|
|
wait.ErrorFunc(func() error {
|
|
return tc.verifyTaskLifecycle((hook.lifecycle).(*trtesting.MockTaskHooks))
|
|
}),
|
|
wait.Timeout(3*time.Second),
|
|
wait.Gap(100*time.Millisecond),
|
|
))
|
|
})
|
|
}
|
|
}
|