Files
nomad/client/widmgr/widmgr.go
Tim Gross e168548341 provide allocrunner hooks with prebuilt taskenv and fix mutation bugs (#25373)
Some of our allocrunner hooks require a task environment for interpolating values based on the node or allocation. But several of the hooks accept an already-built environment or builder and then keep that in memory. Both of these retain a copy of all the node attributes and allocation metadata, which balloons memory usage until the allocation is GC'd.

While we'd like to look into ways to avoid keeping the allocrunner around entirely (see #25372), for now we can significantly reduce memory usage by creating the task environment on-demand when calling allocrunner methods, rather than persisting it in the allocrunner hooks.

In doing so, we uncover two other bugs:
* The WID manager, the group service hook, and the checks hook have to interpolate services for specific tasks. They mutated a taskenv builder to do so, but each time they mutate the builder, they write to the same environment map. When a group has multiple tasks, it's possible for one task to set an environment variable that would then be interpolated in the service definition for another task if that task did not have that environment variable. Only the service definition interpolation is impacted. This does not leak env vars across running tasks, as each taskrunner has its own builder.

  To fix this, we move the `UpdateTask` method off the builder and onto the taskenv as the `WithTask` method. This makes a shallow copy of the taskenv with a deep clone of the environment map used for interpolation, and then overwrites the environment from the task.

* The checks hook interpolates Nomad native service checks only on `Prerun` and not on `Update`. This could cause unexpected deregistration and registration of checks during in-place updates. To fix this, we make sure we interpolate in the `Update` method.

I also bumped into an incorrectly implemented interface in the CSI hook. I've pulled that and some better guardrails out to https://github.com/hashicorp/nomad/pull/25472.

Fixes: https://github.com/hashicorp/nomad/issues/25269
Fixes: https://hashicorp.atlassian.net/browse/NET-12310
Ref: https://github.com/hashicorp/nomad/issues/25372
2025-03-24 12:05:04 -04:00

452 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package widmgr
import (
"context"
"fmt"
"slices"
"sync"
"time"
"github.com/hashicorp/go-hclog"
cstate "github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/nomad/structs"
)
// IdentityManager defines a manager responsible for signing and renewing
// signed identities. At runtime it is implemented by *widmgr.WIDMgr.
type IdentityManager interface {
Run() error
Get(structs.WIHandle) (*structs.SignedWorkloadIdentity, error)
Watch(structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func())
Shutdown()
}
type WIDMgr struct {
allocID string
defaultSignedIdentities map[string]string // signed by the plan applier
minIndex uint64
widSpecs map[structs.WIHandle]*structs.WorkloadIdentity // workload handle -> WI
signer IdentitySigner
db cstate.StateDB
// lastToken are the last retrieved signed workload identifiers keyed by
// TaskIdentity
lastToken map[structs.WIHandle]*structs.SignedWorkloadIdentity
lastTokenLock sync.RWMutex
// watchers is a map of task identities to slices of channels (each identity
// can have multiple watchers)
watchers map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity
watchersLock sync.Mutex
// minWait is the minimum amount of time to wait before renewing. Settable to
// ease testing.
minWait time.Duration
stopCtx context.Context
stop context.CancelFunc
logger hclog.Logger
}
func NewWIDMgr(signer IdentitySigner, a *structs.Allocation, db cstate.StateDB, logger hclog.Logger, allocEnv *taskenv.TaskEnv) *WIDMgr {
widspecs := map[structs.WIHandle]*structs.WorkloadIdentity{}
tg := a.Job.LookupTaskGroup(a.TaskGroup)
for _, service := range tg.Services {
if service.Identity != nil {
handle := *service.IdentityHandle(allocEnv.ReplaceEnv)
widspecs[handle] = service.Identity
}
}
for _, task := range tg.Tasks {
// Omit default identity as it does not expire
for _, id := range task.Identities {
widspecs[*task.IdentityHandle(id)] = id
}
// update the builder for this task
taskEnv := allocEnv.WithTask(a, task)
for _, service := range task.Services {
if service.Identity != nil {
handle := *service.IdentityHandle(taskEnv.ReplaceEnv)
widspecs[handle] = service.Identity
}
}
}
// Create a context for the renew loop. This context will be canceled when
// the allocation is stopped or agent is shutting down
stopCtx, stop := context.WithCancel(context.Background())
return &WIDMgr{
allocID: a.ID,
defaultSignedIdentities: a.SignedIdentities,
minIndex: a.CreateIndex,
widSpecs: widspecs,
signer: signer,
db: db,
minWait: 10 * time.Second,
lastToken: map[structs.WIHandle]*structs.SignedWorkloadIdentity{},
watchers: map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity{},
stopCtx: stopCtx,
stop: stop,
logger: logger.Named("widmgr"),
}
}
// SetMinWait sets the minimum time for renewals
func (m *WIDMgr) SetMinWait(t time.Duration) {
m.minWait = t
}
// Run blocks until identities are initially signed and then renews them in a
// goroutine. The goroutine is stopped when WIDMgr.Shutdown is called.
//
// If an error is returned the identities could not be fetched and the renewal
// goroutine was not started.
func (m *WIDMgr) Run() error {
if len(m.widSpecs) == 0 && len(m.defaultSignedIdentities) == 0 {
m.logger.Debug("no workload identities to retrieve or renew")
return nil
}
m.logger.Debug("retrieving and renewing workload identities", "num_identities", len(m.widSpecs))
hasExpired, err := m.restoreStoredIdentities()
if err != nil {
m.logger.Warn("failed to get signed identities from state DB, refreshing from server: %w", err)
}
if hasExpired {
if err := m.getInitialIdentities(); err != nil {
return fmt.Errorf("failed to fetch signed identities: %w", err)
}
}
go m.renew()
return nil
}
// Get retrieves the latest signed identity or returns an error. It must be
// called after Run and does not block.
//
// For retrieving tokens which might be renewed callers should use Watch
// instead to avoid missing new tokens retrieved by Run between Get and Watch
// calls.
func (m *WIDMgr) Get(id structs.WIHandle) (*structs.SignedWorkloadIdentity, error) {
token := m.get(id)
if token == nil {
// This is an error as every identity should have a token by the time Get
// is called.
return nil, fmt.Errorf("unable to find token for workload %q and identity %q", id.WorkloadIdentifier, id.IdentityName)
}
return token, nil
}
func (m *WIDMgr) get(id structs.WIHandle) *structs.SignedWorkloadIdentity {
m.lastTokenLock.RLock()
defer m.lastTokenLock.RUnlock()
return m.lastToken[id]
}
// Watch returns a channel that sends new signed identities until it is closed
// due to shutdown. Must be called after Run.
//
// The caller must call the returned func to stop watching and ensure the
// watched id actually exists, otherwise the channel never returns a result.
func (m *WIDMgr) Watch(id structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func()) {
// If Shutdown has been called return a closed chan
if m.stopCtx.Err() != nil {
c := make(chan *structs.SignedWorkloadIdentity)
close(c)
return c, func() {}
}
m.watchersLock.Lock()
defer m.watchersLock.Unlock()
// Buffer of 1 so sends don't block on receives
c := make(chan *structs.SignedWorkloadIdentity, 1)
m.watchers[id] = make([]chan *structs.SignedWorkloadIdentity, 0)
m.watchers[id] = append(m.watchers[id], c)
// Create a cancel func for watchers to deregister when they exit.
cancel := func() {
m.watchersLock.Lock()
defer m.watchersLock.Unlock()
m.watchers[id] = slices.DeleteFunc(
m.watchers[id],
func(ch chan *structs.SignedWorkloadIdentity) bool { return ch == c },
)
}
// Prime chan with latest token to avoid a race condition where consumers
// could miss a token update between Get and Watch calls.
if token := m.get(id); token != nil {
c <- token
}
return c, cancel
}
// Shutdown stops renewal and closes all watch chans.
func (m *WIDMgr) Shutdown() {
m.watchersLock.Lock()
defer m.watchersLock.Unlock()
m.stop()
for _, w := range m.watchers {
for _, c := range w {
close(c)
}
}
// ensure it's safe to call Shutdown multiple times
m.watchers = map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity{}
}
// restoreStoredIdentities recreates the state of the WIDMgr from a previously
// saved state, so that we can avoid asking for all identities again after a
// client agent restart. It returns true if the caller should immediately call
// getIdentities because one or more of the identities is expired.
func (m *WIDMgr) restoreStoredIdentities() (bool, error) {
storedIdentities, err := m.db.GetAllocIdentities(m.allocID)
if err != nil {
return true, err
}
if len(storedIdentities) == 0 {
return true, nil
}
m.lastTokenLock.Lock()
defer m.lastTokenLock.Unlock()
var hasExpired bool
for _, identity := range storedIdentities {
if !identity.Expiration.IsZero() && identity.Expiration.Before(time.Now()) {
hasExpired = true
}
m.lastToken[identity.WIHandle] = identity
}
return hasExpired, nil
}
// SignForTesting signs all the identities in m.widspec, typically with the mock
// signer. This should only be used for testing downstream hooks.
func (m *WIDMgr) SignForTesting() {
m.getInitialIdentities()
}
// getInitialIdentities fetches all signed identities or returns an error. It
// should be run once when the WIDMgr first runs.
func (m *WIDMgr) getInitialIdentities() error {
// get the default identity signed by the plan applier
defaultTokens := map[structs.WIHandle]*structs.SignedWorkloadIdentity{}
for taskName, signature := range m.defaultSignedIdentities {
id := structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: "default",
}
widReq := structs.WorkloadIdentityRequest{
AllocID: m.allocID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: "default",
},
}
defaultTokens[id] = &structs.SignedWorkloadIdentity{
WorkloadIdentityRequest: widReq,
JWT: signature,
Expiration: time.Time{},
}
}
if len(m.widSpecs) == 0 && len(defaultTokens) == 0 {
return nil
}
m.lastTokenLock.Lock()
defer m.lastTokenLock.Unlock()
reqs := make([]*structs.WorkloadIdentityRequest, 0, len(m.widSpecs))
for wiHandle := range m.widSpecs {
reqs = append(reqs, &structs.WorkloadIdentityRequest{
AllocID: m.allocID,
WIHandle: wiHandle,
})
}
// Get signed workload identities
signedWIDs := []*structs.SignedWorkloadIdentity{}
if len(m.widSpecs) != 0 {
var err error
signedWIDs, err = m.signer.SignIdentities(m.minIndex, reqs)
if err != nil {
return err
}
}
// Store default identity tokens
for id, token := range defaultTokens {
m.lastToken[id] = token
}
// Index initial workload identities by name
for _, swid := range signedWIDs {
m.lastToken[swid.WIHandle] = swid
}
return m.db.PutAllocIdentities(m.allocID, signedWIDs)
}
// renew fetches new signed workload identity tokens before the existing tokens
// expire.
func (m *WIDMgr) renew() {
if len(m.widSpecs) == 0 {
return
}
reqs := make([]*structs.WorkloadIdentityRequest, 0, len(m.widSpecs))
for workloadHandle, widspec := range m.widSpecs {
if widspec.TTL == 0 {
continue
}
reqs = append(reqs, &structs.WorkloadIdentityRequest{
AllocID: m.allocID,
WIHandle: workloadHandle,
})
}
if len(reqs) == 0 {
m.logger.Trace("no workload identities expire")
return
}
renewNow := false
minExp := time.Time{}
for workloadHandle, wid := range m.widSpecs {
if wid.TTL == 0 {
// No ttl, so no need to renew it
continue
}
token := m.get(workloadHandle)
if token == nil {
// Missing a signature, treat this case as already expired so
// we get a token ASAP
m.logger.Debug("missing token for identity", "identity", wid.Name)
renewNow = true
continue
}
if minExp.IsZero() || token.Expiration.Before(minExp) {
minExp = token.Expiration
}
}
var wait time.Duration
if !renewNow {
wait = helper.ExpiryToRenewTime(minExp, time.Now, m.minWait)
}
timer, timerStop := helper.NewStoppedTimer()
defer timerStop()
var retry uint64
for {
// we need to handle stopCtx.Err() and manually stop the subscribers
if err := m.stopCtx.Err(); err != nil {
// close watchers and shutdown
m.Shutdown()
return
}
m.logger.Debug("waiting to renew identities", "num", len(reqs), "wait", wait)
timer.Reset(wait)
select {
case <-timer.C:
m.logger.Trace("getting new signed identities", "num", len(reqs))
case <-m.stopCtx.Done():
// close watchers and shutdown
m.Shutdown()
return
}
// Renew all tokens together since its cheap
tokens, err := m.signer.SignIdentities(m.minIndex, reqs)
if err != nil {
retry++
wait = helper.Backoff(m.minWait, time.Hour, retry) + helper.RandomStagger(m.minWait)
m.logger.Error("error renewing workload identities", "error", err, "next", wait)
continue
}
if len(tokens) == 0 {
retry++
wait = helper.Backoff(m.minWait, time.Hour, retry) + helper.RandomStagger(m.minWait)
m.logger.Error("error renewing workload identities", "error", "no tokens", "next", wait)
continue
}
// Reset next expiration time
minExp = time.Time{}
for _, token := range tokens {
// Set for getters
m.lastTokenLock.Lock()
m.lastToken[token.WIHandle] = token
m.lastTokenLock.Unlock()
// Send to watchers
m.watchersLock.Lock()
m.send(token.WIHandle, token)
m.watchersLock.Unlock()
// Set next expiration time
if minExp.IsZero() || token.Expiration.Before(minExp) {
minExp = token.Expiration
}
}
// Success! Set next renewal and reset retries
wait = helper.ExpiryToRenewTime(minExp, time.Now, m.minWait)
retry = 0
}
}
// send must be called while holding the m.watchersLock
func (m *WIDMgr) send(id structs.WIHandle, token *structs.SignedWorkloadIdentity) {
w, ok := m.watchers[id]
if !ok {
// No watchers
return
}
for _, c := range w {
// Pop any unreceived tokens
select {
case <-c:
default:
}
// Send new token, should never block since this is the only sender and
// watchersLock is held
c <- token
}
}