mirror of
https://github.com/kemko/nomad.git
synced 2026-01-06 18:35:44 +03:00
vault: select Vault API client by cluster name (#18533)
Nomad Enterprise will support configuring multiple Vault clients. Instead of having a single Vault client field in the Nomad client, we'll have a function that callers can parameterize by the Vault cluster name that returns the correctly configured Vault API client wrapper.
This commit is contained in:
@@ -82,8 +82,8 @@ type allocRunner struct {
|
||||
// managing SI tokens
|
||||
sidsClient consul.ServiceIdentityAPI
|
||||
|
||||
// vaultClient is the used to manage Vault tokens
|
||||
vaultClient vaultclient.VaultClient
|
||||
// vaultClientFunc is used to get the client used to manage Vault tokens
|
||||
vaultClientFunc vaultclient.VaultClientFunc
|
||||
|
||||
// waitCh is closed when the Run loop has exited
|
||||
waitCh chan struct{}
|
||||
@@ -225,7 +225,7 @@ func NewAllocRunner(config *config.AllocRunnerConfig) (interfaces.AllocRunner, e
|
||||
consulClient: config.Consul,
|
||||
consulProxiesClient: config.ConsulProxies,
|
||||
sidsClient: config.ConsulSI,
|
||||
vaultClient: config.Vault,
|
||||
vaultClientFunc: config.VaultFunc,
|
||||
tasks: make(map[string]*taskrunner.TaskRunner, len(tg.Tasks)),
|
||||
waitCh: make(chan struct{}),
|
||||
destroyCh: make(chan struct{}),
|
||||
@@ -297,7 +297,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error {
|
||||
Consul: ar.consulClient,
|
||||
ConsulProxies: ar.consulProxiesClient,
|
||||
ConsulSI: ar.sidsClient,
|
||||
Vault: ar.vaultClient,
|
||||
VaultFunc: ar.vaultClientFunc,
|
||||
DeviceStatsReporter: ar.deviceStatsReporter,
|
||||
CSIManager: ar.csiManager,
|
||||
DeviceManager: ar.devicemanager,
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestTaskRunner_DeriveSIToken_UnWritableTokenFile(t *testing.T) {
|
||||
"run_for": "0s",
|
||||
}
|
||||
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// make the si_token file un-writable, triggering a failure after a
|
||||
|
||||
@@ -193,8 +193,9 @@ type TaskRunner struct {
|
||||
// service identity tokens
|
||||
siClient consul.ServiceIdentityAPI
|
||||
|
||||
// vaultClient is the client to use to derive and renew Vault tokens
|
||||
vaultClient vaultclient.VaultClient
|
||||
// vaultClientFunc is the function to get a client to use to derive and
|
||||
// renew Vault tokens
|
||||
vaultClientFunc vaultclient.VaultClientFunc
|
||||
|
||||
// vaultToken is the current Vault token. It should be accessed with the
|
||||
// getter.
|
||||
@@ -290,8 +291,8 @@ type Config struct {
|
||||
// DynamicRegistry is where dynamic plugins should be registered.
|
||||
DynamicRegistry dynamicplugins.Registry
|
||||
|
||||
// Vault is the client to use to derive and renew Vault tokens
|
||||
Vault vaultclient.VaultClient
|
||||
// VaultFunc is function to get the client to use to derive and renew Vault tokens
|
||||
VaultFunc vaultclient.VaultClientFunc
|
||||
|
||||
// StateDB is used to store and restore state.
|
||||
StateDB cstate.StateDB
|
||||
@@ -379,7 +380,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) {
|
||||
consulServiceClient: config.Consul,
|
||||
consulProxiesClient: config.ConsulProxies,
|
||||
siClient: config.ConsulSI,
|
||||
vaultClient: config.Vault,
|
||||
vaultClientFunc: config.VaultFunc,
|
||||
state: tstate,
|
||||
localState: state.NewLocalState(),
|
||||
allocHookResources: config.AllocHookResources,
|
||||
|
||||
@@ -89,10 +89,10 @@ func (tr *TaskRunner) initHooks() {
|
||||
}
|
||||
|
||||
// If Vault is enabled, add the hook
|
||||
if task.Vault != nil {
|
||||
if task.Vault != nil && tr.vaultClientFunc != nil {
|
||||
tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
|
||||
vaultBlock: task.Vault,
|
||||
client: tr.vaultClient,
|
||||
clientFunc: tr.vaultClientFunc,
|
||||
events: tr,
|
||||
lifecycle: tr,
|
||||
updater: tr,
|
||||
|
||||
@@ -32,12 +32,22 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) {
|
||||
Policies: []string{"default"},
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
// Setup a test Vault client.
|
||||
token := "1234"
|
||||
handler := func(*structs.Allocation, []string) (map[string]string, error) {
|
||||
return map[string]string{task.Name: token}, nil
|
||||
}
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
// Remove private dir and write the Vault token to the secrets dir to
|
||||
// simulate an old task.
|
||||
err := conf.TaskDir.Build(false, nil)
|
||||
err = conf.TaskDir.Build(false, nil)
|
||||
must.NoError(t, err)
|
||||
|
||||
err = syscall.Unmount(conf.TaskDir.PrivateDir, 0)
|
||||
@@ -45,18 +55,10 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) {
|
||||
err = os.Remove(conf.TaskDir.PrivateDir)
|
||||
must.NoError(t, err)
|
||||
|
||||
token := "1234"
|
||||
tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
|
||||
err = os.WriteFile(tokenPath, []byte(token), 0666)
|
||||
must.NoError(t, err)
|
||||
|
||||
// Setup a test Vault client.
|
||||
handler := func(*structs.Allocation, []string) (map[string]string, error) {
|
||||
return map[string]string{task.Name: token}, nil
|
||||
}
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
// Start task runner and wait for task to finish.
|
||||
tr, err := NewTaskRunner(conf)
|
||||
must.NoError(t, err)
|
||||
|
||||
@@ -67,7 +67,7 @@ func (m *MockTaskStateUpdater) TaskStateUpdated() {
|
||||
|
||||
// testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
|
||||
// plus a cleanup func.
|
||||
func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
|
||||
func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string, vault vaultclient.VaultClient) (*Config, func()) {
|
||||
logger := testlog.HCLogger(t)
|
||||
clientConf, cleanup := config.TestClientConfig(t)
|
||||
|
||||
@@ -116,6 +116,11 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
|
||||
nomadRegMock := regMock.NewServiceRegistrationHandler(logger)
|
||||
wrapperMock := wrapper.NewHandlerWrapper(logger, consulRegMock, nomadRegMock)
|
||||
|
||||
var vaultFunc vaultclient.VaultClientFunc
|
||||
if vault != nil {
|
||||
vaultFunc = func(_ string) (vaultclient.VaultClient, error) { return vault, nil }
|
||||
}
|
||||
|
||||
conf := &Config{
|
||||
Alloc: alloc,
|
||||
ClientConfig: clientConf,
|
||||
@@ -124,7 +129,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
|
||||
Logger: clientConf.Logger,
|
||||
Consul: consulRegMock,
|
||||
ConsulSI: consulapi.NewMockServiceIdentitiesClient(),
|
||||
Vault: vaultclient.NewMockVaultClient(),
|
||||
VaultFunc: vaultFunc,
|
||||
StateDB: cstate.NoopDB{},
|
||||
StateUpdater: NewMockTaskStateUpdater(),
|
||||
DeviceManager: devicemanager.NoopMockManager(),
|
||||
@@ -146,7 +151,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
|
||||
// a cleanup function that ensures the runner is stopped and cleaned up. Tests
|
||||
// which need to change the Config *must* use testTaskRunnerConfig instead.
|
||||
func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) {
|
||||
config, cleanup := testTaskRunnerConfig(t, alloc, taskName)
|
||||
config, cleanup := testTaskRunnerConfig(t, alloc, taskName, nil)
|
||||
|
||||
tr, err := NewTaskRunner(config)
|
||||
require.NoError(t, err)
|
||||
@@ -205,7 +210,7 @@ func TestTaskRunner_BuildTaskConfig_CPU_Memory(t *testing.T) {
|
||||
res.Memory.MemoryMB = c.memoryMB
|
||||
res.Memory.MemoryMaxMB = c.memoryMaxMB
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
|
||||
defer cleanup()
|
||||
|
||||
@@ -244,7 +249,7 @@ func TestTaskRunner_Stop_ExitCode(t *testing.T) {
|
||||
"NOMAD_TASK_NAME": task.Name,
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// Run the first TaskRunner
|
||||
@@ -292,7 +297,7 @@ func TestTaskRunner_Restore_Running(t *testing.T) {
|
||||
task.Config = map[string]interface{}{
|
||||
"run_for": "2s",
|
||||
}
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
|
||||
defer cleanup()
|
||||
|
||||
@@ -346,7 +351,7 @@ func TestTaskRunner_Restore_Dead(t *testing.T) {
|
||||
task.Config = map[string]interface{}{
|
||||
"run_for": "2s",
|
||||
}
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
|
||||
defer cleanup()
|
||||
|
||||
@@ -430,7 +435,7 @@ func setupRestoreFailureTest(t *testing.T, alloc *structs.Allocation) (*TaskRunn
|
||||
"NOMAD_ALLOC_ID": alloc.ID,
|
||||
"NOMAD_TASK_NAME": task.Name,
|
||||
}
|
||||
conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
|
||||
|
||||
// Run the first TaskRunner
|
||||
@@ -583,7 +588,7 @@ func TestTaskRunner_Restore_System(t *testing.T) {
|
||||
"NOMAD_ALLOC_ID": alloc.ID,
|
||||
"NOMAD_TASK_NAME": task.Name,
|
||||
}
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
|
||||
|
||||
@@ -651,7 +656,7 @@ func TestTaskRunner_MarkFailedKill(t *testing.T) {
|
||||
// set up some taskrunner
|
||||
alloc := mock.MinAlloc()
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
t.Cleanup(cleanup)
|
||||
tr, err := NewTaskRunner(conf)
|
||||
must.NoError(t, err)
|
||||
@@ -802,7 +807,7 @@ func TestTaskRunner_DevicePropogation(t *testing.T) {
|
||||
tRes := alloc.AllocatedResources.Tasks[task.Name]
|
||||
tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"})
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
|
||||
defer cleanup()
|
||||
|
||||
@@ -887,7 +892,7 @@ func TestTaskRunner_Restore_HookEnv(t *testing.T) {
|
||||
|
||||
alloc := mock.BatchAlloc()
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
|
||||
defer cleanup()
|
||||
|
||||
@@ -932,7 +937,7 @@ func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) {
|
||||
"run_for": "5s",
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
|
||||
defer cleanup()
|
||||
|
||||
@@ -1310,7 +1315,7 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
|
||||
}
|
||||
task.Services[0].Provider = structs.ServiceProviderConsul
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// Replace mock Consul ServiceClient, with the real ServiceClient
|
||||
@@ -1406,7 +1411,7 @@ func TestTaskRunner_BlockForSIDSToken(t *testing.T) {
|
||||
"run_for": "0s",
|
||||
}
|
||||
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// set a consul token on the Nomad client's consul config, because that is
|
||||
@@ -1470,7 +1475,7 @@ func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) {
|
||||
"run_for": "0s",
|
||||
}
|
||||
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// set a consul token on the Nomad client's consul config, because that is
|
||||
@@ -1530,7 +1535,7 @@ func TestTaskRunner_DeriveSIToken_Unrecoverable(t *testing.T) {
|
||||
"run_for": "0s",
|
||||
}
|
||||
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// set a consul token on the Nomad client's consul config, because that is
|
||||
@@ -1582,9 +1587,6 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) {
|
||||
}
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
defer cleanup()
|
||||
|
||||
// Control when we get a Vault token
|
||||
token := "1234"
|
||||
waitCh := make(chan struct{})
|
||||
@@ -1592,9 +1594,15 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) {
|
||||
<-waitCh
|
||||
return map[string]string{task.Name: token}, nil
|
||||
}
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
require.NoError(t, err)
|
||||
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
|
||||
@@ -1671,17 +1679,19 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) {
|
||||
DisableFile: true,
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
defer cleanup()
|
||||
|
||||
// Setup a test Vault client
|
||||
token := "1234"
|
||||
handler := func(*structs.Allocation, []string) (map[string]string, error) {
|
||||
return map[string]string{task.Name: token}, nil
|
||||
}
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
// Start task runner and wait for it to complete.
|
||||
tr, err := NewTaskRunner(conf)
|
||||
must.NoError(t, err)
|
||||
@@ -1716,9 +1726,6 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
defer cleanup()
|
||||
|
||||
// Fail on the first attempt to derive a vault token
|
||||
token := "1234"
|
||||
count := 0
|
||||
@@ -1730,9 +1737,14 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
|
||||
count++
|
||||
return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
|
||||
}
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
require.NoError(t, err)
|
||||
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
|
||||
@@ -1794,12 +1806,15 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
|
||||
}
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
defer cleanup()
|
||||
|
||||
// Error the token derivation
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
vaultClient.SetDeriveTokenError(
|
||||
alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
require.NoError(t, err)
|
||||
@@ -2003,7 +2018,7 @@ func TestTaskRunner_DriverNetwork(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
// Use a mock agent to test for services
|
||||
@@ -2105,9 +2120,6 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
|
||||
// Use vault to block the start
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
defer cleanup()
|
||||
|
||||
// Control when we get a Vault token
|
||||
waitCh := make(chan struct{}, 1)
|
||||
defer close(waitCh)
|
||||
@@ -2115,9 +2127,14 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
|
||||
<-waitCh
|
||||
return map[string]string{task.Name: "1234"}, nil
|
||||
}
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
require.NoError(t, err)
|
||||
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
|
||||
@@ -2215,7 +2232,7 @@ func TestTaskRunner_Template_Artifact(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
@@ -2265,7 +2282,7 @@ func TestTaskRunner_Template_BlockingPreStart(t *testing.T) {
|
||||
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
@@ -2326,7 +2343,11 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
|
||||
}
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
@@ -2348,8 +2369,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
vault := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
renewalCh, ok := vault.RenewTokens()[token]
|
||||
renewalCh, ok := vaultClient.RenewTokens()[token]
|
||||
require.True(t, ok, "no renewal channel for token")
|
||||
|
||||
renewalCh <- fmt.Errorf("Test killing")
|
||||
@@ -2374,11 +2394,11 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
|
||||
|
||||
// Check the token was revoked
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
if len(vault.StoppedTokens()) != 1 {
|
||||
return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens())
|
||||
if len(vaultClient.StoppedTokens()) != 1 {
|
||||
return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
|
||||
}
|
||||
|
||||
if a := vault.StoppedTokens()[0]; a != token {
|
||||
if a := vaultClient.StoppedTokens()[0]; a != token {
|
||||
return false, fmt.Errorf("got stopped token %q; want %q", a, token)
|
||||
}
|
||||
|
||||
@@ -2404,7 +2424,11 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) {
|
||||
ChangeMode: structs.VaultChangeModeRestart,
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
must.NoError(t, err)
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
@@ -2420,8 +2444,7 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) {
|
||||
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
vault := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
renewalCh, ok := vault.RenewTokens()[token]
|
||||
renewalCh, ok := vaultClient.RenewTokens()[token]
|
||||
require.True(t, ok, "no renewal channel for token")
|
||||
|
||||
renewalCh <- fmt.Errorf("Test killing")
|
||||
@@ -2477,8 +2500,11 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) {
|
||||
ChangeMode: structs.VaultChangeModeSignal,
|
||||
ChangeSignal: "SIGUSR1",
|
||||
}
|
||||
vc, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
vaultClient := vc.(*vaultclient.MockVaultClient)
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
@@ -2494,8 +2520,7 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) {
|
||||
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
vault := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
renewalCh, ok := vault.RenewTokens()[token]
|
||||
renewalCh, ok := vaultClient.RenewTokens()[token]
|
||||
require.True(t, ok, "no renewal channel for token")
|
||||
|
||||
renewalCh <- fmt.Errorf("Test killing")
|
||||
@@ -2548,7 +2573,7 @@ func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) {
|
||||
"run_for": "1ns",
|
||||
}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
@@ -2612,7 +2637,7 @@ func TestTaskRunner_BaseLabels(t *testing.T) {
|
||||
"command": "whoami",
|
||||
}
|
||||
|
||||
config, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
config, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
|
||||
defer cleanup()
|
||||
|
||||
tr, err := NewTaskRunner(config)
|
||||
|
||||
@@ -50,7 +50,7 @@ func (tr *TaskRunner) updatedVaultToken(token string) {
|
||||
|
||||
type vaultHookConfig struct {
|
||||
vaultBlock *structs.Vault
|
||||
client vaultclient.VaultClient
|
||||
clientFunc vaultclient.VaultClientFunc
|
||||
events ti.EventEmitter
|
||||
lifecycle ti.TaskLifecycle
|
||||
updater vaultTokenUpdateHandler
|
||||
@@ -72,8 +72,10 @@ type vaultHook struct {
|
||||
// updater is used to update the Vault token
|
||||
updater vaultTokenUpdateHandler
|
||||
|
||||
// client is the Vault client to retrieve and renew the Vault token
|
||||
client vaultclient.VaultClient
|
||||
// client is the Vault client to retrieve and renew the Vault token, and
|
||||
// clientFunc is the injected function that retrieves it
|
||||
client vaultclient.VaultClient
|
||||
clientFunc vaultclient.VaultClientFunc
|
||||
|
||||
// logger is used to log
|
||||
logger log.Logger
|
||||
@@ -105,9 +107,10 @@ type vaultHook struct {
|
||||
|
||||
func newVaultHook(config *vaultHookConfig) *vaultHook {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &vaultHook{
|
||||
vaultBlock: config.vaultBlock,
|
||||
client: config.client,
|
||||
clientFunc: config.clientFunc,
|
||||
eventEmitter: config.events,
|
||||
lifecycle: config.lifecycle,
|
||||
updater: config.updater,
|
||||
@@ -135,6 +138,12 @@ func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRe
|
||||
return nil
|
||||
}
|
||||
|
||||
vclient, err := h.clientFunc(h.vaultBlock.Cluster)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.client = vclient
|
||||
|
||||
// Try to recover a token if it was previously written in the secrets
|
||||
// directory
|
||||
recoveredToken := ""
|
||||
|
||||
@@ -86,7 +86,7 @@ func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*config.All
|
||||
StateDB: stateDB,
|
||||
Consul: consulRegMock,
|
||||
ConsulSI: consul.NewMockServiceIdentitiesClient(),
|
||||
Vault: vaultclient.NewMockVaultClient(),
|
||||
VaultFunc: vaultclient.NewMockVaultClient,
|
||||
StateUpdater: &MockStateUpdater{},
|
||||
PrevAllocWatcher: allocwatcher.NoopPrevAlloc{},
|
||||
PrevAllocMigrator: allocwatcher.NoopPrevAlloc{},
|
||||
|
||||
@@ -264,8 +264,8 @@ type Client struct {
|
||||
// Service Identity tokens through Nomad Server.
|
||||
tokensClient consulApi.ServiceIdentityAPI
|
||||
|
||||
// vaultClient is used to interact with Vault for token and secret renewals
|
||||
vaultClient vaultclient.VaultClient
|
||||
// vaultClients is used to interact with Vault for token and secret renewals
|
||||
vaultClients map[string]vaultclient.VaultClient
|
||||
|
||||
// garbageCollector is used to garbage collect terminal allocations present
|
||||
// in the node automatically
|
||||
@@ -580,7 +580,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie
|
||||
}
|
||||
|
||||
// Setup the vault client for token and secret renewals
|
||||
if err := c.setupVaultClient(); err != nil {
|
||||
if err := c.setupVaultClients(); err != nil {
|
||||
return nil, fmt.Errorf("failed to setup vault client: %v", err)
|
||||
}
|
||||
|
||||
@@ -868,8 +868,8 @@ func (c *Client) Shutdown() error {
|
||||
c.logger.Info("shutting down")
|
||||
|
||||
// Stop renewing tokens and secrets
|
||||
if c.vaultClient != nil {
|
||||
c.vaultClient.Stop()
|
||||
for _, vaultClient := range c.vaultClients {
|
||||
vaultClient.Stop()
|
||||
}
|
||||
|
||||
// Stop Garbage collector
|
||||
@@ -2761,7 +2761,7 @@ func (c *Client) newAllocRunnerConfig(
|
||||
ServiceRegWrapper: c.serviceRegWrapper,
|
||||
StateDB: c.stateDB,
|
||||
StateUpdater: c,
|
||||
Vault: c.vaultClient,
|
||||
VaultFunc: c.VaultClient,
|
||||
WIDMgr: c.widmgr,
|
||||
Wranglers: c.wranglers,
|
||||
Partitions: c.partitions,
|
||||
@@ -2776,26 +2776,43 @@ func (c *Client) setupConsulTokenClient() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupVaultClient creates an object to periodically renew tokens and secrets
|
||||
// with vault.
|
||||
func (c *Client) setupVaultClient() error {
|
||||
var err error
|
||||
c.vaultClient, err = vaultclient.NewVaultClient(c.GetConfig().VaultConfig, c.logger, c.deriveToken)
|
||||
if err != nil {
|
||||
return err
|
||||
// setupVaultClients creates the objects that periodically renew tokens and
|
||||
// secrets with vault.
|
||||
func (c *Client) setupVaultClients() error {
|
||||
|
||||
c.vaultClients = map[string]vaultclient.VaultClient{}
|
||||
vaultConfigs := c.GetConfig().GetVaultConfigs(c.logger)
|
||||
for _, vaultConfig := range vaultConfigs {
|
||||
vaultClient, err := vaultclient.NewVaultClient(c.GetConfig().VaultConfig, c.logger, c.deriveToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if vaultClient == nil {
|
||||
c.logger.Error("failed to create vault client", "name", vaultConfig.Name)
|
||||
return fmt.Errorf("failed to create vault client for cluster %q", vaultConfig.Name)
|
||||
}
|
||||
c.vaultClients[vaultConfig.Name] = vaultClient
|
||||
|
||||
}
|
||||
|
||||
if c.vaultClient == nil {
|
||||
c.logger.Error("failed to create vault client")
|
||||
return fmt.Errorf("failed to create vault client")
|
||||
// Start renewing tokens and secrets only once we've ensured we have created
|
||||
// all the clients
|
||||
for _, vaultClient := range c.vaultClients {
|
||||
vaultClient.Start()
|
||||
}
|
||||
|
||||
// Start renewing tokens and secrets
|
||||
c.vaultClient.Start()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) VaultClient(cluster string) (vaultclient.VaultClient, error) {
|
||||
vaultClient, ok := c.vaultClients[cluster]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no Vault cluster named: %q", cluster)
|
||||
}
|
||||
|
||||
return vaultClient, nil
|
||||
}
|
||||
|
||||
// setupNomadServiceRegistrationHandler sets up the registration handler to use
|
||||
// for native service discovery.
|
||||
func (c *Client) setupNomadServiceRegistrationHandler() {
|
||||
|
||||
@@ -59,8 +59,9 @@ type AllocRunnerConfig struct {
|
||||
// ConsulSI is the Consul client used to manage service identity tokens.
|
||||
ConsulSI consul.ServiceIdentityAPI
|
||||
|
||||
// Vault is the Vault client to use to retrieve Vault tokens
|
||||
Vault vaultclient.VaultClient
|
||||
// VaultFunc is the function to get a Vault client to use to retrieve Vault
|
||||
// tokens
|
||||
VaultFunc vaultclient.VaultClientFunc
|
||||
|
||||
// StateUpdater is used to emit updated task state
|
||||
StateUpdater interfaces.AllocStateHandler
|
||||
|
||||
25
client/config/config_ce.go
Normal file
25
client/config/config_ce.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
//go:build !ent
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/go-hclog"
|
||||
structsc "github.com/hashicorp/nomad/nomad/structs/config"
|
||||
)
|
||||
|
||||
// GetVaultConfigs returns the set of Vault configurations available for this
|
||||
// client. In Nomad CE we only use the default Vault.
|
||||
func (c *Config) GetVaultConfigs(logger hclog.Logger) map[string]*structsc.VaultConfig {
|
||||
if c.VaultConfig == nil || !c.VaultConfig.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(c.VaultConfigs) > 1 {
|
||||
logger.Warn("multiple Vault configurations are only supported in Nomad Enterprise")
|
||||
}
|
||||
|
||||
return map[string]*structsc.VaultConfig{"default": c.VaultConfig}
|
||||
}
|
||||
@@ -41,7 +41,9 @@ func NewVaultFingerprint(logger log.Logger) Fingerprint {
|
||||
|
||||
func (f *VaultFingerprint) Fingerprint(req *FingerprintRequest, resp *FingerprintResponse) error {
|
||||
var mErr *multierror.Error
|
||||
for _, cfg := range f.vaultConfigs(req) {
|
||||
vaultConfigs := req.Config.GetVaultConfigs(f.logger)
|
||||
|
||||
for _, cfg := range vaultConfigs {
|
||||
err := f.fingerprintImpl(cfg, resp)
|
||||
if err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
//go:build !ent
|
||||
|
||||
package fingerprint
|
||||
|
||||
import "github.com/hashicorp/nomad/nomad/structs/config"
|
||||
|
||||
// vaultConfigs returns the set of Vault configurations the fingerprint needs to
|
||||
// check. In Nomad CE we only check the default Vault.
|
||||
func (f *VaultFingerprint) vaultConfigs(req *FingerprintRequest) map[string]*config.VaultConfig {
|
||||
agentCfg := req.Config
|
||||
if agentCfg.VaultConfig == nil || !agentCfg.VaultConfig.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(req.Config.VaultConfigs) > 1 {
|
||||
f.logger.Warn("multiple Vault configurations are only supported in Nomad Enterprise")
|
||||
}
|
||||
|
||||
return map[string]*config.VaultConfig{"default": agentCfg.VaultConfig}
|
||||
}
|
||||
@@ -209,7 +209,7 @@ func checkUpgradedAlloc(t *testing.T, path string, db StateDB, alloc *structs.Al
|
||||
ClientConfig: clientConf,
|
||||
StateDB: db,
|
||||
Consul: regMock.NewServiceRegistrationHandler(clientConf.Logger),
|
||||
Vault: vaultclient.NewMockVaultClient(),
|
||||
VaultFunc: vaultclient.NewMockVaultClient,
|
||||
StateUpdater: &allocrunner.MockStateUpdater{},
|
||||
PrevAllocWatcher: allocwatcher.NoopPrevAlloc{},
|
||||
PrevAllocMigrator: allocwatcher.NoopPrevAlloc{},
|
||||
|
||||
@@ -13,12 +13,17 @@ import (
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/nomad/helper/useragent"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// VaultClientFunc is the interface of a function that retreives the VaultClient
|
||||
// by cluster name. This function is injected into the allocrunner/taskrunner
|
||||
type VaultClientFunc func(string) (VaultClient, error)
|
||||
|
||||
// TokenDeriverFunc takes in an allocation and a set of tasks and derives a
|
||||
// wrapped token for all the tasks, from the nomad server. All the derived
|
||||
// wrapped tokens will be unwrapped using the vault API client.
|
||||
|
||||
@@ -38,7 +38,7 @@ type MockVaultClient struct {
|
||||
}
|
||||
|
||||
// NewMockVaultClient returns a MockVaultClient for testing
|
||||
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
|
||||
func NewMockVaultClient(_ string) (VaultClient, error) { return &MockVaultClient{}, nil }
|
||||
|
||||
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
|
||||
vc.mu.Lock()
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/hashicorp/nomad/helper/testlog"
|
||||
"github.com/hashicorp/nomad/nomad/mock"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/shoenig/test/must"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -137,7 +138,8 @@ func TestConsul_Integration(t *testing.T) {
|
||||
r.NoError(allocDir.Destroy())
|
||||
})
|
||||
taskDir := allocDir.NewTaskDir(task.Name)
|
||||
vclient := vaultclient.NewMockVaultClient()
|
||||
vclient, err := vaultclient.NewMockVaultClient("default")
|
||||
must.NoError(t, err)
|
||||
consulClient, err := consulapi.NewClient(consulConfig)
|
||||
r.Nil(err)
|
||||
|
||||
@@ -163,7 +165,7 @@ func TestConsul_Integration(t *testing.T) {
|
||||
Task: task,
|
||||
TaskDir: taskDir,
|
||||
Logger: logger,
|
||||
Vault: vclient,
|
||||
VaultFunc: func(string) (vaultclient.VaultClient, error) { return vclient, nil },
|
||||
StateDB: state.NoopDB{},
|
||||
StateUpdater: logUpdate,
|
||||
DeviceManager: devicemanager.NoopMockManager(),
|
||||
|
||||
Reference in New Issue
Block a user