diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index c92107adc..31b0248ed 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -91,15 +91,16 @@ func (tr *TaskRunner) initHooks() { // If Vault is enabled, add the hook if task.Vault != nil && tr.vaultClientFunc != nil { tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ - vaultBlock: task.Vault, - clientFunc: tr.vaultClientFunc, - events: tr, - lifecycle: tr, - updater: tr, - logger: hookLogger, - alloc: tr.Alloc(), - task: tr.Task(), - widmgr: tr.widmgr, + vaultBlock: task.Vault, + vaultConfigsFunc: tr.clientConfig.GetVaultConfigs, + clientFunc: tr.vaultClientFunc, + events: tr, + lifecycle: tr, + updater: tr, + logger: hookLogger, + alloc: tr.Alloc(), + task: tr.Task(), + widmgr: tr.widmgr, })) } diff --git a/client/allocrunner/taskrunner/task_runner_linux_test.go b/client/allocrunner/taskrunner/task_runner_linux_test.go index 4f4ec8e29..3f8148ea8 100644 --- a/client/allocrunner/taskrunner/task_runner_linux_test.go +++ b/client/allocrunner/taskrunner/task_runner_linux_test.go @@ -29,6 +29,7 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) { "run_for": "0s", } task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, Policies: []string{"default"}, } diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 9cc269499..290bde161 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -1602,7 +1602,10 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) { task.Config = map[string]interface{}{ "run_for": "0s", } - task.Vault = &structs.Vault{Policies: []string{"default"}} + task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + Policies: []string{"default"}, + } // Control when we get a Vault token token := "1234" @@ -1692,6 +1695,7 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) { "run_for": "0s", } task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, Policies: []string{"default"}, DisableFile: true, } @@ -1741,7 +1745,10 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { ci.Parallel(t) alloc := mock.BatchAlloc() task := alloc.Job.TaskGroups[0].Tasks[0] - task.Vault = &structs.Vault{Policies: []string{"default"}} + task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + Policies: []string{"default"}, + } // Fail on the first attempt to derive a vault token token := "1234" @@ -1821,7 +1828,10 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) { task.Config = map[string]interface{}{ "run_for": "0s", } - task.Vault = &structs.Vault{Policies: []string{"default"}} + task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + Policies: []string{"default"}, + } // Error the token derivation vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) @@ -2135,7 +2145,10 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { } // Use vault to block the start - task.Vault = &structs.Vault{Policies: []string{"default"}} + task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + Policies: []string{"default"}, + } // Control when we get a Vault token waitCh := make(chan struct{}, 1) @@ -2361,7 +2374,10 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { ChangeMode: structs.TemplateChangeModeNoop, }, } - task.Vault = &structs.Vault{Policies: []string{"default"}} + task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + Policies: []string{"default"}, + } vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) must.NoError(t, err) @@ -2440,6 +2456,7 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) { "run_for": "10s", } task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, Policies: []string{"default"}, ChangeMode: structs.VaultChangeModeRestart, } @@ -2516,6 +2533,7 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) { "run_for": "10s", } task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, Policies: []string{"default"}, ChangeMode: structs.VaultChangeModeSignal, ChangeSignal: "SIGUSR1", diff --git a/client/allocrunner/taskrunner/vault_hook.go b/client/allocrunner/taskrunner/vault_hook.go index 25851b16a..2fd28b610 100644 --- a/client/allocrunner/taskrunner/vault_hook.go +++ b/client/allocrunner/taskrunner/vault_hook.go @@ -14,6 +14,7 @@ import ( "time" "github.com/hashicorp/consul-template/signals" + "github.com/hashicorp/go-hclog" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" @@ -22,6 +23,7 @@ import ( "github.com/hashicorp/nomad/client/widmgr" "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/nomad/structs" + sconfig "github.com/hashicorp/nomad/nomad/structs/config" ) const ( @@ -54,21 +56,26 @@ func (tr *TaskRunner) updatedVaultToken(token string) { } type vaultHookConfig struct { - vaultBlock *structs.Vault - clientFunc vaultclient.VaultClientFunc - events ti.EventEmitter - lifecycle ti.TaskLifecycle - updater vaultTokenUpdateHandler - logger log.Logger - alloc *structs.Allocation - task *structs.Task - widmgr widmgr.IdentityManager + vaultBlock *structs.Vault + vaultConfigsFunc func(hclog.Logger) map[string]*sconfig.VaultConfig + clientFunc vaultclient.VaultClientFunc + events ti.EventEmitter + lifecycle ti.TaskLifecycle + updater vaultTokenUpdateHandler + logger log.Logger + alloc *structs.Allocation + task *structs.Task + widmgr widmgr.IdentityManager } type vaultHook struct { // vaultBlock is the vault block for the task vaultBlock *structs.Vault + // vaultConfig is the Nomad client configuration for Vault. + vaultConfig *sconfig.VaultConfig + vaultConfigsFunc func(hclog.Logger) map[string]*sconfig.VaultConfig + // eventEmitter is used to emit events to the task eventEmitter ti.EventEmitter @@ -123,18 +130,19 @@ type vaultHook struct { func newVaultHook(config *vaultHookConfig) *vaultHook { ctx, cancel := context.WithCancel(context.Background()) h := &vaultHook{ - vaultBlock: config.vaultBlock, - clientFunc: config.clientFunc, - eventEmitter: config.events, - lifecycle: config.lifecycle, - updater: config.updater, - alloc: config.alloc, - task: config.task, - firstRun: true, - ctx: ctx, - cancel: cancel, - future: newTokenFuture(), - widmgr: config.widmgr, + vaultBlock: config.vaultBlock, + vaultConfigsFunc: config.vaultConfigsFunc, + clientFunc: config.clientFunc, + eventEmitter: config.events, + lifecycle: config.lifecycle, + updater: config.updater, + alloc: config.alloc, + task: config.task, + firstRun: true, + ctx: ctx, + cancel: cancel, + future: newTokenFuture(), + widmgr: config.widmgr, } h.logger = config.logger.Named(h.Name()) @@ -163,12 +171,18 @@ func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRe return nil } - vclient, err := h.clientFunc(h.vaultBlock.Cluster) + cluster := h.vaultBlock.Cluster + vclient, err := h.clientFunc(cluster) if err != nil { return err } h.client = vclient + h.vaultConfig = h.vaultConfigsFunc(h.logger)[cluster] + if h.vaultConfig == nil { + return fmt.Errorf("No client configuration found for Vault cluster %s", cluster) + } + // Try to recover a token if it was previously written in the secrets // directory recoveredToken := "" @@ -410,10 +424,15 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) { ) } + role := h.vaultConfig.Role + if h.vaultBlock.Role != "" { + role = h.vaultBlock.Role + } + // Derive Vault token with signed identity. token, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{ JWT: signed.JWT, - Role: h.vaultBlock.Role, + Role: role, }) if err != nil { return "", structs.WrapRecoverable( diff --git a/client/allocrunner/taskrunner/vault_hook_test.go b/client/allocrunner/taskrunner/vault_hook_test.go index b3d4eb2a9..942907190 100644 --- a/client/allocrunner/taskrunner/vault_hook_test.go +++ b/client/allocrunner/taskrunner/vault_hook_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/allocrunner/interfaces" @@ -23,6 +24,7 @@ import ( "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + sconfig "github.com/hashicorp/nomad/nomad/structs/config" "github.com/shoenig/test/must" "github.com/shoenig/test/wait" ) @@ -70,6 +72,13 @@ func setupTestVaultHook(t *testing.T, config *vaultHookConfig) *vaultHook { if config.vaultBlock == nil { config.vaultBlock = config.task.Vault } + if config.vaultConfigsFunc == nil { + config.vaultConfigsFunc = func(hclog.Logger) map[string]*sconfig.VaultConfig { + return map[string]*sconfig.VaultConfig{ + "default": sconfig.DefaultVaultConfig(), + } + } + } if config.clientFunc == nil { config.clientFunc = func(cluster string) (vaultclient.VaultClient, error) { return vaultclient.NewMockVaultClient(cluster) @@ -105,6 +114,8 @@ func TestTaskRunner_VaultHook(t *testing.T) { testCases := []struct { name string task *structs.Task + configs map[string]*sconfig.VaultConfig + expectRole string expectLegacy bool }{ { @@ -127,6 +138,61 @@ func TestTaskRunner_VaultHook(t *testing.T) { }, }, }, + { + name: "jwt flow with role", + task: &structs.Task{ + Vault: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + Role: "task-role", + }, + Identities: []*structs.WorkloadIdentity{ + {Name: "vault_default"}, + }, + }, + configs: map[string]*sconfig.VaultConfig{ + "default": { + Role: "client-role", + }, + }, + expectRole: "task-role", + }, + { + name: "jwt flow with role from client", + task: &structs.Task{ + Vault: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + }, + Identities: []*structs.WorkloadIdentity{ + {Name: "vault_default"}, + }, + }, + configs: map[string]*sconfig.VaultConfig{ + "default": { + Role: "client-role", + }, + }, + expectRole: "client-role", + }, + { + name: "jwt flow with role from client and non-default cluster", + task: &structs.Task{ + Vault: &structs.Vault{ + Cluster: "prod", + }, + Identities: []*structs.WorkloadIdentity{ + {Name: "vault_prod"}, + }, + }, + configs: map[string]*sconfig.VaultConfig{ + "default": { + Role: "client-role", + }, + "prod": { + Role: "client-prod-role", + }, + }, + expectRole: "client-prod-role", + }, { name: "disable file", task: &structs.Task{ @@ -149,6 +215,14 @@ func TestTaskRunner_VaultHook(t *testing.T) { hook := setupTestVaultHook(t, &vaultHookConfig{ task: tc.task, alloc: alloc, + vaultConfigsFunc: func(hclog.Logger) map[string]*sconfig.VaultConfig { + if tc.configs != nil { + return tc.configs + } + return map[string]*sconfig.VaultConfig{ + "default": sconfig.DefaultVaultConfig(), + } + }, }) // Ensure Prestart() returns within a reasonable time. @@ -190,6 +264,16 @@ func TestTaskRunner_VaultHook(t *testing.T) { } must.NotEq(t, "", token) + // Token must be derived with correct role. + // + // MockVaultClient generates random UUIDv4 tokens, but append the + // role when requested. + if tc.expectRole != "" { + must.StrHasSuffix(t, tc.expectRole, token) + } else { + must.UUIDv4(t, token) + } + // Token must be set in token updater. updater := (hook.updater).(*vaultTokenUpdaterMock) must.Eq(t, token, updater.currentToken) diff --git a/client/config/config.go b/client/config/config.go index 76faa4c7b..7fcbe10a9 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -179,6 +179,8 @@ type Config struct { ConsulConfigs map[string]*structsc.ConsulConfig // VaultConfig is this Agent's default Vault configuration + // + // Deprecated: use GetVaultConfigs() instead. VaultConfig *structsc.VaultConfig // VaultConfigs is a map of Vault configurations, here to support features diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index 481cf031c..b54d0a798 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -5,6 +5,7 @@ package vaultclient import ( "context" + "fmt" "sync" "github.com/hashicorp/nomad/helper/uuid" @@ -64,6 +65,9 @@ func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginR } token := uuid.Generate() + if req.Role != "" { + token = fmt.Sprintf("%s-%s", token, req.Role) + } vc.jwtTokens[req.JWT] = token return token, nil }