From e9e1631b8c1ee578542efcdb55b63037eb90bd37 Mon Sep 17 00:00:00 2001 From: Michael Smithhisler Date: Thu, 14 Aug 2025 09:35:11 -0400 Subject: [PATCH] test: add task validation when using vault secret provider (#26517) --- .../taskrunner/secrets/nomad_provider.go | 2 + .../taskrunner/secrets/vault_provider.go | 2 + client/allocrunner/taskrunner/secrets_hook.go | 4 +- nomad/structs/structs.go | 7 +++ nomad/structs/structs_test.go | 50 +++++++++++++++++++ 5 files changed, 63 insertions(+), 2 deletions(-) diff --git a/client/allocrunner/taskrunner/secrets/nomad_provider.go b/client/allocrunner/taskrunner/secrets/nomad_provider.go index 056bdcf07..056e3de57 100644 --- a/client/allocrunner/taskrunner/secrets/nomad_provider.go +++ b/client/allocrunner/taskrunner/secrets/nomad_provider.go @@ -13,6 +13,8 @@ import ( "github.com/mitchellh/mapstructure" ) +const SecretProviderNomad = "nomad" + type nomadProviderConfig struct { Namespace string `mapstructure:"namespace"` } diff --git a/client/allocrunner/taskrunner/secrets/vault_provider.go b/client/allocrunner/taskrunner/secrets/vault_provider.go index 079a05ed2..3ad29760e 100644 --- a/client/allocrunner/taskrunner/secrets/vault_provider.go +++ b/client/allocrunner/taskrunner/secrets/vault_provider.go @@ -14,6 +14,8 @@ import ( ) const ( + SecretProviderVault = "vault" + VAULT_KV = "kv" VAULT_KV_V2 = "kv_v2" ) diff --git a/client/allocrunner/taskrunner/secrets_hook.go b/client/allocrunner/taskrunner/secrets_hook.go index f3e204e40..616ff0172 100644 --- a/client/allocrunner/taskrunner/secrets_hook.go +++ b/client/allocrunner/taskrunner/secrets_hook.go @@ -185,13 +185,13 @@ func (h *secretsHook) buildSecretProviders(secretDir string) ([]TemplateProvider tmplFile := fmt.Sprintf("temp-%d", idx) switch s.Provider { - case "nomad": + case secrets.SecretProviderNomad: if p, err := secrets.NewNomadProvider(s, secretDir, tmplFile, h.nomadNamespace); err != nil { multierror.Append(mErr, err) } else { tmplProvider = append(tmplProvider, p) } - case "vault": + case secrets.SecretProviderVault: if p, err := secrets.NewVaultProvider(s, secretDir, tmplFile); err != nil { multierror.Append(mErr, err) } else { diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 549ad9315..9ec055644 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -219,6 +219,9 @@ const ( RateMetricRead = "read" RateMetricList = "list" RateMetricWrite = "write" + + // Vault secret provider used in task validation + SecretProviderVault = "vault" ) var ( @@ -8329,6 +8332,10 @@ func (t *Task) Validate(jobType string, tg *TaskGroup) error { secrets[s.Name] = true } + if s.Provider == SecretProviderVault && t.Vault == nil { + mErr.Errors = append(mErr.Errors, fmt.Errorf("Secret %q has provider \"vault\" but no vault block", s.Name)) + } + if err := s.Validate(); err != nil { mErr.Errors = append(mErr.Errors, fmt.Errorf("Secret %q is invalid: %w", s.Name, err)) } diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index b30589ef6..8ccf1b779 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -6459,6 +6459,56 @@ func TestVault_Canonicalize(t *testing.T) { require.Equal(t, VaultChangeModeRestart, v.ChangeMode) } +func TestTask_Validate_Secret(t *testing.T) { + cases := []struct { + name string + task *Task + expErr bool + }{ + { + name: "errors with vault provider and no vault block", + task: &Task{ + Secrets: []*Secret{ + { + Name: "test", + Provider: "vault", + }, + }, + }, + expErr: true, + }, + { + name: "succeeds with vault provider and vault block", + task: &Task{ + Vault: &Vault{}, + Secrets: []*Secret{ + { + Name: "test", + Provider: "vault", + }, + }, + }, + expErr: false, + }, + } + + vaultProviderErr := "has provider \"vault\" but no vault block" + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.task.Validate(JobTypeService, &TaskGroup{}) + + // Validate will return errors here, we just want to validate + // it contains the above vaultProviderErr or not + if tc.expErr { + must.ErrorContains(t, err, vaultProviderErr) + } else { + // no ErrorNotContains so use string matching + must.StrNotContains(t, err.Error(), vaultProviderErr) + } + }) + } +} + func TestSecrets_Copy(t *testing.T) { ci.Parallel(t) s := &Secret{