Files
nomad/client/widmgr/widmgr.go
James Rasell 4c4cb2c6ad agent: Fix misaligned contextual k/v logging arguments. (#25629)
Arguments passed to hclog log lines should always have an even
number to provide the expected k/v output.
2025-04-10 14:40:21 +01:00

453 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package widmgr
import (
"context"
"fmt"
"slices"
"sync"
"time"
"github.com/hashicorp/go-hclog"
cstate "github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/nomad/structs"
)
// IdentityManager defines a manager responsible for signing and renewing
// signed identities. At runtime it is implemented by *widmgr.WIDMgr.
type IdentityManager interface {
Run() error
Get(structs.WIHandle) (*structs.SignedWorkloadIdentity, error)
Watch(structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func())
Shutdown()
}
type WIDMgr struct {
allocID string
defaultSignedIdentities map[string]string // signed by the plan applier
minIndex uint64
widSpecs map[structs.WIHandle]*structs.WorkloadIdentity // workload handle -> WI
signer IdentitySigner
db cstate.StateDB
// lastToken are the last retrieved signed workload identifiers keyed by
// TaskIdentity
lastToken map[structs.WIHandle]*structs.SignedWorkloadIdentity
lastTokenLock sync.RWMutex
// watchers is a map of task identities to slices of channels (each identity
// can have multiple watchers)
watchers map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity
watchersLock sync.Mutex
// minWait is the minimum amount of time to wait before renewing. Settable to
// ease testing.
minWait time.Duration
stopCtx context.Context
stop context.CancelFunc
logger hclog.Logger
}
func NewWIDMgr(signer IdentitySigner, a *structs.Allocation, db cstate.StateDB, logger hclog.Logger, allocEnv *taskenv.TaskEnv) *WIDMgr {
widspecs := map[structs.WIHandle]*structs.WorkloadIdentity{}
tg := a.Job.LookupTaskGroup(a.TaskGroup)
for _, service := range tg.Services {
if service.Identity != nil {
handle := *service.IdentityHandle(allocEnv.ReplaceEnv)
widspecs[handle] = service.Identity
}
}
for _, task := range tg.Tasks {
// Omit default identity as it does not expire
for _, id := range task.Identities {
widspecs[*task.IdentityHandle(id)] = id
}
// update the builder for this task
taskEnv := allocEnv.WithTask(a, task)
for _, service := range task.Services {
if service.Identity != nil {
handle := *service.IdentityHandle(taskEnv.ReplaceEnv)
widspecs[handle] = service.Identity
}
}
}
// Create a context for the renew loop. This context will be canceled when
// the allocation is stopped or agent is shutting down
stopCtx, stop := context.WithCancel(context.Background())
return &WIDMgr{
allocID: a.ID,
defaultSignedIdentities: a.SignedIdentities,
minIndex: a.CreateIndex,
widSpecs: widspecs,
signer: signer,
db: db,
minWait: 10 * time.Second,
lastToken: map[structs.WIHandle]*structs.SignedWorkloadIdentity{},
watchers: map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity{},
stopCtx: stopCtx,
stop: stop,
logger: logger.Named("widmgr"),
}
}
// SetMinWait sets the minimum time for renewals
func (m *WIDMgr) SetMinWait(t time.Duration) {
m.minWait = t
}
// Run blocks until identities are initially signed and then renews them in a
// goroutine. The goroutine is stopped when WIDMgr.Shutdown is called.
//
// If an error is returned the identities could not be fetched and the renewal
// goroutine was not started.
func (m *WIDMgr) Run() error {
if len(m.widSpecs) == 0 && len(m.defaultSignedIdentities) == 0 {
m.logger.Debug("no workload identities to retrieve or renew")
return nil
}
m.logger.Debug("retrieving and renewing workload identities", "num_identities", len(m.widSpecs))
hasExpired, err := m.restoreStoredIdentities()
if err != nil {
m.logger.Warn("failed to get signed identities from state DB, refreshing from server",
"error", err)
}
if hasExpired {
if err := m.getInitialIdentities(); err != nil {
return fmt.Errorf("failed to fetch signed identities: %w", err)
}
}
go m.renew()
return nil
}
// Get retrieves the latest signed identity or returns an error. It must be
// called after Run and does not block.
//
// For retrieving tokens which might be renewed callers should use Watch
// instead to avoid missing new tokens retrieved by Run between Get and Watch
// calls.
func (m *WIDMgr) Get(id structs.WIHandle) (*structs.SignedWorkloadIdentity, error) {
token := m.get(id)
if token == nil {
// This is an error as every identity should have a token by the time Get
// is called.
return nil, fmt.Errorf("unable to find token for workload %q and identity %q", id.WorkloadIdentifier, id.IdentityName)
}
return token, nil
}
func (m *WIDMgr) get(id structs.WIHandle) *structs.SignedWorkloadIdentity {
m.lastTokenLock.RLock()
defer m.lastTokenLock.RUnlock()
return m.lastToken[id]
}
// Watch returns a channel that sends new signed identities until it is closed
// due to shutdown. Must be called after Run.
//
// The caller must call the returned func to stop watching and ensure the
// watched id actually exists, otherwise the channel never returns a result.
func (m *WIDMgr) Watch(id structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func()) {
// If Shutdown has been called return a closed chan
if m.stopCtx.Err() != nil {
c := make(chan *structs.SignedWorkloadIdentity)
close(c)
return c, func() {}
}
m.watchersLock.Lock()
defer m.watchersLock.Unlock()
// Buffer of 1 so sends don't block on receives
c := make(chan *structs.SignedWorkloadIdentity, 1)
m.watchers[id] = make([]chan *structs.SignedWorkloadIdentity, 0)
m.watchers[id] = append(m.watchers[id], c)
// Create a cancel func for watchers to deregister when they exit.
cancel := func() {
m.watchersLock.Lock()
defer m.watchersLock.Unlock()
m.watchers[id] = slices.DeleteFunc(
m.watchers[id],
func(ch chan *structs.SignedWorkloadIdentity) bool { return ch == c },
)
}
// Prime chan with latest token to avoid a race condition where consumers
// could miss a token update between Get and Watch calls.
if token := m.get(id); token != nil {
c <- token
}
return c, cancel
}
// Shutdown stops renewal and closes all watch chans.
func (m *WIDMgr) Shutdown() {
m.watchersLock.Lock()
defer m.watchersLock.Unlock()
m.stop()
for _, w := range m.watchers {
for _, c := range w {
close(c)
}
}
// ensure it's safe to call Shutdown multiple times
m.watchers = map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity{}
}
// restoreStoredIdentities recreates the state of the WIDMgr from a previously
// saved state, so that we can avoid asking for all identities again after a
// client agent restart. It returns true if the caller should immediately call
// getIdentities because one or more of the identities is expired.
func (m *WIDMgr) restoreStoredIdentities() (bool, error) {
storedIdentities, err := m.db.GetAllocIdentities(m.allocID)
if err != nil {
return true, err
}
if len(storedIdentities) == 0 {
return true, nil
}
m.lastTokenLock.Lock()
defer m.lastTokenLock.Unlock()
var hasExpired bool
for _, identity := range storedIdentities {
if !identity.Expiration.IsZero() && identity.Expiration.Before(time.Now()) {
hasExpired = true
}
m.lastToken[identity.WIHandle] = identity
}
return hasExpired, nil
}
// SignForTesting signs all the identities in m.widspec, typically with the mock
// signer. This should only be used for testing downstream hooks.
func (m *WIDMgr) SignForTesting() {
m.getInitialIdentities()
}
// getInitialIdentities fetches all signed identities or returns an error. It
// should be run once when the WIDMgr first runs.
func (m *WIDMgr) getInitialIdentities() error {
// get the default identity signed by the plan applier
defaultTokens := map[structs.WIHandle]*structs.SignedWorkloadIdentity{}
for taskName, signature := range m.defaultSignedIdentities {
id := structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: "default",
}
widReq := structs.WorkloadIdentityRequest{
AllocID: m.allocID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: "default",
},
}
defaultTokens[id] = &structs.SignedWorkloadIdentity{
WorkloadIdentityRequest: widReq,
JWT: signature,
Expiration: time.Time{},
}
}
if len(m.widSpecs) == 0 && len(defaultTokens) == 0 {
return nil
}
m.lastTokenLock.Lock()
defer m.lastTokenLock.Unlock()
reqs := make([]*structs.WorkloadIdentityRequest, 0, len(m.widSpecs))
for wiHandle := range m.widSpecs {
reqs = append(reqs, &structs.WorkloadIdentityRequest{
AllocID: m.allocID,
WIHandle: wiHandle,
})
}
// Get signed workload identities
signedWIDs := []*structs.SignedWorkloadIdentity{}
if len(m.widSpecs) != 0 {
var err error
signedWIDs, err = m.signer.SignIdentities(m.minIndex, reqs)
if err != nil {
return err
}
}
// Store default identity tokens
for id, token := range defaultTokens {
m.lastToken[id] = token
}
// Index initial workload identities by name
for _, swid := range signedWIDs {
m.lastToken[swid.WIHandle] = swid
}
return m.db.PutAllocIdentities(m.allocID, signedWIDs)
}
// renew fetches new signed workload identity tokens before the existing tokens
// expire.
func (m *WIDMgr) renew() {
if len(m.widSpecs) == 0 {
return
}
reqs := make([]*structs.WorkloadIdentityRequest, 0, len(m.widSpecs))
for workloadHandle, widspec := range m.widSpecs {
if widspec.TTL == 0 {
continue
}
reqs = append(reqs, &structs.WorkloadIdentityRequest{
AllocID: m.allocID,
WIHandle: workloadHandle,
})
}
if len(reqs) == 0 {
m.logger.Trace("no workload identities expire")
return
}
renewNow := false
minExp := time.Time{}
for workloadHandle, wid := range m.widSpecs {
if wid.TTL == 0 {
// No ttl, so no need to renew it
continue
}
token := m.get(workloadHandle)
if token == nil {
// Missing a signature, treat this case as already expired so
// we get a token ASAP
m.logger.Debug("missing token for identity", "identity", wid.Name)
renewNow = true
continue
}
if minExp.IsZero() || token.Expiration.Before(minExp) {
minExp = token.Expiration
}
}
var wait time.Duration
if !renewNow {
wait = helper.ExpiryToRenewTime(minExp, time.Now, m.minWait)
}
timer, timerStop := helper.NewStoppedTimer()
defer timerStop()
var retry uint64
for {
// we need to handle stopCtx.Err() and manually stop the subscribers
if err := m.stopCtx.Err(); err != nil {
// close watchers and shutdown
m.Shutdown()
return
}
m.logger.Debug("waiting to renew identities", "num", len(reqs), "wait", wait)
timer.Reset(wait)
select {
case <-timer.C:
m.logger.Trace("getting new signed identities", "num", len(reqs))
case <-m.stopCtx.Done():
// close watchers and shutdown
m.Shutdown()
return
}
// Renew all tokens together since its cheap
tokens, err := m.signer.SignIdentities(m.minIndex, reqs)
if err != nil {
retry++
wait = helper.Backoff(m.minWait, time.Hour, retry) + helper.RandomStagger(m.minWait)
m.logger.Error("error renewing workload identities", "error", err, "next", wait)
continue
}
if len(tokens) == 0 {
retry++
wait = helper.Backoff(m.minWait, time.Hour, retry) + helper.RandomStagger(m.minWait)
m.logger.Error("error renewing workload identities", "error", "no tokens", "next", wait)
continue
}
// Reset next expiration time
minExp = time.Time{}
for _, token := range tokens {
// Set for getters
m.lastTokenLock.Lock()
m.lastToken[token.WIHandle] = token
m.lastTokenLock.Unlock()
// Send to watchers
m.watchersLock.Lock()
m.send(token.WIHandle, token)
m.watchersLock.Unlock()
// Set next expiration time
if minExp.IsZero() || token.Expiration.Before(minExp) {
minExp = token.Expiration
}
}
// Success! Set next renewal and reset retries
wait = helper.ExpiryToRenewTime(minExp, time.Now, m.minWait)
retry = 0
}
}
// send must be called while holding the m.watchersLock
func (m *WIDMgr) send(id structs.WIHandle, token *structs.SignedWorkloadIdentity) {
w, ok := m.watchers[id]
if !ok {
// No watchers
return
}
for _, c := range w {
// Pop any unreceived tokens
select {
case <-c:
default:
}
// Send new token, should never block since this is the only sender and
// watchersLock is held
c <- token
}
}