vault: select Vault API client by cluster name (#18533)

Nomad Enterprise will support configuring multiple Vault clients. Instead of
having a single Vault client field in the Nomad client, we'll have a function
that callers can parameterize by the Vault cluster name that returns the
correctly configured Vault API client wrapper.
This commit is contained in:
Tim Gross
2023-09-19 14:35:01 -04:00
committed by GitHub
parent fcb9c4a39c
commit fdc6c2151d
17 changed files with 197 additions and 131 deletions

View File

@@ -82,8 +82,8 @@ type allocRunner struct {
// managing SI tokens
sidsClient consul.ServiceIdentityAPI
// vaultClient is the used to manage Vault tokens
vaultClient vaultclient.VaultClient
// vaultClientFunc is used to get the client used to manage Vault tokens
vaultClientFunc vaultclient.VaultClientFunc
// waitCh is closed when the Run loop has exited
waitCh chan struct{}
@@ -225,7 +225,7 @@ func NewAllocRunner(config *config.AllocRunnerConfig) (interfaces.AllocRunner, e
consulClient: config.Consul,
consulProxiesClient: config.ConsulProxies,
sidsClient: config.ConsulSI,
vaultClient: config.Vault,
vaultClientFunc: config.VaultFunc,
tasks: make(map[string]*taskrunner.TaskRunner, len(tg.Tasks)),
waitCh: make(chan struct{}),
destroyCh: make(chan struct{}),
@@ -297,7 +297,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error {
Consul: ar.consulClient,
ConsulProxies: ar.consulProxiesClient,
ConsulSI: ar.sidsClient,
Vault: ar.vaultClient,
VaultFunc: ar.vaultClientFunc,
DeviceStatsReporter: ar.deviceStatsReporter,
CSIManager: ar.csiManager,
DeviceManager: ar.devicemanager,

View File

@@ -268,7 +268,7 @@ func TestTaskRunner_DeriveSIToken_UnWritableTokenFile(t *testing.T) {
"run_for": "0s",
}
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// make the si_token file un-writable, triggering a failure after a

View File

@@ -193,8 +193,9 @@ type TaskRunner struct {
// service identity tokens
siClient consul.ServiceIdentityAPI
// vaultClient is the client to use to derive and renew Vault tokens
vaultClient vaultclient.VaultClient
// vaultClientFunc is the function to get a client to use to derive and
// renew Vault tokens
vaultClientFunc vaultclient.VaultClientFunc
// vaultToken is the current Vault token. It should be accessed with the
// getter.
@@ -290,8 +291,8 @@ type Config struct {
// DynamicRegistry is where dynamic plugins should be registered.
DynamicRegistry dynamicplugins.Registry
// Vault is the client to use to derive and renew Vault tokens
Vault vaultclient.VaultClient
// VaultFunc is function to get the client to use to derive and renew Vault tokens
VaultFunc vaultclient.VaultClientFunc
// StateDB is used to store and restore state.
StateDB cstate.StateDB
@@ -379,7 +380,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) {
consulServiceClient: config.Consul,
consulProxiesClient: config.ConsulProxies,
siClient: config.ConsulSI,
vaultClient: config.Vault,
vaultClientFunc: config.VaultFunc,
state: tstate,
localState: state.NewLocalState(),
allocHookResources: config.AllocHookResources,

View File

@@ -89,10 +89,10 @@ func (tr *TaskRunner) initHooks() {
}
// If Vault is enabled, add the hook
if task.Vault != nil {
if task.Vault != nil && tr.vaultClientFunc != nil {
tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
vaultBlock: task.Vault,
client: tr.vaultClient,
clientFunc: tr.vaultClientFunc,
events: tr,
lifecycle: tr,
updater: tr,

View File

@@ -32,12 +32,22 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) {
Policies: []string{"default"},
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
// Setup a test Vault client.
token := "1234"
handler := func(*structs.Allocation, []string) (map[string]string, error) {
return map[string]string{task.Name: token}, nil
}
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
// Remove private dir and write the Vault token to the secrets dir to
// simulate an old task.
err := conf.TaskDir.Build(false, nil)
err = conf.TaskDir.Build(false, nil)
must.NoError(t, err)
err = syscall.Unmount(conf.TaskDir.PrivateDir, 0)
@@ -45,18 +55,10 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) {
err = os.Remove(conf.TaskDir.PrivateDir)
must.NoError(t, err)
token := "1234"
tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
err = os.WriteFile(tokenPath, []byte(token), 0666)
must.NoError(t, err)
// Setup a test Vault client.
handler := func(*structs.Allocation, []string) (map[string]string, error) {
return map[string]string{task.Name: token}, nil
}
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
// Start task runner and wait for task to finish.
tr, err := NewTaskRunner(conf)
must.NoError(t, err)

View File

@@ -67,7 +67,7 @@ func (m *MockTaskStateUpdater) TaskStateUpdated() {
// testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
// plus a cleanup func.
func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string, vault vaultclient.VaultClient) (*Config, func()) {
logger := testlog.HCLogger(t)
clientConf, cleanup := config.TestClientConfig(t)
@@ -116,6 +116,11 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
nomadRegMock := regMock.NewServiceRegistrationHandler(logger)
wrapperMock := wrapper.NewHandlerWrapper(logger, consulRegMock, nomadRegMock)
var vaultFunc vaultclient.VaultClientFunc
if vault != nil {
vaultFunc = func(_ string) (vaultclient.VaultClient, error) { return vault, nil }
}
conf := &Config{
Alloc: alloc,
ClientConfig: clientConf,
@@ -124,7 +129,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
Logger: clientConf.Logger,
Consul: consulRegMock,
ConsulSI: consulapi.NewMockServiceIdentitiesClient(),
Vault: vaultclient.NewMockVaultClient(),
VaultFunc: vaultFunc,
StateDB: cstate.NoopDB{},
StateUpdater: NewMockTaskStateUpdater(),
DeviceManager: devicemanager.NoopMockManager(),
@@ -146,7 +151,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
// a cleanup function that ensures the runner is stopped and cleaned up. Tests
// which need to change the Config *must* use testTaskRunnerConfig instead.
func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) {
config, cleanup := testTaskRunnerConfig(t, alloc, taskName)
config, cleanup := testTaskRunnerConfig(t, alloc, taskName, nil)
tr, err := NewTaskRunner(config)
require.NoError(t, err)
@@ -205,7 +210,7 @@ func TestTaskRunner_BuildTaskConfig_CPU_Memory(t *testing.T) {
res.Memory.MemoryMB = c.memoryMB
res.Memory.MemoryMaxMB = c.memoryMaxMB
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
defer cleanup()
@@ -244,7 +249,7 @@ func TestTaskRunner_Stop_ExitCode(t *testing.T) {
"NOMAD_TASK_NAME": task.Name,
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// Run the first TaskRunner
@@ -292,7 +297,7 @@ func TestTaskRunner_Restore_Running(t *testing.T) {
task.Config = map[string]interface{}{
"run_for": "2s",
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
defer cleanup()
@@ -346,7 +351,7 @@ func TestTaskRunner_Restore_Dead(t *testing.T) {
task.Config = map[string]interface{}{
"run_for": "2s",
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
defer cleanup()
@@ -430,7 +435,7 @@ func setupRestoreFailureTest(t *testing.T, alloc *structs.Allocation) (*TaskRunn
"NOMAD_ALLOC_ID": alloc.ID,
"NOMAD_TASK_NAME": task.Name,
}
conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
// Run the first TaskRunner
@@ -583,7 +588,7 @@ func TestTaskRunner_Restore_System(t *testing.T) {
"NOMAD_ALLOC_ID": alloc.ID,
"NOMAD_TASK_NAME": task.Name,
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
@@ -651,7 +656,7 @@ func TestTaskRunner_MarkFailedKill(t *testing.T) {
// set up some taskrunner
alloc := mock.MinAlloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
t.Cleanup(cleanup)
tr, err := NewTaskRunner(conf)
must.NoError(t, err)
@@ -802,7 +807,7 @@ func TestTaskRunner_DevicePropogation(t *testing.T) {
tRes := alloc.AllocatedResources.Tasks[task.Name]
tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"})
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
defer cleanup()
@@ -887,7 +892,7 @@ func TestTaskRunner_Restore_HookEnv(t *testing.T) {
alloc := mock.BatchAlloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
defer cleanup()
@@ -932,7 +937,7 @@ func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) {
"run_for": "5s",
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
defer cleanup()
@@ -1310,7 +1315,7 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
}
task.Services[0].Provider = structs.ServiceProviderConsul
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// Replace mock Consul ServiceClient, with the real ServiceClient
@@ -1406,7 +1411,7 @@ func TestTaskRunner_BlockForSIDSToken(t *testing.T) {
"run_for": "0s",
}
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// set a consul token on the Nomad client's consul config, because that is
@@ -1470,7 +1475,7 @@ func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) {
"run_for": "0s",
}
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// set a consul token on the Nomad client's consul config, because that is
@@ -1530,7 +1535,7 @@ func TestTaskRunner_DeriveSIToken_Unrecoverable(t *testing.T) {
"run_for": "0s",
}
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// set a consul token on the Nomad client's consul config, because that is
@@ -1582,9 +1587,6 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) {
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()
// Control when we get a Vault token
token := "1234"
waitCh := make(chan struct{})
@@ -1592,9 +1594,15 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) {
<-waitCh
return map[string]string{task.Name: token}, nil
}
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
require.NoError(t, err)
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
@@ -1671,17 +1679,19 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) {
DisableFile: true,
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()
// Setup a test Vault client
token := "1234"
handler := func(*structs.Allocation, []string) (map[string]string, error) {
return map[string]string{task.Name: token}, nil
}
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
// Start task runner and wait for it to complete.
tr, err := NewTaskRunner(conf)
must.NoError(t, err)
@@ -1716,9 +1726,6 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()
// Fail on the first attempt to derive a vault token
token := "1234"
count := 0
@@ -1730,9 +1737,14 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
count++
return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
}
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
require.NoError(t, err)
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
@@ -1794,12 +1806,15 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()
// Error the token derivation
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
vaultClient.SetDeriveTokenError(
alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
require.NoError(t, err)
@@ -2003,7 +2018,7 @@ func TestTaskRunner_DriverNetwork(t *testing.T) {
},
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
// Use a mock agent to test for services
@@ -2105,9 +2120,6 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
// Use vault to block the start
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()
// Control when we get a Vault token
waitCh := make(chan struct{}, 1)
defer close(waitCh)
@@ -2115,9 +2127,14 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
<-waitCh
return map[string]string{task.Name: "1234"}, nil
}
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
require.NoError(t, err)
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
@@ -2215,7 +2232,7 @@ func TestTaskRunner_Template_Artifact(t *testing.T) {
},
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
tr, err := NewTaskRunner(conf)
@@ -2265,7 +2282,7 @@ func TestTaskRunner_Template_BlockingPreStart(t *testing.T) {
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
tr, err := NewTaskRunner(conf)
@@ -2326,7 +2343,11 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
@@ -2348,8 +2369,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
require.NoError(t, err)
})
vault := conf.Vault.(*vaultclient.MockVaultClient)
renewalCh, ok := vault.RenewTokens()[token]
renewalCh, ok := vaultClient.RenewTokens()[token]
require.True(t, ok, "no renewal channel for token")
renewalCh <- fmt.Errorf("Test killing")
@@ -2374,11 +2394,11 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
// Check the token was revoked
testutil.WaitForResult(func() (bool, error) {
if len(vault.StoppedTokens()) != 1 {
return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens())
if len(vaultClient.StoppedTokens()) != 1 {
return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
}
if a := vault.StoppedTokens()[0]; a != token {
if a := vaultClient.StoppedTokens()[0]; a != token {
return false, fmt.Errorf("got stopped token %q; want %q", a, token)
}
@@ -2404,7 +2424,11 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) {
ChangeMode: structs.VaultChangeModeRestart,
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
vc, err := vaultclient.NewMockVaultClient("default")
vaultClient := vc.(*vaultclient.MockVaultClient)
must.NoError(t, err)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
@@ -2420,8 +2444,7 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) {
require.NotEmpty(t, token)
vault := conf.Vault.(*vaultclient.MockVaultClient)
renewalCh, ok := vault.RenewTokens()[token]
renewalCh, ok := vaultClient.RenewTokens()[token]
require.True(t, ok, "no renewal channel for token")
renewalCh <- fmt.Errorf("Test killing")
@@ -2477,8 +2500,11 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) {
ChangeMode: structs.VaultChangeModeSignal,
ChangeSignal: "SIGUSR1",
}
vc, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
vaultClient := vc.(*vaultclient.MockVaultClient)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, vaultClient)
defer cleanup()
tr, err := NewTaskRunner(conf)
@@ -2494,8 +2520,7 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) {
require.NotEmpty(t, token)
vault := conf.Vault.(*vaultclient.MockVaultClient)
renewalCh, ok := vault.RenewTokens()[token]
renewalCh, ok := vaultClient.RenewTokens()[token]
require.True(t, ok, "no renewal channel for token")
renewalCh <- fmt.Errorf("Test killing")
@@ -2548,7 +2573,7 @@ func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) {
"run_for": "1ns",
}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
tr, err := NewTaskRunner(conf)
@@ -2612,7 +2637,7 @@ func TestTaskRunner_BaseLabels(t *testing.T) {
"command": "whoami",
}
config, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
config, cleanup := testTaskRunnerConfig(t, alloc, task.Name, nil)
defer cleanup()
tr, err := NewTaskRunner(config)

View File

@@ -50,7 +50,7 @@ func (tr *TaskRunner) updatedVaultToken(token string) {
type vaultHookConfig struct {
vaultBlock *structs.Vault
client vaultclient.VaultClient
clientFunc vaultclient.VaultClientFunc
events ti.EventEmitter
lifecycle ti.TaskLifecycle
updater vaultTokenUpdateHandler
@@ -72,8 +72,10 @@ type vaultHook struct {
// updater is used to update the Vault token
updater vaultTokenUpdateHandler
// client is the Vault client to retrieve and renew the Vault token
client vaultclient.VaultClient
// client is the Vault client to retrieve and renew the Vault token, and
// clientFunc is the injected function that retrieves it
client vaultclient.VaultClient
clientFunc vaultclient.VaultClientFunc
// logger is used to log
logger log.Logger
@@ -105,9 +107,10 @@ type vaultHook struct {
func newVaultHook(config *vaultHookConfig) *vaultHook {
ctx, cancel := context.WithCancel(context.Background())
h := &vaultHook{
vaultBlock: config.vaultBlock,
client: config.client,
clientFunc: config.clientFunc,
eventEmitter: config.events,
lifecycle: config.lifecycle,
updater: config.updater,
@@ -135,6 +138,12 @@ func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRe
return nil
}
vclient, err := h.clientFunc(h.vaultBlock.Cluster)
if err != nil {
return err
}
h.client = vclient
// Try to recover a token if it was previously written in the secrets
// directory
recoveredToken := ""

View File

@@ -86,7 +86,7 @@ func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*config.All
StateDB: stateDB,
Consul: consulRegMock,
ConsulSI: consul.NewMockServiceIdentitiesClient(),
Vault: vaultclient.NewMockVaultClient(),
VaultFunc: vaultclient.NewMockVaultClient,
StateUpdater: &MockStateUpdater{},
PrevAllocWatcher: allocwatcher.NoopPrevAlloc{},
PrevAllocMigrator: allocwatcher.NoopPrevAlloc{},

View File

@@ -264,8 +264,8 @@ type Client struct {
// Service Identity tokens through Nomad Server.
tokensClient consulApi.ServiceIdentityAPI
// vaultClient is used to interact with Vault for token and secret renewals
vaultClient vaultclient.VaultClient
// vaultClients is used to interact with Vault for token and secret renewals
vaultClients map[string]vaultclient.VaultClient
// garbageCollector is used to garbage collect terminal allocations present
// in the node automatically
@@ -580,7 +580,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie
}
// Setup the vault client for token and secret renewals
if err := c.setupVaultClient(); err != nil {
if err := c.setupVaultClients(); err != nil {
return nil, fmt.Errorf("failed to setup vault client: %v", err)
}
@@ -868,8 +868,8 @@ func (c *Client) Shutdown() error {
c.logger.Info("shutting down")
// Stop renewing tokens and secrets
if c.vaultClient != nil {
c.vaultClient.Stop()
for _, vaultClient := range c.vaultClients {
vaultClient.Stop()
}
// Stop Garbage collector
@@ -2761,7 +2761,7 @@ func (c *Client) newAllocRunnerConfig(
ServiceRegWrapper: c.serviceRegWrapper,
StateDB: c.stateDB,
StateUpdater: c,
Vault: c.vaultClient,
VaultFunc: c.VaultClient,
WIDMgr: c.widmgr,
Wranglers: c.wranglers,
Partitions: c.partitions,
@@ -2776,26 +2776,43 @@ func (c *Client) setupConsulTokenClient() error {
return nil
}
// setupVaultClient creates an object to periodically renew tokens and secrets
// with vault.
func (c *Client) setupVaultClient() error {
var err error
c.vaultClient, err = vaultclient.NewVaultClient(c.GetConfig().VaultConfig, c.logger, c.deriveToken)
if err != nil {
return err
// setupVaultClients creates the objects that periodically renew tokens and
// secrets with vault.
func (c *Client) setupVaultClients() error {
c.vaultClients = map[string]vaultclient.VaultClient{}
vaultConfigs := c.GetConfig().GetVaultConfigs(c.logger)
for _, vaultConfig := range vaultConfigs {
vaultClient, err := vaultclient.NewVaultClient(c.GetConfig().VaultConfig, c.logger, c.deriveToken)
if err != nil {
return err
}
if vaultClient == nil {
c.logger.Error("failed to create vault client", "name", vaultConfig.Name)
return fmt.Errorf("failed to create vault client for cluster %q", vaultConfig.Name)
}
c.vaultClients[vaultConfig.Name] = vaultClient
}
if c.vaultClient == nil {
c.logger.Error("failed to create vault client")
return fmt.Errorf("failed to create vault client")
// Start renewing tokens and secrets only once we've ensured we have created
// all the clients
for _, vaultClient := range c.vaultClients {
vaultClient.Start()
}
// Start renewing tokens and secrets
c.vaultClient.Start()
return nil
}
func (c *Client) VaultClient(cluster string) (vaultclient.VaultClient, error) {
vaultClient, ok := c.vaultClients[cluster]
if !ok {
return nil, fmt.Errorf("no Vault cluster named: %q", cluster)
}
return vaultClient, nil
}
// setupNomadServiceRegistrationHandler sets up the registration handler to use
// for native service discovery.
func (c *Client) setupNomadServiceRegistrationHandler() {

View File

@@ -59,8 +59,9 @@ type AllocRunnerConfig struct {
// ConsulSI is the Consul client used to manage service identity tokens.
ConsulSI consul.ServiceIdentityAPI
// Vault is the Vault client to use to retrieve Vault tokens
Vault vaultclient.VaultClient
// VaultFunc is the function to get a Vault client to use to retrieve Vault
// tokens
VaultFunc vaultclient.VaultClientFunc
// StateUpdater is used to emit updated task state
StateUpdater interfaces.AllocStateHandler

View File

@@ -0,0 +1,25 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !ent
package config
import (
"github.com/hashicorp/go-hclog"
structsc "github.com/hashicorp/nomad/nomad/structs/config"
)
// GetVaultConfigs returns the set of Vault configurations available for this
// client. In Nomad CE we only use the default Vault.
func (c *Config) GetVaultConfigs(logger hclog.Logger) map[string]*structsc.VaultConfig {
if c.VaultConfig == nil || !c.VaultConfig.IsEnabled() {
return nil
}
if len(c.VaultConfigs) > 1 {
logger.Warn("multiple Vault configurations are only supported in Nomad Enterprise")
}
return map[string]*structsc.VaultConfig{"default": c.VaultConfig}
}

View File

@@ -41,7 +41,9 @@ func NewVaultFingerprint(logger log.Logger) Fingerprint {
func (f *VaultFingerprint) Fingerprint(req *FingerprintRequest, resp *FingerprintResponse) error {
var mErr *multierror.Error
for _, cfg := range f.vaultConfigs(req) {
vaultConfigs := req.Config.GetVaultConfigs(f.logger)
for _, cfg := range vaultConfigs {
err := f.fingerprintImpl(cfg, resp)
if err != nil {
mErr = multierror.Append(mErr, err)

View File

@@ -1,23 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !ent
package fingerprint
import "github.com/hashicorp/nomad/nomad/structs/config"
// vaultConfigs returns the set of Vault configurations the fingerprint needs to
// check. In Nomad CE we only check the default Vault.
func (f *VaultFingerprint) vaultConfigs(req *FingerprintRequest) map[string]*config.VaultConfig {
agentCfg := req.Config
if agentCfg.VaultConfig == nil || !agentCfg.VaultConfig.IsEnabled() {
return nil
}
if len(req.Config.VaultConfigs) > 1 {
f.logger.Warn("multiple Vault configurations are only supported in Nomad Enterprise")
}
return map[string]*config.VaultConfig{"default": agentCfg.VaultConfig}
}

View File

@@ -209,7 +209,7 @@ func checkUpgradedAlloc(t *testing.T, path string, db StateDB, alloc *structs.Al
ClientConfig: clientConf,
StateDB: db,
Consul: regMock.NewServiceRegistrationHandler(clientConf.Logger),
Vault: vaultclient.NewMockVaultClient(),
VaultFunc: vaultclient.NewMockVaultClient,
StateUpdater: &allocrunner.MockStateUpdater{},
PrevAllocWatcher: allocwatcher.NoopPrevAlloc{},
PrevAllocMigrator: allocwatcher.NoopPrevAlloc{},

View File

@@ -13,12 +13,17 @@ import (
metrics "github.com/armon/go-metrics"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/helper/useragent"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
vaultapi "github.com/hashicorp/vault/api"
)
// VaultClientFunc is the interface of a function that retreives the VaultClient
// by cluster name. This function is injected into the allocrunner/taskrunner
type VaultClientFunc func(string) (VaultClient, error)
// TokenDeriverFunc takes in an allocation and a set of tasks and derives a
// wrapped token for all the tasks, from the nomad server. All the derived
// wrapped tokens will be unwrapped using the vault API client.

View File

@@ -38,7 +38,7 @@ type MockVaultClient struct {
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
func NewMockVaultClient(_ string) (VaultClient, error) { return &MockVaultClient{}, nil }
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
vc.mu.Lock()

View File

@@ -27,6 +27,7 @@ import (
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
)
@@ -137,7 +138,8 @@ func TestConsul_Integration(t *testing.T) {
r.NoError(allocDir.Destroy())
})
taskDir := allocDir.NewTaskDir(task.Name)
vclient := vaultclient.NewMockVaultClient()
vclient, err := vaultclient.NewMockVaultClient("default")
must.NoError(t, err)
consulClient, err := consulapi.NewClient(consulConfig)
r.Nil(err)
@@ -163,7 +165,7 @@ func TestConsul_Integration(t *testing.T) {
Task: task,
TaskDir: taskDir,
Logger: logger,
Vault: vclient,
VaultFunc: func(string) (vaultclient.VaultClient, error) { return vclient, nil },
StateDB: state.NoopDB{},
StateUpdater: logUpdate,
DeviceManager: devicemanager.NoopMockManager(),