Reuse token if it exists on client reconnect (#26604)

Currently every time a client starts, it creates a new consul token per service or task,. This PR changes the behaviour , it persists consul ACL token to the client state and it starts by looking up a token before creating a new one.

Fixes: #20184
Fixes: #20185
This commit is contained in:
Juana De La Cuesta
2025-09-04 15:27:57 +02:00
committed by GitHub
parent 3ad22ddad5
commit 2944a34b58
11 changed files with 308 additions and 65 deletions

2
.changelog/26604.txt Normal file
View File

@@ -0,0 +1,2 @@
```release-note:bug
consul: Fixed a bug where restarting the Nomad agent would cause Consul ACL tokens to be recreated

View File

@@ -119,6 +119,7 @@ func (ar *allocRunner) initRunnerHooks(config *clientconfig.Config) error {
consulClientConstructor: consul.NewConsulClientFactory(config),
hookResources: ar.hookResources,
logger: hookLogger,
db: ar.stateDB,
}),
newUpstreamAllocsHook(hookLogger, ar.prevAllocWatcher),
newDiskMigrationHook(hookLogger, ar.prevAllocMigrator, ar.allocDir),

View File

@@ -5,6 +5,8 @@ package allocrunner
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
consulapi "github.com/hashicorp/consul/api"
@@ -13,6 +15,7 @@ import (
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/consul"
cstate "github.com/hashicorp/nomad/client/state"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/client/widmgr"
@@ -26,7 +29,7 @@ type consulHook struct {
widmgr widmgr.IdentityManager
consulConfigs map[string]*structsc.ConsulConfig
consulClientConstructor consul.ConsulClientFunc
hookResources *cstructs.AllocHookResources
resourcesBackend *resourcesBackend
logger log.Logger
shutdownCtx context.Context
@@ -37,6 +40,7 @@ type consulHookConfig struct {
alloc *structs.Allocation
allocdir allocdir.Interface
widmgr widmgr.IdentityManager
db cstate.StateDB
// consulConfigs is a map of cluster names to Consul configs
consulConfigs map[string]*structsc.ConsulConfig
@@ -58,7 +62,7 @@ func newConsulHook(cfg consulHookConfig) *consulHook {
widmgr: cfg.widmgr,
consulConfigs: cfg.consulConfigs,
consulClientConstructor: cfg.consulClientConstructor,
hookResources: cfg.hookResources,
resourcesBackend: newResourcesBackend(cfg.alloc.ID, cfg.hookResources, cfg.db),
shutdownCtx: shutdownCtx,
shutdownCancelFn: shutdownCancelFn,
}
@@ -89,7 +93,10 @@ func (h *consulHook) Prerun(allocEnv *taskenv.TaskEnv) error {
}
// tokens are a map of Consul cluster to identity name to Consul ACL token.
tokens := map[string]map[string]*consulapi.ACLToken{}
tokens, err := h.resourcesBackend.loadAllocTokens()
if err != nil {
h.logger.Error("error reading stored ACL tokens", "error", err)
}
tg := job.LookupTaskGroup(h.alloc.TaskGroup)
if tg == nil { // this is always a programming error
@@ -117,7 +124,9 @@ func (h *consulHook) Prerun(allocEnv *taskenv.TaskEnv) error {
}
// write the tokens to hookResources
h.hookResources.SetConsulTokens(tokens)
if err := h.resourcesBackend.setConsulTokens(tokens); err != nil {
h.logger.Error("unable to update tokens in state", "error", err)
}
return nil
}
@@ -143,41 +152,44 @@ func (h *consulHook) prepareConsulTokensForTask(task *structs.Task, tg *structs.
return nil
}
// Find signed workload identity.
ti := *task.IdentityHandle(wid)
jwt, err := h.widmgr.Get(ti)
if err != nil {
return fmt.Errorf("error getting signed identity for task %s: %v", task.Name, err)
}
tokenName := widName + "/" + task.Name
token := tokens[clusterName][tokenName]
// Derive token for task.
req := consul.JWTLoginRequest{
JWT: jwt.JWT,
AuthMethodName: consulConfig.TaskIdentityAuthMethod,
Meta: map[string]string{
"requested_by": fmt.Sprintf("nomad_task_%s", task.Name),
},
}
token, err := h.getConsulToken(consulConfig.Name, req)
if err != nil {
return fmt.Errorf("failed to derive Consul token for task %s: %v", task.Name, err)
// If no token was previously stored, create one.
if token == nil {
// Find signed workload identity.
ti := *task.IdentityHandle(wid)
swi, err := h.widmgr.Get(ti)
if err != nil {
return fmt.Errorf("error getting signed identity for task %s: %v", task.Name, err)
}
h.logger.Debug("logging into consul", "name", ti.IdentityName, "type", ti.WorkloadType)
req := consul.JWTLoginRequest{
JWT: swi.JWT,
AuthMethodName: consulConfig.TaskIdentityAuthMethod,
Meta: map[string]string{
"requested_by": fmt.Sprintf("nomad_task_%s", task.Name),
},
}
token, err = h.getConsulToken(consulConfig.Name, req)
if err != nil {
return fmt.Errorf("failed to derive Consul token for task %s: %v", task.Name, err)
}
}
// Store token in results.
if _, ok = tokens[clusterName]; !ok {
tokens[clusterName] = make(map[string]*consulapi.ACLToken)
}
tokenName := widName + "/" + task.Name
tokens[clusterName][tokenName] = token
return nil
}
func (h *consulHook) prepareConsulTokensForServices(services []*structs.Service, tg *structs.TaskGroup, tokens map[string]map[string]*consulapi.ACLToken, env *taskenv.TaskEnv) error {
if len(services) == 0 {
return nil
}
var mErr *multierror.Error
for _, service := range services {
// Exit early if service doesn't need a Consul token.
@@ -192,38 +204,47 @@ func (h *consulHook) prepareConsulTokensForServices(services []*structs.Service,
}
// Find signed identity workload.
handle := *service.IdentityHandle(env.ReplaceEnv)
jwt, err := h.widmgr.Get(handle)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf(
"error getting signed identity for service %s: %v",
service.Name, err,
))
continue
}
ti := *service.IdentityHandle(env.ReplaceEnv)
tokenName := service.Identity.Name
token := tokens[clusterName][tokenName]
// If no token was previously stored, create one.
if token == nil {
swi, err := h.widmgr.Get(ti)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf(
"error getting signed identity for service %s: %v",
service.Name, err,
))
continue
}
h.logger.Debug("logging into consul", "name", ti.IdentityName, "type", ti.WorkloadType)
req := consul.JWTLoginRequest{
JWT: swi.JWT,
AuthMethodName: consulConfig.ServiceIdentityAuthMethod,
Meta: map[string]string{
"requested_by": fmt.Sprintf("nomad_service_%s", ti.InterpolatedWorkloadIdentifier),
},
}
token, err = h.getConsulToken(clusterName, req)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf(
"failed to derive Consul token for service %s: %v",
service.Name, err,
))
continue
}
// Derive token for service.
req := consul.JWTLoginRequest{
JWT: jwt.JWT,
AuthMethodName: consulConfig.ServiceIdentityAuthMethod,
Meta: map[string]string{
"requested_by": fmt.Sprintf("nomad_service_%s", handle.InterpolatedWorkloadIdentifier),
},
}
token, err := h.getConsulToken(clusterName, req)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf(
"failed to derive Consul token for service %s: %v",
service.Name, err,
))
continue
}
// Store token in results.
if _, ok = tokens[clusterName]; !ok {
tokens[clusterName] = make(map[string]*consulapi.ACLToken)
}
tokens[clusterName][service.Identity.Name] = token
tokens[clusterName][tokenName] = token
}
return mErr.ErrorOrNil()
@@ -254,13 +275,7 @@ func (h *consulHook) clientForCluster(cluster string) (consul.Client, error) {
// Postrun cleans up the Consul tokens after the tasks have exited.
func (h *consulHook) Postrun() error {
tokens := h.hookResources.GetConsulTokens()
err := h.revokeTokens(tokens)
if err != nil {
return err
}
h.hookResources.SetConsulTokens(tokens)
return nil
return h.Destroy()
}
// Shutdown will get called when the client is gracefully stopping.
@@ -271,12 +286,13 @@ func (h *consulHook) Shutdown() {
// Destroy cleans up any remaining Consul tokens if the alloc is GC'd or fails
// to restore after a client restart.
func (h *consulHook) Destroy() error {
tokens := h.hookResources.GetConsulTokens()
tokens := h.resourcesBackend.getConsulTokens()
err := h.revokeTokens(tokens)
if err != nil {
return err
}
h.hookResources.SetConsulTokens(tokens)
h.resourcesBackend.setConsulTokens(tokens)
return nil
}
@@ -307,3 +323,99 @@ func (h *consulHook) revokeTokens(tokens map[string]map[string]*consulapi.ACLTok
return mErr.ErrorOrNil()
}
type resourcesBackend struct {
allocID string
hookResources *cstructs.AllocHookResources
db cstate.StateDB
}
func newResourcesBackend(allocID string, hr *cstructs.AllocHookResources, db cstate.StateDB) *resourcesBackend {
return &resourcesBackend{
allocID: allocID,
hookResources: hr,
db: db,
}
}
func decodeACLToken(b64ACLToken string, token *consulapi.ACLToken) error {
decodedBytes, err := base64.StdEncoding.DecodeString(b64ACLToken)
if err != nil {
return fmt.Errorf("unable to process ACLToken: %w", err)
}
if len(decodedBytes) != 0 {
if err := json.Unmarshal(decodedBytes, token); err != nil {
return fmt.Errorf("unable to unmarshal ACLToken: %w", err)
}
}
return nil
}
func encodeACLToken(token *consulapi.ACLToken) (string, error) {
jsonBytes, err := json.Marshal(token)
if err != nil {
return "", fmt.Errorf("unable to marshal ACL token: %w", err)
}
return base64.StdEncoding.EncodeToString(jsonBytes), nil
}
// This function will never return nil, even in case of error
func (rs *resourcesBackend) loadAllocTokens() (map[string]map[string]*consulapi.ACLToken, error) {
allocTokens := map[string]map[string]*consulapi.ACLToken{}
ts, err := rs.db.GetAllocConsulACLTokens(rs.allocID)
if err != nil {
return allocTokens, err
}
var mErr *multierror.Error
for _, st := range ts {
token := &consulapi.ACLToken{}
err := decodeACLToken(st.ACLToken, token)
if err != nil {
mErr = multierror.Append(mErr, err)
continue
}
if allocTokens[st.Cluster] == nil {
allocTokens[st.Cluster] = map[string]*consulapi.ACLToken{}
}
allocTokens[st.Cluster][st.TokenID] = token
}
return allocTokens, mErr.ErrorOrNil()
}
func (rs *resourcesBackend) setConsulTokens(m map[string]map[string]*consulapi.ACLToken) error {
rs.hookResources.SetConsulTokens(m)
var mErr *multierror.Error
ts := []*cstructs.ConsulACLToken{}
for cCluster, tokens := range m {
for tokenID, aclToken := range tokens {
stringToken, err := encodeACLToken(aclToken)
if err != nil {
mErr = multierror.Append(mErr, err)
continue
}
ts = append(ts, &cstructs.ConsulACLToken{
Cluster: cCluster,
TokenID: tokenID,
ACLToken: stringToken,
})
}
}
return rs.db.PutAllocConsulACLTokens(rs.allocID, ts)
}
func (rs *resourcesBackend) getConsulTokens() map[string]map[string]*consulapi.ACLToken {
return rs.hookResources.GetConsulTokens()
}

View File

@@ -76,6 +76,7 @@ func consulHookTestHarness(t *testing.T) *consulHook {
consulConfigs: consulConfigs,
consulClientConstructor: consul.NewMockConsulClient,
hookResources: hookResources,
db: db,
logger: logger,
}
return newConsulHook(consulHookCfg)
@@ -263,7 +264,7 @@ func Test_consulHook_Postrun(t *testing.T) {
task := hook.alloc.LookupTask("web")
tokens := map[string]map[string]*consulapi.ACLToken{}
must.NoError(t, hook.prepareConsulTokensForTask(task, nil, tokens))
hook.hookResources.SetConsulTokens(tokens)
hook.resourcesBackend.setConsulTokens(tokens)
must.MapLen(t, 1, tokens)
// gracefully handle wrong tokens
@@ -273,6 +274,6 @@ func Test_consulHook_Postrun(t *testing.T) {
// hook resources should be cleared
must.NoError(t, hook.Postrun())
tokens = hook.hookResources.GetConsulTokens()
tokens = hook.resourcesBackend.getConsulTokens()
must.MapEmpty(t, tokens["default"])
}

View File

@@ -36,7 +36,8 @@ allocations/
|--> network_status -> networkStatusEntry{*structs.AllocNetworkStatus}
|--> acknowledged_state -> acknowledgedStateEntry{*arstate.State}
|--> alloc_volumes -> allocVolumeStatesEntry{arstate.AllocVolumes}
|--> identities -> allocIdentitiesEntry{}
|--> alloc_identities -> allocIdentitiesEntry{}
|--> alloc_consul_acl_token_identities -> consulACLTokensEntry{}
|--> task-<name>/
|--> local_state -> *trstate.LocalState # Local-only state
|--> task_state -> *structs.TaskState # Syncs to servers
@@ -100,6 +101,10 @@ var (
// under
allocIdentityKey = []byte("alloc_identities")
// allocConsulACLTokeKey is the key []*structs.ConsulACLTokens is stored
// under
allocConsulACLTokenKey = []byte("alloc_consul_acl_token_identities")
// checkResultsBucket is the bucket name in which check query results are stored
checkResultsBucket = []byte("check_results")
@@ -576,6 +581,55 @@ func (s *BoltStateDB) GetAllocIdentities(allocID string) ([]*structs.SignedWorkl
return entry.Identities, nil
}
// allocConsulACLTokenEntry wraps the ACLtokens so we can safely add more
// state in the future without needing a new entry type
type allocConsulACLTokenEntry struct {
Tokens []*cstructs.ConsulACLToken
}
// PutAllocConsulACLTokens strores all Consul ACL tokens for an alloc.
func (s *BoltStateDB) PutAllocConsulACLTokens(allocID string, tokens []*cstructs.ConsulACLToken, opts ...WriteOption) error {
return s.updateWithOptions(opts, func(tx *boltdd.Tx) error {
allocBkt, err := getAllocationBucket(tx, allocID)
if err != nil {
return err
}
entry := allocConsulACLTokenEntry{
Tokens: tokens,
}
return allocBkt.Put(allocConsulACLTokenKey, &entry)
})
}
// GetAllocConsulACLTokens returns all Consul ACL tokens for an alloc.
func (s *BoltStateDB) GetAllocConsulACLTokens(allocID string) ([]*cstructs.ConsulACLToken, error) {
var entry allocConsulACLTokenEntry
err := s.db.View(func(tx *boltdd.Tx) error {
allAllocsBkt := tx.Bucket(allocationsBucketName)
if allAllocsBkt == nil {
return nil // No previous state at all
}
allocBkt := allAllocsBkt.Bucket([]byte(allocID))
if allocBkt == nil {
return nil // No previous state for this alloc
}
return allocBkt.Get(allocConsulACLTokenKey, &entry)
})
if boltdd.IsErrNotFound(err) {
return nil, nil // There may not be any previously created tokens
}
if err != nil {
return nil, err
}
return entry.Tokens, nil
}
// GetTaskRunnerState returns the LocalState and TaskState for a
// TaskRunner. LocalState or TaskState will be nil if they do not exist.
//

View File

@@ -173,6 +173,14 @@ func (m *ErrDB) Close() error {
return fmt.Errorf("Error!")
}
func (m *ErrDB) PutAllocConsulACLTokens(allocID string, tokens []*cstructs.ConsulACLToken, opts ...WriteOption) error {
return fmt.Errorf("Error!")
}
func (m *ErrDB) GetAllocConsulACLTokens(allocID string) ([]*cstructs.ConsulACLToken, error) {
return nil, fmt.Errorf("Error!")
}
func (m *ErrDB) PutNodeIdentity(_ string) error { return ErrDBError }
func (m *ErrDB) GetNodeIdentity() (string, error) { return "", ErrDBError }

View File

@@ -47,6 +47,9 @@ type MemDB struct {
// alloc_id -> []identities
identities map[string][]*structs.SignedWorkloadIdentity
// alloc_id -> []consulAclTokens
consulACLTokens map[string][]*cstructs.ConsulACLToken
// devicemanager -> plugin-state
devManagerPs *dmstate.PluginState
@@ -82,6 +85,7 @@ func NewMemDB(logger hclog.Logger) *MemDB {
taskState: make(map[string]map[string]*structs.TaskState),
checks: make(checks.ClientResults),
identities: make(map[string][]*structs.SignedWorkloadIdentity),
consulACLTokens: make(map[string][]*cstructs.ConsulACLToken),
dynamicHostVolumes: make(map[string]*cstructs.HostVolumeState),
clientIdentity: atomic.Value{},
logger: logger,
@@ -180,6 +184,20 @@ func (m *MemDB) GetAllocIdentities(allocID string) ([]*structs.SignedWorkloadIde
return m.identities[allocID], nil
}
func (m *MemDB) PutAllocConsulACLTokens(allocID string, tokens []*cstructs.ConsulACLToken, opts ...WriteOption) error {
m.mu.Lock()
defer m.mu.Unlock()
m.consulACLTokens[allocID] = tokens
return nil
}
func (m *MemDB) GetAllocConsulACLTokens(allocID string) ([]*cstructs.ConsulACLToken, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.consulACLTokens[allocID], nil
}
func (m *MemDB) GetTaskRunnerState(allocID string, taskName string) (*state.LocalState, *structs.TaskState, error) {
m.mu.RLock()
defer m.mu.RUnlock()

View File

@@ -157,9 +157,21 @@ func (n NoopDB) DeleteDynamicHostVolume(_ string) error {
return nil
}
func (n NoopDB) PutNodeIdentity(_ string) error { return nil }
func (n NoopDB) PutNodeIdentity(_ string) error {
return nil
}
func (n NoopDB) GetNodeIdentity() (string, error) { return "", nil }
func (n NoopDB) GetNodeIdentity() (string, error) {
return "", nil
}
func (n NoopDB) PutAllocConsulACLTokens(allocID string, tokens []*cstructs.ConsulACLToken, opts ...WriteOption) error {
return nil
}
func (n NoopDB) GetAllocConsulACLTokens(allocID string) ([]*cstructs.ConsulACLToken, error) {
return nil, nil
}
func (n NoopDB) Close() error {
return nil

View File

@@ -511,6 +511,32 @@ func TestStateDB_NodeIdentity(t *testing.T) {
})
}
func TestStateDB_ConsulACLToken(t *testing.T) {
ci.Parallel(t)
testDB(t, func(t *testing.T, db StateDB) {
alloc1 := mock.Alloc()
must.NoError(t, db.PutAllocation(alloc1))
tokens, err := db.GetAllocConsulACLTokens(alloc1.ID)
must.NoError(t, err)
must.Eq(t, nil, tokens)
fakeToken := &cstructs.ConsulACLToken{
Cluster: "fake cluster",
TokenID: "workloadID",
ACLToken: "token",
}
must.NoError(t, db.PutAllocConsulACLTokens(alloc1.ID, []*cstructs.ConsulACLToken{fakeToken}))
tokens, err = db.GetAllocConsulACLTokens(alloc1.ID)
must.NoError(t, err)
must.One(t, len(tokens))
must.Eq(t, fakeToken, tokens[0])
})
}
// TestStateDB_Upgrade asserts calling Upgrade on new databases always
// succeeds.
func TestStateDB_Upgrade(t *testing.T) {

View File

@@ -152,6 +152,9 @@ type StateDB interface {
// Close the database. Unsafe for further use after calling regardless
// of return value.
Close() error
PutAllocConsulACLTokens(allocID string, tokens []*cstructs.ConsulACLToken, opts ...WriteOption) error
GetAllocConsulACLTokens(allocID string) ([]*cstructs.ConsulACLToken, error)
}
// WriteOptions adjusts the way the data is persisted by the StateDB above. Default is

View File

@@ -409,3 +409,9 @@ var DriverStatsNotImplemented = errors.New("stats not implemented for driver")
type NodeRegistration struct {
HasRegistered bool
}
type ConsulACLToken struct {
Cluster string
TokenID string
ACLToken string
}