From 349c032369373555a0ccf7835184c633cc80b01b Mon Sep 17 00:00:00 2001 From: Luiz Aoqui Date: Mon, 16 Oct 2023 19:37:57 -0400 Subject: [PATCH] vault: update task runner vault hook to support workload identity (#18534) --- ci/test-core.json | 3 +- .../taskrunner/artifact_hook_test.go | 33 +- .../taskrunner/task_runner_hooks.go | 3 +- .../taskrunner/template/template_test.go | 14 +- .../taskrunner/template_hook_test.go | 2 +- .../allocrunner/taskrunner/testing/testing.go | 57 +- client/allocrunner/taskrunner/vault_hook.go | 90 ++- .../allocrunner/taskrunner/vault_hook_test.go | 560 +++++++++++++++++- client/vaultclient/vaultclient.go | 57 ++ client/vaultclient/vaultclient_test.go | 259 ++++++++ client/vaultclient/vaultclient_testing.go | 58 ++ client/widmgr/mock.go | 12 + nomad/server.go | 8 +- nomad/vault_noop.go | 61 ++ testutil/vault.go | 8 + 15 files changed, 1171 insertions(+), 54 deletions(-) create mode 100644 client/vaultclient/vaultclient_test.go create mode 100644 nomad/vault_noop.go diff --git a/ci/test-core.json b/ci/test-core.json index 07a512637..1c64ee7cf 100644 --- a/ci/test-core.json +++ b/ci/test-core.json @@ -16,6 +16,7 @@ "client/devicemanager/...", "client/dynamicplugins/...", "client/fingerprint/...", + "client/hoststats/...", "client/interfaces/...", "client/lib/...", "client/logmon/...", @@ -23,9 +24,9 @@ "client/servers/...", "client/serviceregistration/...", "client/state/...", - "client/hoststats/...", "client/structs/...", "client/taskenv/...", + "client/vaultclient/...", "client/widmgr/...", "command/agent/...", "command/raft_tools/...", diff --git a/client/allocrunner/taskrunner/artifact_hook_test.go b/client/allocrunner/taskrunner/artifact_hook_test.go index dfa38be16..944bd17bd 100644 --- a/client/allocrunner/taskrunner/artifact_hook_test.go +++ b/client/allocrunner/taskrunner/artifact_hook_test.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter" + trtesting "github.com/hashicorp/nomad/client/allocrunner/taskrunner/testing" "github.com/hashicorp/nomad/client/taskenv" "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/helper/testlog" @@ -28,20 +29,12 @@ import ( // Statically assert the artifact hook implements the expected interface var _ interfaces.TaskPrestartHook = (*artifactHook)(nil) -type mockEmitter struct { - events []*structs.TaskEvent -} - -func (m *mockEmitter) EmitEvent(ev *structs.TaskEvent) { - m.events = append(m.events, ev) -} - // TestTaskRunner_ArtifactHook_Recoverable asserts that failures to download // artifacts are a recoverable error. func TestTaskRunner_ArtifactHook_Recoverable(t *testing.T) { ci.Parallel(t) - me := &mockEmitter{} + me := &trtesting.MockEmitter{} sbox := getter.TestSandbox(t) artifactHook := newArtifactHook(me, sbox, testlog.HCLogger(t)) @@ -65,8 +58,8 @@ func TestTaskRunner_ArtifactHook_Recoverable(t *testing.T) { require.False(t, resp.Done) require.NotNil(t, err) require.True(t, structs.IsRecoverable(err)) - require.Len(t, me.events, 1) - require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type) + require.Len(t, me.Events(), 1) + require.Equal(t, structs.TaskDownloadingArtifacts, me.Events()[0].Type) } // TestTaskRunnerArtifactHook_PartialDone asserts that the artifact hook skips @@ -76,7 +69,7 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) { testutil.RequireRoot(t) ci.Parallel(t) - me := &mockEmitter{} + me := &trtesting.MockEmitter{} sbox := getter.TestSandbox(t) artifactHook := newArtifactHook(me, sbox, testlog.HCLogger(t)) @@ -121,8 +114,8 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) { require.True(t, structs.IsRecoverable(err)) require.Len(t, resp.State, 1) require.False(t, resp.Done) - require.Len(t, me.events, 1) - require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type) + require.Len(t, me.Events(), 1) + require.Equal(t, structs.TaskDownloadingArtifacts, me.Events()[0].Type) // Remove file1 from the server so it errors if its downloaded again. require.NoError(t, os.Remove(file1)) @@ -166,7 +159,7 @@ func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) { ci.SkipTestWithoutRootAccess(t) ci.Parallel(t) - me := &mockEmitter{} + me := &trtesting.MockEmitter{} sbox := getter.TestSandbox(t) artifactHook := newArtifactHook(me, sbox, testlog.HCLogger(t)) @@ -231,8 +224,8 @@ func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) { require.NoError(t, err) require.True(t, resp.Done) require.Len(t, resp.State, 7) - require.Len(t, me.events, 1) - require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type) + require.Len(t, me.Events(), 1) + require.Equal(t, structs.TaskDownloadingArtifacts, me.Events()[0].Type) // Assert all files downloaded properly files, err := filepath.Glob(filepath.Join(destdir, "*.txt")) @@ -254,7 +247,7 @@ func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) { func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) { ci.Parallel(t) - me := &mockEmitter{} + me := &trtesting.MockEmitter{} sbox := getter.TestSandbox(t) artifactHook := newArtifactHook(me, sbox, testlog.HCLogger(t)) @@ -311,8 +304,8 @@ func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) { require.True(t, structs.IsRecoverable(err)) require.Len(t, resp.State, 3) require.False(t, resp.Done) - require.Len(t, me.events, 1) - require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type) + require.Len(t, me.Events(), 1) + require.Equal(t, structs.TaskDownloadingArtifacts, me.Events()[0].Type) // delete the downloaded files so that it'll error if it's downloaded again require.NoError(t, os.Remove(file1)) diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 4dc61746e..33a688aea 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -98,7 +98,8 @@ func (tr *TaskRunner) initHooks() { updater: tr, logger: hookLogger, alloc: tr.Alloc(), - task: tr.taskName, + task: tr.Task(), + widmgr: tr.widmgr, })) } diff --git a/client/allocrunner/taskrunner/template/template_test.go b/client/allocrunner/taskrunner/template/template_test.go index cffafe07e..7e1daeca8 100644 --- a/client/allocrunner/taskrunner/template/template_test.go +++ b/client/allocrunner/taskrunner/template/template_test.go @@ -1130,7 +1130,7 @@ func TestTaskTemplateManager_Signal_Error(t *testing.T) { } require.NotNil(harness.mockHooks.KillEvent) - require.Contains(harness.mockHooks.KillEvent.DisplayMessage, "failed to send signals") + require.Contains(harness.mockHooks.KillEvent().DisplayMessage, "failed to send signals") } func TestTaskTemplateManager_ScriptExecution(t *testing.T) { @@ -1292,7 +1292,7 @@ BAR={{key "bar"}} } require.NotNil(harness.mockHooks.KillEvent) - require.Contains(harness.mockHooks.KillEvent.DisplayMessage, "task is being killed") + require.Contains(harness.mockHooks.KillEvent().DisplayMessage, "task is being killed") } func TestTaskTemplateManager_ChangeModeMixed(t *testing.T) { @@ -2098,11 +2098,11 @@ func TestTaskTemplateManager_BlockedEvents(t *testing.T) { // Check to see we got a correct message // assert that all 0-4 keys are missing - require.Len(harness.mockHooks.Events, 1) - t.Logf("first message: %v", harness.mockHooks.Events[0]) - missing, more := missingKeys(harness.mockHooks.Events[0]) + require.Len(harness.mockHooks.Events(), 1) + t.Logf("first message: %v", harness.mockHooks.Events()[0]) + missing, more := missingKeys(harness.mockHooks.Events()[0]) require.Equal(5, len(missing)+more) - require.Contains(harness.mockHooks.Events[0].DisplayMessage, "and 2 more") + require.Contains(harness.mockHooks.Events()[0].DisplayMessage, "and 2 more") // Write 0-2 keys to Consul for i := 0; i < 3; i++ { @@ -2131,7 +2131,7 @@ WAIT_LOOP: } // Check to see we got a correct message - event := harness.mockHooks.Events[len(harness.mockHooks.Events)-1] + event := harness.mockHooks.Events()[len(harness.mockHooks.Events())-1] if !isExpectedFinalEvent(event) { t.Logf("received all events: %v", pretty.Sprint(harness.mockHooks.Events)) diff --git a/client/allocrunner/taskrunner/template_hook_test.go b/client/allocrunner/taskrunner/template_hook_test.go index c5b106659..1b098fcf9 100644 --- a/client/allocrunner/taskrunner/template_hook_test.go +++ b/client/allocrunner/taskrunner/template_hook_test.go @@ -48,7 +48,7 @@ func Test_templateHook_Prestart_ConsulWI(t *testing.T) { conf := &templateHookConfig{ logger: logger, lifecycle: taskHooks, - events: &mockEmitter{}, + events: &trtesting.MockEmitter{}, clientConfig: clientConfig, envBuilder: envBuilder, hookResources: hr, diff --git a/client/allocrunner/taskrunner/testing/testing.go b/client/allocrunner/taskrunner/testing/testing.go index f3084f01a..3760d7798 100644 --- a/client/allocrunner/taskrunner/testing/testing.go +++ b/client/allocrunner/taskrunner/testing/testing.go @@ -30,23 +30,24 @@ func (m *MockEmitter) Events() []*structs.TaskEvent { // MockTaskHooks is a mock of the TaskHooks interface useful for testing type MockTaskHooks struct { - Restarts int - RestartCh chan struct{} + lock sync.Mutex - SignalCh chan struct{} - signals []string - signalLock sync.Mutex + RestartCh chan struct{} + restarts int + + SignalCh chan struct{} + signals []string // SignalError is returned when Signal is called on the mock hook SignalError error UnblockCh chan struct{} - KillEvent *structs.TaskEvent KillCh chan *structs.TaskEvent + killEvent *structs.TaskEvent - Events []*structs.TaskEvent EmitEventCh chan *structs.TaskEvent + events []*structs.TaskEvent // HasHandle can be set to simulate restoring a task after client restart HasHandle bool @@ -62,7 +63,10 @@ func NewMockTaskHooks() *MockTaskHooks { } } func (m *MockTaskHooks) Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error { - m.Restarts++ + m.lock.Lock() + defer m.lock.Unlock() + + m.restarts++ select { case m.RestartCh <- struct{}{}: default: @@ -71,9 +75,10 @@ func (m *MockTaskHooks) Restart(ctx context.Context, event *structs.TaskEvent, f } func (m *MockTaskHooks) Signal(event *structs.TaskEvent, s string) error { - m.signalLock.Lock() + m.lock.Lock() m.signals = append(m.signals, s) - m.signalLock.Unlock() + m.lock.Unlock() + select { case m.SignalCh <- struct{}{}: default: @@ -83,13 +88,16 @@ func (m *MockTaskHooks) Signal(event *structs.TaskEvent, s string) error { } func (m *MockTaskHooks) Signals() []string { - m.signalLock.Lock() - defer m.signalLock.Unlock() + m.lock.Lock() + defer m.lock.Unlock() return m.signals } func (m *MockTaskHooks) Kill(ctx context.Context, event *structs.TaskEvent) error { - m.KillEvent = event + m.lock.Lock() + defer m.lock.Unlock() + + m.killEvent = event select { case m.KillCh <- event: default: @@ -102,7 +110,10 @@ func (m *MockTaskHooks) IsRunning() bool { } func (m *MockTaskHooks) EmitEvent(event *structs.TaskEvent) { - m.Events = append(m.Events, event) + m.lock.Lock() + defer m.lock.Unlock() + + m.events = append(m.events, event) select { case m.EmitEventCh <- event: case <-m.EmitEventCh: @@ -111,3 +122,21 @@ func (m *MockTaskHooks) EmitEvent(event *structs.TaskEvent) { } func (m *MockTaskHooks) SetState(state string, event *structs.TaskEvent) {} + +func (m *MockTaskHooks) KillEvent() *structs.TaskEvent { + m.lock.Lock() + defer m.lock.Unlock() + return m.killEvent +} + +func (m *MockTaskHooks) Events() []*structs.TaskEvent { + m.lock.Lock() + defer m.lock.Unlock() + return m.events +} + +func (m *MockTaskHooks) Restarts() int { + m.lock.Lock() + defer m.lock.Unlock() + return m.restarts +} diff --git a/client/allocrunner/taskrunner/vault_hook.go b/client/allocrunner/taskrunner/vault_hook.go index c97316a6b..25851b16a 100644 --- a/client/allocrunner/taskrunner/vault_hook.go +++ b/client/allocrunner/taskrunner/vault_hook.go @@ -5,6 +5,7 @@ package taskrunner import ( "context" + "errors" "fmt" "os" "path" @@ -18,6 +19,7 @@ import ( "github.com/hashicorp/nomad/client/allocrunner/interfaces" ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" "github.com/hashicorp/nomad/client/vaultclient" + "github.com/hashicorp/nomad/client/widmgr" "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/nomad/structs" ) @@ -40,6 +42,9 @@ type vaultTokenUpdateHandler interface { updatedVaultToken(token string) } +// deriveTokenFunc is the signature of a function used to derive Vault tokens. +type deriveTokenFunc func() (string, error) + func (tr *TaskRunner) updatedVaultToken(token string) { // Update the task runner and environment tr.setVaultToken(token) @@ -56,7 +61,8 @@ type vaultHookConfig struct { updater vaultTokenUpdateHandler logger log.Logger alloc *structs.Allocation - task string + task *structs.Task + widmgr widmgr.IdentityManager } type vaultHook struct { @@ -95,19 +101,27 @@ type vaultHook struct { // alloc is the allocation alloc *structs.Allocation - // taskName is the name of the task - taskName string + // task is the task to run. + task *structs.Task // firstRun stores whether it is the first run for the hook firstRun bool + // widmgr is used to access signed tokens for workload identities. + widmgr widmgr.IdentityManager + + // widName is the workload identity name to use to retrieve signed JWTs. + widName string + + // deriveTokenFunc is the function used to derive Vault tokens. + deriveTokenFunc deriveTokenFunc + // future is used to wait on retrieving a Vault token future *tokenFuture } func newVaultHook(config *vaultHookConfig) *vaultHook { ctx, cancel := context.WithCancel(context.Background()) - h := &vaultHook{ vaultBlock: config.vaultBlock, clientFunc: config.clientFunc, @@ -115,13 +129,24 @@ func newVaultHook(config *vaultHookConfig) *vaultHook { lifecycle: config.lifecycle, updater: config.updater, alloc: config.alloc, - taskName: config.task, + task: config.task, firstRun: true, ctx: ctx, cancel: cancel, future: newTokenFuture(), + widmgr: config.widmgr, } h.logger = config.logger.Named(h.Name()) + + h.widName = config.task.Vault.IdentityName() + wid := config.task.GetIdentity(h.widName) + switch { + case wid != nil: + h.deriveTokenFunc = h.deriveVaultTokenJWT + default: + h.deriveTokenFunc = h.deriveVaultTokenLegacy + } + return h } @@ -320,13 +345,13 @@ OUTER: // deriveVaultToken derives the Vault token using exponential backoffs. It // returns the Vault token and whether the manager should exit. -func (h *vaultHook) deriveVaultToken() (token string, exit bool) { +func (h *vaultHook) deriveVaultToken() (string, bool) { var attempts uint64 var backoff time.Duration for { - tokens, err := h.client.DeriveToken(h.alloc, []string{h.taskName}) + token, err := h.deriveTokenFunc() if err == nil { - return tokens[h.taskName], false + return token, false } // Check if this is a server side error @@ -364,6 +389,55 @@ func (h *vaultHook) deriveVaultToken() (token string, exit bool) { } } +// deriveVaultTokenJWT returns a Vault ACL token using JWT auth login. +func (h *vaultHook) deriveVaultTokenJWT() (string, error) { + // Retrieve signed identity. + signed, err := h.widmgr.Get(structs.WIHandle{ + IdentityName: h.widName, + WorkloadIdentifier: h.task.Name, + WorkloadType: structs.WorkloadTypeTask, + }) + if err != nil { + return "", structs.NewRecoverableError( + fmt.Errorf("failed to retrieve signed workload identity: %w", err), + true, + ) + } + if signed == nil { + return "", structs.NewRecoverableError( + errors.New("no signed workload identity available"), + false, + ) + } + + // Derive Vault token with signed identity. + token, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{ + JWT: signed.JWT, + Role: h.vaultBlock.Role, + }) + if err != nil { + return "", structs.WrapRecoverable( + fmt.Sprintf("failed to derive Vault token for identity %s: %v", h.widName, err), + err, + ) + } + + return token, nil +} + +// deriveVaultTokenLegacy returns a Vault ACL token using the legacy flow where +// Nomad clients request Vault tokens from Nomad servers. +// +// Deprecated: This authentication flow will be removed Nomad 1.9. +func (h *vaultHook) deriveVaultTokenLegacy() (string, error) { + tokens, err := h.client.DeriveToken(h.alloc, []string{h.task.Name}) + if err != nil { + return "", err + } + + return tokens[h.task.Name], nil +} + // writeToken writes the given token to disk func (h *vaultHook) writeToken(token string) error { // Handle upgrade path by first checking if the tasks private directory diff --git a/client/allocrunner/taskrunner/vault_hook_test.go b/client/allocrunner/taskrunner/vault_hook_test.go index a5aed075f..b3d4eb2a9 100644 --- a/client/allocrunner/taskrunner/vault_hook_test.go +++ b/client/allocrunner/taskrunner/vault_hook_test.go @@ -3,9 +3,567 @@ package taskrunner -import "github.com/hashicorp/nomad/client/allocrunner/interfaces" +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + trtesting "github.com/hashicorp/nomad/client/allocrunner/taskrunner/testing" + cstate "github.com/hashicorp/nomad/client/state" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/client/vaultclient" + "github.com/hashicorp/nomad/client/widmgr" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" + "github.com/shoenig/test/wait" +) // Statically assert the stats hook implements the expected interfaces var _ interfaces.TaskPrestartHook = (*vaultHook)(nil) var _ interfaces.TaskStopHook = (*vaultHook)(nil) var _ interfaces.ShutdownHook = (*vaultHook)(nil) + +// vaultTokenUpdaterMock is a mock of the vaultTokenUpdateHandler interface. +type vaultTokenUpdaterMock struct { + currentToken string +} + +func (v *vaultTokenUpdaterMock) updatedVaultToken(token string) { + v.currentToken = token +} + +func setupTestVaultHook(t *testing.T, config *vaultHookConfig) *vaultHook { + t.Helper() + + if config == nil { + config = &vaultHookConfig{} + } + + job := mock.MinJob() + if config.alloc == nil { + config.alloc = mock.MinAlloc() + config.alloc.Job = job + } + if config.task == nil { + config.task = job.TaskGroups[0].Tasks[0] + config.task.Identities = []*structs.WorkloadIdentity{ + {Name: "vault_default"}, + } + config.task.Vault = &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + } + + if config.vaultBlock != nil { + config.task.Identities[0].Name = config.vaultBlock.IdentityName() + config.task.Vault = config.vaultBlock + } + } + if config.vaultBlock == nil { + config.vaultBlock = config.task.Vault + } + if config.clientFunc == nil { + config.clientFunc = func(cluster string) (vaultclient.VaultClient, error) { + return vaultclient.NewMockVaultClient(cluster) + } + } + if config.logger == nil { + config.logger = testlog.HCLogger(t) + } + if config.events == nil { + config.events = &trtesting.MockEmitter{} + } + if config.lifecycle == nil { + config.lifecycle = trtesting.NewMockTaskHooks() + } + if config.updater == nil { + config.updater = &vaultTokenUpdaterMock{} + } + if config.widmgr == nil { + db := cstate.NewMemDB(config.logger) + signer := widmgr.NewMockWIDSigner(config.task.Identities) + + config.widmgr = widmgr.NewWIDMgr(signer, config.alloc, db, config.logger) + err := config.widmgr.Run() + must.NoError(t, err) + } + + return newVaultHook(config) +} + +func TestTaskRunner_VaultHook(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + task *structs.Task + expectLegacy bool + }{ + { + name: "legacy flow", + task: &structs.Task{ + Vault: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + }, + }, + expectLegacy: true, + }, + { + name: "jwt flow", + task: &structs.Task{ + Vault: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + }, + Identities: []*structs.WorkloadIdentity{ + {Name: "vault_default"}, + }, + }, + }, + { + name: "disable file", + task: &structs.Task{ + Vault: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + DisableFile: true, + }, + Identities: []*structs.WorkloadIdentity{ + {Name: "vault_default"}, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + alloc := mock.MinAlloc() + alloc.Job.TaskGroups[0].Tasks[0] = tc.task + + hook := setupTestVaultHook(t, &vaultHookConfig{ + task: tc.task, + alloc: alloc, + }) + + // Ensure Prestart() returns within a reasonable time. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: t.TempDir(), + PrivateDir: t.TempDir(), + }, + Task: tc.task, + } + var resp interfaces.TaskPrestartResponse + + err := hook.Prestart(ctx, req, &resp) + must.NoError(t, err) + must.NoError(t, ctx.Err()) + + // Token must have been derived. + var token string + client := hook.client.(*vaultclient.MockVaultClient) + if tc.expectLegacy { + tokens := client.LegacyTokens() + must.MapLen(t, 1, tokens) + token = tokens[tc.task.Name] + } else { + tokens := client.JWTTokens() + must.MapLen(t, 1, tokens) + + swid, err := hook.widmgr.Get(structs.WIHandle{ + IdentityName: tc.task.Vault.IdentityName(), + WorkloadIdentifier: tc.task.Name, + WorkloadType: structs.WorkloadTypeTask, + }) + must.NoError(t, err) + token = tokens[swid.JWT] + } + must.NotEq(t, "", token) + + // Token must be set in token updater. + updater := (hook.updater).(*vaultTokenUpdaterMock) + must.Eq(t, token, updater.currentToken) + + // Token must be written to disk. + tokenFile, err := os.ReadFile(hook.privateDirTokenPath) + must.NoError(t, err) + must.Eq(t, updater.currentToken, string(tokenFile)) + + if !tc.task.Vault.DisableFile { + tokenFile, err := os.ReadFile(hook.secretsDirTokenPath) + must.NoError(t, err) + must.Eq(t, updater.currentToken, string(tokenFile)) + } else { + _, err = os.ReadFile(hook.secretsDirTokenPath) + must.ErrorIs(t, err, os.ErrNotExist) + } + + // Token must be set for renewal. + must.MapLen(t, 1, client.RenewTokens()) + must.NotNil(t, client.RenewTokens()[updater.currentToken]) + + // PrestartDone must be false so we can recover tokens. + // firstRun is used to prevent multiple executions. + must.False(t, resp.Done) + must.False(t, hook.firstRun) + + // Stop renewal when hook stops. + err = hook.Stop(ctx, nil, nil) + must.NoError(t, err) + must.Wait(t, wait.InitialSuccess( + wait.ErrorFunc(func() error { + tokens := client.StoppedTokens() + if len(tokens) != 1 { + return fmt.Errorf("expected stopped tokens to be %d, got %d", 1, len(tokens)) + } + got := tokens[0] + expect := updater.currentToken + if got != expect { + return fmt.Errorf("expected stopped token to be %s, got %s", expect, got) + } + return nil + }), + wait.Timeout(5*time.Second), + wait.Gap(100*time.Millisecond), + )) + }) + } +} + +func TestTaskRunner_VaultHook_recover(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + setupReq func() (*interfaces.TaskPrestartRequest, error) + }{ + { + name: "recover from secrets dir", + setupReq: func() (*interfaces.TaskPrestartRequest, error) { + // Write token to secrets dir. + secretsDirPath := t.TempDir() + err := os.WriteFile(filepath.Join(secretsDirPath, vaultTokenFile), []byte("much secret"), 0666) + if err != nil { + return nil, err + } + + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: secretsDirPath, + PrivateDir: t.TempDir(), + }, + } + return req, nil + }, + }, + { + name: "recover from private dir", + setupReq: func() (*interfaces.TaskPrestartRequest, error) { + // Write token to private dir. + privateDirPath := t.TempDir() + err := os.WriteFile(filepath.Join(privateDirPath, vaultTokenFile), []byte("much secret"), 0666) + if err != nil { + return nil, err + } + + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: t.TempDir(), + PrivateDir: privateDirPath, + }, + } + return req, nil + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hook := setupTestVaultHook(t, nil) + + req, err := tc.setupReq() + must.NoError(t, err) + req.Task = hook.task + + // Ensure Prestart() returns in a reasonable time. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + var resp interfaces.TaskPrestartResponse + err = hook.Prestart(ctx, req, &resp) + must.NoError(t, err) + must.NoError(t, ctx.Err()) + + // Verify token was recovered and not derived. + client := hook.client.(*vaultclient.MockVaultClient) + must.MapLen(t, 0, client.JWTTokens()) + must.MapLen(t, 0, client.LegacyTokens()) + }) + } +} + +func TestTaskRunner_VaultHook_deriveError(t *testing.T) { + ci.Parallel(t) + + t.Run("unrecoverable error", func(t *testing.T) { + vaultClient, _ := vaultclient.NewMockVaultClient("") + mockVaultClient := vaultClient.(*vaultclient.MockVaultClient) + + hook := setupTestVaultHook(t, &vaultHookConfig{ + clientFunc: func(string) (vaultclient.VaultClient, error) { + return mockVaultClient, nil + }, + }) + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: t.TempDir(), + PrivateDir: t.TempDir(), + }, + Task: hook.task, + } + var resp interfaces.TaskPrestartResponse + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + // Set unrecoverable error. + mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) { + // Cancel the context to simulate the task being killed. + cancel() + return "", structs.NewRecoverableError(errors.New("unrecoverable test error"), false) + }) + + err := hook.Prestart(ctx, req, &resp) + must.NoError(t, err) + + // Verify task is killed because of unrecoverable error. + must.Wait(t, wait.InitialSuccess( + wait.ErrorFunc(func() error { + killEv := (hook.lifecycle.(*trtesting.MockTaskHooks)).KillEvent() + if killEv == nil { + return errors.New("missing kill event") + } + return nil + }), + wait.Timeout(5*time.Second), + wait.Gap(100*time.Millisecond), + )) + killEv := (hook.lifecycle.(*trtesting.MockTaskHooks)).KillEvent() + must.StrContains(t, killEv.DisplayMessage, "unrecoverable test error") + }) + + t.Run("recoverable error", func(t *testing.T) { + vaultClient, _ := vaultclient.NewMockVaultClient("") + mockVaultClient := vaultClient.(*vaultclient.MockVaultClient) + + hook := setupTestVaultHook(t, &vaultHookConfig{ + clientFunc: func(string) (vaultclient.VaultClient, error) { + return mockVaultClient, nil + }, + }) + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: t.TempDir(), + PrivateDir: t.TempDir(), + }, + Task: hook.task, + } + var resp interfaces.TaskPrestartResponse + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + // Set recoverable error. + mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) { + return "", structs.NewRecoverableError(errors.New("recoverable test error"), true) + }) + + go func() { + // Wait a bit for the first error then fix token renewal. + time.Sleep(time.Second) + mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) { + return "secret", nil + }) + + }() + err := hook.Prestart(ctx, req, &resp) + must.NoError(t, err) + must.NoError(t, ctx.Err()) + + // Verify retry happened and token was derived. + updater := (hook.updater).(*vaultTokenUpdaterMock) + must.Eq(t, "secret", updater.currentToken) + }) + + t.Run("renew request failed", func(t *testing.T) { + vaultClient, _ := vaultclient.NewMockVaultClient("") + mockVaultClient := vaultClient.(*vaultclient.MockVaultClient) + + hook := setupTestVaultHook(t, &vaultHookConfig{ + clientFunc: func(string) (vaultclient.VaultClient, error) { + return mockVaultClient, nil + }, + }) + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: t.TempDir(), + PrivateDir: t.TempDir(), + }, + Task: hook.task, + } + var resp interfaces.TaskPrestartResponse + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + // Derive predictable token and fail renew request. + mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) { + return "secret", nil + }) + mockVaultClient.SetRenewTokenError("secret", errors.New("test error")) + + go func() { + // Wait a bit for the renew error then fix token renewal. + time.Sleep(10 * time.Millisecond) + mockVaultClient.SetRenewTokenError("secret", nil) + + }() + err := hook.Prestart(ctx, req, &resp) + must.NoError(t, err) + must.NoError(t, ctx.Err()) + + // Verify retry happened and token was derived. + updater := (hook.updater).(*vaultTokenUpdaterMock) + must.Eq(t, "secret", updater.currentToken) + }) +} + +func TestTaskRunner_VaultHook_tokenRenewalFail(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + vaultBlock *structs.Vault + verifyTaskLifecycle func(*trtesting.MockTaskHooks) error + }{ + { + name: "change mode signal", + vaultBlock: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + ChangeMode: structs.VaultChangeModeSignal, + ChangeSignal: "SIGTERM", + }, + verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error { + signals := h.Signals() + if len(signals) != 1 { + return fmt.Errorf("expected 1 signal, got %d", len(signals)) + } + if signals[0] != "SIGTERM" { + return fmt.Errorf("expected signal to be SIGTERM, got %s", signals[0]) + } + return nil + }, + }, + { + name: "change mode restart", + vaultBlock: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + ChangeMode: structs.VaultChangeModeRestart, + }, + verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error { + restarts := h.Restarts() + if restarts != 1 { + return fmt.Errorf("expected 1 restart, got %d", restarts) + } + return nil + }, + }, + { + name: "change mode noop", + vaultBlock: &structs.Vault{ + Cluster: structs.VaultDefaultCluster, + ChangeMode: structs.VaultChangeModeNoop, + }, + verifyTaskLifecycle: func(h *trtesting.MockTaskHooks) error { + restarts := h.Restarts() + if restarts != 0 { + return fmt.Errorf("expected 0 restarts, got %d", restarts) + } + + signals := h.Signals() + if len(signals) != 0 { + return fmt.Errorf("expected 0 signals, got %d", len(signals)) + } + + return nil + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + vaultClient, _ := vaultclient.NewMockVaultClient("") + mockVaultClient := vaultClient.(*vaultclient.MockVaultClient) + + hook := setupTestVaultHook(t, &vaultHookConfig{ + vaultBlock: tc.vaultBlock, + clientFunc: func(string) (vaultclient.VaultClient, error) { + return mockVaultClient, nil + }, + }) + + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewEmptyTaskEnv(), + TaskDir: &allocdir.TaskDir{ + SecretsDir: t.TempDir(), + PrivateDir: t.TempDir(), + }, + Task: hook.task, + } + var resp interfaces.TaskPrestartResponse + + // Ensure Prestart() returns within a reasonable time. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + err := hook.Prestart(ctx, req, &resp) + must.NoError(t, err) + + // Fetch derived token. + updater := (hook.updater).(*vaultTokenUpdaterMock) + token := updater.currentToken + must.NotEq(t, "", token) + + // Fetch renewal token error channel. + renewErrCh := mockVaultClient.RenewTokenErrCh(token) + must.NotNil(t, renewErrCh) + + // Emit renewal error. + renewErrCh <- errors.New("renew error") + + // Verify expected lifecycle events happen. + must.Wait(t, wait.InitialSuccess( + wait.ErrorFunc(func() error { + return tc.verifyTaskLifecycle((hook.lifecycle).(*trtesting.MockTaskHooks)) + }), + wait.Timeout(3*time.Second), + wait.Gap(100*time.Millisecond), + )) + }) + } +} diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 2a2201733..1cb4081d8 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -5,6 +5,8 @@ package vaultclient import ( "container/heap" + "context" + "errors" "fmt" "math/rand" "strings" @@ -29,6 +31,18 @@ type VaultClientFunc func(string) (VaultClient, error) // wrapped tokens will be unwrapped using the vault API client. type TokenDeriverFunc func(*structs.Allocation, []string, *vaultapi.Client) (map[string]string, error) +// JWTLoginRequest is used to derive a Vault ACL token using a JWT login +// request. +type JWTLoginRequest struct { + // JWT is the signed JWT to be used for the login request. + JWT string + + // Role is Vault ACL role to use for the login request. If empty, the + // Nomad client's create_from_role value is used, or the Vault cluster + // default role. + Role string +} + // VaultClient is the interface which nomad client uses to interact with vault and // periodically renews the tokens and secrets. type VaultClient interface { @@ -43,6 +57,10 @@ type VaultClient interface { // returned. DeriveToken(*structs.Allocation, []string) (map[string]string, error) + // DeriveTokenWithJWT returns a Vault ACL token using the JWT login + // endpoint. + DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, error) + // GetConsulACL fetches the Consul ACL token required for the task GetConsulACL(string, string) (*vaultapi.Secret, error) @@ -261,6 +279,45 @@ func (c *Client) DeriveToken(alloc *structs.Allocation, taskNames []string) (map return tokens, nil } +// DeriveTokenWithJWT returns a Vault ACL token using the JWT login endpoint. +func (c *Client) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, error) { + if !c.Config.IsEnabled() { + return "", fmt.Errorf("vault client not enabled") + } + if !c.isRunning() { + return "", fmt.Errorf("vault client is not running") + } + + c.Lock.Lock() + defer c.unlockAndUnset() + + // Make sure the login request is not passing any token. + c.Vault.SetToken("") + + jwtLoginPath := fmt.Sprintf("auth/%s/login", c.Config.JWTAuthBackendPath) + s, err := c.Vault.Logical().WriteWithContext(ctx, jwtLoginPath, + map[string]any{ + "role": req.Role, + "jwt": req.JWT, + }, + ) + if err != nil { + return "", fmt.Errorf("failed to login with JWT: %v", err) + } + if s == nil { + return "", errors.New("JWT login returned an empty secret") + } + if s.Auth == nil { + return "", errors.New("JWT login did not return a token") + } + + for _, w := range s.Warnings { + c.logger.Warn("JWT login warning", "warning", w) + } + + return s.Auth.ClientToken, nil +} + // GetConsulACL creates a vault API client and reads from vault a consul ACL // token used by the task. func (c *Client) GetConsulACL(token, path string) (*vaultapi.Secret, error) { diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go new file mode 100644 index 000000000..580a8683f --- /dev/null +++ b/client/vaultclient/vaultclient_test.go @@ -0,0 +1,259 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package vaultclient + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "text/template" + "time" + + josejwt "github.com/go-jose/go-jose/v3/jwt" + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/widmgr" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/hashicorp/vault/api" + "github.com/shoenig/test/must" +) + +const ( + jwtAuthMountPathTest = "jwt_test" + + jwtAuthConfigTemplate = ` +{ + "jwks_url": "<<.JWKSURL>>", + "jwt_supported_algs": ["EdDSA"], + "default_role": "nomad-workloads" +} +` + + widVaultPolicyTemplate = ` +path "secret/data/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_job_id}}/*" { + capabilities = ["read"] +} + +path "secret/data/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_job_id}}" { + capabilities = ["read"] +} + +path "secret/metadata/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/*" { + capabilities = ["list"] +} + +path "secret/metadata/*" { + capabilities = ["list"] +} +` + + widVaultRole = ` +{ + "role_type": "jwt", + "bound_audiences": "vault.io", + "user_claim": "/nomad_job_id", + "user_claim_json_pointer": true, + "claim_mappings": { + "nomad_namespace": "nomad_namespace", + "nomad_job_id": "nomad_job_id" + }, + "token_ttl": "30m", + "token_type": "service", + "token_period": "72h", + "token_policies": ["nomad-workloads"] +} +` +) + +func renderVaultTemplate(tmplStr string, data any) ([]byte, error) { + var buf bytes.Buffer + tmpl, err := template.New("policy"). + Delims("<<", ">>"). + Parse(tmplStr) + if err != nil { + return nil, fmt.Errorf("failed to parse policy template: %w", err) + } + + err = tmpl.Execute(&buf, data) + if err != nil { + return nil, fmt.Errorf("failed to render policy template: %w", err) + } + + return buf.Bytes(), nil +} + +func setupVaultForWorkloadIdentity(v *testutil.TestVault, jwksURL string) error { + logical := v.Client.Logical() + sys := v.Client.Sys() + ctx := context.Background() + + // Enable JWT auth method. + err := sys.EnableAuthWithOptions(jwtAuthMountPathTest, &api.MountInput{ + Type: "jwt", + }) + if err != nil { + return fmt.Errorf("failed to enable JWT auth method: %w", err) + } + + secret, err := logical.Read(fmt.Sprintf("sys/auth/%s", jwtAuthMountPathTest)) + jwtAuthAccessor := secret.Data["accessor"].(string) + + // Write JWT auth method config. + jwtAuthConfigData := struct { + JWKSURL string + }{ + JWKSURL: jwksURL, + } + jwtAuthConfig, err := renderVaultTemplate(jwtAuthConfigTemplate, jwtAuthConfigData) + if err != nil { + return err + } + + _, err = logical.WriteBytesWithContext(ctx, fmt.Sprintf("auth/%s/config", jwtAuthMountPathTest), jwtAuthConfig) + if err != nil { + return fmt.Errorf("failed to write JWT auth method config: %w", err) + } + + // Write Nomad workload policy. + data := struct { + JWTAuthAccessorID string + }{ + JWTAuthAccessorID: jwtAuthAccessor, + } + policy, err := renderVaultTemplate(widVaultPolicyTemplate, data) + if err != nil { + return err + } + + encoded := base64.StdEncoding.EncodeToString(policy) + policyReqBody := fmt.Sprintf(`{"policy": "%s"}`, encoded) + + policyPath := "sys/policies/acl/nomad-workloads" + _, err = logical.WriteBytesWithContext(ctx, policyPath, []byte(policyReqBody)) + if err != nil { + return fmt.Errorf("failed to write policy: %w", err) + } + + // Write Nomad workload role. + rolePath := fmt.Sprintf("auth/%s/role/nomad-workloads", jwtAuthMountPathTest) + _, err = logical.WriteBytesWithContext(ctx, rolePath, []byte(widVaultRole)) + if err != nil { + return fmt.Errorf("failed to write role: %w", err) + } + + return nil +} + +func TestVaultClient_DeriveTokenWithJWT(t *testing.T) { + ci.Parallel(t) + + // Create signer and signed identities. + alloc := mock.MinAlloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Identities = []*structs.WorkloadIdentity{ + { + Name: "vault_default", + Audience: []string{"vault.io"}, + TTL: time.Second, + }, + } + + signer := widmgr.NewMockWIDSigner(task.Identities) + signedWIDs, err := signer.SignIdentities(1, []*structs.WorkloadIdentityRequest{ + { + AllocID: alloc.ID, + WIHandle: structs.WIHandle{ + IdentityName: task.Identities[0].Name, + WorkloadIdentifier: task.Name, + WorkloadType: structs.WorkloadTypeTask, + }, + }, + }) + must.NoError(t, err) + must.Len(t, 1, signedWIDs) + + // Setup test JWKS server. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out, err := json.Marshal(signer.JSONWebKeySet()) + if err != nil { + t.Errorf("failed to generate JWKS json response: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + fmt.Fprintln(w, string(out)) + })) + defer ts.Close() + + // Start and configure Vault cluster for JWT authentication. + v := testutil.NewTestVault(t) + defer v.Stop() + + err = setupVaultForWorkloadIdentity(v, ts.URL) + must.NoError(t, err) + + // Start Vault client. + logger := testlog.HCLogger(t) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.JWTAuthBackendPath = jwtAuthMountPathTest + + c, err := NewVaultClient(v.Config, logger, nil) + must.NoError(t, err) + + c.Start() + defer c.Stop() + + // Derive Vault token using signed JWT. + jwtStr := signedWIDs[0].JWT + token, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ + JWT: jwtStr, + }) + must.NoError(t, err) + must.NotEq(t, "", token) + + // Verify token has expected properties. + v.Client.SetToken(token) + s, err := v.Client.Logical().Read("auth/token/lookup-self") + must.NoError(t, err) + + jwt, err := josejwt.ParseSigned(jwtStr) + must.NoError(t, err) + + claims := make(map[string]any) + err = jwt.UnsafeClaimsWithoutVerification(&claims) + must.NoError(t, err) + + must.Eq(t, "service", s.Data["type"].(string)) + must.True(t, s.Data["renewable"].(bool)) + must.SliceContains(t, s.Data["policies"].([]any), "nomad-workloads") + must.MapEq(t, map[string]any{ + "nomad_namespace": claims["nomad_namespace"], + "nomad_job_id": claims["nomad_job_id"], + "role": "nomad-workloads", + }, s.Data["meta"].(map[string]any)) + + // Verify token has the expected permissions. + pathAllowed := fmt.Sprintf("secret/data/%s/%s/a", claims["nomad_namespace"], claims["nomad_job_id"]) + pathDenied := "secret/data/denied" + + s, err = v.Client.Logical().Write("sys/capabilities-self", map[string]any{ + "paths": []string{pathAllowed, pathDenied}, + }) + must.NoError(t, err) + must.Eq(t, []any{"read"}, (s.Data[pathAllowed]).([]any)) + must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any)) + + // Derive Vault token with non-existing role. + token, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ + JWT: jwtStr, + Role: "test", + }) + must.ErrorContains(t, err, `role "test" could not be found`) +} diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index be9b79131..481cf031c 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -4,6 +4,7 @@ package vaultclient import ( + "context" "sync" "github.com/hashicorp/nomad/helper/uuid" @@ -14,6 +15,12 @@ import ( // MockVaultClient is used for testing the vaultclient integration and is safe // for concurrent access. type MockVaultClient struct { + // legacyTokens stores the tokens per task derived using the legacy flow. + legacyTokens map[string]string + + // jwtTokens stores the tokens derived using the JWT flow. + jwtTokens map[string]string + // stoppedTokens tracks the tokens that have stopped renewing stoppedTokens []string @@ -34,12 +41,33 @@ type MockVaultClient struct { // a token is generated and returned DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) + // deriveTokenWithJWTFn allows the caller to control the DeriveTokenWithJWT + // functio. + deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, error) + mu sync.Mutex } // NewMockVaultClient returns a MockVaultClient for testing func NewMockVaultClient(_ string) (VaultClient, error) { return &MockVaultClient{}, nil } +func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, error) { + vc.mu.Lock() + defer vc.mu.Unlock() + + if vc.deriveTokenWithJWTFn != nil { + return vc.deriveTokenWithJWTFn(ctx, req) + } + + if vc.jwtTokens == nil { + vc.jwtTokens = make(map[string]string) + } + + token := uuid.Generate() + vc.jwtTokens[req.JWT] = token + return token, nil +} + func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { vc.mu.Lock() defer vc.mu.Unlock() @@ -59,6 +87,7 @@ func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (m tokens[task] = uuid.Generate() } + vc.legacyTokens = tokens return tokens, nil } @@ -120,6 +149,20 @@ func (vc *MockVaultClient) Stop() {} func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil } +// LegacyTokens returns the tokens generated using the legacy flow. +func (vc *MockVaultClient) LegacyTokens() map[string]string { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.legacyTokens +} + +// JWTTotkens returns the tokens generated suing the JWT flow. +func (vc *MockVaultClient) JWTTokens() map[string]string { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.jwtTokens +} + // StoppedTokens tracks the tokens that have stopped renewing func (vc *MockVaultClient) StoppedTokens() []string { vc.mu.Lock() @@ -135,6 +178,14 @@ func (vc *MockVaultClient) RenewTokens() map[string]chan error { return vc.renewTokens } +// RenewTokenErrCh returns the error channel for the given token renewal +// process. +func (vc *MockVaultClient) RenewTokenErrCh(token string) chan error { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.renewTokens[token] +} + // RenewTokenErrors is used to return an error when the RenewToken is called // with the given token func (vc *MockVaultClient) RenewTokenErrors() map[string]error { @@ -150,3 +201,10 @@ func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error { defer vc.mu.Unlock() return vc.deriveTokenErrors } + +// SetDeriveTokenWithJWTFn sets the function used to derive tokens using JWT. +func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, error)) { + vc.mu.Lock() + defer vc.mu.Unlock() + vc.deriveTokenWithJWTFn = f +} diff --git a/client/widmgr/mock.go b/client/widmgr/mock.go index 350fab20c..a23341360 100644 --- a/client/widmgr/mock.go +++ b/client/widmgr/mock.go @@ -59,6 +59,18 @@ func (m *MockWIDSigner) now() time.Time { return m.mockNow } +func (m *MockWIDSigner) JSONWebKeySet() *jose.JSONWebKeySet { + jwk := jose.JSONWebKey{ + Key: m.key.Public(), + KeyID: m.keyID, + Algorithm: "EdDSA", + Use: "sig", + } + return &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{jwk}, + } +} + func (m *MockWIDSigner) SignIdentities(minIndex uint64, req []*structs.WorkloadIdentityRequest) ([]*structs.SignedWorkloadIdentity, error) { swids := make([]*structs.SignedWorkloadIdentity, 0, len(req)) for _, idReq := range req { diff --git a/nomad/server.go b/nomad/server.go index 4f3c0b64a..17a8c112f 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1150,8 +1150,14 @@ func (s *Server) setupConsul(consulConfigEntries consul.ConfigAPI, consulACLs co // setupVaultClient is used to set up the Vault API client. func (s *Server) setupVaultClient() error { + vconfig := s.config.VaultConfig + if vconfig != nil && vconfig.DefaultIdentity != nil { + s.vault = &NoopVault{} + return nil + } + delegate := s.entVaultDelegate() - v, err := NewVaultClient(s.config.VaultConfig, s.logger, s.purgeVaultAccessors, delegate) + v, err := NewVaultClient(vconfig, s.logger, s.purgeVaultAccessors, delegate) if err != nil { return err } diff --git a/nomad/vault_noop.go b/nomad/vault_noop.go new file mode 100644 index 000000000..c02f4a26a --- /dev/null +++ b/nomad/vault_noop.go @@ -0,0 +1,61 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package nomad + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" + vapi "github.com/hashicorp/vault/api" +) + +type NoopVault struct { + l sync.Mutex + config *config.VaultConfig +} + +func (v *NoopVault) SetActive(_ bool) {} + +func (v *NoopVault) SetConfig(c *config.VaultConfig) error { + v.l.Lock() + defer v.l.Unlock() + + v.config = c + return nil +} + +func (v *NoopVault) GetConfig() *config.VaultConfig { + v.l.Lock() + defer v.l.Unlock() + + return v.config.Copy() +} + +func (v *NoopVault) CreateToken(_ context.Context, _ *structs.Allocation, _ string) (*vapi.Secret, error) { + return nil, errors.New("Vault client not able to create tokens") +} + +func (v *NoopVault) LookupToken(_ context.Context, _ string) (*vapi.Secret, error) { + return nil, errors.New("Vault client not able to lookup tokens") +} + +func (v *NoopVault) RevokeTokens(_ context.Context, _ []*structs.VaultAccessor, _ bool) error { + return errors.New("Vault client not able to revoke tokens") +} + +func (v *NoopVault) MarkForRevocation(accessors []*structs.VaultAccessor) error { + return errors.New("Vault client not able to revoke tokens") +} + +func (v *NoopVault) Stop() {} + +func (v *NoopVault) Running() bool { return true } + +func (v *NoopVault) Stats() map[string]string { return nil } + +func (v *NoopVault) EmitStats(_ time.Duration, _ <-chan struct{}) {} diff --git a/testutil/vault.go b/testutil/vault.go index a51764888..c5ff10523 100644 --- a/testutil/vault.go +++ b/testutil/vault.go @@ -42,6 +42,12 @@ type TestVault struct { } func NewTestVaultFromPath(t testing.T, binary string) *TestVault { + t.Helper() + + if _, err := exec.LookPath(binary); err != nil { + t.Skipf("Skipping test %s, Vault binary %q not found in path.", t.Name(), binary) + } + port := ci.PortAllocator.Grab(1)[0] token := uuid.Generate() bind := fmt.Sprintf("-dev-listen-address=127.0.0.1:%d", port) @@ -112,6 +118,8 @@ func NewTestVaultFromPath(t testing.T, binary string) *TestVault { // NewTestVault returns a new TestVault instance that is ready for API calls func NewTestVault(t testing.T) *TestVault { + t.Helper() + // Lookup vault from the path return NewTestVaultFromPath(t, "vault") }