From 26c3f19129e7a08cff755e1a239542791b3b3fcb Mon Sep 17 00:00:00 2001 From: James Rasell Date: Wed, 18 Jun 2025 07:43:27 +0100 Subject: [PATCH 1/7] identity: Base server objects and mild rework of identity implementation to support node identities. (#26052) When Nomad generates an identity for a node, the root key used to sign the JWT will be stored as a field on the node object and written to state. To provide fast lookup of nodes by their signing key, the node table schema has been modified to include the keyID as an index. In order to ensure a root key is not deleted while identities are still actively signed by it, the Nomad state has an in-use check. This check has been extended to cover node identities. Nomad node identities will have an expiration. The expiration will be defined by a TTL configured within the node pool specification as a time duration. When not supplied by the operator, a default value of 24hr is applied. On cluster upgrades, a Nomad server will restore from snapshot and/or replay logs. The FSM has therefore been modified to ensure restored node pool objects include the default value. The builtin "all" and "default" pools have also been updated to include this default value. Nomad node identities will be a new identity concept in Nomad and will exist alongside workload identities. This change introduces a new envelope identity claim which contains generic public claims as well as either a node or workload identity claims. This allows us to use a single encryption and decryption path, no matter what the underlying identity. Where possible node and workload identities will use common functions for identity claim generation. The new node identity has the following claims: * "nomad_node_id" - the node ID which is typically generated on the first boot of the Nomad client as a UUID within the "ensureNodeID" function. * "nomad_node_pool" - the node pool is a client configuration parameter which provides logical grouping of Nomad clients. * "nomad_node_class" - the node class is a client configuration parameter which provides scheduling constraints for Nomad clients. * "nomad_node_datacenter" - the node datacenter is a client configuration parameter which provides scheduling constraints for Nomad clients and a logical grouping method. --- client/widmgr/mock.go | 12 +- nomad/auth/auth_test.go | 20 +- nomad/encrypter.go | 22 +- nomad/encrypter_test.go | 18 +- nomad/fsm.go | 16 ++ nomad/fsm_test.go | 72 +++++- nomad/mock/mock.go | 7 +- nomad/node_pool_endpoint.go | 3 + nomad/node_pool_endpoint_test.go | 29 +-- nomad/state/schema.go | 11 +- nomad/state/state_store_keyring.go | 9 + nomad/state/state_store_keyring_test.go | 116 +++++++++ nomad/state/state_store_node_pools.go | 10 +- nomad/state/state_store_node_pools_test.go | 18 +- nomad/structs/identity.go | 100 ++++++++ nomad/structs/identity_test.go | 232 ++++++++++++++++++ nomad/structs/node.go | 45 ++++ nomad/structs/node_pool.go | 27 ++- nomad/structs/node_pool_test.go | 34 +++ nomad/structs/node_test.go | 25 ++ nomad/structs/structs.go | 9 + nomad/structs/structs_test.go | 6 +- nomad/structs/workload_id.go | 84 +++---- nomad/structs/workload_id_test.go | 262 ++++++++++++--------- 24 files changed, 953 insertions(+), 234 deletions(-) create mode 100644 nomad/structs/identity.go create mode 100644 nomad/structs/identity_test.go diff --git a/client/widmgr/mock.go b/client/widmgr/mock.go index 84bdab454..37a947928 100644 --- a/client/widmgr/mock.go +++ b/client/widmgr/mock.go @@ -73,15 +73,17 @@ func (m *MockWIDSigner) JSONWebKeySet() *jose.JSONWebKeySet { } } -func (m *MockWIDSigner) SignIdentities(minIndex uint64, req []*structs.WorkloadIdentityRequest) ([]*structs.SignedWorkloadIdentity, error) { +func (m *MockWIDSigner) SignIdentities(_ uint64, req []*structs.WorkloadIdentityRequest) ([]*structs.SignedWorkloadIdentity, error) { swids := make([]*structs.SignedWorkloadIdentity, 0, len(req)) for _, idReq := range req { // Set test values for default claims claims := &structs.IdentityClaims{ - Namespace: "default", - JobID: "test", - AllocationID: idReq.AllocID, - TaskName: idReq.WorkloadIdentifier, + WorkloadIdentityClaims: &structs.WorkloadIdentityClaims{ + Namespace: "default", + JobID: "test", + AllocationID: idReq.AllocID, + TaskName: idReq.WorkloadIdentifier, + }, } claims.ID = uuid.Generate() // If test has set workload identities. Lookup claims or reject unknown diff --git a/nomad/auth/auth_test.go b/nomad/auth/auth_test.go index 5de234474..e56f50e62 100644 --- a/nomad/auth/auth_test.go +++ b/nomad/auth/auth_test.go @@ -1038,17 +1038,21 @@ func TestResolveClaims(t *testing.T) { dispatchAlloc.Job.ParentID = alloc.JobID claims := &structs.IdentityClaims{ - Namespace: alloc.Namespace, - JobID: alloc.Job.ID, - AllocationID: alloc.ID, - TaskName: alloc.Job.TaskGroups[0].Tasks[0].Name, + WorkloadIdentityClaims: &structs.WorkloadIdentityClaims{ + Namespace: alloc.Namespace, + JobID: alloc.Job.ID, + AllocationID: alloc.ID, + TaskName: alloc.Job.TaskGroups[0].Tasks[0].Name, + }, } dispatchClaims := &structs.IdentityClaims{ - Namespace: dispatchAlloc.Namespace, - JobID: dispatchAlloc.Job.ID, - AllocationID: dispatchAlloc.ID, - TaskName: dispatchAlloc.Job.TaskGroups[0].Tasks[0].Name, + WorkloadIdentityClaims: &structs.WorkloadIdentityClaims{ + Namespace: dispatchAlloc.Namespace, + JobID: dispatchAlloc.Job.ID, + AllocationID: dispatchAlloc.ID, + TaskName: dispatchAlloc.Job.TaskGroups[0].Tasks[0].Name, + }, } // unrelated policy diff --git a/nomad/encrypter.go b/nomad/encrypter.go index 4bf5b5a66..ab580031b 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -45,8 +45,6 @@ type claimSigner interface { SignClaims(*structs.IdentityClaims) (string, string, error) } -var _ claimSigner = &Encrypter{} - // Encrypter is the keyring for encrypting variables and signing workload // identities. type Encrypter struct { @@ -351,8 +349,8 @@ func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, return raw, cs.rootKey.Meta.KeyID, nil } -// VerifyClaim accepts a previously-signed encoded claim and validates -// it before returning the claim +// VerifyClaim accepts a previously signed encoded claim and validates +// it before returning the claim. func (e *Encrypter) VerifyClaim(tokenString string) (*structs.IdentityClaims, error) { token, err := jwt.ParseSigned(tokenString) @@ -377,21 +375,21 @@ func (e *Encrypter) VerifyClaim(tokenString string) (*structs.IdentityClaims, er return nil, err } + claims := structs.IdentityClaims{} + // Validate the claims. - claims := &structs.IdentityClaims{} - if err := token.Claims(typedPubKey, claims); err != nil { + if err := token.Claims(typedPubKey, &claims); err != nil { return nil, fmt.Errorf("invalid signature: %w", err) } - //COMPAT Until we can guarantee there are no pre-1.7 JWTs in use we can only - // validate the signature and have no further expectations of the - // claims. - expect := jwt.Expected{} - if err := claims.Validate(expect); err != nil { + // COMPAT: Until we can guarantee there are no pre-1.7 JWTs in use, we can + // only validate the signature and have no further expectations of the + // claims. + if err := claims.Validate(jwt.Expected{}); err != nil { return nil, fmt.Errorf("invalid claims: %w", err) } - return claims, nil + return &claims, nil } // AddUnwrappedKey stores the key in the keystore and creates a new cipher for diff --git a/nomad/encrypter_test.go b/nomad/encrypter_test.go index 8fbd3e876..bcdfbd0b9 100644 --- a/nomad/encrypter_test.go +++ b/nomad/encrypter_test.go @@ -25,6 +25,7 @@ import ( "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/auth" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" @@ -42,6 +43,13 @@ var ( } ) +// Assert that the Encrypter implements the claimSigner and auth.Encrypter +// interfaces. +var ( + _ claimSigner = &Encrypter{} + _ auth.Encrypter = &Encrypter{} +) + type mockSigner struct { calls []*structs.IdentityClaims @@ -697,10 +705,12 @@ func TestEncrypter_Upgrade17(t *testing.T) { // Create a 1.6 style workload identity claims := &structs.IdentityClaims{ - Namespace: "default", - JobID: "fakejob", - AllocationID: uuid.Generate(), - TaskName: "faketask", + WorkloadIdentityClaims: &structs.WorkloadIdentityClaims{ + Namespace: "default", + JobID: "fakejob", + AllocationID: uuid.Generate(), + TaskName: "faketask", + }, } // Sign the claims and assert they were signed with EdDSA (the 1.6 signing diff --git a/nomad/fsm.go b/nomad/fsm.go index 717b16cfd..758b8321f 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -586,6 +586,15 @@ func (n *nomadFSM) applyNodePoolUpsert(msgType structs.MessageType, buf []byte, panic(fmt.Errorf("failed to decode request: %v", err)) } + // Nomad 1.11 added the NodeIdentityTTL field to NodePool. When the + // cluster is upgraded, we need to ensure that the field is set with + // its default value. The hash also needs to be recalculated since it would + // have changed. + for _, pool := range req.NodePools { + pool.Canonicalize() + _ = pool.SetHash() + } + if err := n.state.UpsertNodePools(msgType, index, req.NodePools); err != nil { n.logger.Error("UpsertNodePool failed", "error", err) return err @@ -1900,6 +1909,13 @@ func (n *nomadFSM) restoreImpl(old io.ReadCloser, filter *FSMFilter) error { return err } + // Nomad 1.11 added the NodeIdentityTTL field to NodePool. When the + // cluster is upgraded, we need to ensure that the field is set with + // its default value. The hash also needs to be recalculated since + // it would have changed. + pool.Canonicalize() + _ = pool.SetHash() + // Perform the restoration. if err := restore.NodePoolRestore(pool); err != nil { return err diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index b75808596..e76e33730 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -245,7 +245,7 @@ func TestFSM_UpsertNode_NodePool(t *testing.T) { validateFn func(*testing.T, *structs.Node, *structs.NodePool) }{ { - name: "node with empty node pool is placed in defualt", + name: "node with empty node pool is placed in default", setupReqFn: func(req *structs.NodeRegisterRequest) { req.Node.NodePool = "" }, @@ -724,7 +724,31 @@ func TestFSM_NodePoolUpsert(t *testing.T) { structs.NodePool{}, "CreateIndex", "ModifyIndex", + "NodeIdentityTTL", + "Hash", ))) + + // Nomad 1.11 introduced the NodeIdentityTTL field for node pools. To test + // the upgrade path, we upsert a node pool without the TTL set which mimics + // a server applying a pre-1.11 object. + preTTLNodePool := mock.NodePool() + preTTLNodePool.NodeIdentityTTL = 0 + preTTLNodePool.SetHash() + + req = structs.NodePoolUpsertRequest{ + NodePools: []*structs.NodePool{preTTLNodePool}, + } + buf, err = structs.Encode(structs.NodePoolUpsertRequestType, req) + must.NoError(t, err) + must.Nil(t, fsm.Apply(makeLog(buf))) + + // Verify the apply function set the NodeIdentityTTL to the default value + // and recalculated the hash. + ws = memdb.NewWatchSet() + preTTLNodePoolResp, err := fsm.State().NodePoolByName(ws, preTTLNodePool.Name) + must.NoError(t, err) + must.NonZero(t, preTTLNodePoolResp.NodeIdentityTTL) + must.NotEq(t, preTTLNodePool.Hash, preTTLNodePoolResp.Hash) } func TestFSM_RegisterJob(t *testing.T) { @@ -2273,18 +2297,52 @@ func TestFSM_SnapshotRestore_NodePools(t *testing.T) { ci.Parallel(t) // Add some state - fsm := testFSM(t) - state := fsm.State() + testFSM := testFSM(t) + testState := testFSM.State() pool := mock.NodePool() - state.UpsertNodePools(structs.MsgTypeTestSetup, 1000, []*structs.NodePool{pool}) + must.NoError(t, + testState.UpsertNodePools( + structs.MsgTypeTestSetup, + 1000, []*structs.NodePool{pool}, + )) // Verify the contents - fsm2 := testSnapshotRestore(t, fsm) - state2 := fsm2.State() - out, _ := state2.NodePoolByName(nil, pool.Name) + testFSM2 := testSnapshotRestore(t, testFSM) + testState2 := testFSM2.State() + out, err := testState2.NodePoolByName(nil, pool.Name) + must.NoError(t, err) must.Eq(t, pool, out) } +func TestFSM_SnapshotRestore_NodePoolsPreTTL(t *testing.T) { + ci.Parallel(t) + + testFSM := testFSM(t) + testState := testFSM.State() + + // Nomad 1.11 introduced the NodeIdentityTTL field for node pools. To test + // the upgrade path, we upsert a node pool without the TTL set which mimics + // a server restoring a pre-1.11 snapshot. + pool := mock.NodePool() + pool.NodeIdentityTTL = 0 + pool.SetHash() + + must.NoError(t, + testState.UpsertNodePools( + structs.MsgTypeTestSetup, + 1000, []*structs.NodePool{pool}, + )) + + // Verify the apply function set the NodeIdentityTTL to the default value + // and recalculated the hash. + testFSM2 := testSnapshotRestore(t, testFSM) + testState2 := testFSM2.State() + out, err := testState2.NodePoolByName(nil, pool.Name) + must.NoError(t, err) + must.NonZero(t, out.NodeIdentityTTL) + must.NotEq(t, pool.Hash, out.Hash) +} + func TestFSM_SnapshotRestore_Jobs(t *testing.T) { ci.Parallel(t) // Add some state diff --git a/nomad/mock/mock.go b/nomad/mock/mock.go index 4fe1854ce..cba6cb07b 100644 --- a/nomad/mock/mock.go +++ b/nomad/mock/mock.go @@ -246,9 +246,10 @@ func Namespace() *structs.Namespace { func NodePool() *structs.NodePool { pool := &structs.NodePool{ - Name: fmt.Sprintf("pool-%s", uuid.Short()), - Description: "test node pool", - Meta: map[string]string{"team": "test"}, + Name: fmt.Sprintf("pool-%s", uuid.Short()), + Description: "test node pool", + NodeIdentityTTL: 24 * time.Hour, + Meta: map[string]string{"team": "test"}, } pool.SetHash() return pool diff --git a/nomad/node_pool_endpoint.go b/nomad/node_pool_endpoint.go index d84edd09e..f647308c7 100644 --- a/nomad/node_pool_endpoint.go +++ b/nomad/node_pool_endpoint.go @@ -200,6 +200,9 @@ func (n *NodePool) UpsertNodePools(args *structs.NodePoolUpsertRequest, reply *s return structs.NewErrRPCCodedf(http.StatusBadRequest, "must specify at least one node pool") } for _, pool := range args.NodePools { + + pool.Canonicalize() + if err := pool.Validate(); err != nil { return structs.NewErrRPCCodedf(http.StatusBadRequest, "invalid node pool %q: %v", pool.Name, err) } diff --git a/nomad/node_pool_endpoint_test.go b/nomad/node_pool_endpoint_test.go index 5df1bf218..b68266005 100644 --- a/nomad/node_pool_endpoint_test.go +++ b/nomad/node_pool_endpoint_test.go @@ -654,8 +654,9 @@ func TestNodePoolEndpoint_UpsertNodePools(t *testing.T) { name: "update pool", pools: []*structs.NodePool{ { - Name: existing.Name, - Description: "updated pool", + Name: existing.Name, + Description: "updated pool", + NodeIdentityTTL: 24 * time.Hour, Meta: map[string]string{ "updated": "true", }, @@ -774,38 +775,38 @@ func TestNodePoolEndpoint_UpsertNodePool_ACL(t *testing.T) { name: "management token has full access", token: root.SecretID, pools: []*structs.NodePool{ - {Name: "dev-1"}, - {Name: "prod-1"}, - {Name: "qa-1"}, + {Name: "dev-1", NodeIdentityTTL: 24 * time.Minute}, + {Name: "prod-1", NodeIdentityTTL: 24 * time.Minute}, + {Name: "qa-1", NodeIdentityTTL: 24 * time.Minute}, }, }, { name: "allowed by policy", token: devToken.SecretID, pools: []*structs.NodePool{ - {Name: "dev-1"}, + {Name: "dev-1", NodeIdentityTTL: 24 * time.Minute}, }, }, { name: "allowed by capability", token: prodToken.SecretID, pools: []*structs.NodePool{ - {Name: "prod-1"}, + {Name: "prod-1", NodeIdentityTTL: 24 * time.Minute}, }, }, { name: "allowed by exact match", token: devSpecificToken.SecretID, pools: []*structs.NodePool{ - {Name: "dev-1"}, + {Name: "dev-1", NodeIdentityTTL: 24 * time.Minute}, }, }, { name: "token restricted to wildcard", token: devToken.SecretID, pools: []*structs.NodePool{ - {Name: "dev-1"}, // ok - {Name: "prod-1"}, // not ok + {Name: "dev-1", NodeIdentityTTL: 24 * time.Minute}, // ok + {Name: "prod-1", NodeIdentityTTL: 24 * time.Minute}, // not ok }, expectedErr: structs.ErrPermissionDenied.Error(), }, @@ -813,7 +814,7 @@ func TestNodePoolEndpoint_UpsertNodePool_ACL(t *testing.T) { name: "token restricted if not exact match", token: devSpecificToken.SecretID, pools: []*structs.NodePool{ - {Name: "dev-2"}, + {Name: "dev-2", NodeIdentityTTL: 24 * time.Minute}, }, expectedErr: structs.ErrPermissionDenied.Error(), }, @@ -821,7 +822,7 @@ func TestNodePoolEndpoint_UpsertNodePool_ACL(t *testing.T) { name: "no token", token: "", pools: []*structs.NodePool{ - {Name: "dev-2"}, + {Name: "dev-2", NodeIdentityTTL: 24 * time.Minute}, }, expectedErr: structs.ErrPermissionDenied.Error(), }, @@ -829,7 +830,7 @@ func TestNodePoolEndpoint_UpsertNodePool_ACL(t *testing.T) { name: "no policy", token: noPolicyToken.SecretID, pools: []*structs.NodePool{ - {Name: "dev-1"}, + {Name: "dev-1", NodeIdentityTTL: 24 * time.Minute}, }, expectedErr: structs.ErrPermissionDenied.Error(), }, @@ -837,7 +838,7 @@ func TestNodePoolEndpoint_UpsertNodePool_ACL(t *testing.T) { name: "no write", token: readOnlyToken.SecretID, pools: []*structs.NodePool{ - {Name: "dev-1"}, + {Name: "dev-1", NodeIdentityTTL: 24 * time.Minute}, }, expectedErr: structs.ErrPermissionDenied.Error(), }, diff --git a/nomad/state/schema.go b/nomad/state/schema.go index 021e7660e..7ef56b51b 100644 --- a/nomad/state/schema.go +++ b/nomad/state/schema.go @@ -16,6 +16,7 @@ const ( tableIndex = "index" TableNamespaces = "namespaces" + TableNodes = "nodes" TableNodePools = "node_pools" TableServiceRegistrations = "service_registrations" TableVariables = "variables" @@ -147,7 +148,7 @@ func indexTableSchema() *memdb.TableSchema { // This table is used to store all the client nodes that are registered. func nodeTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ - Name: "nodes", + Name: TableNodes, Indexes: map[string]*memdb.IndexSchema{ // Primary index is used for node management // and simple direct lookup. ID is required to be @@ -176,6 +177,14 @@ func nodeTableSchema() *memdb.TableSchema { Field: "NodePool", }, }, + indexSigningKey: { + Name: indexSigningKey, + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "IdentitySigningKeyID", + }, + }, }, } } diff --git a/nomad/state/state_store_keyring.go b/nomad/state/state_store_keyring.go index aab27dc93..d52f8b464 100644 --- a/nomad/state/state_store_keyring.go +++ b/nomad/state/state_store_keyring.go @@ -190,5 +190,14 @@ func (s *StateStore) IsRootKeyInUse(keyID string) (bool, error) { return true, nil } + iter, err = txn.Get(TableNodes, indexSigningKey, keyID) + if err != nil { + return false, err + } + node := iter.Next() + if node != nil { + return true, nil + } + return false, nil } diff --git a/nomad/state/state_store_keyring_test.go b/nomad/state/state_store_keyring_test.go index 7a7244a1b..f7d60918d 100644 --- a/nomad/state/state_store_keyring_test.go +++ b/nomad/state/state_store_keyring_test.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/shoenig/test/must" ) @@ -84,3 +85,118 @@ func TestStateStore_WrappedRootKey_CRUD(t *testing.T) { // deleting non-existent keys is safe must.NoError(t, store.DeleteRootKey(index, uuid.Generate())) } + +func TestStateStore_IsRootKeyInUse(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + fn func(*StateStore) + }{ + { + name: "in use by alloc", + fn: func(store *StateStore) { + + keyID := uuid.Generate() + + mockAlloc := mock.Alloc() + mockAlloc.SigningKeyID = keyID + + must.NoError(t, store.UpsertAllocs( + structs.MsgTypeTestSetup, + 100, + []*structs.Allocation{mockAlloc}, + )) + + isInUse, err := store.IsRootKeyInUse(keyID) + must.NoError(t, err) + must.True(t, isInUse) + }, + }, + { + name: "in use by variable", + fn: func(store *StateStore) { + + keyID := uuid.Generate() + + mockVariable := mock.VariableEncrypted() + mockVariable.KeyID = keyID + + stateResp := store.VarSet(110, + &structs.VarApplyStateRequest{Var: mockVariable, Op: structs.VarOpSet}, + ) + + must.NoError(t, stateResp.Error) + must.Eq(t, structs.VarOpResultOk, stateResp.Result) + + isInUse, err := store.IsRootKeyInUse(keyID) + must.NoError(t, err) + must.True(t, isInUse) + }, + }, + { + name: "in use by node", + fn: func(store *StateStore) { + keyID := uuid.Generate() + + mockNode := mock.Node() + mockNode.IdentitySigningKeyID = keyID + + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 120, mockNode)) + + isInUse, err := store.IsRootKeyInUse(keyID) + must.NoError(t, err) + must.True(t, isInUse) + }, + }, + { + name: "not in use", + fn: func(store *StateStore) { + + // Generate a random key ID to use to sign all the state + // objects. + keyID := uuid.Generate() + + // Create a node, variable, and alloc that all use the same key + // and write them to the store. + mockNode := mock.Node() + mockNode.IdentitySigningKeyID = keyID + + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 130, mockNode)) + + mockVariable := mock.VariableEncrypted() + mockVariable.KeyID = keyID + + stateResp := store.VarSet(140, + &structs.VarApplyStateRequest{Var: mockVariable, Op: structs.VarOpSet}, + ) + + must.NoError(t, stateResp.Error) + must.Eq(t, structs.VarOpResultOk, stateResp.Result) + + mockAlloc := mock.Alloc() + mockAlloc.SigningKeyID = keyID + + must.NoError(t, store.UpsertAllocs( + structs.MsgTypeTestSetup, + 150, + []*structs.Allocation{mockAlloc}, + )) + + // Perform a check using a different key ID to ensure we get the + // expected result. + isInUse, err := store.IsRootKeyInUse(uuid.Generate()) + must.NoError(t, err) + must.False(t, isInUse) + }, + }, + } + + testStore := testStateStore(t) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.fn(testStore) + }) + } +} diff --git a/nomad/state/state_store_node_pools.go b/nomad/state/state_store_node_pools.go index f542b0014..58e33f0be 100644 --- a/nomad/state/state_store_node_pools.go +++ b/nomad/state/state_store_node_pools.go @@ -14,13 +14,15 @@ import ( // in the cluster. func (s *StateStore) nodePoolInit() error { allNodePool := &structs.NodePool{ - Name: structs.NodePoolAll, - Description: structs.NodePoolAllDescription, + Name: structs.NodePoolAll, + Description: structs.NodePoolAllDescription, + NodeIdentityTTL: structs.DefaultNodePoolNodeIdentityTTL, } defaultNodePool := &structs.NodePool{ - Name: structs.NodePoolDefault, - Description: structs.NodePoolDefaultDescription, + Name: structs.NodePoolDefault, + Description: structs.NodePoolDefaultDescription, + NodeIdentityTTL: structs.DefaultNodePoolNodeIdentityTTL, } return s.UpsertNodePools( diff --git a/nomad/state/state_store_node_pools_test.go b/nomad/state/state_store_node_pools_test.go index 4649f99b6..f1abea5db 100644 --- a/nomad/state/state_store_node_pools_test.go +++ b/nomad/state/state_store_node_pools_test.go @@ -132,20 +132,22 @@ func TestStateStore_NodePool_ByName(t *testing.T) { name: "find built-in pool all", pool: structs.NodePoolAll, expected: &structs.NodePool{ - Name: structs.NodePoolAll, - Description: structs.NodePoolAllDescription, - CreateIndex: 1, - ModifyIndex: 1, + Name: structs.NodePoolAll, + Description: structs.NodePoolAllDescription, + NodeIdentityTTL: structs.DefaultNodePoolNodeIdentityTTL, + CreateIndex: 1, + ModifyIndex: 1, }, }, { name: "find built-in pool default", pool: structs.NodePoolDefault, expected: &structs.NodePool{ - Name: structs.NodePoolDefault, - Description: structs.NodePoolDefaultDescription, - CreateIndex: 1, - ModifyIndex: 1, + Name: structs.NodePoolDefault, + Description: structs.NodePoolDefaultDescription, + NodeIdentityTTL: structs.DefaultNodePoolNodeIdentityTTL, + CreateIndex: 1, + ModifyIndex: 1, }, }, { diff --git a/nomad/structs/identity.go b/nomad/structs/identity.go new file mode 100644 index 000000000..41e43f99d --- /dev/null +++ b/nomad/structs/identity.go @@ -0,0 +1,100 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package structs + +import ( + "strings" + "time" + + "github.com/go-jose/go-jose/v3/jwt" +) + +// IdentityDefaultAud is the default audience to use for default Nomad +// identities. +const IdentityDefaultAud = "nomadproject.io" + +// IdentityClaims is an envelope for a Nomad identity JWT that can be either a +// node identity or a workload identity. It contains the specific claims for the +// identity type, as well as the common JWT claims. +type IdentityClaims struct { + + // *NodeIdentityClaims contains the claims specific to a node identity. + *NodeIdentityClaims + + // *WorkloadIdentityClaims contains the claims specific to a workload as + // defined by an allocation running on a client. + *WorkloadIdentityClaims + + // The public JWT claims for the identity. These claims are always present, + // regardless of whether the identity is for a node or workload. + jwt.Claims +} + +// IsNode checks if the identity JWT is a node identity. +func (i *IdentityClaims) IsNode() bool { return i != nil && i.NodeIdentityClaims != nil } + +// IsWorkload checks if the identity JWT is a workload identity. +func (i *IdentityClaims) IsWorkload() bool { return i != nil && i.WorkloadIdentityClaims != nil } + +// IsExpiring checks if the identity JWT is expired or close to expiring. Close +// is defined as within one-third of the TTL provided. +func (i *IdentityClaims) IsExpiring(now time.Time, ttl time.Duration) bool { + + // Protect against nil identity claims and fast circuit a check on an + // identity that does not have expiry. + if i == nil || i.Expiry == nil { + return false + } + + // Calculate the threshold for "close to expiring" as one-third of the TTL + // relative to the current time. + threshold := now.Add(ttl / 3) + + return i.Expiry.Time().Before(threshold) +} + +// setExpiry sets the "expiry" or "exp" claim for the identity JWT. It is the +// absolute time at which the JWT will expire. +// +// If no TTL is provided, the expiry will not be set, which means the JWT will +// never expire. +func (i *IdentityClaims) setExpiry(now time.Time, ttl time.Duration) { + if ttl > 0 { + i.Expiry = jwt.NewNumericDate(now.Add(ttl)) + } +} + +// setAudience sets the "audience" or "aud" claim for the identity JWT. +func (i *IdentityClaims) setAudience(aud []string) { i.Audience = aud } + +// setNodeSubject sets the "subject" or "sub" claim for the node identity JWT. +// It follows the format "node::::default", where +// "default" indicates identity name. While this is currently hardcoded, it +// could be configurable in the future as we expand the node identity offering +// and allow greater control of node access. +func (i *IdentityClaims) setNodeSubject(node *Node, region string) { + i.Subject = strings.Join([]string{ + "node", + region, + node.NodePool, + node.ID, + "default", + }, ":") +} + +// setWorkloadSubject sets the "subject" or "sub" claim for the workload +// identity JWT. It follows the format +// ":::::". The +// subject does not include a type identifier which differs from the node +// identity and is something we may want to change in the future. +func (i *IdentityClaims) setWorkloadSubject(job *Job, group, wID, id string) { + i.Subject = strings.Join([]string{ + job.Region, + job.Namespace, + job.GetIDforWorkloadIdentity(), + group, + wID, + id, + }, ":") +} diff --git a/nomad/structs/identity_test.go b/nomad/structs/identity_test.go new file mode 100644 index 000000000..b35690f0c --- /dev/null +++ b/nomad/structs/identity_test.go @@ -0,0 +1,232 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package structs + +import ( + "testing" + "time" + + "github.com/go-jose/go-jose/v3/jwt" + "github.com/hashicorp/nomad/ci" + "github.com/shoenig/test/must" +) + +func TestIdentityClaims_IsNode(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputIdentityClaims *IdentityClaims + expectedOutput bool + }{ + { + name: "nil identity claims", + inputIdentityClaims: nil, + expectedOutput: false, + }, + { + name: "no identity claims", + inputIdentityClaims: &IdentityClaims{}, + expectedOutput: false, + }, + { + name: "workload identity claims", + inputIdentityClaims: &IdentityClaims{ + WorkloadIdentityClaims: &WorkloadIdentityClaims{}, + }, + expectedOutput: false, + }, + { + name: "node identity claims", + inputIdentityClaims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + }, + expectedOutput: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputIdentityClaims.IsNode() + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} + +func TestIdentityClaims_IsWorkload(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputIdentityClaims *IdentityClaims + expectedOutput bool + }{ + { + name: "nil identity claims", + inputIdentityClaims: nil, + expectedOutput: false, + }, + { + name: "no identity claims", + inputIdentityClaims: &IdentityClaims{}, + expectedOutput: false, + }, + { + name: "node identity claims", + inputIdentityClaims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + }, + expectedOutput: false, + }, + { + name: "workload identity claims", + inputIdentityClaims: &IdentityClaims{ + WorkloadIdentityClaims: &WorkloadIdentityClaims{}, + }, + expectedOutput: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputIdentityClaims.IsWorkload() + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} + +func TestIdentityClaims_IsExpiring(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputIdentityClaims *IdentityClaims + inputNow time.Time + inputTTL time.Duration + expectedResult bool + }{ + { + name: "nil identity", + inputIdentityClaims: nil, + inputNow: time.Now(), + inputTTL: 10 * time.Minute, + expectedResult: false, + }, + { + name: "no expiry", + inputIdentityClaims: &IdentityClaims{}, + inputNow: time.Now(), + inputTTL: 10 * time.Minute, + expectedResult: false, + }, + { + name: "not expiring not close", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(100 * time.Hour)), + }, + }, + inputNow: time.Now(), + inputTTL: 100 * time.Hour, + expectedResult: false, + }, + { + name: "not expiring close", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(100 * time.Hour)), + }, + }, + inputNow: time.Now().Add(30 * time.Hour), + inputTTL: 100 * time.Hour, + expectedResult: false, + }, + { + name: "expired close", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(100 * time.Hour)), + }, + }, + inputNow: time.Now().Add(67 * time.Hour), + inputTTL: 100 * time.Hour, + expectedResult: true, + }, + { + name: "expired not close", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(100 * time.Hour)), + }, + }, + inputNow: time.Now().Add(200 * time.Hour), + inputTTL: 100 * time.Hour, + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputIdentityClaims.IsExpiring(tc.inputNow, tc.inputTTL) + must.Eq(t, tc.expectedResult, actualOutput) + }) + } +} + +func TestIdentityClaimsNg_setExpiry(t *testing.T) { + ci.Parallel(t) + + timeNow := time.Now().UTC() + ttl := 10 * time.Minute + + claims := IdentityClaims{} + claims.setExpiry(timeNow, ttl) + + // Round the time to the nearest minute for comparison, to accommodate + // potential time differences in the test environment caused by function + // call overhead. This can be seen when running a suite of tests in + // parallel. + must.Eq(t, timeNow.Add(ttl).Round(time.Minute), + claims.Expiry.Time().UTC().Round(time.Minute)) +} + +func TestIdentityClaimsNg_setNodeSubject(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNode *Node + inputRegion string + expectedSubject string + }{ + { + name: "global region", + inputNode: &Node{ + ID: "node-id-1", + NodePool: "default", + }, + inputRegion: "global", + expectedSubject: "node:global:default:node-id-1:default", + }, + { + name: "eu1 region", + inputNode: &Node{ + ID: "node-id-2", + NodePool: "nlp", + }, + inputRegion: "eu1", + expectedSubject: "node:eu1:nlp:node-id-2:default", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ci.Parallel(t) + + claims := IdentityClaims{} + claims.setNodeSubject(tc.inputNode, tc.inputRegion) + must.Eq(t, tc.expectedSubject, claims.Subject) + }) + } +} diff --git a/nomad/structs/node.go b/nomad/structs/node.go index baa809c12..bcc0fec39 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -11,7 +11,9 @@ import ( "strings" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/hcl/v2/hclsyntax" + "github.com/hashicorp/nomad/helper/uuid" ) // CSITopology is a map of topological domains to topological segments. @@ -492,3 +494,46 @@ type NodeMetaResponse struct { // Static is the static Node metadata (set via agent configuration) Static map[string]string } + +// NodeIdentityClaims represents the claims for a Nomad node identity JWT. +type NodeIdentityClaims struct { + NodeID string `json:"nomad_node_id"` + NodePool string `json:"nomad_node_pool"` + NodeClass string `json:"nomad_node_class"` + NodeDatacenter string `json:"nomad_node_datacenter"` +} + +// GenerateNodeIdentityClaims creates a new NodeIdentityClaims for the given +// node and region. The returned claims will be ready for signing and returning +// to the node. +// +// The caller is responsible for ensuring that the passed arguments are valid. +func GenerateNodeIdentityClaims(node *Node, region string, ttl time.Duration) *IdentityClaims { + + // The time does not need to be passed into the function as an argument, as + // we only create a single identity per node at a time. This explains the + // difference with the workload identity claims, as each allocation can have + // multiple identities. + timeNow := time.Now().UTC() + timeJWTNow := jwt.NewNumericDate(timeNow) + + claims := &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: node.ID, + NodePool: node.NodePool, + NodeClass: node.NodeClass, + NodeDatacenter: node.Datacenter, + }, + Claims: jwt.Claims{ + ID: uuid.Generate(), + IssuedAt: timeJWTNow, + NotBefore: timeJWTNow, + }, + } + + claims.setAudience([]string{IdentityDefaultAud}) + claims.setExpiry(timeNow, ttl) + claims.setNodeSubject(node, region) + + return claims +} diff --git a/nomad/structs/node_pool.go b/nomad/structs/node_pool.go index aa0ab96be..8d5138127 100644 --- a/nomad/structs/node_pool.go +++ b/nomad/structs/node_pool.go @@ -8,6 +8,7 @@ import ( "maps" "regexp" "sort" + "time" "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/helper/pointer" @@ -28,6 +29,11 @@ const ( // maxNodePoolDescriptionLength is the maximum length allowed for a node // pool description. maxNodePoolDescriptionLength = 256 + + // DefaultNodePoolNodeIdentityTTL is the default TTL for node identities + // when the operator does not specify this as part of the node pool + // specification. + DefaultNodePoolNodeIdentityTTL = 24 * time.Hour ) var ( @@ -43,7 +49,7 @@ func ValidateNodePoolName(pool string) error { return nil } -// NodePool allows partioning infrastructure +// NodePool allows partitioning infrastructure type NodePool struct { // Name is the node pool name. It must be unique. Name string @@ -58,6 +64,9 @@ type NodePool struct { // node pool. SchedulerConfiguration *NodePoolSchedulerConfiguration + // NodeIdentityTTL is the time-to-live for node identities in the pool. + NodeIdentityTTL time.Duration + // Hash is the hash of the node pool which is used to efficiently diff when // we replicate pools across regions. Hash []byte @@ -87,6 +96,18 @@ func (n *NodePool) Validate() error { return mErr.ErrorOrNil() } +// Canonicalize is used to set default values for the node pool. This currently +// only sets the default TTL for node identities if it is not set. +func (n *NodePool) Canonicalize() { + if n == nil { + return + } + + if n.NodeIdentityTTL == 0 { + n.NodeIdentityTTL = DefaultNodePoolNodeIdentityTTL + } +} + // Copy returns a deep copy of the node pool. func (n *NodePool) Copy() *NodePool { if n == nil { @@ -151,6 +172,8 @@ func (n *NodePool) SetHash() []byte { // Write all the user set fields _, _ = hash.Write([]byte(n.Name)) _, _ = hash.Write([]byte(n.Description)) + _, _ = hash.Write([]byte(n.NodeIdentityTTL.String())) + if n.SchedulerConfiguration != nil { _, _ = hash.Write([]byte(n.SchedulerConfiguration.SchedulerAlgorithm)) @@ -184,7 +207,7 @@ func (n *NodePool) SetHash() []byte { return hashVal } -// NodePoolSchedulerConfiguration is the scheduler confinguration applied to a +// NodePoolSchedulerConfiguration is the scheduler configuration applied to a // node pool. // // When adding new values that should override global scheduler configuration, diff --git a/nomad/structs/node_pool_test.go b/nomad/structs/node_pool_test.go index eef8c8bef..135652b4c 100644 --- a/nomad/structs/node_pool_test.go +++ b/nomad/structs/node_pool_test.go @@ -6,12 +6,46 @@ package structs import ( "strings" "testing" + "time" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/helper/pointer" "github.com/shoenig/test/must" ) +func TestNodePool_Canonicalize(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodePool *NodePool + expected *NodePool + }{ + { + name: "nil node pool", + inputNodePool: nil, + expected: nil, + }, + { + name: "identity ttl set", + inputNodePool: &NodePool{NodeIdentityTTL: 43830 * time.Hour}, + expected: &NodePool{NodeIdentityTTL: 43830 * time.Hour}, + }, + { + name: "identity ttl not set", + inputNodePool: &NodePool{}, + expected: &NodePool{NodeIdentityTTL: DefaultNodePoolNodeIdentityTTL}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.inputNodePool.Canonicalize() + must.Eq(t, tc.inputNodePool, tc.expected) + }) + } +} + func TestNodePool_Copy(t *testing.T) { ci.Parallel(t) diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index d9185a672..95970aaed 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -5,6 +5,7 @@ package structs import ( "testing" + "time" "github.com/hashicorp/nomad/ci" "github.com/shoenig/test/must" @@ -254,3 +255,27 @@ func TestCSITopology_Contains(t *testing.T) { } } + +func TestGenerateNodeIdentityClaims(t *testing.T) { + ci.Parallel(t) + + node := &Node{ + ID: "node-id-1", + NodePool: "custom-pool", + NodeClass: "custom-class", + Datacenter: "euw2", + } + + claims := GenerateNodeIdentityClaims(node, "euw", 10*time.Minute) + + must.Eq(t, "node-id-1", claims.NodeID) + must.Eq(t, "custom-pool", claims.NodePool) + must.Eq(t, "custom-class", claims.NodeClass) + must.Eq(t, "euw2", claims.NodeDatacenter) + must.StrEqFold(t, "node:euw:custom-pool:node-id-1:default", claims.Subject) + must.Eq(t, []string{IdentityDefaultAud}, claims.Audience) + must.NotNil(t, claims.ID) + must.NotNil(t, claims.IssuedAt) + must.NotNil(t, claims.NotBefore) + must.NotNil(t, claims.Expiry) +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 8a76e83e0..bd576f5cf 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -2138,6 +2138,15 @@ type Node struct { // StatusDescription is meant to provide more human useful information StatusDescription string + // IdentitySigningKeyID is the ID of the root key used to sign the identity + // of the node. This is primarily used to ensure Nomad does not delete a + // root keyring that still has nodes with identities signed by it. + // + // This field is only set if the node has a workload identity and will be + // modified by the server when the node is registered or updated, and the + // signing key ID has changed from what is stored in state. + IdentitySigningKeyID string + // StatusUpdatedAt is the time stamp at which the state of the node was // updated, stored as Unix (no nano seconds!) StatusUpdatedAt int64 diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index eab6d433a..c3d6f759d 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -258,7 +258,9 @@ func TestAuthenticatedIdentity_String(t *testing.T) { name: "alloc claim", inputAuthenticatedIdentity: &AuthenticatedIdentity{ Claims: &IdentityClaims{ - AllocationID: "my-testing-alloc-id", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + AllocationID: "my-testing-alloc-id", + }, }, }, expectedOutput: "alloc:my-testing-alloc-id", @@ -8291,7 +8293,7 @@ func TestTaskIdentity_Canonicalize(t *testing.T) { // to the original field. must.NotNil(t, task.Identity) must.Eq(t, WorkloadIdentityDefaultName, task.Identity.Name) - must.Eq(t, []string{WorkloadIdentityDefaultAud}, task.Identity.Audience) + must.Eq(t, []string{IdentityDefaultAud}, task.Identity.Audience) must.False(t, task.Identity.Env) must.False(t, task.Identity.File) diff --git a/nomad/structs/workload_id.go b/nomad/structs/workload_id.go index f2b2b0c77..abb5aaa4d 100644 --- a/nomad/structs/workload_id.go +++ b/nomad/structs/workload_id.go @@ -21,9 +21,6 @@ const ( // Identity. WorkloadIdentityDefaultName = "default" - // WorkloadIdentityDefaultAud is the audience of the default identity. - WorkloadIdentityDefaultAud = "nomadproject.io" - // WorkloadIdentityVaultPrefix is the name prefix of workload identities // used to derive Vault tokens. WorkloadIdentityVaultPrefix = "vault_" @@ -63,9 +60,9 @@ var ( MinNomadVersionVaultWID = version.Must(version.NewVersion("1.7.0-a")) ) -// IdentityClaims are the input to a JWT identifying a workload. It +// WorkloadIdentityClaims are the input to a JWT identifying a workload. It // should never be serialized to msgpack unsigned. -type IdentityClaims struct { +type WorkloadIdentityClaims struct { Namespace string `json:"nomad_namespace"` JobID string `json:"nomad_job_id"` AllocationID string `json:"nomad_allocation_id"` @@ -79,15 +76,13 @@ type IdentityClaims struct { // ExtraClaims are added based on this identity's // WorkloadIdentityConfiguration, controlled by server configuration ExtraClaims map[string]string `json:"extra_claims,omitempty"` - - jwt.Claims } -// IdentityClaimsBuilder is used to build up all the context we need to create -// IdentityClaims from jobs, allocs, tasks, services, Vault and Consul -// configurations, etc. This lets us treat IdentityClaims as the immutable -// output of that process. -type IdentityClaimsBuilder struct { +// WorkloadIdentityClaimsBuilder is used to build up all the context we need to create +// WorkloadIdentityClaims from jobs, allocs, tasks, services, Vault and Consul +// configurations, etc. This lets us treat WorkloadIdentityClaims as the +// immutable output of that process. +type WorkloadIdentityClaimsBuilder struct { wid *WorkloadIdentity // from jobspec wihandle *WIHandle alloc *Allocation @@ -101,11 +96,11 @@ type IdentityClaimsBuilder struct { extras map[string]string } -// NewIdentityClaimsBuilder returns an initialized IdentityClaimsBuilder for the +// NewIdentityClaimsBuilder returns an initialized WorkloadIdentityClaimsBuilder for the // allocation and identity request. Because it may be called with a denormalized // Allocation in the plan applier, the Job must be passed in as a separate // parameter. -func NewIdentityClaimsBuilder(job *Job, alloc *Allocation, wihandle *WIHandle, wid *WorkloadIdentity) *IdentityClaimsBuilder { +func NewIdentityClaimsBuilder(job *Job, alloc *Allocation, wihandle *WIHandle, wid *WorkloadIdentity) *WorkloadIdentityClaimsBuilder { tg := job.LookupTaskGroup(alloc.TaskGroup) if tg == nil { return nil @@ -114,7 +109,7 @@ func NewIdentityClaimsBuilder(job *Job, alloc *Allocation, wihandle *WIHandle, w wid = DefaultWorkloadIdentity() } - return &IdentityClaimsBuilder{ + return &WorkloadIdentityClaimsBuilder{ alloc: alloc, job: job, wihandle: wihandle, @@ -125,7 +120,7 @@ func NewIdentityClaimsBuilder(job *Job, alloc *Allocation, wihandle *WIHandle, w } // WithTask adds a task to the builder context. -func (b *IdentityClaimsBuilder) WithTask(task *Task) *IdentityClaimsBuilder { +func (b *WorkloadIdentityClaimsBuilder) WithTask(task *Task) *WorkloadIdentityClaimsBuilder { if task == nil { return b } @@ -135,7 +130,7 @@ func (b *IdentityClaimsBuilder) WithTask(task *Task) *IdentityClaimsBuilder { // WithVault adds the task's vault block to the builder context. This should // only be called after WithTask. -func (b *IdentityClaimsBuilder) WithVault(extraClaims map[string]string) *IdentityClaimsBuilder { +func (b *WorkloadIdentityClaimsBuilder) WithVault(extraClaims map[string]string) *WorkloadIdentityClaimsBuilder { if !b.wid.IsVault() || b.task == nil { return b } @@ -148,7 +143,7 @@ func (b *IdentityClaimsBuilder) WithVault(extraClaims map[string]string) *Identi // WithConsul adds the group or task's consul block to the builder context. For // task identities, this should only be called after WithTask. -func (b *IdentityClaimsBuilder) WithConsul() *IdentityClaimsBuilder { +func (b *WorkloadIdentityClaimsBuilder) WithConsul() *WorkloadIdentityClaimsBuilder { if !b.wid.IsConsul() { return b } @@ -163,7 +158,7 @@ func (b *IdentityClaimsBuilder) WithConsul() *IdentityClaimsBuilder { // WithService adds a service block to the builder context. This should only be // called for service identities, and a builder for service identities will // never set the task_name claim. -func (b *IdentityClaimsBuilder) WithService(service *Service) *IdentityClaimsBuilder { +func (b *WorkloadIdentityClaimsBuilder) WithService(service *Service) *WorkloadIdentityClaimsBuilder { if b.wihandle.WorkloadType != WorkloadTypeService { return b } @@ -176,7 +171,7 @@ func (b *IdentityClaimsBuilder) WithService(service *Service) *IdentityClaimsBui } // WithNode add the allocation's node to the builder context. -func (b *IdentityClaimsBuilder) WithNode(node *Node) *IdentityClaimsBuilder { +func (b *WorkloadIdentityClaimsBuilder) WithNode(node *Node) *WorkloadIdentityClaimsBuilder { b.node = node return b } @@ -185,21 +180,24 @@ func (b *IdentityClaimsBuilder) WithNode(node *Node) *IdentityClaimsBuilder { // on the claim. The claim ID is random (nondeterministic) so multiple calls // with the same values will not return equal claims by design. JWT IDs should // never collide. -func (b *IdentityClaimsBuilder) Build(now time.Time) *IdentityClaims { +func (b *WorkloadIdentityClaimsBuilder) Build(now time.Time) *IdentityClaims { b.interpolate() jwtnow := jwt.NewNumericDate(now.UTC()) claims := &IdentityClaims{ - Namespace: b.alloc.Namespace, - JobID: b.job.GetIDforWorkloadIdentity(), - AllocationID: b.alloc.ID, - ServiceName: b.serviceName, + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: b.alloc.Namespace, + JobID: b.job.GetIDforWorkloadIdentity(), + AllocationID: b.alloc.ID, + ServiceName: b.serviceName, + ExtraClaims: b.extras, + }, Claims: jwt.Claims{ NotBefore: jwtnow, IssuedAt: jwtnow, }, - ExtraClaims: b.extras, } + if b.task != nil && b.wihandle.WorkloadType != WorkloadTypeService { claims.TaskName = b.task.Name } @@ -211,9 +209,9 @@ func (b *IdentityClaimsBuilder) Build(now time.Time) *IdentityClaims { claims.VaultRole = b.vault.Role } - claims.Audience = slices.Clone(b.wid.Audience) - claims.setSubject(b.job, b.alloc.TaskGroup, b.wihandle.WorkloadIdentifier, b.wid.Name) - claims.setExp(now, b.wid) + claims.setAudience(slices.Clone(b.wid.Audience)) + claims.setWorkloadSubject(b.job, b.alloc.TaskGroup, b.wihandle.WorkloadIdentifier, b.wid.Name) + claims.setExpiry(now, b.wid.TTL) claims.ID = uuid.Generate() @@ -227,7 +225,7 @@ func strAttrGet[T any](x *T, fn func(x *T) string) string { return "" } -func (b *IdentityClaimsBuilder) interpolate() { +func (b *WorkloadIdentityClaimsBuilder) interpolate() { if len(b.extras) == 0 { return } @@ -256,28 +254,6 @@ func (b *IdentityClaimsBuilder) interpolate() { } } -// setSubject creates the standard subject claim for workload identities. -func (claims *IdentityClaims) setSubject(job *Job, group, widentifier, id string) { - claims.Subject = strings.Join([]string{ - job.Region, - job.Namespace, - job.GetIDforWorkloadIdentity(), - group, - widentifier, - id, - }, ":") -} - -// setExp sets the absolute time at which these identity claims expire. -func (claims *IdentityClaims) setExp(now time.Time, wid *WorkloadIdentity) { - if wid.TTL == 0 { - // No expiry - return - } - - claims.Expiry = jwt.NewNumericDate(now.Add(wid.TTL)) -} - // WorkloadIdentity is the jobspec block which determines if and how a workload // identity is exposed to tasks similar to the Vault block. // @@ -326,7 +302,7 @@ type WorkloadIdentity struct { func DefaultWorkloadIdentity() *WorkloadIdentity { return &WorkloadIdentity{ Name: WorkloadIdentityDefaultName, - Audience: []string{WorkloadIdentityDefaultAud}, + Audience: []string{IdentityDefaultAud}, } } @@ -421,7 +397,7 @@ func (wi *WorkloadIdentity) Canonicalize() { // The default identity is only valid for use with Nomad itself. if wi.Name == WorkloadIdentityDefaultName { - wi.Audience = []string{WorkloadIdentityDefaultAud} + wi.Audience = []string{IdentityDefaultAud} } if wi.ChangeSignal != "" { diff --git a/nomad/structs/workload_id_test.go b/nomad/structs/workload_id_test.go index 0386b26d4..7b604f54d 100644 --- a/nomad/structs/workload_id_test.go +++ b/nomad/structs/workload_id_test.go @@ -178,261 +178,303 @@ func TestNewIdentityClaims(t *testing.T) { expectedClaims := map[string]*IdentityClaims{ // group: no consul. "job/group/services/group-service": { - Namespace: "default", - JobID: "parentJob", - ServiceName: "group-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + ServiceName: "group-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:group-service:consul-service_group-service-http", Audience: jwt.Audience{"group-service.consul.io"}, }, - ExtraClaims: map[string]string{}, }, // group: no consul. // task: no consul, no vault. "job/group/task/default-identity": { - Namespace: "default", - JobID: "parentJob", - TaskName: "task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:task:default-identity", Audience: jwt.Audience{"example.com"}, }, - ExtraClaims: map[string]string{}, }, "job/group/task/alt-identity": { - Namespace: "default", - JobID: "parentJob", - TaskName: "task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:task:alt-identity", Audience: jwt.Audience{"alt.example.com"}, }, - ExtraClaims: map[string]string{}, }, // No ConsulNamespace because there is no consul block at either task // or group level. "job/group/task/consul_default": { - ConsulNamespace: "", - Namespace: "default", - JobID: "parentJob", - TaskName: "task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "", + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:task:consul_default", Audience: jwt.Audience{"consul.io"}, }, - ExtraClaims: map[string]string{}, }, // No VaultNamespace because there is no vault block at either task // or group level. "job/group/task/vault_default": { - VaultNamespace: "", - Namespace: "default", - JobID: "parentJob", - TaskName: "task", - VaultRole: "", // not specified in jobspec + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + VaultNamespace: "", + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + VaultRole: "", // not specified in jobspec + ExtraClaims: map[string]string{ + "nomad_workload_id": "global:default:parentJob", + }, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:task:vault_default", Audience: jwt.Audience{"vault.io"}, }, - ExtraClaims: map[string]string{ - "nomad_workload_id": "global:default:parentJob", - }, }, "job/group/task/services/task-service": { - Namespace: "default", - JobID: "parentJob", - ServiceName: "task-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + ServiceName: "task-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:task-service:consul-service_task-task-service-http", Audience: jwt.Audience{"task-service.consul.io"}, }, - ExtraClaims: map[string]string{}, }, // group: no consul. // task: with consul, with vault. "job/group/consul-vault-task/default-identity": { - Namespace: "default", - JobID: "parentJob", - TaskName: "consul-vault-task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "consul-vault-task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:consul-vault-task:default-identity", Audience: jwt.Audience{"example.com"}, }, - ExtraClaims: map[string]string{}, }, // Use task-level Consul namespace. "job/group/consul-vault-task/consul_default": { - ConsulNamespace: "task-consul-namespace", - Namespace: "default", - JobID: "parentJob", - TaskName: "consul-vault-task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "task-consul-namespace", + Namespace: "default", + JobID: "parentJob", + TaskName: "consul-vault-task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:consul-vault-task:consul_default", Audience: jwt.Audience{"consul.io"}, }, - ExtraClaims: map[string]string{}, }, // Use task-level Vault namespace. "job/group/consul-vault-task/vault_default": { - VaultNamespace: "vault-namespace", - Namespace: "default", - JobID: "parentJob", - TaskName: "consul-vault-task", - VaultRole: "role-from-spec-group", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + VaultNamespace: "vault-namespace", + Namespace: "default", + JobID: "parentJob", + TaskName: "consul-vault-task", + VaultRole: "role-from-spec-group", + ExtraClaims: map[string]string{ + "nomad_workload_id": "global:default:parentJob", + }, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:consul-vault-task:vault_default", Audience: jwt.Audience{"vault.io"}, }, - ExtraClaims: map[string]string{ - "nomad_workload_id": "global:default:parentJob", - }, }, // Use task-level Consul namespace for task services. "job/group/consul-vault-task/services/consul-vault-task-service": { - ConsulNamespace: "task-consul-namespace", - Namespace: "default", - JobID: "parentJob", - ServiceName: "consul-vault-task-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "task-consul-namespace", + Namespace: "default", + JobID: "parentJob", + ServiceName: "consul-vault-task-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:consul-vault-task-service:consul-service_consul-vault-task-service-http", Audience: jwt.Audience{"consul.io"}, }, - ExtraClaims: map[string]string{}, }, // group: with consul. // Use group-level Consul namespace for group services. "job/consul-group/services/group-service": { - ConsulNamespace: "group-consul-namespace", - Namespace: "default", - JobID: "parentJob", - ServiceName: "group-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "group-consul-namespace", + Namespace: "default", + JobID: "parentJob", + ServiceName: "group-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:group-service:consul-service_group-service-http", Audience: jwt.Audience{"group-service.consul.io"}, }, - ExtraClaims: map[string]string{}, }, // group: with consul. // task: no consul, no vault. "job/consul-group/task/default-identity": { - Namespace: "default", - JobID: "parentJob", - TaskName: "task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:task:default-identity", Audience: jwt.Audience{"example.com"}, }, - ExtraClaims: map[string]string{}, }, "job/consul-group/task/alt-identity": { - Namespace: "default", - JobID: "parentJob", - TaskName: "task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:task:alt-identity", Audience: jwt.Audience{"alt.example.com"}, }, - ExtraClaims: map[string]string{}, }, // Use group-level Consul namespace because task doesn't have a consul // block. "job/consul-group/task/consul_default": { - ConsulNamespace: "group-consul-namespace", - Namespace: "default", - JobID: "parentJob", - TaskName: "task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "group-consul-namespace", + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:task:consul_default", Audience: jwt.Audience{"consul.io"}, }, - ExtraClaims: map[string]string{}, }, "job/consul-group/task/vault_default": { - Namespace: "default", - JobID: "parentJob", - TaskName: "task", - VaultRole: "", // not specified in jobspec + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "task", + VaultRole: "", // not specified in jobspec + ExtraClaims: map[string]string{ + "nomad_workload_id": "global:default:parentJob", + }, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:task:vault_default", Audience: jwt.Audience{"vault.io"}, }, - ExtraClaims: map[string]string{ - "nomad_workload_id": "global:default:parentJob", - }, }, // Use group-level Consul namespace for task service because task // doesn't have a consul block. "job/consul-group/task/services/task-service": { - ConsulNamespace: "group-consul-namespace", - Namespace: "default", - JobID: "parentJob", - ServiceName: "task-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "group-consul-namespace", + Namespace: "default", + JobID: "parentJob", + ServiceName: "task-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:task-service:consul-service_task-task-service-http", Audience: jwt.Audience{"task-service.consul.io"}, }, - ExtraClaims: map[string]string{}, }, // group: no consul. // task: with consul, with vault. "job/consul-group/consul-vault-task/default-identity": { - Namespace: "default", - JobID: "parentJob", - TaskName: "consul-vault-task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + Namespace: "default", + JobID: "parentJob", + TaskName: "consul-vault-task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:consul-vault-task:default-identity", Audience: jwt.Audience{"example.com"}, }, - ExtraClaims: map[string]string{}, }, // Use task-level Consul namespace. "job/consul-group/consul-vault-task/consul_default": { - ConsulNamespace: "task-consul-namespace", - Namespace: "default", - JobID: "parentJob", - TaskName: "consul-vault-task", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "task-consul-namespace", + Namespace: "default", + JobID: "parentJob", + TaskName: "consul-vault-task", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:consul-vault-task:consul_default", Audience: jwt.Audience{"consul.io"}, }, - ExtraClaims: map[string]string{}, }, "job/consul-group/consul-vault-task/vault_default": { - VaultNamespace: "vault-namespace", - Namespace: "default", - JobID: "parentJob", - TaskName: "consul-vault-task", - VaultRole: "role-from-spec-consul-group", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + VaultNamespace: "vault-namespace", + Namespace: "default", + JobID: "parentJob", + TaskName: "consul-vault-task", + VaultRole: "role-from-spec-consul-group", + ExtraClaims: map[string]string{ + "nomad_workload_id": "global:default:parentJob", + }, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:consul-vault-task:vault_default", Audience: jwt.Audience{"vault.io"}, }, - ExtraClaims: map[string]string{ - "nomad_workload_id": "global:default:parentJob", - }, }, // Use task-level Consul namespace for task services. "job/consul-group/consul-vault-task/services/consul-task-service": { - ConsulNamespace: "task-consul-namespace", - Namespace: "default", - JobID: "parentJob", - ServiceName: "consul-task-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "task-consul-namespace", + Namespace: "default", + JobID: "parentJob", + ServiceName: "consul-task-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:consul-group:consul-task-service:consul-service_consul-vault-task-consul-task-service-http", Audience: jwt.Audience{"consul.io"}, }, - ExtraClaims: map[string]string{}, }, "job/group/consul-vault-task/services/consul-task-service": { - ConsulNamespace: "task-consul-namespace", - Namespace: "default", - JobID: "parentJob", - ServiceName: "consul-task-service", + WorkloadIdentityClaims: &WorkloadIdentityClaims{ + ConsulNamespace: "task-consul-namespace", + Namespace: "default", + JobID: "parentJob", + ServiceName: "consul-task-service", + ExtraClaims: map[string]string{}, + }, Claims: jwt.Claims{ Subject: "global:default:parentJob:group:consul-task-service:consul-service_consul-vault-task-consul-task-service-http", Audience: jwt.Audience{"task-service.consul.io"}, }, - ExtraClaims: map[string]string{}, }, } @@ -625,7 +667,7 @@ func TestWorkloadIdentity_Validate(t *testing.T) { In: WorkloadIdentity{}, Exp: WorkloadIdentity{ Name: WorkloadIdentityDefaultName, - Audience: []string{WorkloadIdentityDefaultAud}, + Audience: []string{IdentityDefaultAud}, }, }, { @@ -635,7 +677,7 @@ func TestWorkloadIdentity_Validate(t *testing.T) { }, Exp: WorkloadIdentity{ Name: WorkloadIdentityDefaultName, - Audience: []string{WorkloadIdentityDefaultAud}, + Audience: []string{IdentityDefaultAud}, }, }, { From 4a440d0b0eae0d170a49c96e494be1184fca8ea8 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Thu, 26 Jun 2025 07:43:35 +0100 Subject: [PATCH 2/7] fsm: Add identity signing key handling for node status updates. (#26139) When a node heartbeats, the RPC handler will optionally generate an identity to return to the caller. The identity key ID will be stored in the node object, so we have tracking of keys in use. The state store has been updated to handle node status update requests that include a signing key ID. Rather than add another parameter into the function signature, the FSM function now takes the entire request object. --- nomad/fsm.go | 2 +- nomad/node_endpoint_test.go | 23 +++++++++++++----- nomad/state/events_test.go | 2 +- nomad/state/state_store.go | 33 ++++++++++++++++++-------- nomad/state/state_store_test.go | 41 ++++++++++++++++++++++++++++++--- nomad/structs/structs.go | 14 +++++++++-- 6 files changed, 92 insertions(+), 23 deletions(-) diff --git a/nomad/fsm.go b/nomad/fsm.go index 758b8321f..2b668764e 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -498,7 +498,7 @@ func (n *nomadFSM) applyStatusUpdate(msgType structs.MessageType, buf []byte, in panic(fmt.Errorf("failed to decode request: %v", err)) } - if err := n.state.UpdateNodeStatus(msgType, index, req.NodeID, req.Status, req.UpdatedAt, req.NodeEvent); err != nil { + if err := n.state.UpdateNodeStatus(msgType, index, &req); err != nil { n.logger.Error("UpdateNodeStatus failed", "error", err) return err } diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index a69ae4470..0fae446de 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -2958,9 +2958,12 @@ func TestNode_UpdateAlloc_NodeNotReady(t *testing.T) { must.NoError(t, store.UpsertJobSummary(99, mock.JobSummary(alloc.JobID))) must.NoError(t, store.UpsertAllocs(structs.MsgTypeTestSetup, 100, []*structs.Allocation{alloc})) - // Mark node as down. - must.NoError(t, store.UpdateNodeStatus( - structs.MsgTypeTestSetup, 101, node.ID, structs.NodeStatusDown, time.Now().UnixNano(), nil)) + downReq := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusDown, + UpdatedAt: time.Now().UnixNano(), + } + must.NoError(t, store.UpdateNodeStatus(structs.MsgTypeTestSetup, 101, &downReq)) // Try to update alloc. updatedAlloc := new(structs.Allocation) @@ -2991,8 +2994,12 @@ func TestNode_UpdateAlloc_NodeNotReady(t *testing.T) { must.ErrorContains(t, err, "not found") // Mark node as ready and try again. - must.NoError(t, store.UpdateNodeStatus( - structs.MsgTypeTestSetup, 102, node.ID, structs.NodeStatusReady, time.Now().UnixNano(), nil)) + readyReq := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + UpdatedAt: time.Now().UnixNano(), + } + must.NoError(t, store.UpdateNodeStatus(structs.MsgTypeTestSetup, 102, &readyReq)) updatedAlloc.NodeID = node.ID must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateAlloc", allocUpdateReq, &allocUpdateResp)) @@ -3752,8 +3759,12 @@ func TestClientEndpoint_ListNodes_Blocking(t *testing.T) { } // Node status update triggers watches + triggerReq := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusDown, + } time.AfterFunc(100*time.Millisecond, func() { - errCh <- state.UpdateNodeStatus(structs.MsgTypeTestSetup, 40, node.ID, structs.NodeStatusDown, 0, nil) + errCh <- state.UpdateNodeStatus(structs.MsgTypeTestSetup, 40, &triggerReq) }) req.MinQueryIndex = 38 diff --git a/nomad/state/events_test.go b/nomad/state/events_test.go index a9dc877bb..3ce793ca0 100644 --- a/nomad/state/events_test.go +++ b/nomad/state/events_test.go @@ -377,7 +377,7 @@ func TestEventsFromChanges_NodeUpdateStatusRequest(t *testing.T) { NodeEvent: &structs.NodeEvent{Message: "down"}, } - must.NoError(t, s.UpdateNodeStatus(msgType, 100, req.NodeID, req.Status, req.UpdatedAt, req.NodeEvent)) + must.NoError(t, s.UpdateNodeStatus(msgType, 100, req)) events := WaitForEvents(t, s, 100, 1, 1*time.Second) must.Len(t, 1, events) diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 2b5a005d5..e51d29a68 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -1097,21 +1097,26 @@ func (s *StateStore) deleteNodeTxn(txn *txn, index uint64, nodes []string) error } // UpdateNodeStatus is used to update the status of a node -func (s *StateStore) UpdateNodeStatus(msgType structs.MessageType, index uint64, nodeID, status string, updatedAt int64, event *structs.NodeEvent) error { +func (s *StateStore) UpdateNodeStatus( + msgType structs.MessageType, + index uint64, + req *structs.NodeUpdateStatusRequest, +) error { + txn := s.db.WriteTxnMsgT(msgType, index) defer txn.Abort() - if err := s.updateNodeStatusTxn(txn, nodeID, status, updatedAt, event); err != nil { + if err := s.updateNodeStatusTxn(txn, req); err != nil { return err } return txn.Commit() } -func (s *StateStore) updateNodeStatusTxn(txn *txn, nodeID, status string, updatedAt int64, event *structs.NodeEvent) error { +func (s *StateStore) updateNodeStatusTxn(txn *txn, req *structs.NodeUpdateStatusRequest) error { // Lookup the node - existing, err := txn.First("nodes", "id", nodeID) + existing, err := txn.First(TableNodes, indexID, req.NodeID) if err != nil { return fmt.Errorf("node lookup failed: %v", err) } @@ -1122,15 +1127,23 @@ func (s *StateStore) updateNodeStatusTxn(txn *txn, nodeID, status string, update // Copy the existing node existingNode := existing.(*structs.Node) copyNode := existingNode.Copy() - copyNode.StatusUpdatedAt = updatedAt + copyNode.StatusUpdatedAt = req.UpdatedAt + + // If the request has a signing key ID, we should update the node reference + // to this. We need to check for the empty string, as a new identity won't + // always be generated, and we don't want to overwrite the exiting entry + // with an empty string. + if req.IdentitySigningKeyID != "" { + copyNode.IdentitySigningKeyID = req.IdentitySigningKeyID + } // Add the event if given - if event != nil { - appendNodeEvents(txn.Index, copyNode, []*structs.NodeEvent{event}) + if req.NodeEvent != nil { + appendNodeEvents(txn.Index, copyNode, []*structs.NodeEvent{req.NodeEvent}) } // Update the status in the copy - copyNode.Status = status + copyNode.Status = req.Status copyNode.ModifyIndex = txn.Index // Update last missed heartbeat if the node became unresponsive or reset it @@ -1143,10 +1156,10 @@ func (s *StateStore) updateNodeStatusTxn(txn *txn, nodeID, status string, update } // Insert the node - if err := txn.Insert("nodes", copyNode); err != nil { + if err := txn.Insert(TableNodes, copyNode); err != nil { return fmt.Errorf("node update failed: %v", err) } - if err := txn.Insert("index", &IndexEntry{"nodes", txn.Index}); err != nil { + if err := txn.Insert(tableIndex, &IndexEntry{TableNodes, txn.Index}); err != nil { return fmt.Errorf("index update failed: %v", err) } diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index a9ae00a8f..5774c6396 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -1497,7 +1497,17 @@ func TestStateStore_UpdateNodeStatus_Node(t *testing.T) { Timestamp: time.Now(), } - must.NoError(t, state.UpdateNodeStatus(structs.MsgTypeTestSetup, 801, node.ID, structs.NodeStatusReady, 70, event)) + signingKeyID := uuid.Generate() + + stateReq := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + IdentitySigningKeyID: signingKeyID, + NodeEvent: event, + UpdatedAt: 70, + } + + must.NoError(t, state.UpdateNodeStatus(structs.MsgTypeTestSetup, 801, &stateReq)) must.True(t, watchFired(ws)) ws = memdb.NewWatchSet() @@ -1508,11 +1518,31 @@ func TestStateStore_UpdateNodeStatus_Node(t *testing.T) { must.Eq(t, 70, out.StatusUpdatedAt) must.Len(t, 2, out.Events) must.Eq(t, event.Message, out.Events[1].Message) + must.Eq(t, signingKeyID, out.IdentitySigningKeyID) - index, err := state.Index("nodes") + index, err := state.Index(TableNodes) must.NoError(t, err) must.Eq(t, 801, index) must.False(t, watchFired(ws)) + + // Send another update, but the signing key ID is empty, this should not + // overwrite the existing signing key ID. + stateReq = structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + IdentitySigningKeyID: "", + NodeEvent: &structs.NodeEvent{ + Message: "Node even more ready foo", + Subsystem: structs.NodeEventSubsystemCluster, + Timestamp: time.Now(), + }, + UpdatedAt: 80, + } + + must.NoError(t, state.UpdateNodeStatus(structs.MsgTypeTestSetup, 802, &stateReq)) + out, err = state.NodeByID(ws, node.ID) + must.NoError(t, err) + must.Eq(t, signingKeyID, out.IdentitySigningKeyID) } func TestStatStore_UpdateNodeStatus_LastMissedHeartbeatIndex(t *testing.T) { @@ -1598,7 +1628,12 @@ func TestStatStore_UpdateNodeStatus_LastMissedHeartbeatIndex(t *testing.T) { for i, status := range tc.transitions { now := time.Now().UnixNano() - err := state.UpdateNodeStatus(structs.MsgTypeTestSetup, uint64(1000+i), node.ID, status, now, nil) + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: status, + UpdatedAt: now, + } + err := state.UpdateNodeStatus(structs.MsgTypeTestSetup, uint64(1000+i), &req) must.NoError(t, err) ws := memdb.NewWatchSet() diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index bd576f5cf..aceb5fcfb 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -643,8 +643,18 @@ type NodeServerInfo struct { // NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint // to update the status of a node. type NodeUpdateStatusRequest struct { - NodeID string - Status string + NodeID string + Status string + + // IdentitySigningKeyID is the ID of the root key used to sign the node's + // identity. This is not provided by the client, but is set by the server, + // so that the value can be propagated through Raft. + IdentitySigningKeyID string + + // ForceIdentityRenewal is used to force the Nomad server to generate a new + // identity for the node. + ForceIdentityRenewal bool + NodeEvent *NodeEvent UpdatedAt int64 WriteRequest From 4393c0e76ba76bae2dca81ddfb2621bf69eb10f1 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Fri, 27 Jun 2025 14:59:23 +0100 Subject: [PATCH 3/7] auth: Add authentication support for node identities. (#26148) The authenticator process which performs RPC authentication has been modified to support node identities. Node identities are verified by ensuring the node ID as claimed has a node written to Nomad state. The client only and generic authenticate methods now support both node secret IDs and node identities. It uses uuid checking to attempt to parse either option. A new method has also been added to handle the specific RPCs that will optionally generate node identities. While a new authenticator method is not ideal, it is better than the alternative option for these RPCs to perform complex additional RPC context work in order to understand whether an identity should be generated. The TLS verification functionality has been pulled into its own method to avoid further code duplication. --- nomad/auth/auth.go | 227 ++++++++++++++++----- nomad/auth/auth_test.go | 411 ++++++++++++++++++++++++++++++++++++++- nomad/structs/structs.go | 5 +- 3 files changed, 586 insertions(+), 57 deletions(-) diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index 9435a2e61..1e412961c 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -268,35 +268,86 @@ func (s *Authenticator) AuthenticateServerOnly(ctx RPCContext, args structs.Requ identity := &structs.AuthenticatedIdentity{RemoteIP: remoteIP} defer args.SetIdentity(identity) // always set the identity, even on errors - if s.verifyTLS.Load() && !ctx.IsStatic() { - tlsCert := ctx.Certificate() - if tlsCert == nil { - return nil, errors.New("missing certificate information") - } - - // set on the identity whether or not its valid for server RPC, so we - // can capture it for metrics - identity.TLSName = tlsCert.Subject.CommonName - _, err := validateCertificateForNames(tlsCert, s.validServerCertNames) - if err != nil { - return nil, err - } - return acl.ServerACL, nil - } - // Note: if servers had auth tokens like clients do, we would be able to // verify them here and only return the server ACL for actual servers even // if mTLS was disabled. Without mTLS, any request can spoof server RPCs. // This is known and documented in the Security Model: // https://developer.hashicorp.com/nomad/docs/concepts/security#requirements + if err := verifyTLS(s.verifyTLS.Load(), ctx, s.validServerCertNames, identity); err != nil { + return nil, err + } + return acl.ServerACL, nil } +// AuthenticateNodeIdentityGenerator is used for RPC endpoints (Node.Register +// and Node.UpdateStatus) that have the potential to generate node identities. +// +// While the Authenticate method serves as a complete general purpose +// authenticator, in some critical cases for identity generation checking, it +// swallows the information needed. +func (s *Authenticator) AuthenticateNodeIdentityGenerator(ctx RPCContext, args structs.RequestWithIdentity) error { + + remoteIP, err := ctx.GetRemoteIP() // capture for metrics + if err != nil { + s.logger.Error("could not determine remote address", "error", err) + } + + identity := &structs.AuthenticatedIdentity{RemoteIP: remoteIP} + defer args.SetIdentity(identity) + + if err := verifyTLS(s.verifyTLS.Load(), ctx, s.validClientCertNames, identity); err != nil { + return err + } + + authToken := args.GetAuthToken() + + // If the auth token is empty, we treat it as an anonymous request. In the + // event of a node registration, this means the node is not yet registered. + if authToken == "" { + identity.ACLToken = structs.AnonymousACLToken + return nil + } + + // If the auth token is a UUID, we check whether it's a node secret ID or + // the leader's ACL token. If it's not a UUID, we assume it's a node + // identity. Anything outside these cases is not supported and no identity + // will be set. + if helper.IsUUID(authToken) { + if leaderAcl := s.getLeaderACL(); leaderAcl != "" && authToken == leaderAcl { + identity.ACLToken = structs.LeaderACLToken + } else { + node, err := s.getState().NodeBySecretID(nil, authToken) + if err != nil { + return fmt.Errorf("could not resolve node secret: %w", err) + } + if node == nil { + return structs.ErrPermissionDenied + } + identity.ClientID = node.ID + } + } else { + // When verifying a node identity claim, we do not want to swallow the + // initial error. This is because the caller may want to handle the + // error type in the case that the JWT is expired. + claims, err := s.VerifyClaim(authToken) + if err != nil { + return err + } + if !claims.IsNode() { + return structs.ErrPermissionDenied + } + identity.Claims = claims + } + return nil +} + // AuthenticateClientOnly returns an ACL object for use *only* with internal // RPCs originating from clients (including those forwarded). This should never // be used for RPCs that serve HTTP endpoints to avoid confused deputy attacks // by making a request to a client that's forwarded. It should also not be used -// with Node.Register, which should use AuthenticateClientTOFU +// with Node.Register or NodeUpdateStatus, which should use +// AuthenticateNodeIdentityGenerator. // // The returned ACL object is always a acl.ClientACL but in the future this // could be extended to allow clients access only to their own pool and @@ -311,40 +362,70 @@ func (s *Authenticator) AuthenticateClientOnly(ctx RPCContext, args structs.Requ identity := &structs.AuthenticatedIdentity{RemoteIP: remoteIP} defer args.SetIdentity(identity) // always set the identity, even on errors - if s.verifyTLS.Load() && !ctx.IsStatic() { - tlsCert := ctx.Certificate() - if tlsCert == nil { - return nil, errors.New("missing certificate information") - } + if err := verifyTLS(s.verifyTLS.Load(), ctx, s.validClientCertNames, identity); err != nil { + return nil, err + } - // set on the identity whether or not its valid for server RPC, so we - // can capture it for metrics - identity.TLSName = tlsCert.Subject.CommonName - _, err := validateCertificateForNames(tlsCert, s.validClientCertNames) + authToken := args.GetAuthToken() + if authToken == "" { + return nil, structs.ErrPermissionDenied + } + + // If the auth token is a UUID, we treat it as a node secret ID. Otherwise, + // we assume it's a node identity claim. Anything outside these cases is not + // permitted when using this method. + if helper.IsUUID(authToken) { + node, err := s.getState().NodeBySecretID(nil, authToken) + if err != nil { + return nil, fmt.Errorf("could not resolve node secret: %w", err) + } + if node == nil { + return nil, structs.ErrPermissionDenied + } + identity.ClientID = node.ID + } else { + claims, err := s.VerifyClaim(authToken) if err != nil { return nil, err } + if !claims.IsNode() { + return nil, structs.ErrPermissionDenied + } + identity.ClientID = claims.NodeIdentityClaims.NodeID + identity.Claims = claims } - secretID := args.GetAuthToken() - if secretID == "" { - return nil, structs.ErrPermissionDenied - } - - // Otherwise, see if the secret ID belongs to a node. We should - // reach this point only on first connection. - node, err := s.getState().NodeBySecretID(nil, secretID) - if err != nil { - // this is a go-memdb error; shouldn't happen - return nil, fmt.Errorf("could not resolve node secret: %w", err) - } - if node == nil { - return nil, structs.ErrPermissionDenied - } - identity.ClientID = node.ID return acl.ClientACL, nil } +// verifyTLS is a helper function that performs TLS verification, if required, +// given the passed RPCContext and valid names. +// +// It will always set the TLSName on the identity if we are performing +// verification, so callers don't have to worry about setting it themselves. +func verifyTLS(verify bool, ctx RPCContext, validNames []string, identity *structs.AuthenticatedIdentity) error { + + if verify && !ctx.IsStatic() { + + tlsCert := ctx.Certificate() + if tlsCert == nil { + return errors.New("missing certificate information") + } + + // Always set on the identity, even before validating the name, so we + // can capture it for metrics. + identity.TLSName = tlsCert.Subject.CommonName + + // Perform the certificate validation, using the passed valid names. + _, err := validateCertificateForNames(tlsCert, validNames) + if err != nil { + return err + } + } + + return nil +} + // validateCertificateForNames returns true if the certificate is valid for any // of the given domain names. func validateCertificateForNames(cert *x509.Certificate, expectedNames []string) (bool, error) { @@ -432,37 +513,83 @@ func (s *Authenticator) ResolveToken(secretID string) (*acl.ACL, error) { return resolveTokenFromSnapshotCache(snap, s.aclCache, secretID) } -// VerifyClaim asserts that the token is valid and that the resulting allocation -// ID belongs to a non-terminal allocation. This should usually not be called by -// RPC handlers, and exists only to support the ACL.WhoAmI endpoint. +// VerifyClaim asserts that the token is valid. If it is for a workload +// identity, it will ensure that the resulting allocation ID belongs to a +// non-terminal allocation. If the token is for a node identity, it will ensure +// the node ID matches the claim. +// +// This should usually not be called by RPC handlers. func (s *Authenticator) VerifyClaim(token string) (*structs.IdentityClaims, error) { claims, err := s.encrypter.VerifyClaim(token) if err != nil { return nil, err } + + if claims.IsWorkload() { + if err := s.verifyWorkloadIdentityClaim(claims); err != nil { + return nil, err + } + return claims, nil + } + + if claims.IsNode() { + if err := s.verifyNodeIdentityClaim(claims); err != nil { + return nil, err + } + return claims, nil + } + + return nil, errors.New("failed to determine claim type") +} + +func (s *Authenticator) verifyWorkloadIdentityClaim(claims *structs.IdentityClaims) error { snap, err := s.getState().Snapshot() if err != nil { - return nil, err + return err } alloc, err := snap.AllocByID(nil, claims.AllocationID) if err != nil { - return nil, err + return err } if alloc == nil || alloc.Job == nil { - return nil, fmt.Errorf("allocation does not exist") + return fmt.Errorf("allocation does not exist") } // the claims for terminal allocs are always treated as expired if alloc.ClientTerminalStatus() { - return nil, fmt.Errorf("allocation is terminal") + return fmt.Errorf("allocation is terminal") } - return claims, nil + return nil +} + +func (s *Authenticator) verifyNodeIdentityClaim(claims *structs.IdentityClaims) error { + + snap, err := s.getState().Snapshot() + if err != nil { + return err + } + node, err := snap.NodeByID(nil, claims.NodeIdentityClaims.NodeID) + if err != nil { + return err + } + if node == nil { + return errors.New("node does not exist") + } + + return nil } func (s *Authenticator) resolveClaims(claims *structs.IdentityClaims) (*acl.ACL, error) { + // Nomad node identity claims currently map to a client ACL. If we open this + // up in the future, we will want to modify this section to perform similar + // work that is done for workload claims. + if claims.IsNode() { + return acl.ClientACL, nil + } + policies, err := s.ResolvePoliciesForClaims(claims) if err != nil { return nil, err diff --git a/nomad/auth/auth_test.go b/nomad/auth/auth_test.go index e56f50e62..b3059f6b7 100644 --- a/nomad/auth/auth_test.go +++ b/nomad/auth/auth_test.go @@ -359,6 +359,50 @@ func TestAuthenticateDefault(t *testing.T) { must.True(t, aclObj.IsManagement()) }, }, + { + name: "mTLS and ACLs with node identity", + testFn: func(t *testing.T, store *state.StateStore) { + + node := mock.Node() + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims(node, "global", 1*time.Hour) + + auth := testAuthenticator(t, store, true, true) + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + args := &structs.GenericRequest{} + args.AuthToken = token + var ctx *testContext + + must.NoError(t, auth.Authenticate(ctx, args)) + must.Eq(t, "client:"+node.ID, args.GetIdentity().String()) + + aclObj, err := auth.ResolveACL(args) + must.NoError(t, err) + must.Eq(t, acl.ClientACL, aclObj) + }, + }, + { + name: "mTLS and ACLs with invalid node identity", + testFn: func(t *testing.T, store *state.StateStore) { + + node := mock.Node() + + claims := structs.GenerateNodeIdentityClaims(node, "global", 1*time.Hour) + + auth := testAuthenticator(t, store, true, true) + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + args := &structs.GenericRequest{} + args.AuthToken = token + var ctx *testContext + + must.ErrorContains(t, auth.Authenticate(ctx, args), "node does not exist") + }, + }, } for _, tc := range testCases { @@ -464,6 +508,253 @@ func TestAuthenticateServerOnly(t *testing.T) { } } +func TestAuthenticator_AuthenticateClientRegistration(t *testing.T) { + ci.Parallel(t) + + testAuthenticator := func( + t *testing.T, + store *state.StateStore, + hasACLs, + verifyTLS bool, + ) *Authenticator { + + leaderACL := uuid.Generate() + + return NewAuthenticator(&AuthenticatorConfig{ + StateFn: func() *state.StateStore { return store }, + Logger: testlog.HCLogger(t), + GetLeaderACLFn: func() string { return leaderACL }, + AclsEnabled: hasACLs, + VerifyTLS: verifyTLS, + Region: "global", + Encrypter: newTestEncrypter(), + }) + } + + testCases := []struct { + name string + testFn func(*testing.T, *state.StateStore) + }{ + { + name: "incorrect mTLS", + testFn: func(t *testing.T, store *state.StateStore) { + ctx := newTestContext(t, "pony.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{} + + auth := testAuthenticator(t, store, false, true) + must.ErrorContains(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args), "invalid certificate") + }, + }, + { + name: "client mTLS with no auth", + testFn: func(t *testing.T, store *state.StateStore) { + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{} + + auth := testAuthenticator(t, store, false, true) + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowClientOp()) + }, + }, + { + name: "no mTLS no acl with no auth", + testFn: func(t *testing.T, store *state.StateStore) { + ctx := newTestContext(t, noTLSCtx, "192.168.1.1") + + args := structs.GenericRequest{} + + auth := testAuthenticator(t, store, false, false) + must.Nil(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.False(t, aclObj.AllowServerOp()) + must.False(t, aclObj.AllowServerOp()) + }, + }, + { + name: "no mTLS acl with no auth", + testFn: func(t *testing.T, store *state.StateStore) { + ctx := newTestContext(t, noTLSCtx, "192.168.1.1") + + args := structs.GenericRequest{} + + auth := testAuthenticator(t, store, true, false) + must.Nil(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.False(t, aclObj.AllowServerOp()) + must.False(t, aclObj.AllowServerOp()) + }, + }, + { + name: "no mTLS no acl with server leader token auth", + testFn: func(t *testing.T, store *state.StateStore) { + + auth := testAuthenticator(t, store, false, false) + + ctx := newTestContext(t, noTLSCtx, "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: auth.getLeaderACL(), + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowServerOp() || aclObj.AllowClientOp()) + }, + }, + { + name: "mTLS acl with server leader token auth", + testFn: func(t *testing.T, store *state.StateStore) { + + auth := testAuthenticator(t, store, true, true) + + ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: auth.getLeaderACL(), + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowServerOp()) + }, + }, + { + name: "mTLS no acl with server leader token auth", + testFn: func(t *testing.T, store *state.StateStore) { + + auth := testAuthenticator(t, store, false, true) + + ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: auth.getLeaderACL(), + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowClientOp()) + }, + }, + { + name: "mTLS no acl with node secret token auth", + testFn: func(t *testing.T, store *state.StateStore) { + + node := mock.Node() + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + auth := testAuthenticator(t, store, false, true) + + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: node.SecretID, + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowClientOp()) + }, + }, + { + name: "mTLS acl with node secret token auth", + testFn: func(t *testing.T, store *state.StateStore) { + + node := mock.Node() + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + auth := testAuthenticator(t, store, true, true) + + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: node.SecretID, + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowClientOp()) + }, + }, + { + name: "mTLS acl with bad node secret token auth", + testFn: func(t *testing.T, store *state.StateStore) { + + node := mock.Node() + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + auth := testAuthenticator(t, store, true, true) + + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: node.ID, + }, + } + + must.ErrorContains(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args), "Permission denied") + }, + }, + { + name: "mTLS acl with node identity", + testFn: func(t *testing.T, store *state.StateStore) { + + node := mock.Node() + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims(node, "global", 1*time.Hour) + + auth := testAuthenticator(t, store, true, true) + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: token, + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + aclObj, err := auth.ResolveACL(&args) + must.NoError(t, err) + must.True(t, aclObj.AllowClientOp()) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.testFn(t, testStateStore(t)) + }) + } +} + func TestAuthenticateClientOnly(t *testing.T) { ci.Parallel(t) @@ -478,7 +769,7 @@ func TestAuthenticateClientOnly(t *testing.T) { AclsEnabled: hasACLs, VerifyTLS: verifyTLS, Region: "global", - Encrypter: nil, + Encrypter: newTestEncrypter(), }) } @@ -487,7 +778,7 @@ func TestAuthenticateClientOnly(t *testing.T) { testFn func(*testing.T, *state.StateStore, *structs.Node) }{ { - name: "no mTLS or ACLs but no node secret", + name: "no mTLS or ACLs but no auth token", testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { ctx := newTestContext(t, noTLSCtx, "192.168.1.1") args := &structs.GenericRequest{} @@ -535,7 +826,7 @@ func TestAuthenticateClientOnly(t *testing.T) { }, }, { - name: "no mTLS but with ACLs and bad secret", + name: "no mTLS but with ACLs and bad auth token", testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { ctx := newTestContext(t, noTLSCtx, "192.168.1.1") args := &structs.GenericRequest{} @@ -567,7 +858,7 @@ func TestAuthenticateClientOnly(t *testing.T) { }, }, { - name: "with mTLS and ACLs with server cert but bad token", + name: "with mTLS and ACLs with server cert but bad auth token", testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") args := &structs.GenericRequest{} @@ -583,7 +874,7 @@ func TestAuthenticateClientOnly(t *testing.T) { }, }, { - name: "with mTLS and ACLs with server cert and valid token", + name: "with mTLS and ACLs with server cert and valid secret ID token", testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") args := &structs.GenericRequest{} @@ -615,13 +906,82 @@ func TestAuthenticateClientOnly(t *testing.T) { must.True(t, aclObj.AllowClientOp()) }, }, + { + name: "with mTLS and ACLs with client cert and valid node identity", + testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + auth := testAuthenticator(t, store, true, true) + + claims := structs.GenerateNodeIdentityClaims(node, "global", 1*time.Hour) + + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + args := &structs.GenericRequest{} + args.AuthToken = token + + aclObj, err := auth.AuthenticateClientOnly(ctx, args) + must.NoError(t, err) + + must.Eq(t, "client:"+node.ID, args.GetIdentity().String()) + must.NotNil(t, aclObj) + must.True(t, aclObj.AllowClientOp()) + }, + }, + { + name: "with mTLS and ACLs with server cert and valid node identity", + testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { + ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") + + auth := testAuthenticator(t, store, true, true) + + claims := structs.GenerateNodeIdentityClaims(node, "global", 1*time.Hour) + + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + args := &structs.GenericRequest{} + args.AuthToken = token + + aclObj, err := auth.AuthenticateClientOnly(ctx, args) + must.NoError(t, err) + + must.Eq(t, "client:"+node.ID, args.GetIdentity().String()) + must.NotNil(t, aclObj) + must.True(t, aclObj.AllowClientOp()) + }, + }, + { + name: "with mTLS and ACLs with server cert and invalid node identity", + testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { + ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") + + auth := testAuthenticator(t, store, true, true) + + copiedNode := node.Copy() + copiedNode.ID = uuid.Generate() + + claims := structs.GenerateNodeIdentityClaims(copiedNode, "global", 1*time.Hour) + + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + args := &structs.GenericRequest{} + args.AuthToken = token + + aclObj, err := auth.AuthenticateClientOnly(ctx, args) + must.Error(t, err) + must.Nil(t, aclObj) + }, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { node := mock.Node() store := testStateStore(t) - store.UpsertNode(structs.MsgTypeTestSetup, 100, node) + must.NoError(t, store.UpsertNode(structs.MsgTypeTestSetup, 100, node)) tc.testFn(t, store, node) }) } @@ -1178,6 +1538,45 @@ func TestResolveClaims(t *testing.T) { } +func TestAuthenticator_verifyNodeIdentityClaim(t *testing.T) { + ci.Parallel(t) + + // Create our base test objects including a node that can be used in the + // tests. + testAuthenticator := testDefaultAuthenticator(t) + + mockNode := mock.Node() + must.NoError(t, testAuthenticator.getState().UpsertNode(structs.MsgTypeTestSetup, 100, mockNode)) + + testCases := []struct { + name string + inputClaims *structs.IdentityClaims + expectedOutput error + }{ + { + name: "node does not exist", + inputClaims: structs.GenerateNodeIdentityClaims(mock.Node(), "global", 1*time.Hour), + expectedOutput: errors.New("node does not exist"), + }, + { + name: "verified node claims", + inputClaims: structs.GenerateNodeIdentityClaims(mockNode, "global", 1*time.Hour), + expectedOutput: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := testAuthenticator.verifyNodeIdentityClaim(tc.inputClaims) + if tc.expectedOutput == nil { + must.NoError(t, actualOutput) + } else { + must.EqError(t, actualOutput, tc.expectedOutput.Error()) + } + }) + } +} + func testStateStore(t *testing.T) *state.StateStore { sconfig := &state.StateStoreConfig{ Logger: testlog.HCLogger(t), diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index aceb5fcfb..b88c20109 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -542,12 +542,15 @@ func (ai *AuthenticatedIdentity) String() string { if ai.ACLToken != nil && ai.ACLToken != AnonymousACLToken { return "token:" + ai.ACLToken.AccessorID } - if ai.Claims != nil { + if ai.Claims != nil && ai.Claims.IsWorkload() { return "alloc:" + ai.Claims.AllocationID } if ai.ClientID != "" { return "client:" + ai.ClientID } + if ai.Claims != nil && ai.Claims.IsNode() { + return "client:" + ai.Claims.NodeID + } return ai.TLSName + ":" + ai.RemoteIP.String() } From d5b2d5078b120b8ace704edece02ed916c204a27 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Tue, 1 Jul 2025 17:07:21 +0200 Subject: [PATCH 4/7] rpc: Generate node identities with node RPC handlers when needed. (#26165) When a Nomad client register or re-registers, the RPC handler will generate and return a node identity if required. When an identity is generated, the signing key ID will be stored within the node object, to ensure a root key is not deleted until it is not used. During normal client operation it will periodically heartbeat to the Nomad servers to indicate aliveness. The RPC handler that is used for this action has also been updated to conditionally perform identity generation. Performing it here means no extra RPC handlers are required and we inherit the jitter in identity generation from the heartbeat mechanism. The identity generation check methods are performed from the RPC request arguments, so they a scoped to the required behaviour and can handle the nuance of each RPC. Failure to generate an identity is considered terminal to the RPC call. The client will include behaviour to retry this error which is always caused by the encrypter not being ready unless the servers keyring has been corrupted. --- client/client_test.go | 1 + client/drain_test.go | 2 + command/acl_bootstrap_test.go | 6 +- command/agent/agent_test.go | 2 +- nomad/acl.go | 4 + nomad/auth/auth.go | 5 +- nomad/client_agent_endpoint_test.go | 5 +- nomad/client_alloc_endpoint_test.go | 1 + nomad/client_csi_endpoint_test.go | 1 + nomad/client_stats_endpoint_test.go | 1 + nomad/csi_endpoint_test.go | 11 +- nomad/drainer_int_test.go | 8 + nomad/encrypter.go | 11 +- nomad/eval_broker_test.go | 1 + nomad/host_volume_endpoint_test.go | 3 + nomad/node_endpoint.go | 148 +++++++- nomad/node_endpoint_test.go | 551 +++++++++++++++++++++++++++- nomad/rpc_test.go | 8 +- nomad/server_test.go | 1 + nomad/structs/identity.go | 12 +- nomad/structs/identity_test.go | 51 +++ nomad/structs/node.go | 177 +++++++++ nomad/structs/node_test.go | 370 +++++++++++++++++++ nomad/structs/structs.go | 63 ---- nomad/worker_test.go | 4 + 25 files changed, 1344 insertions(+), 103 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index fc6fb60e8..d3dafd194 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -703,6 +703,7 @@ func TestClient_SaveRestoreState(t *testing.T) { s1, _, cleanupS1 := testServer(t, nil) t.Cleanup(cleanupS1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.DevMode = false diff --git a/client/drain_test.go b/client/drain_test.go index d67a6219a..a995baf4c 100644 --- a/client/drain_test.go +++ b/client/drain_test.go @@ -29,6 +29,7 @@ func TestClient_SelfDrainConfig(t *testing.T) { srv, _, cleanupSRV := testServer(t, nil) defer cleanupSRV() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.RPCHandler = srv @@ -81,6 +82,7 @@ func TestClient_SelfDrain_FailLocal(t *testing.T) { srv, _, cleanupSRV := testServer(t, nil) defer cleanupSRV() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.RPCHandler = srv diff --git a/command/acl_bootstrap_test.go b/command/acl_bootstrap_test.go index 78c5fc566..14355c18d 100644 --- a/command/acl_bootstrap_test.go +++ b/command/acl_bootstrap_test.go @@ -23,7 +23,7 @@ func TestACLBootstrapCommand(t *testing.T) { c.ACL.PolicyTTL = 0 } - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) @@ -101,7 +101,7 @@ func TestACLBootstrapCommand_WithOperatorFileBootstrapToken(t *testing.T) { err := os.WriteFile(file, []byte(mockToken.SecretID), 0700) must.NoError(t, err) - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) @@ -139,7 +139,7 @@ func TestACLBootstrapCommand_WithBadOperatorFileBootstrapToken(t *testing.T) { err := os.WriteFile(file, []byte(invalidToken), 0700) must.NoError(t, err) - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 4b77df344..cbd7a076d 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -1120,7 +1120,7 @@ func TestServer_Reload_TLS_Shared_Keyloader(t *testing.T) { TLSConfig: &config.TLSConfig{ EnableHTTP: true, EnableRPC: true, - VerifyServerHostname: true, + VerifyServerHostname: false, CAFile: foocafile, CertFile: fooclientcert, KeyFile: fooclientkey, diff --git a/nomad/acl.go b/nomad/acl.go index 78cfc052c..1b77dc565 100644 --- a/nomad/acl.go +++ b/nomad/acl.go @@ -16,6 +16,10 @@ func (s *Server) AuthenticateServerOnly(ctx *RPCContext, args structs.RequestWit return s.auth.AuthenticateServerOnly(ctx, args) } +func (s *Server) AuthenticateNodeIdentityGenerator(ctx *RPCContext, args structs.RequestWithIdentity) error { + return s.auth.AuthenticateNodeIdentityGenerator(ctx, args) +} + func (s *Server) AuthenticateClientOnly(ctx *RPCContext, args structs.RequestWithIdentity) (*acl.ACL, error) { return s.auth.AuthenticateClientOnly(ctx, args) } diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index 1e412961c..afbd5ddfb 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -217,10 +217,11 @@ func (s *Authenticator) Authenticate(ctx RPCContext, args structs.RequestWithIde return nil } -// ResolveACL is an authentication wrapper which handles resolving ACL tokens, +// ResolveACL is an authentication wrapper that handles resolving ACL tokens, // Workload Identities, or client secrets into acl.ACL objects. Exclusively // server-to-server or client-to-server requests should be using -// AuthenticateServerOnly or AuthenticateClientOnly and never use this method. +// AuthenticateServerOnly or AuthenticateClientOnly unless they use the +// AuthenticateNodeIdentityGenerator function. func (s *Authenticator) ResolveACL(args structs.RequestWithIdentity) (*acl.ACL, error) { identity := args.GetIdentity() if identity == nil { diff --git a/nomad/client_agent_endpoint_test.go b/nomad/client_agent_endpoint_test.go index 3dcaa9ef7..44e752010 100644 --- a/nomad/client_agent_endpoint_test.go +++ b/nomad/client_agent_endpoint_test.go @@ -854,7 +854,10 @@ func TestAgentHost_Server(t *testing.T) { } c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s2.GetConfig().RPCAddr.String()} + c.Servers = []string{ + s1.GetConfig().RPCAddr.String(), + s2.GetConfig().RPCAddr.String(), + } c.EnableDebug = true }) defer cleanupC() diff --git a/nomad/client_alloc_endpoint_test.go b/nomad/client_alloc_endpoint_test.go index 48a5185bf..20a05679d 100644 --- a/nomad/client_alloc_endpoint_test.go +++ b/nomad/client_alloc_endpoint_test.go @@ -38,6 +38,7 @@ func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.Region()) c, cleanupC := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s.config.RPCAddr.String()} diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index d2d127584..241379f89 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -474,6 +474,7 @@ func setupLocal(t *testing.T) rpc.ClientCodec { t.Cleanup(cleanupS1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) codec := rpcClient(t, s1) mockCSI := newMockClientCSI() diff --git a/nomad/client_stats_endpoint_test.go b/nomad/client_stats_endpoint_test.go index 55c439da5..6b8d01d36 100644 --- a/nomad/client_stats_endpoint_test.go +++ b/nomad/client_stats_endpoint_test.go @@ -29,6 +29,7 @@ func TestClientStats_Stats_Local(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.Region()) c, cleanupC := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s.config.RPCAddr.String()} diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 2d71a0f46..744dd38c2 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -1136,12 +1136,13 @@ func TestCSIVolumeEndpoint_List_PaginationFiltering(t *testing.T) { func TestCSIVolumeEndpoint_Create(t *testing.T) { ci.Parallel(t) var err error - srv, rootToken, shutdown := TestACLServer(t, func(c *Config) { + srv, _, shutdown := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) fake := newMockClientCSI() fake.NextValidateError = nil @@ -1158,6 +1159,7 @@ func TestCSIVolumeEndpoint_Create(t *testing.T) { client, cleanup := client.TestClientWithRPCs(t, func(c *cconfig.Config) { c.Servers = []string{srv.config.RPCAddr.String()} + c.TLSConfig = srv.config.TLSConfig }, map[string]interface{}{"CSI": fake}, ) @@ -1169,8 +1171,11 @@ func TestCSIVolumeEndpoint_Create(t *testing.T) { }).Node req0 := &structs.NodeRegisterRequest{ - Node: node, - WriteRequest: structs.WriteRequest{Region: "global", AuthToken: rootToken.SecretID}, + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: node.SecretID, + }, } var resp0 structs.NodeUpdateResponse err = client.RPC("Node.Register", req0, &resp0) diff --git a/nomad/drainer_int_test.go b/nomad/drainer_int_test.go index 02f4e3142..b4c7d4507 100644 --- a/nomad/drainer_int_test.go +++ b/nomad/drainer_int_test.go @@ -149,6 +149,7 @@ func TestDrainer_Simple_ServiceOnly(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -220,6 +221,7 @@ func TestDrainer_Simple_ServiceOnly_Deadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -277,6 +279,7 @@ func TestDrainer_DrainEmptyNode(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create an empty node @@ -312,6 +315,7 @@ func TestDrainer_AllTypes_Deadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -420,6 +424,7 @@ func TestDrainer_AllTypes_NoDeadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create two nodes, registering the second later @@ -551,6 +556,7 @@ func TestDrainer_AllTypes_Deadline_GarbageCollectedNode(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -668,6 +674,7 @@ func TestDrainer_MultipleNSes_ServiceOnly(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -762,6 +769,7 @@ func TestDrainer_Batch_TransitionToForce(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node diff --git a/nomad/encrypter.go b/nomad/encrypter.go index ab580031b..fa1330c5b 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -303,11 +303,12 @@ func (e *Encrypter) Decrypt(ciphertext []byte, keyID string) ([]byte, error) { // header name. const keyIDHeader = "kid" -// SignClaims signs the identity claim for the task and returns an encoded JWT -// (including both the claim and its signature) and the key ID of the key used -// to sign it, or an error. +// SignClaims signs the identity claim and returns an encoded JWT (including +// both the claim and its signature) and the key ID of the key used to sign it, +// or an error. // -// SignClaims adds the Issuer claim prior to signing. +// SignClaims adds the Issuer claim prior to signing if it is unset by the +// caller. func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, error) { if claims == nil { @@ -324,7 +325,7 @@ func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, claims.Issuer = e.issuer } - opts := (&jose.SignerOptions{}).WithHeader("kid", cs.rootKey.Meta.KeyID).WithType("JWT") + opts := (&jose.SignerOptions{}).WithHeader(keyIDHeader, cs.rootKey.Meta.KeyID).WithType("JWT") var sig jose.Signer if cs.rsaPrivateKey != nil { diff --git a/nomad/eval_broker_test.go b/nomad/eval_broker_test.go index d44a189b9..df117310d 100644 --- a/nomad/eval_broker_test.go +++ b/nomad/eval_broker_test.go @@ -1535,6 +1535,7 @@ func TestEvalBroker_IntegrationTest(t *testing.T) { defer cleanupS1() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) codec := rpcClient(t, srv) store := srv.fsm.State() diff --git a/nomad/host_volume_endpoint_test.go b/nomad/host_volume_endpoint_test.go index 6a46c83b9..7433ff5aa 100644 --- a/nomad/host_volume_endpoint_test.go +++ b/nomad/host_volume_endpoint_test.go @@ -38,6 +38,7 @@ func TestHostVolumeEndpoint_CreateRegisterGetDelete(t *testing.T) { }) t.Cleanup(cleanupSrv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) store := srv.fsm.State() c1, node1 := newMockHostVolumeClient(t, srv, "prod") @@ -434,6 +435,7 @@ func TestHostVolumeEndpoint_List(t *testing.T) { }) t.Cleanup(cleanupSrv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) store := srv.fsm.State() codec := rpcClient(t, srv) @@ -809,6 +811,7 @@ func TestHostVolumeEndpoint_concurrency(t *testing.T) { srv, cleanup := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) t.Cleanup(cleanup) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) c, node := newMockHostVolumeClient(t, srv, "default") diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index e0743379e..c04b1db00 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -4,12 +4,14 @@ package nomad import ( + "errors" "fmt" "net/http" "reflect" "sync" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" metrics "github.com/hashicorp/go-metrics/compat" @@ -91,9 +93,10 @@ func NewNodeEndpoint(srv *Server, ctx *RPCContext) *Node { // Register is used to upsert a client that is available for scheduling func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUpdateResponse) error { - // note that we trust-on-first use and the identity will be anonymous for - // that initial request; we lean on mTLS for handling that safely - authErr := n.srv.Authenticate(n.ctx, args) + + // The node register RPC is responsible for generating node identities, so + // we use the custom authentication method shared with UpdateStatus. + authErr := n.srv.AuthenticateNodeIdentityGenerator(n.ctx, args) isForwarded := args.IsForwarded() if done, err := n.srv.forward("Node.Register", args, args, reply); done { @@ -108,7 +111,15 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp return err } n.srv.MeasureRPCRate("node", structs.RateMetricWrite, args) - if authErr != nil { + + // The authentication error can be because the identity is expired. If we + // stopped the handler execution here, the node would never be able to + // register after being disconnected. + // + // Further within the RPC we check the supplied SecretID against the stored + // value in state. This acts as a secondary check and can be seen as a + // refresh token, in the event the identity is expired. + if authErr != nil && !errors.Is(authErr, jwt.ErrExpired) { return structs.ErrPermissionDenied } @@ -161,8 +172,13 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp args.Node.NodePool = structs.NodePoolDefault } + // The current time is used at a number of places in the registration + // workflow. Generating it once avoids multiple calls to time.Now() and also + // means the same time is used across all checks and sets. + timeNow := time.Now() + // Set the timestamp when the node is registered - args.Node.StatusUpdatedAt = time.Now().Unix() + args.Node.StatusUpdatedAt = timeNow.Unix() // Compute the node class if err := args.Node.ComputeClass(); err != nil { @@ -214,6 +230,40 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp if n.srv.Region() == n.srv.config.AuthoritativeRegion { args.CreateNodePool = true } + + // Track the TTL that will be used for the node identity. + var identityTTL time.Duration + + // The identity TTL is determined by the node pool the node is registered + // in. In the event the node registration is triggering creation of a new + // node pool, it will be created with the default TTL, so we use this for + // the identity. + nodePool, err := snap.NodePoolByName(ws, args.Node.NodePool) + if err != nil { + return fmt.Errorf("failed to query node pool: %v", err) + } + if nodePool == nil { + identityTTL = structs.DefaultNodePoolNodeIdentityTTL + } else { + identityTTL = nodePool.NodeIdentityTTL + } + + // Check if we need to generate a node identity. This must happen before we + // send the Raft message, as the signing key ID is set on the node if we + // generate one. + if args.ShouldGenerateNodeIdentity(authErr, timeNow.UTC(), identityTTL) { + + claims := structs.GenerateNodeIdentityClaims(args.Node, n.srv.Region(), identityTTL) + + signedJWT, signingKeyID, err := n.srv.encrypter.SignClaims(claims) + if err != nil { + return fmt.Errorf("failed to sign node identity claims: %v", err) + } + + reply.SignedIdentity = &signedJWT + args.Node.IdentitySigningKeyID = signingKeyID + } + _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) if err != nil { n.logger.Error("register failed", "error", err) @@ -509,9 +559,13 @@ func (n *Node) deregister(args *structs.NodeBatchDeregisterRequest, // │ │ // └──── ready ─────┘ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *structs.NodeUpdateResponse) error { - // UpdateStatus receives requests from client and servers that mark failed - // heartbeats, so we can't use AuthenticateClientOnly - authErr := n.srv.Authenticate(n.ctx, args) + + // The node update status RPC is responsible for generating node identities, + // so we use the custom authentication method shared with Register. + // + // Note; UpdateStatus receives requests from clients and servers that mark + // failed heartbeats. + authErr := n.srv.AuthenticateNodeIdentityGenerator(n.ctx, args) isForwarded := args.IsForwarded() if done, err := n.srv.forward("Node.UpdateStatus", args, args, reply); done { @@ -573,14 +627,62 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // to track SecretIDs. // Update the timestamp of when the node status was updated - args.UpdatedAt = time.Now().Unix() + timeNow := time.Now() + args.UpdatedAt = timeNow.Unix() + + // Track the TTL that will be used for the node identity. + var identityTTL time.Duration + + // The identity TTL is determined by the node pool the node is registered + // in. The pool should already exist, as the node is already registered. If + // it does not, we use the default TTL as we have no better value to use. + // + // Once the node pool is created, the node's identity will have the TTL set + // by the node pool on its renewal. + nodePool, err := snap.NodePoolByName(ws, node.NodePool) + if err != nil { + return fmt.Errorf("failed to query node pool: %v", err) + } + if nodePool == nil { + identityTTL = structs.DefaultNodePoolNodeIdentityTTL + } else { + identityTTL = nodePool.NodeIdentityTTL + } + + // Check and generate a node identity if needed. + if args.ShouldGenerateNodeIdentity(timeNow.UTC(), identityTTL) { + + claims := structs.GenerateNodeIdentityClaims(node, n.srv.Region(), identityTTL) + + // Sign the claims with the encrypter and conditionally handle the + // error. The IdentitySigningErrorTerminal method has a description of + // why we do this. + signedJWT, signingKeyID, err := n.srv.encrypter.SignClaims(claims) + if err != nil { + if args.IdentitySigningErrorIsTerminal(timeNow) { + return fmt.Errorf("failed to sign node identity claims: %v", err) + } else { + n.logger.Warn( + "failed to sign node identity claims, will retry on next heartbeat", + "error", err, "node_id", node.ID) + } + } + + reply.SignedIdentity = &signedJWT + args.IdentitySigningKeyID = signingKeyID + } else { + // Ensure the IdentitySigningKeyID is cleared if we are not generating a + // new identity. This is important to ensure that we do not cause Raft + // updates unless we need to. + args.IdentitySigningKeyID = "" + } // Compute next status. switch node.Status { case structs.NodeStatusInit: if args.Status == structs.NodeStatusReady { - // Keep node in the initializing status if it has allocations but - // they are not updated. + // Keep the node in the initializing status if it has allocations, + // but they are not updated. allocs, err := snap.AllocsByNodeTerminal(ws, args.NodeID, false) if err != nil { return fmt.Errorf("failed to query node allocs: %v", err) @@ -592,13 +694,9 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct args.Status = structs.NodeStatusInit } - // Keep node in the initialing status if it's in a node pool that - // doesn't exist. - pool, err := snap.NodePoolByName(ws, node.NodePool) - if err != nil { - return fmt.Errorf("failed to query node pool: %v", err) - } - if pool == nil { + // Keep the node in the initialing status if it's in a node pool + // that doesn't exist. + if nodePool == nil { n.logger.Debug(fmt.Sprintf("marking node as %s due to missing node pool", structs.NodeStatusInit)) args.Status = structs.NodeStatusInit if !node.HasEvent(NodeWaitingForNodePool) { @@ -617,7 +715,19 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // Commit this update via Raft var index uint64 - if node.Status != args.Status || args.NodeEvent != nil { + + // Only perform a Raft apply if we really have to, so we avoid unnecessary + // cluster traffic and CPU load. + // + // We must update state if: + // - The node informed us of a new status. + // - The node informed us of a new event. + // - We have generated an identity which has been signed with a different + // key ID compared to the last identity generated for the node. + if node.Status != args.Status || + args.NodeEvent != nil || + node.IdentitySigningKeyID != args.IdentitySigningKeyID && args.IdentitySigningKeyID != "" { + // Attach an event if we are updating the node status to ready when it // is down via a heartbeat if node.Status == structs.NodeStatusDown && args.NodeEvent == nil { diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 0fae446de..92ad8b6a6 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/acl" @@ -37,6 +38,7 @@ func TestClientEndpoint_Register(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -89,6 +91,267 @@ func TestClientEndpoint_Register(t *testing.T) { }) } +func TestNode_Register_Identity(t *testing.T) { + ci.Parallel(t) + + // This helper function verifies the identity token generated by the server + // in the Node.Register RPC call. + verifyIdentityFn := func( + t *testing.T, + testServer *Server, + token string, + node *structs.Node, + ttl time.Duration, + ) { + t.Helper() + + identityClaims, err := testServer.encrypter.VerifyClaim(token) + must.NoError(t, err) + + must.Eq(t, ttl, identityClaims.Expiry.Time().Sub(identityClaims.NotBefore.Time())) + must.True(t, identityClaims.IsNode()) + must.Eq(t, identityClaims.NodeIdentityClaims, &structs.NodeIdentityClaims{ + NodeID: node.ID, + NodeDatacenter: node.Datacenter, + NodeClass: node.NodeClass, + NodePool: node.NodePool, + }) + + // Identify the active encrypter key ID, which would have been used to + // sign the identity token. + _, keyID, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Perform a lookup of the node in state. The IdentitySigningKeyID field + // should be populated with the active encrypter key ID. + stateNodeResp, err := testServer.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNodeResp) + must.Eq(t, keyID, stateNodeResp.IdentitySigningKeyID) + } + + testCases := []struct { + name string + testFn func(t *testing.T, srv *Server, codec rpc.ClientCodec) + }{ + { + // Test the initial registration flow, where a node will not include + // an authentication token in the request. + // + // A later registration will not generate a new identity, as the + // included identity is still valid. + name: "identity generation and node reregister", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + + req.WriteRequest.AuthToken = *resp.SignedIdentity + var resp2 structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp2)) + must.Nil(t, resp2.SignedIdentity) + }, + }, + { + // A node can register with a node pool that does not exist, and the + // server will create it on FSM write. In this case, the server + // should generate an identity with the default node pool identity + // TTL. + name: "create on register node pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + node.NodePool = "custom-pool" + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // A node can register with a node pool that exists, and the server + // will generate an identity with the node pool's identity TTL. + name: "non-default identity ttl", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools(structs.MsgTypeTestSetup, 1000, []*structs.NodePool{nodePool})) + + node := mock.Node() + node.NodePool = nodePool.Name + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, nodePool.NodeIdentityTTL) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration. + name: "identity close to expiration", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + timeNow := time.Now().UTC().Add(-20 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // A node could disconnect from the cluster for long enough for the + // identity to expire. When it reconnects and performs its + // reregistration, the server should generate a new identity. + name: "identity expired", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(-1 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure that if the node's SecretID is tampered with, the server + // rejects any attempt to register. This test is to gate against a + // potential regressions in how we handle identities within this + // RPC. + name: "identity expired secret ID tampered", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node.Copy())) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(-1 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + node.SecretID = uuid.Generate() + + req := structs.NodeRegisterRequest{ + Node: node.Copy(), + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.ErrorContains( + t, + msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp), + "node secret ID does not match", + ) + }, + }, + } + + // ACL enabled server test run. + testACLServer, _, aclServerCleanup := TestACLServer(t, func(c *Config) {}) + defer aclServerCleanup() + testACLCodec := rpcClient(t, testACLServer) + + testutil.WaitForLeader(t, testACLServer.RPC) + testutil.WaitForKeyring(t, testACLServer.RPC, testACLServer.config.Region) + + // ACL disabled server test run. + testServer, serverCleanup := TestServer(t, func(c *Config) {}) + defer serverCleanup() + testCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + for _, tc := range testCases { + t.Run("ACL_enabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testACLServer, testACLCodec) + }) + t.Run("ACL_disabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testServer, testCodec) + }) + } +} + // This test asserts that we only track node connections if they are not from // forwarded RPCs. This is essential otherwise we will think a Yamux session to // a Nomad server is actually the session to the node. @@ -106,8 +369,8 @@ func TestClientEndpoint_Register_NodeConn_Forwarded(t *testing.T) { }) defer cleanupS2() TestJoin(t, s1, s2) - testutil.WaitForLeader(t, s1.RPC) - testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForLeaders(t, s1.RPC, s2.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Determine the non-leader server var leader, nonLeader *Server @@ -190,6 +453,7 @@ func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -219,6 +483,7 @@ func TestClientEndpoint_Register_NodePool(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.config.Region) testCases := []struct { name string @@ -328,6 +593,7 @@ func TestClientEndpoint_Register_NodePool_Multiregion(t *testing.T) { defer cleanupS1() codec1 := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) s2, _, cleanupS2 := TestACLServer(t, func(c *Config) { c.Region = "region-2" @@ -340,6 +606,7 @@ func TestClientEndpoint_Register_NodePool_Multiregion(t *testing.T) { defer cleanupS2() codec2 := rpcClient(t, s2) testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForKeyring(t, s2.RPC, s2.config.Region) // Verify that registering a node with a new node pool in the authoritative // region creates the node pool. @@ -504,6 +771,7 @@ func TestClientEndpoint_DeregisterOne(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -617,6 +885,7 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -721,6 +990,7 @@ func TestClientEndpoint_UpdateStatus_Reconnect(t *testing.T) { codec := rpcClient(t, s) defer cleanupS() testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.config.Region) // Register node. node := mock.Node() @@ -914,6 +1184,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatRecovery(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -964,6 +1235,7 @@ func TestClientEndpoint_Register_GetEvals(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register a system job. job := mock.SystemJob() @@ -1055,6 +1327,7 @@ func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register a system job. job := mock.SystemJob() @@ -1163,6 +1436,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly(t *testing.T) { codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1224,6 +1498,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly_Advertise(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1255,6 +1530,7 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { testServer, serverCleanup := TestServer(t, nil) defer serverCleanup() testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) // Create a node and upsert this into state. node := mock.Node() @@ -1304,6 +1580,256 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { must.NoError(t, nodeEndpoint.UpdateStatus(&args, &reply)) } +func TestNode_UpdateStatus_Identity(t *testing.T) { + ci.Parallel(t) + + // This helper function verifies the identity token generated by the server + // in the Node.UpdateStatus RPC call. + verifyIdentityFn := func( + t *testing.T, + testServer *Server, + token string, + node *structs.Node, + ttl time.Duration, + ) { + t.Helper() + + identityClaims, err := testServer.encrypter.VerifyClaim(token) + must.NoError(t, err) + + must.Eq(t, ttl, identityClaims.Expiry.Time().Sub(identityClaims.NotBefore.Time())) + must.True(t, identityClaims.IsNode()) + must.Eq(t, identityClaims.NodeIdentityClaims, &structs.NodeIdentityClaims{ + NodeID: node.ID, + NodeDatacenter: node.Datacenter, + NodeClass: node.NodeClass, + NodePool: node.NodePool, + }) + + // Identify the active encrypter key ID, which would have been used to + // sign the identity token. + _, keyID, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Perform a lookup of the node in state. The IdentitySigningKeyID field + // should be populated with the active encrypter key ID. + stateNodeResp, err := testServer.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNodeResp) + must.Eq(t, keyID, stateNodeResp.IdentitySigningKeyID) + } + + testCases := []struct { + name string + testFn func(t *testing.T, srv *Server, codec rpc.ClientCodec) + }{ + { + // Ensure that the Node.UpdateStatus RPC generates a new identity + // for a client authenticating using its secret ID. + name: "node secret ID authenticated default pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.Eq(t, "", node.IdentitySigningKeyID) + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: node.SecretID, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure that the Node.UpdateStatus RPC generates a new identity + // for a client authenticating using its secret ID which belongs to + // a non-default node pool. + name: "node secret ID authenticated non-default pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools( + structs.MsgTypeTestSetup, + srv.raft.LastIndex(), + []*structs.NodePool{nodePool}, + )) + + node := mock.Node() + node.NodePool = nodePool.Name + + must.Eq(t, "", node.IdentitySigningKeyID) + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: node.SecretID, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, nodePool.NodeIdentityTTL) + }, + }, + { + // Nomad servers often call the Node.UpdateStatus RPC to notify that + // a node has missed its heartbeat. In this case, we should write + // the update to state, but not generate an identity token. + name: "leader acl token authenticated", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusDown, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: srv.getLeaderAcl(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.Nil(t, resp.SignedIdentity) + + stateNode, err := srv.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNode) + must.Eq(t, structs.NodeStatusDown, stateNode.Status) + must.Greater(t, stateNode.CreateIndex, stateNode.ModifyIndex) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration. + name: "identity close to expiration", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + timeNow := time.Now().UTC().Add(-20 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration and the new identity has a TTL set + // by its custom node pool configuration. + name: "identity close to expiration custom pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools( + structs.MsgTypeTestSetup, + srv.raft.LastIndex(), + []*structs.NodePool{nodePool}, + )) + + timeNow := time.Now().UTC().Add(-135 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + node.NodePool = nodePool.Name + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, nodePool.NodeIdentityTTL) + }, + }, + } + + // ACL enabled server test run. + testACLServer, _, aclServerCleanup := TestACLServer(t, func(c *Config) {}) + defer aclServerCleanup() + testACLCodec := rpcClient(t, testACLServer) + + testutil.WaitForLeader(t, testACLServer.RPC) + testutil.WaitForKeyring(t, testACLServer.RPC, testACLServer.config.Region) + + // ACL disabled server test run. + testServer, serverCleanup := TestServer(t, func(c *Config) {}) + defer serverCleanup() + testCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + for _, tc := range testCases { + t.Run("ACL_enabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testACLServer, testACLCodec) + }) + t.Run("ACL_disabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testServer, testCodec) + }) + } +} + // TestClientEndpoint_UpdateDrain asserts the ability to initiate drain // against a node and cancel that drain. It also asserts: // * an evaluation is created when the node becomes eligible @@ -1316,6 +1842,7 @@ func TestClientEndpoint_UpdateDrain(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Disable drainer to prevent drain from completing during test s1.nodeDrainer.SetEnabled(false, nil) @@ -1435,6 +1962,7 @@ func TestClientEndpoint_UpdatedDrainAndCompleted(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) state := s1.fsm.State() // Disable drainer for now @@ -1545,6 +2073,7 @@ func TestClientEndpoint_UpdatedDrainNoop(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) state := s1.fsm.State() // Create the register request @@ -1688,6 +2217,7 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) require := require.New(t) // Register a node @@ -1820,6 +2350,7 @@ func TestClientEndpoint_UpdateEligibility(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1933,6 +2464,7 @@ func TestClientEndpoint_GetNode(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1966,10 +2498,14 @@ func TestClientEndpoint_GetNode(t *testing.T) { t.Fatalf("bad ComputedClass: %#v", resp2.Node) } + _, keyID, err := s1.encrypter.GetActiveKey() + must.NoError(t, err) + // Update the status updated at value node.StatusUpdatedAt = resp2.Node.StatusUpdatedAt node.SecretID = "" node.Events = resp2.Node.Events + node.IdentitySigningKeyID = keyID must.Eq(t, node, resp2.Node) // assert that the node register event was set correctly @@ -2167,6 +2703,7 @@ func TestClientEndpoint_GetAllocs(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2497,6 +3034,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2621,6 +3159,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking_GC(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2699,6 +3238,7 @@ func TestClientEndpoint_GetClientAllocs_WithoutMigrateTokens(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2754,6 +3294,7 @@ func TestClientEndpoint_GetAllocs_Blocking(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2853,6 +3394,7 @@ func TestNode_UpdateAlloc(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2933,6 +3475,7 @@ func TestNode_UpdateAlloc_NodeNotReady(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register node. node := mock.Node() @@ -3109,6 +3652,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3522,6 +4066,7 @@ func TestClientEndpoint_ListNodes(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3594,6 +4139,7 @@ func TestClientEndpoint_ListNodes_Fields(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3961,6 +4507,7 @@ func TestClientEndpoint_UpdateAlloc_Evals_ByTrigger(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 6c0446561..5cb00ec6c 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -254,6 +254,7 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { s1, cleanupS1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "node1") + c.Region = "regionFoo" c.TLSConfig = &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, @@ -264,18 +265,19 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { } }) defer cleanupS1() + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) - codec := rpcClient(t, s1) + tlsCodec := rpcClientWithTLS(t, s1, s1.config.TLSConfig) // Create the register request node := mock.Node() req := &structs.NodeRegisterRequest{ Node: node, - WriteRequest: structs.WriteRequest{Region: "global"}, + WriteRequest: structs.WriteRequest{Region: s1.Region()}, } var resp structs.GenericResponse - err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) + err := msgpackrpc.CallWithCodec(tlsCodec, "Node.Register", req, &resp) assert.Nil(err) // Check that heartbeatTimers has the heartbeat ID diff --git a/nomad/server_test.go b/nomad/server_test.go index d7175af67..ea490c6e4 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -375,6 +375,7 @@ func TestServer_Reload_TLSConnections_TLSToPlaintext_OnlyRPC(t *testing.T) { } }) defer cleanupS1() + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) newTLSConfig := &config.TLSConfig{ EnableHTTP: true, diff --git a/nomad/structs/identity.go b/nomad/structs/identity.go index 41e43f99d..4c17a9b03 100644 --- a/nomad/structs/identity.go +++ b/nomad/structs/identity.go @@ -51,7 +51,17 @@ func (i *IdentityClaims) IsExpiring(now time.Time, ttl time.Duration) bool { // relative to the current time. threshold := now.Add(ttl / 3) - return i.Expiry.Time().Before(threshold) + return i.Expiry.Time().UTC().Before(threshold) +} + +// IsExpiringInThreshold checks if the identity JWT is expired or close to +// expiring. It uses a passed threshold to determine "close to expiring" which +// is not manipulated, unlike TTL in the IsExpiring method. +func (i *IdentityClaims) IsExpiringInThreshold(threshold time.Time) bool { + if i != nil && i.Expiry != nil { + return threshold.After(i.Expiry.Time()) + } + return false } // setExpiry sets the "expiry" or "exp" claim for the identity JWT. It is the diff --git a/nomad/structs/identity_test.go b/nomad/structs/identity_test.go index b35690f0c..8e09a7061 100644 --- a/nomad/structs/identity_test.go +++ b/nomad/structs/identity_test.go @@ -174,6 +174,57 @@ func TestIdentityClaims_IsExpiring(t *testing.T) { } } +func TestIdentityClaims_IsExpiringWithTTL(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputIdentityClaims *IdentityClaims + inputThreshold time.Time + expectedResult bool + }{ + { + name: "nil identity", + inputIdentityClaims: nil, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "no expiry", + inputIdentityClaims: &IdentityClaims{}, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "not close to expiring", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + }, + }, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "close to expiring", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now()), + }, + }, + inputThreshold: time.Now().Add(1 * time.Minute), + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputIdentityClaims.IsExpiringInThreshold(tc.inputThreshold) + must.Eq(t, tc.expectedResult, actualOutput) + }) + } +} + func TestIdentityClaimsNg_setExpiry(t *testing.T) { ci.Parallel(t) diff --git a/nomad/structs/node.go b/nomad/structs/node.go index bcc0fec39..a5a308e3f 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -537,3 +537,180 @@ func GenerateNodeIdentityClaims(node *Node, region string, ttl time.Duration) *I return claims } + +// NodeRegisterRequest is used by the Node.Register RPC endpoint to register a +// node as being a schedulable entity. +type NodeRegisterRequest struct { + Node *Node + NodeEvent *NodeEvent + + // CreateNodePool is used to indicate that the node's node pool should be + // created along with the node registration if it doesn't exist. + CreateNodePool bool + + WriteRequest +} + +// ShouldGenerateNodeIdentity compliments the functionality within +// AuthenticateNodeIdentityGenerator to determine whether a new node identity +// should be generated within the RPC handler. +func (n *NodeRegisterRequest) ShouldGenerateNodeIdentity( + authErr error, + now time.Time, + ttl time.Duration, +) bool { + + // In the event the error is because the node identity is expired, we should + // generate a new identity. Without this, a disconnected node would never be + // able to re-register. Any other error is not a reason to generate a new + // identity. + if authErr != nil { + return errors.Is(authErr, jwt.ErrExpired) + } + + // If an ACL token or client ID is set, a node is attempting to register for + // the first time, or is re-registering using its secret ID. In either case, + // we should generate a new identity. + if n.identity.ACLToken != nil || n.identity.ClientID != "" { + return true + } + + // If we have reached this point, we can assume that the request is using a + // node identity. + claims := n.GetIdentity().GetClaims() + + // It is possible that the node has been restarted and had its configuration + // updated. In this case, we should generate a new identity for the node, so + // it reflects its new claims. + if n.Node.NodePool != claims.NodeIdentityClaims.NodePool || + n.Node.NodeClass != claims.NodeIdentityClaims.NodeClass || + n.Node.Datacenter != claims.NodeIdentityClaims.NodeDatacenter { + return true + } + + // The final check is to see if the node identity is expiring. + return claims.IsExpiring(now, ttl) +} + +// NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint +// to update the status of a node. +type NodeUpdateStatusRequest struct { + NodeID string + Status string + + // IdentitySigningKeyID is the ID of the root key used to sign the node's + // identity. This is not provided by the client, but is set by the server, + // so that the value can be propagated through Raft. + IdentitySigningKeyID string + + // ForceIdentityRenewal is used to force the Nomad server to generate a new + // identity for the node. + ForceIdentityRenewal bool + + NodeEvent *NodeEvent + UpdatedAt int64 + WriteRequest +} + +// ShouldGenerateNodeIdentity determines whether the handler should generate a +// new node identity based on the caller identity information. +func (n *NodeUpdateStatusRequest) ShouldGenerateNodeIdentity( + now time.Time, + ttl time.Duration, +) bool { + + identity := n.GetIdentity() + + // If the client ID is set, we should generate a new identity as the node + // has authenticated using its secret ID. + if identity.ClientID != "" { + return true + } + + // Confirm we have a node identity and then check for forced renewal or + // expiration. + if identity.GetClaims().IsNode() { + if n.ForceIdentityRenewal { + return true + } + return n.GetIdentity().GetClaims().IsExpiring(now, ttl) + } + + // No other conditions should generate a new identity. In the case of the + // update status endpoint, this will likely be a Nomad server propagating + // that a node has missed its heartbeat. + return false +} + +// IdentitySigningErrorIsTerminal determines if the RPC handler should return an +// error because it failed to sign a newly generated node identity. +// +// This is because a client might be connected to a follower at the point the +// root keyring is rotated. If the client heartbeats right at that moment and +// before the follower decrypts the key (e.g., network latency to external KMS), +// we will mark the node as down. This is despite identity being valid and the +// likelihood it will get a new identity signed on the next heartbeat. +func (n *NodeUpdateStatusRequest) IdentitySigningErrorIsTerminal(now time.Time) bool { + + identity := n.GetIdentity() + + // If the client has authenticated using a secret ID, we can continue to let + // it do that, until we successfully generate a new identity. + if identity.ClientID != "" { + return false + } + + // If the identity is a node identity, we can check if it is expiring. This + // check is used to determine if the RPC handler should return an error, so + // we use a short threshold of 10 minutes. This is to ensure we don't return + // errors unless we absolutely have to. + // + // A threshold of 10 minutes more than covers another heartbeat on the + // largest Nomad clusters, which can reach ~5 minutes. + if identity.GetClaims().IsNode() { + return n.GetIdentity().GetClaims().IsExpiringInThreshold(now.Add(10 * time.Minute)) + } + + // No other condition should result in the RPC handler returning an error + // because we failed to sign the node identity. No caller should be able to + // reach this point, as identity generation should be gated by + // ShouldGenerateNodeIdentity. + return false +} + +// NodeUpdateResponse is used to respond to a node update. The object is a +// shared response used by the Node.Register, Node.Deregister, +// Node.BatchDeregister, Node.UpdateStatus, and Node.Evaluate RPCs. +type NodeUpdateResponse struct { + HeartbeatTTL time.Duration + EvalIDs []string + EvalCreateIndex uint64 + NodeModifyIndex uint64 + + // Features informs clients what enterprise features are allowed + Features uint64 + + // LeaderRPCAddr is the RPC address of the current Raft Leader. If + // empty, the current Nomad Server is in the minority of a partition. + LeaderRPCAddr string + + // NumNodes is the number of Nomad nodes attached to this quorum of + // Nomad Servers at the time of the response. This value can + // fluctuate based on the health of the cluster between heartbeats. + NumNodes int32 + + // Servers is the full list of known Nomad servers in the local + // region. + Servers []*NodeServerInfo + + // SchedulingEligibility is used to inform clients what the server-side + // has for their scheduling status during heartbeats. + SchedulingEligibility string + + // SignedIdentity is the newly signed node identity that the server has + // generated. The node should check if this is set, and if so, update its + // state with the new identity. + SignedIdentity *string + + QueryMeta +} diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index 95970aaed..57e3912ca 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/nomad/ci" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" @@ -279,3 +280,372 @@ func TestGenerateNodeIdentityClaims(t *testing.T) { must.NotNil(t, claims.NotBefore) must.NotNil(t, claims.Expiry) } + +func TestNodeRegisterRequest_ShouldGenerateNodeIdentity(t *testing.T) { + ci.Parallel(t) + + // Generate a stable mock node for testing. + mockNode := MockNode() + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeRegisterRequest + inputAuthErr error + inputTime time.Time + inputTTL time.Duration + expectedOutput bool + }{ + { + name: "expired node identity", + inputNodeRegisterRequest: &NodeRegisterRequest{}, + inputAuthErr: jwt.ErrExpired, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "first time node registration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: AnonymousACLToken, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "registration using node secret ID", + inputNodeRegisterRequest: &NodeRegisterRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id-1", + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "modified node node pool configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: "new-pool", + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "modified node class configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: "new-class", + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "modified node datacenter configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: "new-datacenter", + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "expiring node identity", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(5 * time.Minute)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "no generation", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.ShouldGenerateNodeIdentity( + tc.inputAuthErr, + tc.inputTime, + tc.inputTTL, + ) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} + +func TestNodeUpdateStatusRequest_ShouldGenerateNodeIdentity(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeUpdateStatusRequest + inputTime time.Time + inputTTL time.Duration + expectedOutput bool + }{ + { + name: "authenticated by node secret ID", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id-1", + }, + }, + }, + inputTime: time.Now(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "expiring node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(1 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "not expiring node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + { + name: "not expiring forced renewal node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + ForceIdentityRenewal: true, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "server authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: LeaderACLToken, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.ShouldGenerateNodeIdentity( + tc.inputTime, + tc.inputTTL, + ) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} +func TestNodeUpdateStatusRequest_IdentitySigningErrorIsTerminal(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeUpdateStatusRequest + inputTime time.Time + expectedOutput bool + }{ + { + name: "not close to expiring", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour).UTC()), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + { + name: "very close to expiring", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC()), + }, + }, + }, + }, + }, + inputTime: time.Now().Add(1 * time.Minute).UTC(), + expectedOutput: true, + }, + { + name: "server authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: LeaderACLToken, + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + { + name: "client secret ID authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id", + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.IdentitySigningErrorIsTerminal(tc.inputTime) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index b88c20109..b67c85b61 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -597,19 +597,6 @@ type WriteMeta struct { Index uint64 } -// NodeRegisterRequest is used for Node.Register endpoint -// to register a node as being a schedulable entity. -type NodeRegisterRequest struct { - Node *Node - NodeEvent *NodeEvent - - // CreateNodePool is used to indicate that the node's node pool should be - // create along with the node registration if it doesn't exist. - CreateNodePool bool - - WriteRequest -} - // NodeDeregisterRequest is used for Node.Deregister endpoint // to deregister a node as being a schedulable entity. type NodeDeregisterRequest struct { @@ -643,26 +630,6 @@ type NodeServerInfo struct { Datacenter string } -// NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint -// to update the status of a node. -type NodeUpdateStatusRequest struct { - NodeID string - Status string - - // IdentitySigningKeyID is the ID of the root key used to sign the node's - // identity. This is not provided by the client, but is set by the server, - // so that the value can be propagated through Raft. - IdentitySigningKeyID string - - // ForceIdentityRenewal is used to force the Nomad server to generate a new - // identity for the node. - ForceIdentityRenewal bool - - NodeEvent *NodeEvent - UpdatedAt int64 - WriteRequest -} - // NodeUpdateDrainRequest is used for updating the drain strategy type NodeUpdateDrainRequest struct { NodeID string @@ -1506,36 +1473,6 @@ type JobValidateResponse struct { Warnings string } -// NodeUpdateResponse is used to respond to a node update -type NodeUpdateResponse struct { - HeartbeatTTL time.Duration - EvalIDs []string - EvalCreateIndex uint64 - NodeModifyIndex uint64 - - // Features informs clients what enterprise features are allowed - Features uint64 - - // LeaderRPCAddr is the RPC address of the current Raft Leader. If - // empty, the current Nomad Server is in the minority of a partition. - LeaderRPCAddr string - - // NumNodes is the number of Nomad nodes attached to this quorum of - // Nomad Servers at the time of the response. This value can - // fluctuate based on the health of the cluster between heartbeats. - NumNodes int32 - - // Servers is the full list of known Nomad servers in the local - // region. - Servers []*NodeServerInfo - - // SchedulingEligibility is used to inform clients what the server-side - // has for their scheduling status during heartbeats. - SchedulingEligibility string - - QueryMeta -} - // NodeDrainUpdateResponse is used to respond to a node drain update type NodeDrainUpdateResponse struct { NodeModifyIndex uint64 diff --git a/nomad/worker_test.go b/nomad/worker_test.go index 4bd18c7ea..eb00f5806 100644 --- a/nomad/worker_test.go +++ b/nomad/worker_test.go @@ -522,6 +522,7 @@ func TestWorker_SubmitPlanNormalizedAllocations(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -574,6 +575,7 @@ func TestWorker_SubmitPlan_MissingNodeRefresh(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -648,6 +650,7 @@ func TestWorker_UpdateEval(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -699,6 +702,7 @@ func TestWorker_CreateEval(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() From 2f302051021b844f4c4fad549c49bb4298cb6e1a Mon Sep 17 00:00:00 2001 From: James Rasell Date: Mon, 7 Jul 2025 16:28:27 +0200 Subject: [PATCH 5/7] client: Add state functionality for set and get client identities. (#26184) The Nomad client will persist its own identity within its state store for restart persistence. The added benefit of using it over the filesystem is that it supports transactions. This is useful when considering the identity will be renewed periodically. --- client/state/db_bolt.go | 42 +++++++++++++++++++++++++++ client/state/db_error.go | 4 +++ client/state/db_mem.go | 18 ++++++++++++ client/state/db_noop.go | 4 +++ client/state/db_test.go | 18 ++++++++++++ client/state/interface.go | 8 +++++ command/operator_client_state.go | 15 ++++++++-- command/operator_client_state_test.go | 6 ++++ 8 files changed, 113 insertions(+), 2 deletions(-) diff --git a/client/state/db_bolt.go b/client/state/db_bolt.go index bef111f6e..c952500b6 100644 --- a/client/state/db_bolt.go +++ b/client/state/db_bolt.go @@ -140,6 +140,12 @@ var ( nodeRegistrationKey = []byte("node_registration") hostVolBucket = []byte("host_volumes_to_create") + + // nodeIdentityBucket and nodeIdentityBucketStateKey are used to persist + // the client identity and its state. Each client will only have a single + // identity, so we use a single key value for the storage. + nodeIdentityBucket = []byte("node_identity") + nodeIdentityBucketStateKey = []byte("node_identity_state") ) // taskBucketName returns the bucket name for the given task name. @@ -1089,6 +1095,42 @@ func (s *BoltStateDB) DeleteDynamicHostVolume(id string) error { }) } +// clientIdentity wraps the signed client identity so we can safely add more +// state in the future without needing a new entry type. +type clientIdentity struct { + SignedIdentity string +} + +func (s *BoltStateDB) PutNodeIdentity(identity string) error { + return s.db.Update(func(tx *boltdd.Tx) error { + b, err := tx.CreateBucketIfNotExists(nodeIdentityBucket) + if err != nil { + return err + } + + identityWrapper := clientIdentity{SignedIdentity: identity} + + return b.Put(nodeIdentityBucketStateKey, &identityWrapper) + }) +} + +func (s *BoltStateDB) GetNodeIdentity() (string, error) { + var identityWrapper clientIdentity + err := s.db.View(func(tx *boltdd.Tx) error { + b := tx.Bucket(nodeIdentityBucket) + if b == nil { + return nil + } + return b.Get(nodeIdentityBucketStateKey, &identityWrapper) + }) + + if boltdd.IsErrNotFound(err) { + return "", nil + } + + return identityWrapper.SignedIdentity, err +} + // init initializes metadata entries in a newly created state database. func (s *BoltStateDB) init() error { return s.db.Update(func(tx *boltdd.Tx) error { diff --git a/client/state/db_error.go b/client/state/db_error.go index 6c99defa2..6edfbbdd9 100644 --- a/client/state/db_error.go +++ b/client/state/db_error.go @@ -172,3 +172,7 @@ func (m *ErrDB) DeleteDynamicHostVolume(_ string) error { func (m *ErrDB) Close() error { return fmt.Errorf("Error!") } + +func (m *ErrDB) PutNodeIdentity(_ string) error { return ErrDBError } + +func (m *ErrDB) GetNodeIdentity() (string, error) { return "", ErrDBError } diff --git a/client/state/db_mem.go b/client/state/db_mem.go index 32abd883e..4fd827852 100644 --- a/client/state/db_mem.go +++ b/client/state/db_mem.go @@ -6,6 +6,7 @@ package state import ( "maps" "sync" + "sync/atomic" "github.com/hashicorp/go-hclog" arstate "github.com/hashicorp/nomad/client/allocrunner/state" @@ -62,6 +63,9 @@ type MemDB struct { dynamicHostVolumes map[string]*cstructs.HostVolumeState + // clientIdentity is the persisted identity of the client. + clientIdentity atomic.Value + logger hclog.Logger mu sync.RWMutex @@ -79,6 +83,7 @@ func NewMemDB(logger hclog.Logger) *MemDB { checks: make(checks.ClientResults), identities: make(map[string][]*structs.SignedWorkloadIdentity), dynamicHostVolumes: make(map[string]*cstructs.HostVolumeState), + clientIdentity: atomic.Value{}, logger: logger, } } @@ -379,6 +384,19 @@ func (m *MemDB) DeleteDynamicHostVolume(s string) error { return nil } +func (m *MemDB) PutNodeIdentity(identity string) error { + m.clientIdentity.Store(identity) + return nil +} + +func (m *MemDB) GetNodeIdentity() (string, error) { + if obj := m.clientIdentity.Load(); obj == nil { + return "", nil + } else { + return obj.(string), nil + } +} + func (m *MemDB) Close() error { m.mu.Lock() defer m.mu.Unlock() diff --git a/client/state/db_noop.go b/client/state/db_noop.go index 09488c181..3c53ae57c 100644 --- a/client/state/db_noop.go +++ b/client/state/db_noop.go @@ -157,6 +157,10 @@ func (n NoopDB) DeleteDynamicHostVolume(_ string) error { return nil } +func (n NoopDB) PutNodeIdentity(_ string) error { return nil } + +func (n NoopDB) GetNodeIdentity() (string, error) { return "", nil } + func (n NoopDB) Close() error { return nil } diff --git a/client/state/db_test.go b/client/state/db_test.go index 3a03cf3a2..23f6fb55d 100644 --- a/client/state/db_test.go +++ b/client/state/db_test.go @@ -493,6 +493,24 @@ func TestStateDB_CheckResult(t *testing.T) { } +func TestStateDB_NodeIdentity(t *testing.T) { + ci.Parallel(t) + + testDB(t, func(t *testing.T, db StateDB) { + identity, err := db.GetNodeIdentity() + must.NoError(t, err) + must.Eq(t, "", identity) + + fakeIdentity := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30" + + must.NoError(t, db.PutNodeIdentity(fakeIdentity)) + + identity, err = db.GetNodeIdentity() + must.NoError(t, err) + must.Eq(t, fakeIdentity, identity) + }) +} + // TestStateDB_Upgrade asserts calling Upgrade on new databases always // succeeds. func TestStateDB_Upgrade(t *testing.T) { diff --git a/client/state/interface.go b/client/state/interface.go index 0460a75e2..136466d95 100644 --- a/client/state/interface.go +++ b/client/state/interface.go @@ -141,6 +141,14 @@ type StateDB interface { GetDynamicHostVolumes() ([]*cstructs.HostVolumeState, error) DeleteDynamicHostVolume(string) error + // PutNodeIdentity stores the signed identity JWT for the client. + PutNodeIdentity(identity string) error + + // GetNodeIdentity retrieves the signed identity JWT for the client. If the + // client has not generated an identity, this will return an empty string + // and no error. + GetNodeIdentity() (string, error) + // Close the database. Unsafe for further use after calling regardless // of return value. Close() error diff --git a/command/operator_client_state.go b/command/operator_client_state.go index b761d4bdb..ef5c2d9cb 100644 --- a/command/operator_client_state.go +++ b/command/operator_client_state.go @@ -132,8 +132,18 @@ func (c *OperatorClientStateCommand) Run(args []string) int { Tasks: tasks, } } + + // Get the node identity state, which is useful when debugging to see the + // real and current identity the node is using. + nodeIdentity, err := db.GetNodeIdentity() + if err != nil { + c.Ui.Error(fmt.Sprintf("failed to get node identity state: %v", err)) + return 1 + } + output := debugOutput{ - Allocations: data, + Allocations: data, + NodeIdentity: nodeIdentity, } bytes, err := json.Marshal(output) if err != nil { @@ -146,7 +156,8 @@ func (c *OperatorClientStateCommand) Run(args []string) int { } type debugOutput struct { - Allocations map[string]*clientStateAlloc + Allocations map[string]*clientStateAlloc + NodeIdentity string } type clientStateAlloc struct { diff --git a/command/operator_client_state_test.go b/command/operator_client_state_test.go index 2220d33cc..696e5784b 100644 --- a/command/operator_client_state_test.go +++ b/command/operator_client_state_test.go @@ -38,10 +38,16 @@ func TestOperatorClientStateCommand(t *testing.T) { alloc := structs.MockAlloc() err = db.PutAllocation(alloc) must.NoError(t, err) + + // Write a node identity to the DB, so we can test that the command reads + // this data. + must.NoError(t, db.PutNodeIdentity("mynodeidentity")) + must.NoError(t, db.Close()) // run against an incomplete client state directory code = cmd.Run([]string{dir}) must.Eq(t, 0, code) must.StrContains(t, ui.OutputWriter.String(), alloc.ID) + must.StrContains(t, ui.OutputWriter.String(), "NodeIdentity\":\"mynodeidentity") } From 8096ea4129f223a6ea7175628e7a19d093d3ac99 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Mon, 14 Jul 2025 15:24:43 +0200 Subject: [PATCH 6/7] client: Handle identities from servers and use for RPC auth. (#26218) Nomad servers, if upgraded, can return node identities as part of the register and update/heartbeat response objects. The Nomad client will now handle this and store it as appropriate within its memory and statedb. The client will now use any stored identity for RPC authentication with a fallback to the secretID. This supports upgrades paths where the Nomad clients are updated before the Nomad servers. --- client/client.go | 140 +++++++++++++++++++++++--- client/client_test.go | 84 ++++++++++++++++ client/drain.go | 8 +- client/identity.go | 21 ++++ client/identity_test.go | 32 ++++++ client/rpc.go | 2 +- client/serviceregistration/nsd/nsd.go | 26 ++++- client/widmgr/signer.go | 24 ++++- 8 files changed, 312 insertions(+), 25 deletions(-) create mode 100644 client/identity.go create mode 100644 client/identity_test.go diff --git a/client/client.go b/client/client.go index d993903c2..74abd0c16 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" consulapi "github.com/hashicorp/consul/api" @@ -333,6 +334,11 @@ type Client struct { // users is a pool of dynamic workload users users dynamic.Pool + + // identity is the node identity token that has been generated and signed by + // the servers. This is used to authenticate the client to the servers when + // performing RPC calls. + identity atomic.Value } var ( @@ -395,6 +401,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie getter: getter.New(cfg.Artifact, logger), EnterpriseClient: newEnterpriseClient(logger), allocrunnerFactory: cfg.AllocRunnerFactory, + identity: atomic.Value{}, } // we can't have this set in the default Config because of import cycles @@ -604,6 +611,23 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie logger.Warn("batch fingerprint operation timed out; proceeding to register with fingerprinted plugins so far") } + // Attempt to pull the node identity from the state database. If the client + // is starting for the first time, this will be empty, so avoid an + // unnecessary set call to the client atomic. This needs to happen before we + // start heartbeating to avoid unnecessary identity generation and load on + // the Nomad servers. + // + // If the DB returns an error, it is more than likely that the full + // restoration will fail. It isn't terminal for us at this point though, as + // we can generate a new identity on registration. + clientIdentity, err := c.stateDB.GetNodeIdentity() + if err != nil { + logger.Error("failed to get client identity from state", "error", err) + } + if clientIdentity != "" { + c.setNodeIdentityToken(clientIdentity) + } + // Register and then start heartbeating to the servers. c.shutdownGroup.Go(c.registerAndHeartbeat) @@ -903,11 +927,57 @@ func (c *Client) NodeID() string { return c.GetConfig().Node.ID } -// secretNodeID returns the secret node ID for the given client +// secretNodeID returns the secret node ID for the given client. This is no +// longer used as the primary authentication method for Nomad clients. In fully +// upgraded clusters, the node identity token is used instead. It will still be +// used if the client has been upgraded, but the Nomad server has not. Most +// callers should use the nodeAuthToken function instead of this as it correctly +// handles both authentication token methods. There are some limited places +// where the secret node ID is still used on the RPC request object such as +// "Node.GetClientAllocs". func (c *Client) secretNodeID() string { return c.GetConfig().Node.SecretID } +// nodeAuthToken will return the authentication token for the client. This will +// return the node identity token if it is set, otherwise it will return the +// secret node ID. +// +// The callers of this should be moved to nodeIdentityToken in Nomad 1.13 when +// all clients should be using the node identity token. +func (c *Client) nodeAuthToken() string { + if nID := c.nodeIdentityToken(); nID != "" { + return nID + } + return c.secretNodeID() +} + +// nodeIdentityToken returns the node identity token for the given client. If +// the client is coming up for the first time, restarting, or is in a cluster +// where the Nomad servers have not been upgraded to support the node identity, +// this will be empty. Callers should use the nodeAuthToken function instead of +// this as it correctly handles both authentication token methods. +func (c *Client) nodeIdentityToken() string { + if v := c.identity.Load(); v != nil { + return v.(string) + } + return "" +} + +// setNodeIdentityToken handles storing and updating all the client backend +// processes with a new node identity token. +func (c *Client) setNodeIdentityToken(token string) { + + // Store the token on the client as the first step, so it's available for + // use by all RPCs immediately. + c.identity.Store(token) + + // Update the Nomad service registration handler and workload identity + // signer processes. + assertAndSetNodeIdentityToken(c.nomadService, token) + assertAndSetNodeIdentityToken(c.widsigner, token) +} + // Shutdown is used to tear down the client func (c *Client) Shutdown() error { c.shutdownLock.Lock() @@ -1954,7 +2024,7 @@ func (c *Client) submitNodeEvents(events []*structs.NodeEvent) error { NodeEvents: nodeEvents, WriteRequest: structs.WriteRequest{ Region: c.Region(), - AuthToken: c.secretNodeID(), + AuthToken: c.nodeAuthToken(), }, } var resp structs.EmitNodeEventsResponse @@ -2053,7 +2123,7 @@ func (c *Client) getRegistrationToken() string { select { case <-c.registeredCh: - return c.secretNodeID() + return c.nodeAuthToken() default: // If we haven't yet closed the registeredCh we're either starting for // the 1st time or we've just restarted. Check the local state to see if @@ -2065,7 +2135,7 @@ func (c *Client) getRegistrationToken() string { } if registration != nil && registration.HasRegistered { c.registeredOnce.Do(func() { close(c.registeredCh) }) - return c.secretNodeID() + return c.nodeAuthToken() } } return "" @@ -2086,6 +2156,11 @@ func (c *Client) registerNode(authToken string) error { return err } + // + if err := c.handleNodeUpdateResponse(resp); err != nil { + return err + } + // Signal that we've registered once so that RPCs sent from the client can // send authenticated requests. Persist this information in the state so // that we don't block restoring running allocs when restarting while @@ -2100,11 +2175,6 @@ func (c *Client) registerNode(authToken string) error { close(c.registeredCh) }) - err := c.handleNodeUpdateResponse(resp) - if err != nil { - return err - } - // Update the node status to ready after we register. c.UpdateConfig(func(c *config.Config) { c.Node.Status = structs.NodeStatusReady @@ -2131,7 +2201,7 @@ func (c *Client) updateNodeStatus() error { Status: structs.NodeStatusReady, WriteRequest: structs.WriteRequest{ Region: c.Region(), - AuthToken: c.secretNodeID(), + AuthToken: c.nodeAuthToken(), }, } var resp structs.NodeUpdateResponse @@ -2175,9 +2245,8 @@ func (c *Client) updateNodeStatus() error { } }) - err := c.handleNodeUpdateResponse(resp) - if err != nil { - return fmt.Errorf("heartbeat response returned no valid servers") + if err := c.handleNodeUpdateResponse(resp); err != nil { + return fmt.Errorf("failed to handle node update response: %w", err) } // If there's no Leader in the response we may be talking to a partitioned @@ -2195,6 +2264,20 @@ func (c *Client) handleNodeUpdateResponse(resp structs.NodeUpdateResponse) error // rebalance rate. c.servers.SetNumNodes(resp.NumNodes) + // If the response includes a new identity, set it and save it to the state + // DB. + // + // In the unlikely event that we cannot write the identity to the state DB, + // we do not want to set the client identity token. That would mean the + // client memory state and persistent state DB are out of sync. Instead, we + // return an error and wait until the next heartbeat to try again. + if resp.SignedIdentity != nil { + if err := c.stateDB.PutNodeIdentity(*resp.SignedIdentity); err != nil { + return fmt.Errorf("error saving client identity: %w", err) + } + c.setNodeIdentityToken(*resp.SignedIdentity) + } + // Convert []*NodeServerInfo to []*servers.Server nomadServers := make([]*servers.Server, 0, len(resp.Servers)) for _, s := range resp.Servers { @@ -2277,7 +2360,7 @@ func (c *Client) allocSync() { Alloc: toSync, WriteRequest: structs.WriteRequest{ Region: c.Region(), - AuthToken: c.secretNodeID(), + AuthToken: c.nodeAuthToken(), }, } @@ -2336,6 +2419,27 @@ type allocUpdates struct { // watchAllocations is used to scan for updates to allocations func (c *Client) watchAllocations(updates chan *allocUpdates) { + + // The request object is generated as soon as this function is called, but + // the RPC can block on the register channel being closed. If we are + // starting for the first time and have not got our identity, the + // authentication token could be set to an empty string. This will result in + // a failed RPC when the call is unblocked. + // + // Although this will be quickly retried, we want to ensure that we do not + // throw errors into the logs or perform calls we know will fail if we can + // avoid it. Therefore, we wait for the registered channel to be closed, + // indicating the client has registered and has an identity token. + // + // This is a prevalent problem when the Nomad agent is run in development + // mode, as the server needs to start and have its encrypter ready, before + // it can generate identities. + select { + case <-c.shutdownCh: + return + case <-c.registeredCh: + } + // The request and response for getting the map of allocations that should // be running on the Node to their AllocModifyIndex which is incremented // when the allocation is updated by the servers. @@ -2352,7 +2456,7 @@ func (c *Client) watchAllocations(updates chan *allocUpdates) { // After the first request, only require monotonically // increasing state. AllowStale: false, - AuthToken: c.secretNodeID(), + AuthToken: c.nodeAuthToken(), }, } var resp structs.NodeClientAllocsResponse @@ -2363,7 +2467,7 @@ func (c *Client) watchAllocations(updates chan *allocUpdates) { QueryOptions: structs.QueryOptions{ Region: c.Region(), AllowStale: true, - AuthToken: c.secretNodeID(), + AuthToken: c.nodeAuthToken(), }, } var allocsResp structs.AllocsGetResponse @@ -2373,6 +2477,9 @@ OUTER: // Get the allocation modify index map, blocking for updates. We will // use this to determine exactly what allocations need to be downloaded // in full. + + req.AuthToken = c.nodeAuthToken() + resp = structs.NodeClientAllocsResponse{} err := c.RPC("Node.GetClientAllocs", &req, &resp) if err != nil { @@ -2463,6 +2570,7 @@ OUTER: // Pull the allocations that need to be updated. allocsReq.AllocIDs = pull allocsReq.MinQueryIndex = pullIndex - 1 + allocsReq.AuthToken = c.nodeAuthToken() allocsResp = structs.AllocsGetResponse{} if err := c.RPC("Alloc.GetAllocs", &allocsReq, &allocsResp); err != nil { c.logger.Error("error querying updated allocations", "error", err) diff --git a/client/client_test.go b/client/client_test.go index d3dafd194..cfc3cd369 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -23,6 +23,7 @@ import ( trstate "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/fingerprint" + "github.com/hashicorp/nomad/client/servers" regMock "github.com/hashicorp/nomad/client/serviceregistration/mock" "github.com/hashicorp/nomad/client/state" cstate "github.com/hashicorp/nomad/client/state" @@ -30,6 +31,7 @@ import ( "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/pluginutils/catalog" "github.com/hashicorp/nomad/helper/pluginutils/singleton" + "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad" @@ -1393,6 +1395,38 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { } } +func TestClient_nodeAuthToken(t *testing.T) { + ci.Parallel(t) + + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.Node.ID = uuid.Generate() + }) + defer func() { + _ = testClientCleanup() + }() + + must.Eq(t, testClient.GetConfig().Node.SecretID, testClient.nodeAuthToken()) + + testClient.setNodeIdentityToken("my-identity-token") + must.Eq(t, "my-identity-token", testClient.nodeAuthToken()) +} + +func TestClient_setNodeIdentityToken(t *testing.T) { + ci.Parallel(t) + + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.Node.ID = uuid.Generate() + }) + defer func() { + _ = testClientCleanup() + }() + + must.Eq(t, "", testClient.nodeIdentityToken()) + + testClient.setNodeIdentityToken("my-identity-token") + must.Eq(t, "my-identity-token", testClient.nodeIdentityToken()) +} + // TestClient_ServerList tests client methods that interact with the internal // nomad server list. func TestClient_ServerList(t *testing.T) { @@ -1419,6 +1453,56 @@ func TestClient_ServerList(t *testing.T) { } } +func TestClient_handleNodeUpdateResponse(t *testing.T) { + ci.Parallel(t) + + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.StateDBFactory = func(logger hclog.Logger, stateDir string) (state.StateDB, error) { + return cstate.NewMemDB(logger), nil + } + }) + defer func() { + _ = testClientCleanup() + }() + + // Assert our starting state, so we can ensure we are not testing for values + // that already exist. + must.Eq(t, 0, testClient.servers.NumNodes()) + must.Eq(t, 0, testClient.servers.NumServers()) + must.Eq(t, []*servers.Server{}, testClient.servers.GetServers()) + must.Eq(t, "", testClient.nodeIdentityToken()) + + stateIdentity, err := testClient.stateDB.GetNodeIdentity() + must.NoError(t, err) + must.Eq(t, "", stateIdentity) + + updateResp := structs.NodeUpdateResponse{ + NumNodes: 1010, + Servers: []*structs.NodeServerInfo{ + {RPCAdvertiseAddr: "10.0.0.1:4647", Datacenter: "dc1"}, + {RPCAdvertiseAddr: "10.0.0.2:4647", Datacenter: "dc1"}, + {RPCAdvertiseAddr: "10.0.0.3:4647", Datacenter: "dc1"}, + }, + SignedIdentity: pointer.Of("node-identity"), + } + + // Perform the update and test the outcome. + must.NoError(t, testClient.handleNodeUpdateResponse(updateResp)) + + must.Eq(t, 1010, testClient.servers.NumNodes()) + must.Eq(t, 3, testClient.servers.NumServers()) + must.SliceContainsAllEqual(t, []*servers.Server{ + {Addr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 4647}}, + {Addr: &net.TCPAddr{IP: net.ParseIP("10.0.0.2"), Port: 4647}}, + {Addr: &net.TCPAddr{IP: net.ParseIP("10.0.0.3"), Port: 4647}}, + }, testClient.servers.GetServers()) + must.Eq(t, "node-identity", testClient.nodeIdentityToken()) + + stateIdentity, err = testClient.stateDB.GetNodeIdentity() + must.NoError(t, err) + must.Eq(t, "node-identity", stateIdentity) +} + func TestClient_UpdateNodeFromDevicesAccumulates(t *testing.T) { ci.Parallel(t) diff --git a/client/drain.go b/client/drain.go index a1bb34e8c..6644bb8bd 100644 --- a/client/drain.go +++ b/client/drain.go @@ -34,7 +34,9 @@ func (c *Client) DrainSelf() error { MarkEligible: false, Meta: map[string]string{"message": "shutting down"}, WriteRequest: structs.WriteRequest{ - Region: c.Region(), AuthToken: c.secretNodeID()}, + Region: c.Region(), + AuthToken: c.nodeAuthToken(), + }, } if drainSpec.Deadline > 0 { drainReq.DrainStrategy.ForceDeadline = now.Add(drainSpec.Deadline) @@ -94,7 +96,9 @@ func (c *Client) pollServerForDrainStatus(ctx context.Context, interval time.Dur NodeID: c.NodeID(), SecretID: c.secretNodeID(), QueryOptions: structs.QueryOptions{ - Region: c.Region(), AuthToken: c.secretNodeID()}, + Region: c.Region(), + AuthToken: c.nodeAuthToken(), + }, } var statusResp structs.SingleNodeResponse diff --git a/client/identity.go b/client/identity.go new file mode 100644 index 000000000..b108aadf4 --- /dev/null +++ b/client/identity.go @@ -0,0 +1,21 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package client + +// NodeIdentityHandler is an interface that allows setting a node identity +// token. The client uses this to inform its subsystems about a new node +// identity that it should use for RPC calls. +type NodeIdentityHandler interface { + SetNodeIdentityToken(token string) +} + +// assertAndSetNodeIdentityToken expects the passed interface implements +// NodeIdentityHandler and calls SetNodeIdentityToken. It is a programming error +// if the interface does not implement NodeIdentityHandler and will panic. The +// test file performs test assertions. +func assertAndSetNodeIdentityToken(impl any, token string) { + if impl != nil { + impl.(NodeIdentityHandler).SetNodeIdentityToken(token) + } +} diff --git a/client/identity_test.go b/client/identity_test.go new file mode 100644 index 000000000..6f2863b6e --- /dev/null +++ b/client/identity_test.go @@ -0,0 +1,32 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package client + +import ( + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/serviceregistration/nsd" + "github.com/hashicorp/nomad/client/widmgr" + "github.com/shoenig/test/must" +) + +var ( + _ NodeIdentityHandler = (*widmgr.Signer)(nil) + _ NodeIdentityHandler = (*nsd.ServiceRegistrationHandler)(nil) +) + +func Test_assertAndSetNodeIdentityToken(t *testing.T) { + ci.Parallel(t) + + // Call the function with a non-nil object that implements the interface and + // verify that SetNodeIdentityToken is called with the expected token. + testImpl := &testHandler{} + assertAndSetNodeIdentityToken(testImpl, "test-token") + must.Eq(t, "test-token", testImpl.t) +} + +type testHandler struct{ t string } + +func (t *testHandler) SetNodeIdentityToken(token string) { t.t = token } diff --git a/client/rpc.go b/client/rpc.go index 9d3441119..69db60c45 100644 --- a/client/rpc.go +++ b/client/rpc.go @@ -493,7 +493,7 @@ func resolveServer(s string) (net.Addr, error) { func (c *Client) Ping(srv net.Addr) error { pingRequest := &structs.GenericRequest{ QueryOptions: structs.QueryOptions{ - AuthToken: c.secretNodeID(), + AuthToken: c.nodeAuthToken(), }, } var reply struct{} diff --git a/client/serviceregistration/nsd/nsd.go b/client/serviceregistration/nsd/nsd.go index 45a5aaf14..3cf97a2fb 100644 --- a/client/serviceregistration/nsd/nsd.go +++ b/client/serviceregistration/nsd/nsd.go @@ -9,6 +9,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-hclog" @@ -33,6 +34,11 @@ type ServiceRegistrationHandler struct { // registering new ones. registrationEnabled bool + // nodeAuthToken is the token the node is using for RPC authentication with + // the servers. This is an atomic value as the node identity is periodically + // renewed, meaning this value is updated while potentially being read. + nodeAuthToken atomic.Value + // shutDownCh coordinates shutting down the handler and any long-running // processes, such as the RPC retry. shutDownCh chan struct{} @@ -102,6 +108,11 @@ func NewServiceRegistrationHandler(log hclog.Logger, cfg *ServiceRegistrationHan return s } +// SetNodeIdentityToken fulfills the NodeIdentityHandler interface, allowing +// the client to update the node identity token used for RPC calls when it is +// renewed. +func (s *ServiceRegistrationHandler) SetNodeIdentityToken(token string) { s.nodeAuthToken.Store(token) } + func (s *ServiceRegistrationHandler) RegisterWorkload(workload *serviceregistration.WorkloadServices) error { // Check whether we are enabled or not first. Hitting this likely means // there is a bug within the implicit constraint, or process using it, as @@ -148,7 +159,7 @@ func (s *ServiceRegistrationHandler) RegisterWorkload(workload *serviceregistrat Services: registrations, WriteRequest: structs.WriteRequest{ Region: s.cfg.Region, - AuthToken: s.cfg.NodeSecret, + AuthToken: s.authToken(), }, } @@ -201,7 +212,7 @@ func (s *ServiceRegistrationHandler) removeWorkload( WriteRequest: structs.WriteRequest{ Region: s.cfg.Region, Namespace: workload.ProviderNamespace, - AuthToken: s.cfg.NodeSecret, + AuthToken: s.authToken(), }, } @@ -390,3 +401,14 @@ func (s *ServiceRegistrationHandler) generateNomadServiceRegistration( Port: port, }, nil } + +// authToken returns the current authentication token used for RPC calls. It +// will use the node identity token if it is set, otherwise it will fallback to +// the node secret. This handles the case where the node is upgraded before the +// Nomad servers and should be removed in Nomad 1.13. +func (s *ServiceRegistrationHandler) authToken() string { + if id := s.nodeAuthToken.Load(); id != nil { + return id.(string) + } + return s.cfg.NodeSecret +} diff --git a/client/widmgr/signer.go b/client/widmgr/signer.go index 2102d5324..a137c4598 100644 --- a/client/widmgr/signer.go +++ b/client/widmgr/signer.go @@ -5,6 +5,7 @@ package widmgr import ( "fmt" + "sync/atomic" "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/nomad/structs" @@ -34,9 +35,10 @@ type SignerConfig struct { // Signer fetches and validates workload identities. type Signer struct { - nodeSecret string - region string - rpc RPCer + nodeSecret string + nodeIdentityToken atomic.Value + region string + rpc RPCer } // NewSigner workload identity manager. @@ -48,6 +50,11 @@ func NewSigner(c SignerConfig) *Signer { } } +// SetNodeIdentityToken fulfills the NodeIdentityHandler interface, allowing +// the client to update the node identity token used for RPC calls when it is +// renewed. +func (s *Signer) SetNodeIdentityToken(token string) { s.nodeIdentityToken.Store(token) } + // SignIdentities wraps the Alloc.SignIdentities RPC and retrieves signed // workload identities. The minIndex should be set to the lowest allocation // CreateIndex to ensure that the server handling the request isn't so stale @@ -62,6 +69,15 @@ func (s *Signer) SignIdentities(minIndex uint64, req []*structs.WorkloadIdentity return nil, fmt.Errorf("no identities to sign") } + // Default to using the node secret, but if the node identity token is set, + // this will be used instead. This handles the case where the node is + // upgraded before the Nomad servers and should be removed in Nomad 1.13. + authToken := s.nodeSecret + + if id := s.nodeIdentityToken.Load(); id != nil { + authToken = id.(string) + } + args := structs.AllocIdentitiesRequest{ Identities: req, QueryOptions: structs.QueryOptions{ @@ -73,7 +89,7 @@ func (s *Signer) SignIdentities(minIndex uint64, req []*structs.WorkloadIdentity // Server to block at least until the Allocation is created. MinQueryIndex: minIndex - 1, AllowStale: true, - AuthToken: s.nodeSecret, + AuthToken: authToken, }, } reply := structs.AllocIdentitiesResponse{} From 953a1491808232b0c091a96ebf28600c509ac625 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Wed, 16 Jul 2025 15:56:00 +0200 Subject: [PATCH 7/7] client: Allow operators to force a client to renew its identity. (#26277) The Nomad client will have its identity renewed according to the TTL which defaults to 24h. In certain situations such as root keyring rotation, operators may want to force clients to renew their identities before the TTL threshold is met. This change introduces a client HTTP and RPC endpoint which will instruct the node to request a new identity at its next heartbeat. This can be used via the API or a new command. While this is a manual intervention step on top of the any keyring rotation, it dramatically reduces the initial feature complexity as it provides an asynchronous and efficient method of renewal that utilises existing functionality. --- api/node_identity.go | 33 ++++++++ api/node_identity_test.go | 29 +++++++ client/client.go | 34 +++++++- client/node_identity_endpoint.go | 33 ++++++++ client/node_identity_endpoint_test.go | 103 ++++++++++++++++++++++++ client/rpc.go | 17 ++-- command/agent/http.go | 1 + command/agent/node_identity_endpoint.go | 43 ++++++++++ command/commands.go | 10 +++ command/node_identity.go | 34 ++++++++ command/node_identity_renew.go | 88 ++++++++++++++++++++ nomad/client_identity_endpoint.go | 47 +++++++++++ nomad/client_identity_endpoint_test.go | 85 +++++++++++++++++++ nomad/node_endpoint.go | 2 + nomad/server.go | 1 + nomad/structs/node.go | 22 +++++ 16 files changed, 574 insertions(+), 8 deletions(-) create mode 100644 api/node_identity.go create mode 100644 api/node_identity_test.go create mode 100644 client/node_identity_endpoint.go create mode 100644 client/node_identity_endpoint_test.go create mode 100644 command/agent/node_identity_endpoint.go create mode 100644 command/node_identity.go create mode 100644 command/node_identity_renew.go create mode 100644 nomad/client_identity_endpoint.go create mode 100644 nomad/client_identity_endpoint_test.go diff --git a/api/node_identity.go b/api/node_identity.go new file mode 100644 index 000000000..497ebcd23 --- /dev/null +++ b/api/node_identity.go @@ -0,0 +1,33 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package api + +type NodeIdentityRenewRequest struct { + NodeID string +} + +type NodeIdentityRenewResponse struct{} + +type NodeIdentity struct { + client *Client +} + +func (n *Nodes) Identity() *NodeIdentity { + return &NodeIdentity{client: n.client} +} + +// Renew instructs the node to request a new identity from the server at its +// next heartbeat. +// +// The request uses query options to control the forwarding behavior of the +// request only. Parameters such as Filter, WaitTime, and WaitIndex are not used +// and ignored. +func (n *NodeIdentity) Renew(req *NodeIdentityRenewRequest, qo *QueryOptions) (*NodeIdentityRenewResponse, error) { + var out NodeIdentityRenewResponse + _, err := n.client.postQuery("/v1/client/identity/renew", req, &out, qo) + if err != nil { + return nil, err + } + return &out, nil +} diff --git a/api/node_identity_test.go b/api/node_identity_test.go new file mode 100644 index 000000000..56f682f15 --- /dev/null +++ b/api/node_identity_test.go @@ -0,0 +1,29 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package api + +import ( + "testing" + + "github.com/hashicorp/nomad/api/internal/testutil" + "github.com/shoenig/test/must" +) + +func TestNodeIdentity_Renew(t *testing.T) { + testutil.Parallel(t) + + configCallback := func(c *testutil.TestServerConfig) { c.DevMode = true } + testClient, testServer := makeClient(t, nil, configCallback) + defer testServer.Stop() + + nodeID := oneNodeFromNodeList(t, testClient.Nodes()).ID + + req := NodeIdentityRenewRequest{ + NodeID: nodeID, + } + + resp, err := testClient.Nodes().Identity().Renew(&req, nil) + must.NoError(t, err) + must.NotNil(t, resp) +} diff --git a/client/client.go b/client/client.go index 74abd0c16..de9b465ec 100644 --- a/client/client.go +++ b/client/client.go @@ -339,6 +339,11 @@ type Client struct { // the servers. This is used to authenticate the client to the servers when // performing RPC calls. identity atomic.Value + + // identityForceRenewal is used to force the client to renew its identity + // at the next heartbeat. It is set by an operator calling the node identity + // renew RPC method. + identityForceRenewal atomic.Bool } var ( @@ -402,6 +407,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie EnterpriseClient: newEnterpriseClient(logger), allocrunnerFactory: cfg.AllocRunnerFactory, identity: atomic.Value{}, + identityForceRenewal: atomic.Bool{}, } // we can't have this set in the default Config because of import cycles @@ -968,6 +974,10 @@ func (c *Client) nodeIdentityToken() string { // processes with a new node identity token. func (c *Client) setNodeIdentityToken(token string) { + // It's a bit of a simple log line, but it is useful to know when the client + // has renewed or set its node identity token. + c.logger.Info("setting node identity token") + // Store the token on the client as the first step, so it's available for // use by all RPCs immediately. c.identity.Store(token) @@ -2204,6 +2214,14 @@ func (c *Client) updateNodeStatus() error { AuthToken: c.nodeAuthToken(), }, } + + // Check if the client has been informed to force a renewal of its identity, + // and set the flag in the request if so. + if c.identityForceRenewal.Load() { + c.logger.Debug("forcing identity renewal") + req.ForceIdentityRenewal = true + } + var resp structs.NodeUpdateResponse if err := c.RPC("Node.UpdateStatus", &req, &resp); err != nil { c.triggerDiscovery() @@ -2226,7 +2244,17 @@ func (c *Client) updateNodeStatus() error { c.heartbeatLock.Unlock() c.logger.Trace("next heartbeat", "period", resp.HeartbeatTTL) - if resp.Index != 0 { + // The Nomad server will return an index of greater than zero when a Raft + // update has occurred, indicating a change in the state of the persisted + // node object. + // + // This can be due to a Nomad server invalidating the node's heartbeat timer + // and marking the node as down. In this case, we want to log a warning for + // the operator to see the client missed a heartbeat. If the server + // responded with a new identity, we assume the client did not miss a + // heartbeat. If we did, this line would appear each time the identity was + // renewed, which could confuse cluster operators. + if resp.Index != 0 && resp.SignedIdentity == nil { c.logger.Debug("state updated", "node_status", req.Status) // We have potentially missed our TTL log how delayed we were @@ -2276,6 +2304,10 @@ func (c *Client) handleNodeUpdateResponse(resp structs.NodeUpdateResponse) error return fmt.Errorf("error saving client identity: %w", err) } c.setNodeIdentityToken(*resp.SignedIdentity) + + // If the operator forced this renewal, reset the flag so that we don't + // keep renewing the identity on every heartbeat. + c.identityForceRenewal.Store(false) } // Convert []*NodeServerInfo to []*servers.Server diff --git a/client/node_identity_endpoint.go b/client/node_identity_endpoint.go new file mode 100644 index 000000000..8d3eb9289 --- /dev/null +++ b/client/node_identity_endpoint.go @@ -0,0 +1,33 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package client + +import ( + "github.com/hashicorp/nomad/nomad/structs" +) + +type NodeIdentity struct { + c *Client +} + +func newNodeIdentityEndpoint(c *Client) *NodeIdentity { + n := &NodeIdentity{c: c} + return n +} + +func (n *NodeIdentity) Renew(args *structs.NodeIdentityRenewReq, _ *structs.NodeIdentityRenewResp) error { + + // Check node write permissions. + if aclObj, err := n.c.ResolveToken(args.AuthToken); err != nil { + return err + } else if !aclObj.AllowNodeWrite() { + return structs.ErrPermissionDenied + } + + // Store the node identity renewal request on the client, so it can be + // picked up at the next heartbeat. + n.c.identityForceRenewal.Store(true) + + return nil +} diff --git a/client/node_identity_endpoint_test.go b/client/node_identity_endpoint_test.go new file mode 100644 index 000000000..cdbbd06e6 --- /dev/null +++ b/client/node_identity_endpoint_test.go @@ -0,0 +1,103 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package client + +import ( + "testing" + + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/nomad" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test/must" +) + +func TestNodeIdentity_Renew(t *testing.T) { + ci.Parallel(t) + + // Create a test ACL server and client and perform our node identity renewal + // tests against it. + testACLServer, testServerToken, testACLServerCleanup := nomad.TestACLServer(t, nil) + t.Cleanup(func() { testACLServerCleanup() }) + testutil.WaitForLeader(t, testACLServer.RPC) + + testACLClient, testACLClientCleanup := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{testACLServer.GetConfig().RPCAddr.String()} + }) + t.Cleanup(func() { _ = testACLClientCleanup() }) + testutil.WaitForClientStatusWithToken( + t, testACLServer.RPC, testACLClient.NodeID(), testACLClient.Region(), + structs.NodeStatusReady, testServerToken.SecretID, + ) + + t.Run("acl_denied", func(t *testing.T) { + must.ErrorContains( + t, + testACLClient.ClientRPC( + structs.NodeIdentityRenewRPCMethod, + &structs.NodeIdentityRenewReq{}, + &structs.NodeIdentityRenewResp{}, + ), + structs.ErrPermissionDenied.Error(), + ) + }) + + t.Run("acl_valid", func(t *testing.T) { + + aclPolicy := mock.NodePolicy(acl.PolicyWrite) + aclToken := mock.CreatePolicyAndToken(t, testACLServer.State(), 10, t.Name(), aclPolicy) + + req := structs.NodeIdentityRenewReq{ + NodeID: testACLClient.NodeID(), + QueryOptions: structs.QueryOptions{ + AuthToken: aclToken.SecretID, + }, + } + + must.NoError( + t, + testACLClient.ClientRPC( + structs.NodeIdentityRenewRPCMethod, + &req, + &structs.NodeIdentityRenewResp{}, + ), + ) + + renewalVal := testACLClient.identityForceRenewal.Load() + must.True(t, renewalVal) + }) + + // Create a test non-ACL server and client and perform our node identity + // renewal tests against it. + testServer, testServerCleanup := nomad.TestServer(t, nil) + t.Cleanup(func() { testServerCleanup() }) + testutil.WaitForLeader(t, testServer.RPC) + + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.Servers = []string{testServer.GetConfig().RPCAddr.String()} + }) + t.Cleanup(func() { _ = testClientCleanup() }) + testutil.WaitForClient(t, testServer.RPC, testClient.NodeID(), testClient.Region()) + + t.Run("non_acl_valid", func(t *testing.T) { + must.NoError( + t, + testClient.ClientRPC( + structs.NodeIdentityRenewRPCMethod, + &structs.NodeIdentityRenewReq{ + NodeID: testClient.NodeID(), + QueryOptions: structs.QueryOptions{}, + }, + &structs.NodeIdentityRenewResp{}, + ), + ) + + renewalVal := testClient.identityForceRenewal.Load() + must.True(t, renewalVal) + }) +} diff --git a/client/rpc.go b/client/rpc.go index 69db60c45..846589724 100644 --- a/client/rpc.go +++ b/client/rpc.go @@ -22,13 +22,14 @@ import ( // rpcEndpoints holds the RPC endpoints type rpcEndpoints struct { - ClientStats *ClientStats - CSI *CSI - FileSystem *FileSystem - Allocations *Allocations - Agent *Agent - NodeMeta *NodeMeta - HostVolume *HostVolume + ClientStats *ClientStats + CSI *CSI + FileSystem *FileSystem + Allocations *Allocations + Agent *Agent + NodeIdentity *NodeIdentity + NodeMeta *NodeMeta + HostVolume *HostVolume } // ClientRPC is used to make a local, client only RPC call @@ -301,6 +302,7 @@ func (c *Client) setupClientRpc(rpcs map[string]interface{}) { c.endpoints.FileSystem = NewFileSystemEndpoint(c) c.endpoints.Allocations = NewAllocationsEndpoint(c) c.endpoints.Agent = NewAgentEndpoint(c) + c.endpoints.NodeIdentity = newNodeIdentityEndpoint(c) c.endpoints.NodeMeta = newNodeMetaEndpoint(c) c.endpoints.HostVolume = newHostVolumesEndpoint(c) c.setupClientRpcServer(c.rpcServer) @@ -317,6 +319,7 @@ func (c *Client) setupClientRpcServer(server *rpc.Server) { server.Register(c.endpoints.FileSystem) server.Register(c.endpoints.Allocations) server.Register(c.endpoints.Agent) + _ = server.Register(c.endpoints.NodeIdentity) server.Register(c.endpoints.NodeMeta) server.Register(c.endpoints.HostVolume) } diff --git a/command/agent/http.go b/command/agent/http.go index 52c552677..67e88fa06 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -450,6 +450,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.Handle("/v1/client/stats", wrapCORS(s.wrap(s.ClientStatsRequest))) s.mux.Handle("/v1/client/allocation/", wrapCORS(s.wrap(s.ClientAllocRequest))) s.mux.Handle("/v1/client/metadata", wrapCORS(s.wrap(s.NodeMetaRequest))) + s.mux.Handle("/v1/client/identity/renew", wrapCORS(s.wrap(s.NodeIdentityRenewRequest))) s.mux.HandleFunc("/v1/agent/self", s.wrap(s.AgentSelfRequest)) s.mux.HandleFunc("/v1/agent/join", s.wrap(s.AgentJoinRequest)) diff --git a/command/agent/node_identity_endpoint.go b/command/agent/node_identity_endpoint.go new file mode 100644 index 000000000..4109c98e5 --- /dev/null +++ b/command/agent/node_identity_endpoint.go @@ -0,0 +1,43 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package agent + +import ( + "net/http" + + "github.com/hashicorp/nomad/nomad/structs" +) + +func (s *HTTPServer) NodeIdentityRenewRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + // Build the request by parsing all common parameters and node id + args := structs.NodeIdentityRenewReq{} + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + parseNode(req, &args.NodeID) + + // Determine the handler to use + useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(args.NodeID) + + // Make the RPC + var reply structs.NodeIdentityRenewResp + var rpcErr error + if useLocalClient { + rpcErr = s.agent.Client().ClientRPC(structs.NodeIdentityRenewRPCMethod, &args, &reply) + } else if useClientRPC { + rpcErr = s.agent.Client().RPC(structs.NodeIdentityRenewRPCMethod, &args, &reply) + } else if useServerRPC { + rpcErr = s.agent.Server().RPC(structs.NodeIdentityRenewRPCMethod, &args, &reply) + } else { + rpcErr = CodedError(400, "no local Node and node_id not provided") + } + + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } + + return nil, rpcErr + } + + return reply, nil +} diff --git a/command/commands.go b/command/commands.go index 0766ea43f..dbcd37bd6 100644 --- a/command/commands.go +++ b/command/commands.go @@ -634,6 +634,16 @@ func Commands(metaPtr *Meta, agentUi cli.Ui) map[string]cli.CommandFactory { Meta: meta, }, nil }, + "node identity": func() (cli.Command, error) { + return &NodeIdentityCommand{ + Meta: meta, + }, nil + }, + "node identity renew": func() (cli.Command, error) { + return &NodeIdentityRenewCommand{ + Meta: meta, + }, nil + }, "node meta": func() (cli.Command, error) { return &NodeMetaCommand{ Meta: meta, diff --git a/command/node_identity.go b/command/node_identity.go new file mode 100644 index 000000000..095ec6e44 --- /dev/null +++ b/command/node_identity.go @@ -0,0 +1,34 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "strings" + + "github.com/hashicorp/cli" +) + +type NodeIdentityCommand struct { + Meta +} + +func (n *NodeIdentityCommand) Help() string { + helpText := ` +Usage: nomad node identity [subcommand] + + Interact with a node's identity. All commands interact directly with a client + and require setting the target node via its 36 character ID. + + Please see the individual subcommand help for detailed usage information. +` + return strings.TrimSpace(helpText) +} + +func (n *NodeIdentityCommand) Synopsis() string { return "Force renewal of a nodes identity" } + +func (n *NodeIdentityCommand) Name() string { return "node identity" } + +func (n *NodeIdentityCommand) Run(_ []string) int { + return cli.RunResultHelp +} diff --git a/command/node_identity_renew.go b/command/node_identity_renew.go new file mode 100644 index 000000000..991e9c5a4 --- /dev/null +++ b/command/node_identity_renew.go @@ -0,0 +1,88 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "fmt" + "strings" + + "github.com/hashicorp/nomad/api" + "github.com/posener/complete" +) + +type NodeIdentityRenewCommand struct { + Meta +} + +func (n *NodeIdentityRenewCommand) Help() string { + helpText := ` +Usage: nomad node identity renew [options] + + Instruct a node to renew its identity at the next heartbeat. This command only + applies to client agents. + +General Options: + + ` + generalOptionsUsage(usageOptsDefault|usageOptsNoNamespace) + + return strings.TrimSpace(helpText) +} + +func (n *NodeIdentityRenewCommand) Synopsis() string { return "Force a node to renew its identity" } + +func (n *NodeIdentityRenewCommand) Name() string { return "node identity renew" } + +func (n *NodeIdentityRenewCommand) Run(args []string) int { + + flags := n.Meta.FlagSet(n.Name(), FlagSetClient) + flags.Usage = func() { n.Ui.Output(n.Help()) } + + if err := flags.Parse(args); err != nil { + return 1 + } + args = flags.Args() + + if len(args) != 1 { + n.Ui.Error("This command takes one argument: ") + n.Ui.Error(commandErrorText(n)) + return 1 + } + + // Get the HTTP client + client, err := n.Meta.Client() + if err != nil { + n.Ui.Error(fmt.Sprintf("Error initializing client: %s", err)) + return 1 + } + + nodeID := args[0] + + // Lookup nodeID + if nodeID != "" { + nodeID, err = lookupNodeID(client.Nodes(), nodeID) + if err != nil { + n.Ui.Error(err.Error()) + return 1 + } + } + + req := api.NodeIdentityRenewRequest{ + NodeID: nodeID, + } + + if _, err := client.Nodes().Identity().Renew(&req, nil); err != nil { + n.Ui.Error(fmt.Sprintf("Error requesting node identity renewal: %s", err)) + return 1 + } + + return 0 +} + +func (n *NodeIdentityRenewCommand) AutocompleteFlags() complete.Flags { + return n.Meta.AutocompleteFlags(FlagSetClient) +} + +func (n *NodeIdentityRenewCommand) AutocompleteArgs() complete.Predictor { + return nodePredictor(n.Client, nil) +} diff --git a/nomad/client_identity_endpoint.go b/nomad/client_identity_endpoint.go new file mode 100644 index 000000000..78235d546 --- /dev/null +++ b/nomad/client_identity_endpoint.go @@ -0,0 +1,47 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package nomad + +import ( + "time" + + metrics "github.com/hashicorp/go-metrics/compat" + "github.com/hashicorp/nomad/nomad/structs" +) + +type NodeIdentity struct { + srv *Server +} + +func newNodeIdentityEndpoint(srv *Server) *NodeIdentity { + return &NodeIdentity{ + srv: srv, + } +} + +func (n *NodeIdentity) Renew(args *structs.NodeIdentityRenewReq, reply *structs.NodeIdentityRenewResp) error { + + // Prevent infinite loop between the leader and the follower with the target + // node connection. + args.QueryOptions.AllowStale = true + + authErr := n.srv.Authenticate(nil, args) + if done, err := n.srv.forward(structs.NodeIdentityRenewRPCMethod, args, args, reply); done { + return err + } + n.srv.MeasureRPCRate("client_identity", structs.RateMetricWrite, args) + if authErr != nil { + return structs.ErrPermissionDenied + } + defer metrics.MeasureSince([]string{"nomad", "client_identity", "renew"}, time.Now()) + + // Check node write permissions + if aclObj, err := n.srv.ResolveACL(args); err != nil { + return err + } else if !aclObj.AllowNodeWrite() { + return structs.ErrPermissionDenied + } + + return n.srv.forwardClientRPC(structs.NodeIdentityRenewRPCMethod, args.NodeID, args, reply) +} diff --git a/nomad/client_identity_endpoint_test.go b/nomad/client_identity_endpoint_test.go new file mode 100644 index 000000000..a40289b88 --- /dev/null +++ b/nomad/client_identity_endpoint_test.go @@ -0,0 +1,85 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package nomad + +import ( + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test/must" +) + +func TestNodeIdentity_Renew_Forward(t *testing.T) { + ci.Parallel(t) + + servers := []*Server{} + for i := 0; i < 3; i++ { + s, cleanup := TestServer(t, func(c *Config) { + c.BootstrapExpect = 3 + c.NumSchedulers = 0 + }) + t.Cleanup(cleanup) + servers = append(servers, s) + } + + TestJoin(t, servers...) + leader := testutil.WaitForLeaders(t, servers[0].RPC, servers[1].RPC, servers[2].RPC) + + followers := []string{} + for _, s := range servers { + if addr := s.config.RPCAddr.String(); addr != leader { + followers = append(followers, addr) + } + } + t.Logf("leader=%s followers=%q", leader, followers) + + clients := make([]*client.Client, 4) + + for i := 0; i < 4; i++ { + c, cleanup := client.TestClient(t, func(c *config.Config) { + c.Servers = followers + }) + t.Cleanup(func() { _ = cleanup() }) + clients[i] = c + } + for _, c := range clients { + testutil.WaitForClient(t, servers[0].RPC, c.NodeID(), c.Region()) + } + + agentRPCs := []func(string, any, any) error{} + nodeIDs := make([]string, 0, len(clients)) + + // Build list of agents and node IDs + for _, s := range servers { + agentRPCs = append(agentRPCs, s.RPC) + } + + for _, c := range clients { + agentRPCs = append(agentRPCs, c.RPC) + nodeIDs = append(nodeIDs, c.NodeID()) + } + + // Iterate through all the agent RPCs to ensure that the renew RPC will + // succeed, no matter which agent we connect to. + for _, agentRPC := range agentRPCs { + for _, nodeID := range nodeIDs { + args := &structs.NodeIdentityRenewReq{ + NodeID: nodeID, + QueryOptions: structs.QueryOptions{ + Region: clients[0].Region(), + }, + } + must.NoError(t, + agentRPC(structs.NodeIdentityRenewRPCMethod, + args, + &structs.NodeIdentityRenewResp{}, + ), + ) + } + } +} diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index c04b1db00..d752be090 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -262,6 +262,8 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp reply.SignedIdentity = &signedJWT args.Node.IdentitySigningKeyID = signingKeyID + } else if originalNode != nil { + args.Node.IdentitySigningKeyID = originalNode.IdentitySigningKeyID } _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) diff --git a/nomad/server.go b/nomad/server.go index c6f7b0611..649fecf02 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1282,6 +1282,7 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { // These endpoints are client RPCs and don't include a connection context _ = server.Register(NewClientStatsEndpoint(s)) _ = server.Register(newNodeMetaEndpoint(s)) + _ = server.Register(newNodeIdentityEndpoint(s)) // These endpoints have their streaming component registered in // setupStreamingEndpoints, but their non-streaming RPCs are registered diff --git a/nomad/structs/node.go b/nomad/structs/node.go index a5a308e3f..8e62ddbf0 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -714,3 +714,25 @@ type NodeUpdateResponse struct { QueryMeta } + +const ( + // NodeIdentityRenewRPCMethod is the RPC method for instructing a client to + // forcibly request a renewal of its node identity at the next heartbeat. + // + // Args: NodeIdentityRenewReq + // Reply: NodeIdentityRenewResp + NodeIdentityRenewRPCMethod = "NodeIdentity.Renew" +) + +// NodeIdentityRenewReq is used to instruct the Nomad server to renew the client +// identity at its next heartbeat regardless of whether it is close to +// expiration. +type NodeIdentityRenewReq struct { + NodeID string + + // This is a client RPC, so we must use query options which allow us to set + // AllowStale=true. + QueryOptions +} + +type NodeIdentityRenewResp struct{}