From ef24e40b392ca90a4fdee6216333f55c9dc55b45 Mon Sep 17 00:00:00 2001 From: Michael Schurter Date: Fri, 8 Sep 2023 14:50:34 -0700 Subject: [PATCH] identity: support jwt expiration and rotation (#18262) Implements expirations and renewals for alternate workload identity tokens. --- api/tasks.go | 11 +- .../allocrunner/taskrunner/identity_hook.go | 206 ++++++++-- .../taskrunner/identity_hook_test.go | 364 +++++++++++++++++- .../taskrunner/task_runner_test.go | 18 +- command/agent/job_endpoint.go | 2 + helper/retry_test.go | 15 + nomad/alloc_endpoint.go | 1 + nomad/structs/diff_test.go | 40 ++ nomad/structs/structs.go | 19 +- nomad/structs/workload_id.go | 35 +- nomad/structs/workload_id_test.go | 29 ++ testutil/file.go | 20 + 12 files changed, 706 insertions(+), 54 deletions(-) create mode 100644 testutil/file.go diff --git a/api/tasks.go b/api/tasks.go index 25bf56206..eff41aefd 100644 --- a/api/tasks.go +++ b/api/tasks.go @@ -1156,9 +1156,10 @@ func (t *TaskCSIPluginConfig) Canonicalize() { // WorkloadIdentity is the jobspec block which determines if and how a workload // identity is exposed to tasks. type WorkloadIdentity struct { - Name string `hcl:"name,optional"` - Audience []string `mapstructure:"aud" hcl:"aud,optional"` - Env bool `hcl:"env,optional"` - File bool `hcl:"file,optional"` - ServiceName string `hcl:"service_name,optional"` + Name string `hcl:"name,optional"` + Audience []string `mapstructure:"aud" hcl:"aud,optional"` + Env bool `hcl:"env,optional"` + File bool `hcl:"file,optional"` + ServiceName string `hcl:"service_name,optional"` + TTL time.Duration `mapstructure:"ttl" hcl:"ttl,optional"` } diff --git a/client/allocrunner/taskrunner/identity_hook.go b/client/allocrunner/taskrunner/identity_hook.go index be2f54c31..8de9344b6 100644 --- a/client/allocrunner/taskrunner/identity_hook.go +++ b/client/allocrunner/taskrunner/identity_hook.go @@ -7,10 +7,13 @@ import ( "context" "fmt" "path/filepath" + "time" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/users" "github.com/hashicorp/nomad/nomad/structs" ) @@ -30,16 +33,45 @@ type IdentitySigner interface { SignIdentities(minIndex uint64, req []*structs.WorkloadIdentityRequest) ([]*structs.SignedWorkloadIdentity, error) } +// tokenSetter provides methods for exposing workload identities to other +// internal Nomad components. +type tokenSetter interface { + setNomadToken(token string) +} + type identityHook struct { - tr *TaskRunner - tokenDir string - logger log.Logger + alloc *structs.Allocation + task *structs.Task + tokenDir string + envBuilder *taskenv.Builder + ts tokenSetter + widmgr IdentitySigner + logger log.Logger + + // minWait is the minimum amount of time to wait before renewing. Settable to + // ease testing. + minWait time.Duration + + stopCtx context.Context + stop context.CancelFunc } func newIdentityHook(tr *TaskRunner, logger log.Logger) *identityHook { + // Create a context for the renew loop. This context will be canceled when + // the task is stopped or agent is shutting down, unlike Prestart's ctx which + // is not intended for use after Prestart is returns. + stopCtx, stop := context.WithCancel(context.Background()) + h := &identityHook{ - tr: tr, - tokenDir: tr.taskDir.SecretsDir, + alloc: tr.Alloc(), + task: tr.Task(), + tokenDir: tr.taskDir.SecretsDir, + envBuilder: tr.envBuilder, + ts: tr, + widmgr: tr.widmgr, + minWait: 10 * time.Second, + stopCtx: stopCtx, + stop: stop, } h.logger = logger.Named(h.Name()) return h @@ -49,19 +81,19 @@ func (*identityHook) Name() string { return "identity" } -func (h *identityHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { +func (h *identityHook) Prestart(context.Context, *interfaces.TaskPrestartRequest, *interfaces.TaskPrestartResponse) error { // Handle default workload identity if err := h.setDefaultToken(); err != nil { return err } - signedWIDs, err := h.getIdentities(req.Alloc, req.Task) + signedWIDs, err := h.getIdentities() if err != nil { return fmt.Errorf("error fetching alternate identities: %w", err) } - for _, widspec := range req.Task.Identities { + for _, widspec := range h.task.Identities { signedWID := signedWIDs[widspec.Name] if signedWID == nil { // The only way to hit this should be a bug as it indicates the server @@ -74,27 +106,39 @@ func (h *identityHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar } } + // Start token renewal loop + go h.renew(h.alloc.CreateIndex, signedWIDs) + return nil } +// Stop implements interfaces.TaskStopHook +func (h *identityHook) Stop(context.Context, *interfaces.TaskStopRequest, *interfaces.TaskStopResponse) error { + h.stop() + return nil +} + +// Shutdown implements interfaces.ShutdownHook +func (h *identityHook) Shutdown() { + h.stop() +} + // setDefaultToken adds the Nomad token to the task's environment and writes it to a // file if requested by the jobsepc. func (h *identityHook) setDefaultToken() error { - token := h.tr.alloc.SignedIdentities[h.tr.taskName] + token := h.alloc.SignedIdentities[h.task.Name] if token == "" { return nil } // Handle internal use and env var - h.tr.setNomadToken(token) - - task := h.tr.Task() + h.ts.setNomadToken(token) // Handle file writing - if id := task.Identity; id != nil && id.File { + if id := h.task.Identity; id != nil && id.File { // Write token as owner readable only tokenPath := filepath.Join(h.tokenDir, wiTokenFile) - if err := users.WriteFileFor(tokenPath, []byte(token), task.User); err != nil { + if err := users.WriteFileFor(tokenPath, []byte(token), h.task.User); err != nil { return fmt.Errorf("failed to write nomad token: %w", err) } } @@ -106,12 +150,12 @@ func (h *identityHook) setDefaultToken() error { // writes the token file as specified by the jobspec. func (h *identityHook) setAltToken(widspec *structs.WorkloadIdentity, rawJWT string) error { if widspec.Env { - h.tr.envBuilder.SetWorkloadToken(widspec.Name, rawJWT) + h.envBuilder.SetWorkloadToken(widspec.Name, rawJWT) } if widspec.File { tokenPath := filepath.Join(h.tokenDir, fmt.Sprintf("nomad_%s.jwt", widspec.Name)) - if err := users.WriteFileFor(tokenPath, []byte(rawJWT), h.tr.Task().User); err != nil { + if err := users.WriteFileFor(tokenPath, []byte(rawJWT), h.task.User); err != nil { return fmt.Errorf("failed to write token for identity %q: %w", widspec.Name, err) } } @@ -122,23 +166,23 @@ func (h *identityHook) setAltToken(widspec *structs.WorkloadIdentity, rawJWT str // getIdentities calls Alloc.SignIdentities to get all of the identities for // this workload signed. If there are no identities to be signed then (nil, // nil) is returned. -func (h *identityHook) getIdentities(alloc *structs.Allocation, task *structs.Task) (map[string]*structs.SignedWorkloadIdentity, error) { +func (h *identityHook) getIdentities() (map[string]*structs.SignedWorkloadIdentity, error) { - if len(task.Identities) == 0 { + if len(h.task.Identities) == 0 { return nil, nil } - req := make([]*structs.WorkloadIdentityRequest, len(task.Identities)) - for i, widspec := range task.Identities { + req := make([]*structs.WorkloadIdentityRequest, len(h.task.Identities)) + for i, widspec := range h.task.Identities { req[i] = &structs.WorkloadIdentityRequest{ - AllocID: alloc.ID, - TaskName: task.Name, + AllocID: h.alloc.ID, + TaskName: h.task.Name, IdentityName: widspec.Name, } } // Get signed workload identities - signedWIDs, err := h.tr.widmgr.SignIdentities(alloc.CreateIndex, req) + signedWIDs, err := h.widmgr.SignIdentities(h.alloc.CreateIndex, req) if err != nil { return nil, err } @@ -151,3 +195,119 @@ func (h *identityHook) getIdentities(alloc *structs.Allocation, task *structs.Ta return widMap, nil } + +// renew fetches new signed workload identity tokens before the existing tokens +// expire. +func (h *identityHook) renew(createIndex uint64, signedWIDs map[string]*structs.SignedWorkloadIdentity) { + wids := h.task.Identities + if len(wids) == 0 { + h.logger.Trace("no workload identities to renew") + return + } + + var reqs []*structs.WorkloadIdentityRequest + renewNow := false + minExp := time.Now().Add(30 * time.Hour) // set high default expiration + widMap := make(map[string]*structs.WorkloadIdentity, len(wids)) // Identity.Name -> Identity + + for _, wid := range wids { + if wid.TTL == 0 { + // No ttl, so no need to renew it + continue + } + + widMap[wid.Name] = wid + + reqs = append(reqs, &structs.WorkloadIdentityRequest{ + AllocID: h.alloc.ID, + TaskName: h.task.Name, + IdentityName: wid.Name, + }) + + sid, ok := signedWIDs[wid.Name] + if !ok { + // Missing a signature, treat this case as already expired so we get a + // token ASAP + h.logger.Trace("missing token for identity", "identity", wid.Name) + renewNow = true + continue + } + + if sid.Expiration.Before(minExp) { + minExp = sid.Expiration + } + } + + if len(reqs) == 0 { + h.logger.Trace("no workload identities expire") + return + } + + var wait time.Duration + if !renewNow { + wait = helper.ExpiryToRenewTime(minExp, time.Now, h.minWait) + } + + timer, timerStop := helper.NewStoppedTimer() + defer timerStop() + + var retry uint64 + + for err := h.stopCtx.Err(); err == nil; { + h.logger.Debug("waiting to renew identities", "num", len(reqs), "wait", wait) + timer.Reset(wait) + select { + case <-timer.C: + h.logger.Trace("getting new signed identities", "num", len(reqs)) + case <-h.stopCtx.Done(): + return + } + + // Renew all tokens together since its cheap + tokens, err := h.widmgr.SignIdentities(createIndex, reqs) + if err != nil { + retry++ + wait = helper.Backoff(h.minWait, time.Hour, retry) + helper.RandomStagger(h.minWait) + h.logger.Error("error renewing workload identities", "error", err, "next", wait) + continue + } + + if len(tokens) == 0 { + retry++ + wait = helper.Backoff(h.minWait, time.Hour, retry) + helper.RandomStagger(h.minWait) + h.logger.Error("error renewing workload identities", "error", "no tokens", "next", wait) + continue + } + + // Reset next expiration time + minExp = time.Time{} + + for _, token := range tokens { + widspec, ok := widMap[token.IdentityName] + if !ok { + // Bug: Every requested workload identity should either have a signed + // identity or rejection. + h.logger.Warn("bug: unexpected workload identity received", "identity", token.IdentityName) + continue + } + + if err := h.setAltToken(widspec, token.JWT); err != nil { + // Set minExp using retry's backoff logic + minExp = time.Now().Add(helper.Backoff(h.minWait, time.Hour, retry+1) + helper.RandomStagger(h.minWait)) + h.logger.Error("error setting new workload identity", "error", err, "identity", token.IdentityName) + continue + } + + // Set next expiration time + if minExp.IsZero() { + minExp = token.Expiration + } else if token.Expiration.Before(minExp) { + minExp = token.Expiration + } + } + + // Success! Set next renewal and reset retries + wait = helper.ExpiryToRenewTime(minExp, time.Now, h.minWait) + retry = 0 + } +} diff --git a/client/allocrunner/taskrunner/identity_hook_test.go b/client/allocrunner/taskrunner/identity_hook_test.go index f7bfce575..7f6a8dc1d 100644 --- a/client/allocrunner/taskrunner/identity_hook_test.go +++ b/client/allocrunner/taskrunner/identity_hook_test.go @@ -3,8 +3,370 @@ package taskrunner -import "github.com/hashicorp/nomad/client/allocrunner/interfaces" +import ( + "context" + "crypto/ed25519" + "fmt" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test/must" +) var _ interfaces.TaskPrestartHook = (*identityHook)(nil) +var _ interfaces.TaskStopHook = (*identityHook)(nil) +var _ interfaces.ShutdownHook = (*identityHook)(nil) // See task_runner_test.go:TestTaskRunner_IdentityHook + +// MockWIDMgr allows TaskRunner unit tests to avoid having to setup a Server, +// Client, and Allocation. +type MockWIDMgr struct { + // wids maps identity names to workload identities. If wids is non-nil then + // SignIdentities will use it to find expirations or reject invalid identity + // names + wids map[string]*structs.WorkloadIdentity + + key ed25519.PrivateKey + keyID string +} + +func NewMockWIDMgr(wids []*structs.WorkloadIdentity) *MockWIDMgr { + _, privKey, err := ed25519.GenerateKey(nil) + if err != nil { + panic(err) + } + m := &MockWIDMgr{ + key: privKey, + keyID: uuid.Generate(), + } + + if wids != nil { + m.setWIDs(wids) + } + + return m +} + +// setWIDs is a test helper to use Task.Identities in the MockWIDMgr for +// sharing TTLs and validating names. +func (m *MockWIDMgr) setWIDs(wids []*structs.WorkloadIdentity) { + m.wids = make(map[string]*structs.WorkloadIdentity, len(wids)) + for _, wid := range wids { + m.wids[wid.Name] = wid + } +} + +func (m *MockWIDMgr) SignIdentities(minIndex uint64, req []*structs.WorkloadIdentityRequest) ([]*structs.SignedWorkloadIdentity, error) { + swids := make([]*structs.SignedWorkloadIdentity, 0, len(req)) + for _, idReq := range req { + // Set test values for default claims + claims := &structs.IdentityClaims{ + Namespace: "default", + JobID: "test", + AllocationID: idReq.AllocID, + TaskName: idReq.TaskName, + } + claims.ID = uuid.Generate() + + // If test has set workload identities. Lookup claims or reject unknown + // identity. + if m.wids != nil { + wid, ok := m.wids[idReq.IdentityName] + if !ok { + return nil, fmt.Errorf("unknown identity: %q", idReq.IdentityName) + } + + claims.Audience = slices.Clone(wid.Audience) + + if wid.TTL > 0 { + claims.Expiry = jwt.NewNumericDate(time.Now().Add(wid.TTL)) + } + } + + opts := (&jose.SignerOptions{}).WithHeader("kid", m.keyID).WithType("JWT") + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: m.key}, opts) + if err != nil { + return nil, fmt.Errorf("error creating signer: %w", err) + } + token, err := jwt.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + return nil, fmt.Errorf("error signing: %w", err) + } + + swid := &structs.SignedWorkloadIdentity{ + WorkloadIdentityRequest: *idReq, + JWT: token, + Expiration: claims.Expiry.Time(), + } + + swids = append(swids, swid) + } + return swids, nil +} + +// MockTokenSetter is a mock implementation of tokenSetter which is satisfied +// by TaskRunner at runtime. +type MockTokenSetter struct { + defaultToken string +} + +func (m *MockTokenSetter) setNomadToken(token string) { + m.defaultToken = token +} + +// TestIdentityHook_RenewAll asserts token renewal happens when expected. +func TestIdentityHook_RenewAll(t *testing.T) { + ci.Parallel(t) + + // TTL is used for expiration and the test will sleep this long before + // checking that tokens were rotated. Therefore the time must be long enough + // to generate new tokens. Since no Raft or IO (outside of potentially + // writing 1 token file) is performed, this should be relatively fast. + ttl := 8 * time.Second + + node := mock.Node() + alloc := mock.Alloc() + alloc.NodeID = node.ID + task := alloc.LookupTask("web") + task.Identities = []*structs.WorkloadIdentity{ + { + Name: "consul", + Audience: []string{"consul"}, + Env: true, + TTL: ttl, + }, + { + Name: "vault", + Audience: []string{"vault"}, + File: true, + TTL: ttl, + }, + } + + secretsDir := t.TempDir() + + widmgr := NewMockWIDMgr(task.Identities) + + mockTR := &MockTokenSetter{} + + stopCtx, stop := context.WithCancel(context.Background()) + t.Cleanup(stop) + + h := &identityHook{ + alloc: alloc, + task: task, + tokenDir: secretsDir, + envBuilder: taskenv.NewBuilder(node, alloc, task, alloc.Job.Region), + ts: mockTR, + widmgr: widmgr, + minWait: time.Second, + logger: testlog.HCLogger(t), + stopCtx: stopCtx, + stop: stop, + } + + start := time.Now() + must.NoError(t, h.Prestart(context.Background(), nil, nil)) + env := h.envBuilder.Build().EnvMap + + // Assert initial tokens were set in Prestart + must.Eq(t, alloc.SignedIdentities["web"], mockTR.defaultToken) + must.FileNotExists(t, filepath.Join(secretsDir, wiTokenFile)) + must.FileNotExists(t, filepath.Join(secretsDir, "nomad_consul.jwt")) + must.MapContainsKey(t, env, "NOMAD_TOKEN_consul") + must.FileExists(t, filepath.Join(secretsDir, "nomad_vault.jwt")) + + origConsul := env["NOMAD_TOKEN_consul"] + origVault := testutil.MustReadFile(t, secretsDir, "nomad_vault.jwt") + + // Tokens should be rotated by their expiration + wait := time.Until(start.Add(ttl)) + h.logger.Trace("sleeping until expiration", "wait", wait) + time.Sleep(wait) + + // Stop renewal before checking to ensure stopping works + must.NoError(t, h.Stop(context.Background(), nil, nil)) + time.Sleep(time.Second) // Stop is async so give renewal time to exit + + newConsul := h.envBuilder.Build().EnvMap["NOMAD_TOKEN_consul"] + must.StrContains(t, newConsul, ".") // ensure new token is JWTish + must.NotEq(t, newConsul, origConsul) + + newVault := testutil.MustReadFile(t, secretsDir, "nomad_vault.jwt") + must.StrContains(t, string(newVault), ".") // ensure new token is JWTish + must.NotEq(t, newVault, origVault) + + // Assert Stop work. Tokens should not have changed. + time.Sleep(wait) + must.Eq(t, newConsul, h.envBuilder.Build().EnvMap["NOMAD_TOKEN_consul"]) + must.Eq(t, newVault, testutil.MustReadFile(t, secretsDir, "nomad_vault.jwt")) +} + +// TestIdentityHook_RenewOne asserts token renewal only renews tokens with a TTL. +func TestIdentityHook_RenewOne(t *testing.T) { + ci.Parallel(t) + + ttl := 8 * time.Second + + node := mock.Node() + alloc := mock.Alloc() + alloc.NodeID = node.ID + alloc.SignedIdentities = map[string]string{"web": "does.not.matter"} + task := alloc.LookupTask("web") + task.Identities = []*structs.WorkloadIdentity{ + { + Name: "consul", + Audience: []string{"consul"}, + Env: true, + }, + { + Name: "vault", + Audience: []string{"vault"}, + File: true, + TTL: ttl, + }, + } + + secretsDir := t.TempDir() + + widmgr := NewMockWIDMgr(task.Identities) + + mockTR := &MockTokenSetter{} + + stopCtx, stop := context.WithCancel(context.Background()) + t.Cleanup(stop) + + h := &identityHook{ + alloc: alloc, + task: task, + tokenDir: secretsDir, + envBuilder: taskenv.NewBuilder(node, alloc, task, alloc.Job.Region), + ts: mockTR, + widmgr: widmgr, + minWait: time.Second, + logger: testlog.HCLogger(t), + stopCtx: stopCtx, + stop: stop, + } + + start := time.Now() + must.NoError(t, h.Prestart(context.Background(), nil, nil)) + env := h.envBuilder.Build().EnvMap + + // Assert initial tokens were set in Prestart + must.Eq(t, alloc.SignedIdentities["web"], mockTR.defaultToken) + must.FileNotExists(t, filepath.Join(secretsDir, wiTokenFile)) + must.FileNotExists(t, filepath.Join(secretsDir, "nomad_consul.jwt")) + must.MapContainsKey(t, env, "NOMAD_TOKEN_consul") + must.FileExists(t, filepath.Join(secretsDir, "nomad_vault.jwt")) + + origConsul := env["NOMAD_TOKEN_consul"] + origVault := testutil.MustReadFile(t, secretsDir, "nomad_vault.jwt") + + // One token should be rotated by their expiration + wait := time.Until(start.Add(ttl)) + h.logger.Trace("sleeping until expiration", "wait", wait) + time.Sleep(wait) + + // Stop renewal before checking to ensure stopping works + must.NoError(t, h.Stop(context.Background(), nil, nil)) + time.Sleep(time.Second) // Stop is async so give renewal time to exit + + newConsul := h.envBuilder.Build().EnvMap["NOMAD_TOKEN_consul"] + must.StrContains(t, newConsul, ".") // ensure new token is JWTish + must.Eq(t, newConsul, origConsul) + + newVault := testutil.MustReadFile(t, secretsDir, "nomad_vault.jwt") + must.StrContains(t, string(newVault), ".") // ensure new token is JWTish + must.NotEq(t, newVault, origVault) + + // Assert Stop work. Tokens should not have changed. + time.Sleep(wait) + must.Eq(t, newConsul, h.envBuilder.Build().EnvMap["NOMAD_TOKEN_consul"]) + must.Eq(t, newVault, testutil.MustReadFile(t, secretsDir, "nomad_vault.jwt")) +} + +// TestIdentityHook_ErrorWriting assert Prestart returns an error if the +// default token could not be written when requested. +func TestIdentityHook_ErrorWriting(t *testing.T) { + ci.Parallel(t) + + alloc := mock.Alloc() + alloc.SignedIdentities = map[string]string{"web": "does.not.need.to.be.valid"} + task := alloc.LookupTask("web") + task.Identity.File = true + node := mock.Node() + stopCtx, stop := context.WithCancel(context.Background()) + t.Cleanup(stop) + + h := &identityHook{ + alloc: alloc, + task: task, + tokenDir: "/this-should-not-exist", + envBuilder: taskenv.NewBuilder(node, alloc, task, alloc.Job.Region), + ts: &MockTokenSetter{}, + widmgr: NewMockWIDMgr(nil), + minWait: time.Second, + logger: testlog.HCLogger(t), + stopCtx: stopCtx, + stop: stop, + } + + // Prestart should fail when trying to write the default identity file + err := h.Prestart(context.Background(), nil, nil) + must.ErrorContains(t, err, "failed to write nomad token") +} + +// TestIdentityHook_GetIdentitiesMismatch asserts that if SignIdentities() does +// not return enough identities then Prestart fails. +func TestIdentityHook_GetIdentitiesMismatch(t *testing.T) { + ci.Parallel(t) + + alloc := mock.Alloc() + task := alloc.LookupTask("web") + task.Identities = []*structs.WorkloadIdentity{ + { + Name: "consul", + Audience: []string{"consul"}, + TTL: time.Minute, + }, + } + node := mock.Node() + stopCtx, stop := context.WithCancel(context.Background()) + t.Cleanup(stop) + + wids := []*structs.WorkloadIdentity{ + { + Name: "not-consul", + }, + } + h := &identityHook{ + alloc: alloc, + task: task, + tokenDir: t.TempDir(), + envBuilder: taskenv.NewBuilder(node, alloc, task, alloc.Job.Region), + ts: &MockTokenSetter{}, + widmgr: NewMockWIDMgr(wids), + minWait: time.Second, + logger: testlog.HCLogger(t), + stopCtx: stopCtx, + stop: stop, + } + + // Prestart should fail when trying to write the default identity file + err := h.Prestart(context.Background(), nil, nil) + must.ErrorContains(t, err, "error fetching alternate identities") +} diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index fc5a7d4fe..70121907e 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -65,22 +65,6 @@ func (m *MockTaskStateUpdater) TaskStateUpdated() { } } -// MockWIDMgr allows TaskRunner unit tests to avoid having to setup a Server, -// Client, and Allocation. -type MockWIDMgr struct{} - -func (m MockWIDMgr) SignIdentities(minIndex uint64, req []*structs.WorkloadIdentityRequest) ([]*structs.SignedWorkloadIdentity, error) { - swids := make([]*structs.SignedWorkloadIdentity, 0, len(req)) - for _, idReq := range req { - swids = append(swids, &structs.SignedWorkloadIdentity{ - WorkloadIdentityRequest: *idReq, - // Just the sample jwt from jwt.io so it "looks" like a jwt - JWT: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", - }) - } - return swids, nil -} - // testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task // plus a cleanup func. func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) { @@ -152,7 +136,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri ServiceRegWrapper: wrapperMock, Getter: getter.TestSandbox(t), Wranglers: proclib.New(&proclib.Configs{Logger: testlog.HCLogger(t)}), - WIDMgr: MockWIDMgr{}, + WIDMgr: NewMockWIDMgr(nil), } return conf, trCleanup diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index e26825230..959ff2580 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -1213,6 +1213,7 @@ func ApiTaskToStructsTask(job *structs.Job, group *structs.TaskGroup, Audience: slices.Clone(id.Audience), Env: id.Env, File: id.File, + TTL: id.TTL, } } @@ -1228,6 +1229,7 @@ func ApiTaskToStructsTask(job *structs.Job, group *structs.TaskGroup, Audience: slices.Clone(id.Audience), Env: id.Env, File: id.File, + TTL: id.TTL, } } diff --git a/helper/retry_test.go b/helper/retry_test.go index cfe6c4a9c..685e97b52 100644 --- a/helper/retry_test.go +++ b/helper/retry_test.go @@ -74,3 +74,18 @@ func TestExpiryToRenewTime_Expired(t *testing.T) { must.Greater(t, min, renew) must.Less(t, min*2, renew) } + +// TestExpiryToRenewTime_Zero asserts that ExpiryToRenewTime handles the zero +// value for renewal time and returns the minimum. +func TestExpiryToRenewTime_Zero(t *testing.T) { + exp := time.Time{} + now := func() time.Time { + return time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC) + } + min := time.Hour + + renew := ExpiryToRenewTime(exp, now, min) + + must.Greater(t, min, renew) + must.Less(t, min*2, renew) +} diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index a30cb6df3..ab024213b 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -590,6 +590,7 @@ func (a *Alloc) SignIdentities(args *structs.AllocIdentitiesRequest, reply *stru reply.SignedIdentities = append(reply.SignedIdentities, &structs.SignedWorkloadIdentity{ WorkloadIdentityRequest: *idReq, JWT: token, + Expiration: claims.Expiry.Time(), }) break } diff --git a/nomad/structs/diff_test.go b/nomad/structs/diff_test.go index 138f7e470..b1d73f7dd 100644 --- a/nomad/structs/diff_test.go +++ b/nomad/structs/diff_test.go @@ -8320,6 +8320,12 @@ func TestTaskDiff(t *testing.T) { Old: "", New: "false", }, + { + Type: DiffTypeAdded, + Name: "TTL", + Old: "", + New: "0", + }, }, }, }, @@ -8350,6 +8356,12 @@ func TestTaskDiff(t *testing.T) { Name: "File", Old: "false", }, + { + Type: DiffTypeDeleted, + Name: "TTL", + Old: "0", + New: "", + }, }, }, }, @@ -8408,6 +8420,7 @@ func TestTaskDiff(t *testing.T) { Name: "vault", Audience: []string{"vault.io"}, File: true, + TTL: time.Hour, }, }, }, @@ -8450,6 +8463,12 @@ func TestTaskDiff(t *testing.T) { Old: "", New: "vault", }, + { + Type: DiffTypeAdded, + Name: "TTL", + Old: "", + New: "3600000000000", + }, }, }, }, @@ -8464,6 +8483,7 @@ func TestTaskDiff(t *testing.T) { Audience: []string{"consul.io"}, Env: true, File: false, + TTL: time.Hour, }, { Name: "vault", @@ -8479,6 +8499,7 @@ func TestTaskDiff(t *testing.T) { Audience: []string{"consul-prod.io"}, Env: false, File: true, + TTL: 2 * time.Hour, }, { // Modifying the previous block to be deleted and a new @@ -8486,6 +8507,7 @@ func TestTaskDiff(t *testing.T) { Name: "vault-dev", Audience: []string{"vault.io"}, File: true, + TTL: time.Hour, }, }, }, @@ -8528,6 +8550,12 @@ func TestTaskDiff(t *testing.T) { Old: "false", New: "true", }, + { + Type: DiffTypeEdited, + Name: "TTL", + Old: "3600000000000", + New: "7200000000000", + }, }, }, { @@ -8566,6 +8594,12 @@ func TestTaskDiff(t *testing.T) { Old: "", New: "vault-dev", }, + { + Type: DiffTypeAdded, + Name: "TTL", + Old: "", + New: "3600000000000", + }, }, }, { @@ -8604,6 +8638,12 @@ func TestTaskDiff(t *testing.T) { Old: "vault", New: "", }, + { + Type: DiffTypeDeleted, + Name: "TTL", + Old: "0", + New: "", + }, }, }, }, diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 4335e6fe2..b78af764a 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -11300,16 +11300,17 @@ func NewIdentityClaims(job *Job, alloc *Allocation, taskName string, wid *Worklo } claims.TaskName = taskName - claims.Audience = wid.Audience - claims.SetSubject(job, alloc.TaskGroup, taskName, wid.Name) + claims.Audience = slices.Clone(wid.Audience) + claims.setSubject(job, alloc.TaskGroup, taskName, wid.Name) + claims.setExp(now, wid) claims.ID = uuid.Generate() return claims } -// SetSubject creates the standard subject claim for workload identities. -func (claims *IdentityClaims) SetSubject(job *Job, group, task, id string) { +// setSubject creates the standard subject claim for workload identities. +func (claims *IdentityClaims) setSubject(job *Job, group, task, id string) { claims.Subject = strings.Join([]string{ job.Region, job.Namespace, @@ -11320,6 +11321,16 @@ func (claims *IdentityClaims) SetSubject(job *Job, group, task, id string) { }, ":") } +// setExp sets the absolute time at which these identity claims expire. +func (claims *IdentityClaims) setExp(now time.Time, wid *WorkloadIdentity) { + if wid.TTL == 0 { + // No expiry + return + } + + claims.Expiry = jwt.NewNumericDate(now.Add(wid.TTL)) +} + // AllocationDiff is another named type for Allocation (to use the same fields), // which is used to represent the delta for an Allocation. If you need a method // defined on the al diff --git a/nomad/structs/workload_id.go b/nomad/structs/workload_id.go index c15a51070..2a1a33c79 100644 --- a/nomad/structs/workload_id.go +++ b/nomad/structs/workload_id.go @@ -6,6 +6,7 @@ package structs import ( "fmt" "slices" + "time" "github.com/hashicorp/go-multierror" ) @@ -63,6 +64,10 @@ type WorkloadIdentity struct { // ServiceName is used to bind the identity to a correct Consul service. ServiceName string + + // TTL is used to determine the expiration of the credentials created for + // this identity (eg the JWT "exp" claim). + TTL time.Duration } func (wi *WorkloadIdentity) Copy() *WorkloadIdentity { @@ -75,6 +80,7 @@ func (wi *WorkloadIdentity) Copy() *WorkloadIdentity { Env: wi.Env, File: wi.File, ServiceName: wi.ServiceName, + TTL: wi.TTL, } } @@ -103,6 +109,10 @@ func (wi *WorkloadIdentity) Equal(other *WorkloadIdentity) bool { return false } + if wi.TTL != other.TTL { + return false + } + return true } @@ -139,6 +149,14 @@ func (wi *WorkloadIdentity) Validate() error { } } + if wi.TTL > 0 && (wi.Name == "" || wi.Name == WorkloadIdentityDefaultName) { + mErr.Errors = append(mErr.Errors, fmt.Errorf("ttl for default identity not yet supported")) + } + + if wi.TTL < 0 { + mErr.Errors = append(mErr.Errors, fmt.Errorf("ttl must be >= 0")) + } + return mErr.ErrorOrNil() } @@ -147,13 +165,21 @@ func (wi *WorkloadIdentity) Warnings() error { return fmt.Errorf("must not be nil") } + var mErr multierror.Error + if n := len(wi.Audience); n == 0 { - return fmt.Errorf("identities without an audience are insecure") + mErr.Errors = append(mErr.Errors, fmt.Errorf("identities without an audience are insecure")) } else if n > 1 { - return fmt.Errorf("while multiple audiences is allowed, it is more secure to use 1 audience per identity") + mErr.Errors = append(mErr.Errors, fmt.Errorf("while multiple audiences is allowed, it is more secure to use 1 audience per identity")) } - return nil + if wi.Name != "" && wi.Name != WorkloadIdentityDefaultName { + if wi.TTL == 0 { + mErr.Errors = append(mErr.Errors, fmt.Errorf("identities without an expiration are insecure")) + } + } + + return mErr.ErrorOrNil() } // WorkloadIdentityRequest encapsulates the 3 parameters used to generated a @@ -168,7 +194,8 @@ type WorkloadIdentityRequest struct { // includes the JWT for the requested workload identity. type SignedWorkloadIdentity struct { WorkloadIdentityRequest - JWT string + JWT string + Expiration time.Time } // WorkloadIdentityRejection is the response to a WorkloadIdentityRequest that diff --git a/nomad/structs/workload_id_test.go b/nomad/structs/workload_id_test.go index c30f75426..5c452396c 100644 --- a/nomad/structs/workload_id_test.go +++ b/nomad/structs/workload_id_test.go @@ -6,6 +6,7 @@ package structs import ( "strings" "testing" + "time" "github.com/hashicorp/nomad/ci" "github.com/shoenig/test/must" @@ -45,6 +46,12 @@ func TestWorkloadIdentity_Equal(t *testing.T) { newWI.Audience = []string{"foo"} must.NotEqual(t, orig, newWI) + + newWI.Audience = orig.Audience + must.Equal(t, orig, newWI) + + newWI.TTL = 123 * time.Hour + must.NotEqual(t, orig, newWI) } // TestWorkloadIdentity_Validate asserts that canonicalized workload identities @@ -84,12 +91,14 @@ func TestWorkloadIdentity_Validate(t *testing.T) { Audience: []string{"http://nomadproject.io/"}, Env: true, File: true, + TTL: time.Hour, }, Exp: WorkloadIdentity{ Name: "foo-id", Audience: []string{"http://nomadproject.io/"}, Env: true, File: true, + TTL: time.Hour, }, }, { @@ -143,6 +152,26 @@ func TestWorkloadIdentity_Validate(t *testing.T) { }, Warn: "while multiple audiences is allowed, it is more secure to use 1 audience per identity", }, + { + Desc: "Bad TTL", + In: WorkloadIdentity{ + Name: "foo", + TTL: -1 * time.Hour, + }, + Err: "ttl must be >= 0", + }, + { + Desc: "No TTL", + In: WorkloadIdentity{ + Name: "foo", + Audience: []string{"foo"}, + }, + Exp: WorkloadIdentity{ + Name: "foo", + Audience: []string{"foo"}, + }, + Warn: "identities without an expiration are insecure", + }, } for _, tc := range cases { diff --git a/testutil/file.go b/testutil/file.go new file mode 100644 index 000000000..cefc85f4c --- /dev/null +++ b/testutil/file.go @@ -0,0 +1,20 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package testutil + +import ( + "os" + "path/filepath" + + testing "github.com/mitchellh/go-testing-interface" + "github.com/shoenig/test/must" +) + +// MustReadFile returns the contents of the specified file or fails the test. +// Multiple arguments are joined with filepath.Join. +func MustReadFile(t testing.T, path ...string) []byte { + contents, err := os.ReadFile(filepath.Join(path...)) + must.NoError(t, err) + return contents +}