diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index b8c838d76..a84fb312a 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -82,8 +82,8 @@ type allocRunner struct { // managing SI tokens sidsClient consul.ServiceIdentityAPI - // vaultClient is the used to manage Vault tokens - vaultClient vaultclient.VaultClient + // vaultClientFunc is used to get the client used to manage Vault tokens + vaultClientFunc vaultclient.VaultClientFunc // waitCh is closed when the Run loop has exited waitCh chan struct{} @@ -225,7 +225,7 @@ func NewAllocRunner(config *config.AllocRunnerConfig) (interfaces.AllocRunner, e consulClient: config.Consul, consulProxiesClient: config.ConsulProxies, sidsClient: config.ConsulSI, - vaultClient: config.Vault, + vaultClientFunc: config.VaultFunc, tasks: make(map[string]*taskrunner.TaskRunner, len(tg.Tasks)), waitCh: make(chan struct{}), destroyCh: make(chan struct{}), @@ -297,7 +297,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error { Consul: ar.consulClient, ConsulProxies: ar.consulProxiesClient, ConsulSI: ar.sidsClient, - Vault: ar.vaultClient, + VaultFunc: ar.vaultClientFunc, DeviceStatsReporter: ar.deviceStatsReporter, CSIManager: ar.csiManager, DeviceManager: ar.devicemanager, diff --git a/client/allocrunner/taskrunner/sids_hook_test.go b/client/allocrunner/taskrunner/sids_hook_test.go index e33fc63e3..1111f8314 100644 --- a/client/allocrunner/taskrunner/sids_hook_test.go +++ b/client/allocrunner/taskrunner/sids_hook_test.go @@ -268,7 +268,7 @@ func TestTaskRunner_DeriveSIToken_UnWritableTokenFile(t *testing.T) { "run_for": "0s", } - trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // make the si_token file un-writable, triggering a failure after a diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index cf70f5ace..53c3bf393 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -193,8 +193,9 @@ type TaskRunner struct { // service identity tokens siClient consul.ServiceIdentityAPI - // vaultClient is the client to use to derive and renew Vault tokens - vaultClient vaultclient.VaultClient + // vaultClientFunc is the function to get a client to use to derive and + // renew Vault tokens + vaultClientFunc vaultclient.VaultClientFunc // vaultToken is the current Vault token. It should be accessed with the // getter. @@ -290,8 +291,8 @@ type Config struct { // DynamicRegistry is where dynamic plugins should be registered. DynamicRegistry dynamicplugins.Registry - // Vault is the client to use to derive and renew Vault tokens - Vault vaultclient.VaultClient + // VaultFunc is function to get the client to use to derive and renew Vault tokens + VaultFunc vaultclient.VaultClientFunc // StateDB is used to store and restore state. StateDB cstate.StateDB @@ -379,7 +380,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) { consulServiceClient: config.Consul, consulProxiesClient: config.ConsulProxies, siClient: config.ConsulSI, - vaultClient: config.Vault, + vaultClientFunc: config.VaultFunc, state: tstate, localState: state.NewLocalState(), allocHookResources: config.AllocHookResources, diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 4a54c3b2b..9e849ec39 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -89,10 +89,10 @@ func (tr *TaskRunner) initHooks() { } // If Vault is enabled, add the hook - if task.Vault != nil { + if task.Vault != nil && tr.vaultClientFunc != nil { tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ vaultBlock: task.Vault, - client: tr.vaultClient, + clientFunc: tr.vaultClientFunc, events: tr, lifecycle: tr, updater: tr, diff --git a/client/allocrunner/taskrunner/task_runner_linux_test.go b/client/allocrunner/taskrunner/task_runner_linux_test.go index 44d701729..b425e426a 100644 --- a/client/allocrunner/taskrunner/task_runner_linux_test.go +++ b/client/allocrunner/taskrunner/task_runner_linux_test.go @@ -32,12 +32,22 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) { Policies: []string{"default"}, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + // Setup a test Vault client. + token := "1234" + handler := func(*structs.Allocation, []string) (map[string]string, error) { + return map[string]string{task.Name: token}, nil + } + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) + vaultClient.DeriveTokenFn = handler + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) defer cleanup() // Remove private dir and write the Vault token to the secrets dir to // simulate an old task. - err := conf.TaskDir.Build(false, nil) + err = conf.TaskDir.Build(false, nil) must.NoError(t, err) err = syscall.Unmount(conf.TaskDir.PrivateDir, 0) @@ -45,18 +55,10 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) { err = os.Remove(conf.TaskDir.PrivateDir) must.NoError(t, err) - token := "1234" tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile) err = os.WriteFile(tokenPath, []byte(token), 0666) must.NoError(t, err) - // Setup a test Vault client. - handler := func(*structs.Allocation, []string) (map[string]string, error) { - return map[string]string{task.Name: token}, nil - } - vaultClient := conf.Vault.(*vaultclient.MockVaultClient) - vaultClient.DeriveTokenFn = handler - // Start task runner and wait for task to finish. tr, err := NewTaskRunner(conf) must.NoError(t, err) diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index a85545939..b32a82741 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -67,7 +67,7 @@ func (m *MockTaskStateUpdater) TaskStateUpdated() { // testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task // plus a cleanup func. -func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) { +func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string, vault vaultclient.VaultClient) (*Config, func()) { logger := testlog.HCLogger(t) clientConf, cleanup := config.TestClientConfig(t) @@ -116,6 +116,11 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri nomadRegMock := regMock.NewServiceRegistrationHandler(logger) wrapperMock := wrapper.NewHandlerWrapper(logger, consulRegMock, nomadRegMock) + var vaultFunc vaultclient.VaultClientFunc + if vault != nil { + vaultFunc = func(_ string) (vaultclient.VaultClient, error) { return vault, nil } + } + conf := &Config{ Alloc: alloc, ClientConfig: clientConf, @@ -124,7 +129,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri Logger: clientConf.Logger, Consul: consulRegMock, ConsulSI: consulapi.NewMockServiceIdentitiesClient(), - Vault: vaultclient.NewMockVaultClient(), + VaultFunc: vaultFunc, StateDB: cstate.NoopDB{}, StateUpdater: NewMockTaskStateUpdater(), DeviceManager: devicemanager.NoopMockManager(), @@ -146,7 +151,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri // a cleanup function that ensures the runner is stopped and cleaned up. Tests // which need to change the Config *must* use testTaskRunnerConfig instead. func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) { - config, cleanup := testTaskRunnerConfig(t, alloc, taskName) + config, cleanup := testTaskRunnerConfig(t, alloc, taskName, nil) tr, err := NewTaskRunner(config) require.NoError(t, err) @@ -205,7 +210,7 @@ func TestTaskRunner_BuildTaskConfig_CPU_Memory(t *testing.T) { res.Memory.MemoryMB = c.memoryMB res.Memory.MemoryMaxMB = c.memoryMaxMB - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners defer cleanup() @@ -244,7 +249,7 @@ func TestTaskRunner_Stop_ExitCode(t *testing.T) { "NOMAD_TASK_NAME": task.Name, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // Run the first TaskRunner @@ -292,7 +297,7 @@ func TestTaskRunner_Restore_Running(t *testing.T) { task.Config = map[string]interface{}{ "run_for": "2s", } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners defer cleanup() @@ -346,7 +351,7 @@ func TestTaskRunner_Restore_Dead(t *testing.T) { task.Config = map[string]interface{}{ "run_for": "2s", } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners defer cleanup() @@ -430,7 +435,7 @@ func setupRestoreFailureTest(t *testing.T, alloc *structs.Allocation) (*TaskRunn "NOMAD_ALLOC_ID": alloc.ID, "NOMAD_TASK_NAME": task.Name, } - conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs // Run the first TaskRunner @@ -583,7 +588,7 @@ func TestTaskRunner_Restore_System(t *testing.T) { "NOMAD_ALLOC_ID": alloc.ID, "NOMAD_TASK_NAME": task.Name, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs @@ -651,7 +656,7 @@ func TestTaskRunner_MarkFailedKill(t *testing.T) { // set up some taskrunner alloc := mock.MinAlloc() task := alloc.Job.TaskGroups[0].Tasks[0] - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) t.Cleanup(cleanup) tr, err := NewTaskRunner(conf) must.NoError(t, err) @@ -802,7 +807,7 @@ func TestTaskRunner_DevicePropogation(t *testing.T) { tRes := alloc.AllocatedResources.Tasks[task.Name] tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"}) - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners defer cleanup() @@ -887,7 +892,7 @@ func TestTaskRunner_Restore_HookEnv(t *testing.T) { alloc := mock.BatchAlloc() task := alloc.Job.TaskGroups[0].Tasks[0] - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls defer cleanup() @@ -932,7 +937,7 @@ func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) { "run_for": "5s", } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls defer cleanup() @@ -1310,7 +1315,7 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) { } task.Services[0].Provider = structs.ServiceProviderConsul - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // Replace mock Consul ServiceClient, with the real ServiceClient @@ -1406,7 +1411,7 @@ func TestTaskRunner_BlockForSIDSToken(t *testing.T) { "run_for": "0s", } - trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // set a consul token on the Nomad client's consul config, because that is @@ -1470,7 +1475,7 @@ func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) { "run_for": "0s", } - trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // set a consul token on the Nomad client's consul config, because that is @@ -1530,7 +1535,7 @@ func TestTaskRunner_DeriveSIToken_Unrecoverable(t *testing.T) { "run_for": "0s", } - trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // set a consul token on the Nomad client's consul config, because that is @@ -1582,9 +1587,6 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) { } task.Vault = &structs.Vault{Policies: []string{"default"}} - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) - defer cleanup() - // Control when we get a Vault token token := "1234" waitCh := make(chan struct{}) @@ -1592,9 +1594,15 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) { <-waitCh return map[string]string{task.Name: token}, nil } - vaultClient := conf.Vault.(*vaultclient.MockVaultClient) + + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) vaultClient.DeriveTokenFn = handler + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) + defer cleanup() + tr, err := NewTaskRunner(conf) require.NoError(t, err) defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup")) @@ -1671,17 +1679,19 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) { DisableFile: true, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) - defer cleanup() - // Setup a test Vault client token := "1234" handler := func(*structs.Allocation, []string) (map[string]string, error) { return map[string]string{task.Name: token}, nil } - vaultClient := conf.Vault.(*vaultclient.MockVaultClient) + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) vaultClient.DeriveTokenFn = handler + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) + defer cleanup() + // Start task runner and wait for it to complete. tr, err := NewTaskRunner(conf) must.NoError(t, err) @@ -1716,9 +1726,6 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { task := alloc.Job.TaskGroups[0].Tasks[0] task.Vault = &structs.Vault{Policies: []string{"default"}} - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) - defer cleanup() - // Fail on the first attempt to derive a vault token token := "1234" count := 0 @@ -1730,9 +1737,14 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { count++ return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true) } - vaultClient := conf.Vault.(*vaultclient.MockVaultClient) + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) vaultClient.DeriveTokenFn = handler + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) + defer cleanup() + tr, err := NewTaskRunner(conf) require.NoError(t, err) defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup")) @@ -1794,12 +1806,15 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) { } task.Vault = &structs.Vault{Policies: []string{"default"}} - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) - defer cleanup() - // Error the token derivation - vaultClient := conf.Vault.(*vaultclient.MockVaultClient) - vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable")) + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) + vaultClient.SetDeriveTokenError( + alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable")) + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) + defer cleanup() tr, err := NewTaskRunner(conf) require.NoError(t, err) @@ -2003,7 +2018,7 @@ func TestTaskRunner_DriverNetwork(t *testing.T) { }, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() // Use a mock agent to test for services @@ -2105,9 +2120,6 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { // Use vault to block the start task.Vault = &structs.Vault{Policies: []string{"default"}} - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) - defer cleanup() - // Control when we get a Vault token waitCh := make(chan struct{}, 1) defer close(waitCh) @@ -2115,9 +2127,14 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { <-waitCh return map[string]string{task.Name: "1234"}, nil } - vaultClient := conf.Vault.(*vaultclient.MockVaultClient) + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) vaultClient.DeriveTokenFn = handler + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) + defer cleanup() + tr, err := NewTaskRunner(conf) require.NoError(t, err) defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup")) @@ -2215,7 +2232,7 @@ func TestTaskRunner_Template_Artifact(t *testing.T) { }, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() tr, err := NewTaskRunner(conf) @@ -2265,7 +2282,7 @@ func TestTaskRunner_Template_BlockingPreStart(t *testing.T) { task.Vault = &structs.Vault{Policies: []string{"default"}} - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() tr, err := NewTaskRunner(conf) @@ -2326,7 +2343,11 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { } task.Vault = &structs.Vault{Policies: []string{"default"}} - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) defer cleanup() tr, err := NewTaskRunner(conf) @@ -2348,8 +2369,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { require.NoError(t, err) }) - vault := conf.Vault.(*vaultclient.MockVaultClient) - renewalCh, ok := vault.RenewTokens()[token] + renewalCh, ok := vaultClient.RenewTokens()[token] require.True(t, ok, "no renewal channel for token") renewalCh <- fmt.Errorf("Test killing") @@ -2374,11 +2394,11 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { // Check the token was revoked testutil.WaitForResult(func() (bool, error) { - if len(vault.StoppedTokens()) != 1 { - return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens()) + if len(vaultClient.StoppedTokens()) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens()) } - if a := vault.StoppedTokens()[0]; a != token { + if a := vaultClient.StoppedTokens()[0]; a != token { return false, fmt.Errorf("got stopped token %q; want %q", a, token) } @@ -2404,7 +2424,11 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) { ChangeMode: structs.VaultChangeModeRestart, } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + vc, err := vaultclient.NewMockVaultClient("default") + vaultClient := vc.(*vaultclient.MockVaultClient) + must.NoError(t, err) + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) defer cleanup() tr, err := NewTaskRunner(conf) @@ -2420,8 +2444,7 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) { require.NotEmpty(t, token) - vault := conf.Vault.(*vaultclient.MockVaultClient) - renewalCh, ok := vault.RenewTokens()[token] + renewalCh, ok := vaultClient.RenewTokens()[token] require.True(t, ok, "no renewal channel for token") renewalCh <- fmt.Errorf("Test killing") @@ -2477,8 +2500,11 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) { ChangeMode: structs.VaultChangeModeSignal, ChangeSignal: "SIGUSR1", } + vc, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) + vaultClient := vc.(*vaultclient.MockVaultClient) - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient) defer cleanup() tr, err := NewTaskRunner(conf) @@ -2494,8 +2520,7 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) { require.NotEmpty(t, token) - vault := conf.Vault.(*vaultclient.MockVaultClient) - renewalCh, ok := vault.RenewTokens()[token] + renewalCh, ok := vaultClient.RenewTokens()[token] require.True(t, ok, "no renewal channel for token") renewalCh <- fmt.Errorf("Test killing") @@ -2548,7 +2573,7 @@ func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) { "run_for": "1ns", } - conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() tr, err := NewTaskRunner(conf) @@ -2612,7 +2637,7 @@ func TestTaskRunner_BaseLabels(t *testing.T) { "command": "whoami", } - config, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + config, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil) defer cleanup() tr, err := NewTaskRunner(config) diff --git a/client/allocrunner/taskrunner/vault_hook.go b/client/allocrunner/taskrunner/vault_hook.go index 7c43bee8e..c97316a6b 100644 --- a/client/allocrunner/taskrunner/vault_hook.go +++ b/client/allocrunner/taskrunner/vault_hook.go @@ -50,7 +50,7 @@ func (tr *TaskRunner) updatedVaultToken(token string) { type vaultHookConfig struct { vaultBlock *structs.Vault - client vaultclient.VaultClient + clientFunc vaultclient.VaultClientFunc events ti.EventEmitter lifecycle ti.TaskLifecycle updater vaultTokenUpdateHandler @@ -72,8 +72,10 @@ type vaultHook struct { // updater is used to update the Vault token updater vaultTokenUpdateHandler - // client is the Vault client to retrieve and renew the Vault token - client vaultclient.VaultClient + // client is the Vault client to retrieve and renew the Vault token, and + // clientFunc is the injected function that retrieves it + client vaultclient.VaultClient + clientFunc vaultclient.VaultClientFunc // logger is used to log logger log.Logger @@ -105,9 +107,10 @@ type vaultHook struct { func newVaultHook(config *vaultHookConfig) *vaultHook { ctx, cancel := context.WithCancel(context.Background()) + h := &vaultHook{ vaultBlock: config.vaultBlock, - client: config.client, + clientFunc: config.clientFunc, eventEmitter: config.events, lifecycle: config.lifecycle, updater: config.updater, @@ -135,6 +138,12 @@ func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRe return nil } + vclient, err := h.clientFunc(h.vaultBlock.Cluster) + if err != nil { + return err + } + h.client = vclient + // Try to recover a token if it was previously written in the secrets // directory recoveredToken := "" diff --git a/client/allocrunner/testing.go b/client/allocrunner/testing.go index 615d9bb23..d7f104efa 100644 --- a/client/allocrunner/testing.go +++ b/client/allocrunner/testing.go @@ -86,7 +86,7 @@ func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*config.All StateDB: stateDB, Consul: consulRegMock, ConsulSI: consul.NewMockServiceIdentitiesClient(), - Vault: vaultclient.NewMockVaultClient(), + VaultFunc: vaultclient.NewMockVaultClient, StateUpdater: &MockStateUpdater{}, PrevAllocWatcher: allocwatcher.NoopPrevAlloc{}, PrevAllocMigrator: allocwatcher.NoopPrevAlloc{}, diff --git a/client/client.go b/client/client.go index f2aa403e4..9bba11c58 100644 --- a/client/client.go +++ b/client/client.go @@ -264,8 +264,8 @@ type Client struct { // Service Identity tokens through Nomad Server. tokensClient consulApi.ServiceIdentityAPI - // vaultClient is used to interact with Vault for token and secret renewals - vaultClient vaultclient.VaultClient + // vaultClients is used to interact with Vault for token and secret renewals + vaultClients map[string]vaultclient.VaultClient // garbageCollector is used to garbage collect terminal allocations present // in the node automatically @@ -580,7 +580,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie } // Setup the vault client for token and secret renewals - if err := c.setupVaultClient(); err != nil { + if err := c.setupVaultClients(); err != nil { return nil, fmt.Errorf("failed to setup vault client: %v", err) } @@ -868,8 +868,8 @@ func (c *Client) Shutdown() error { c.logger.Info("shutting down") // Stop renewing tokens and secrets - if c.vaultClient != nil { - c.vaultClient.Stop() + for _, vaultClient := range c.vaultClients { + vaultClient.Stop() } // Stop Garbage collector @@ -2761,7 +2761,7 @@ func (c *Client) newAllocRunnerConfig( ServiceRegWrapper: c.serviceRegWrapper, StateDB: c.stateDB, StateUpdater: c, - Vault: c.vaultClient, + VaultFunc: c.VaultClient, WIDMgr: c.widmgr, Wranglers: c.wranglers, Partitions: c.partitions, @@ -2776,26 +2776,43 @@ func (c *Client) setupConsulTokenClient() error { return nil } -// setupVaultClient creates an object to periodically renew tokens and secrets -// with vault. -func (c *Client) setupVaultClient() error { - var err error - c.vaultClient, err = vaultclient.NewVaultClient(c.GetConfig().VaultConfig, c.logger, c.deriveToken) - if err != nil { - return err +// setupVaultClients creates the objects that periodically renew tokens and +// secrets with vault. +func (c *Client) setupVaultClients() error { + + c.vaultClients = map[string]vaultclient.VaultClient{} + vaultConfigs := c.GetConfig().GetVaultConfigs(c.logger) + for _, vaultConfig := range vaultConfigs { + vaultClient, err := vaultclient.NewVaultClient(c.GetConfig().VaultConfig, c.logger, c.deriveToken) + if err != nil { + return err + } + if vaultClient == nil { + c.logger.Error("failed to create vault client", "name", vaultConfig.Name) + return fmt.Errorf("failed to create vault client for cluster %q", vaultConfig.Name) + } + c.vaultClients[vaultConfig.Name] = vaultClient + } - if c.vaultClient == nil { - c.logger.Error("failed to create vault client") - return fmt.Errorf("failed to create vault client") + // Start renewing tokens and secrets only once we've ensured we have created + // all the clients + for _, vaultClient := range c.vaultClients { + vaultClient.Start() } - // Start renewing tokens and secrets - c.vaultClient.Start() - return nil } +func (c *Client) VaultClient(cluster string) (vaultclient.VaultClient, error) { + vaultClient, ok := c.vaultClients[cluster] + if !ok { + return nil, fmt.Errorf("no Vault cluster named: %q", cluster) + } + + return vaultClient, nil +} + // setupNomadServiceRegistrationHandler sets up the registration handler to use // for native service discovery. func (c *Client) setupNomadServiceRegistrationHandler() { diff --git a/client/config/arconfig.go b/client/config/arconfig.go index 6e8828091..d47221a33 100644 --- a/client/config/arconfig.go +++ b/client/config/arconfig.go @@ -59,8 +59,9 @@ type AllocRunnerConfig struct { // ConsulSI is the Consul client used to manage service identity tokens. ConsulSI consul.ServiceIdentityAPI - // Vault is the Vault client to use to retrieve Vault tokens - Vault vaultclient.VaultClient + // VaultFunc is the function to get a Vault client to use to retrieve Vault + // tokens + VaultFunc vaultclient.VaultClientFunc // StateUpdater is used to emit updated task state StateUpdater interfaces.AllocStateHandler diff --git a/client/config/config_ce.go b/client/config/config_ce.go new file mode 100644 index 000000000..671655317 --- /dev/null +++ b/client/config/config_ce.go @@ -0,0 +1,25 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !ent + +package config + +import ( + "github.com/hashicorp/go-hclog" + structsc "github.com/hashicorp/nomad/nomad/structs/config" +) + +// GetVaultConfigs returns the set of Vault configurations available for this +// client. In Nomad CE we only use the default Vault. +func (c *Config) GetVaultConfigs(logger hclog.Logger) map[string]*structsc.VaultConfig { + if c.VaultConfig == nil || !c.VaultConfig.IsEnabled() { + return nil + } + + if len(c.VaultConfigs) > 1 { + logger.Warn("multiple Vault configurations are only supported in Nomad Enterprise") + } + + return map[string]*structsc.VaultConfig{"default": c.VaultConfig} +} diff --git a/client/fingerprint/vault.go b/client/fingerprint/vault.go index 4a6b539b1..a46aea36f 100644 --- a/client/fingerprint/vault.go +++ b/client/fingerprint/vault.go @@ -41,7 +41,9 @@ func NewVaultFingerprint(logger log.Logger) Fingerprint { func (f *VaultFingerprint) Fingerprint(req *FingerprintRequest, resp *FingerprintResponse) error { var mErr *multierror.Error - for _, cfg := range f.vaultConfigs(req) { + vaultConfigs := req.Config.GetVaultConfigs(f.logger) + + for _, cfg := range vaultConfigs { err := f.fingerprintImpl(cfg, resp) if err != nil { mErr = multierror.Append(mErr, err) diff --git a/client/fingerprint/vault_ce.go b/client/fingerprint/vault_ce.go deleted file mode 100644 index b016ad1dc..000000000 --- a/client/fingerprint/vault_ce.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -//go:build !ent - -package fingerprint - -import "github.com/hashicorp/nomad/nomad/structs/config" - -// vaultConfigs returns the set of Vault configurations the fingerprint needs to -// check. In Nomad CE we only check the default Vault. -func (f *VaultFingerprint) vaultConfigs(req *FingerprintRequest) map[string]*config.VaultConfig { - agentCfg := req.Config - if agentCfg.VaultConfig == nil || !agentCfg.VaultConfig.IsEnabled() { - return nil - } - - if len(req.Config.VaultConfigs) > 1 { - f.logger.Warn("multiple Vault configurations are only supported in Nomad Enterprise") - } - - return map[string]*config.VaultConfig{"default": agentCfg.VaultConfig} -} diff --git a/client/state/upgrade_int_test.go b/client/state/upgrade_int_test.go index b2f1f4185..12a4a7375 100644 --- a/client/state/upgrade_int_test.go +++ b/client/state/upgrade_int_test.go @@ -209,7 +209,7 @@ func checkUpgradedAlloc(t *testing.T, path string, db StateDB, alloc *structs.Al ClientConfig: clientConf, StateDB: db, Consul: regMock.NewServiceRegistrationHandler(clientConf.Logger), - Vault: vaultclient.NewMockVaultClient(), + VaultFunc: vaultclient.NewMockVaultClient, StateUpdater: &allocrunner.MockStateUpdater{}, PrevAllocWatcher: allocwatcher.NoopPrevAlloc{}, PrevAllocMigrator: allocwatcher.NoopPrevAlloc{}, diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index e6b4d9346..f055778d6 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -13,12 +13,17 @@ import ( metrics "github.com/armon/go-metrics" hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/helper/useragent" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" ) +// VaultClientFunc is the interface of a function that retreives the VaultClient +// by cluster name. This function is injected into the allocrunner/taskrunner +type VaultClientFunc func(string) (VaultClient, error) + // TokenDeriverFunc takes in an allocation and a set of tasks and derives a // wrapped token for all the tasks, from the nomad server. All the derived // wrapped tokens will be unwrapped using the vault API client. diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index 9a9f8b7f6..be9b79131 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -38,7 +38,7 @@ type MockVaultClient struct { } // NewMockVaultClient returns a MockVaultClient for testing -func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } +func NewMockVaultClient(_ string) (VaultClient, error) { return &MockVaultClient{}, nil } func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { vc.mu.Lock() diff --git a/command/agent/consul/int_test.go b/command/agent/consul/int_test.go index 27e5d945f..2d4b177cf 100644 --- a/command/agent/consul/int_test.go +++ b/command/agent/consul/int_test.go @@ -27,6 +27,7 @@ import ( "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -137,7 +138,8 @@ func TestConsul_Integration(t *testing.T) { r.NoError(allocDir.Destroy()) }) taskDir := allocDir.NewTaskDir(task.Name) - vclient := vaultclient.NewMockVaultClient() + vclient, err := vaultclient.NewMockVaultClient("default") + must.NoError(t, err) consulClient, err := consulapi.NewClient(consulConfig) r.Nil(err) @@ -163,7 +165,7 @@ func TestConsul_Integration(t *testing.T) { Task: task, TaskDir: taskDir, Logger: logger, - Vault: vclient, + VaultFunc: func(string) (vaultclient.VaultClient, error) { return vclient, nil }, StateDB: state.NoopDB{}, StateUpdater: logUpdate, DeviceManager: devicemanager.NoopMockManager(),