From d5b2d5078b120b8ace704edece02ed916c204a27 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Tue, 1 Jul 2025 17:07:21 +0200 Subject: [PATCH] rpc: Generate node identities with node RPC handlers when needed. (#26165) When a Nomad client register or re-registers, the RPC handler will generate and return a node identity if required. When an identity is generated, the signing key ID will be stored within the node object, to ensure a root key is not deleted until it is not used. During normal client operation it will periodically heartbeat to the Nomad servers to indicate aliveness. The RPC handler that is used for this action has also been updated to conditionally perform identity generation. Performing it here means no extra RPC handlers are required and we inherit the jitter in identity generation from the heartbeat mechanism. The identity generation check methods are performed from the RPC request arguments, so they a scoped to the required behaviour and can handle the nuance of each RPC. Failure to generate an identity is considered terminal to the RPC call. The client will include behaviour to retry this error which is always caused by the encrypter not being ready unless the servers keyring has been corrupted. --- client/client_test.go | 1 + client/drain_test.go | 2 + command/acl_bootstrap_test.go | 6 +- command/agent/agent_test.go | 2 +- nomad/acl.go | 4 + nomad/auth/auth.go | 5 +- nomad/client_agent_endpoint_test.go | 5 +- nomad/client_alloc_endpoint_test.go | 1 + nomad/client_csi_endpoint_test.go | 1 + nomad/client_stats_endpoint_test.go | 1 + nomad/csi_endpoint_test.go | 11 +- nomad/drainer_int_test.go | 8 + nomad/encrypter.go | 11 +- nomad/eval_broker_test.go | 1 + nomad/host_volume_endpoint_test.go | 3 + nomad/node_endpoint.go | 148 +++++++- nomad/node_endpoint_test.go | 551 +++++++++++++++++++++++++++- nomad/rpc_test.go | 8 +- nomad/server_test.go | 1 + nomad/structs/identity.go | 12 +- nomad/structs/identity_test.go | 51 +++ nomad/structs/node.go | 177 +++++++++ nomad/structs/node_test.go | 370 +++++++++++++++++++ nomad/structs/structs.go | 63 ---- nomad/worker_test.go | 4 + 25 files changed, 1344 insertions(+), 103 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index fc6fb60e8..d3dafd194 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -703,6 +703,7 @@ func TestClient_SaveRestoreState(t *testing.T) { s1, _, cleanupS1 := testServer(t, nil) t.Cleanup(cleanupS1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.DevMode = false diff --git a/client/drain_test.go b/client/drain_test.go index d67a6219a..a995baf4c 100644 --- a/client/drain_test.go +++ b/client/drain_test.go @@ -29,6 +29,7 @@ func TestClient_SelfDrainConfig(t *testing.T) { srv, _, cleanupSRV := testServer(t, nil) defer cleanupSRV() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.RPCHandler = srv @@ -81,6 +82,7 @@ func TestClient_SelfDrain_FailLocal(t *testing.T) { srv, _, cleanupSRV := testServer(t, nil) defer cleanupSRV() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.RPCHandler = srv diff --git a/command/acl_bootstrap_test.go b/command/acl_bootstrap_test.go index 78c5fc566..14355c18d 100644 --- a/command/acl_bootstrap_test.go +++ b/command/acl_bootstrap_test.go @@ -23,7 +23,7 @@ func TestACLBootstrapCommand(t *testing.T) { c.ACL.PolicyTTL = 0 } - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) @@ -101,7 +101,7 @@ func TestACLBootstrapCommand_WithOperatorFileBootstrapToken(t *testing.T) { err := os.WriteFile(file, []byte(mockToken.SecretID), 0700) must.NoError(t, err) - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) @@ -139,7 +139,7 @@ func TestACLBootstrapCommand_WithBadOperatorFileBootstrapToken(t *testing.T) { err := os.WriteFile(file, []byte(invalidToken), 0700) must.NoError(t, err) - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 4b77df344..cbd7a076d 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -1120,7 +1120,7 @@ func TestServer_Reload_TLS_Shared_Keyloader(t *testing.T) { TLSConfig: &config.TLSConfig{ EnableHTTP: true, EnableRPC: true, - VerifyServerHostname: true, + VerifyServerHostname: false, CAFile: foocafile, CertFile: fooclientcert, KeyFile: fooclientkey, diff --git a/nomad/acl.go b/nomad/acl.go index 78cfc052c..1b77dc565 100644 --- a/nomad/acl.go +++ b/nomad/acl.go @@ -16,6 +16,10 @@ func (s *Server) AuthenticateServerOnly(ctx *RPCContext, args structs.RequestWit return s.auth.AuthenticateServerOnly(ctx, args) } +func (s *Server) AuthenticateNodeIdentityGenerator(ctx *RPCContext, args structs.RequestWithIdentity) error { + return s.auth.AuthenticateNodeIdentityGenerator(ctx, args) +} + func (s *Server) AuthenticateClientOnly(ctx *RPCContext, args structs.RequestWithIdentity) (*acl.ACL, error) { return s.auth.AuthenticateClientOnly(ctx, args) } diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index 1e412961c..afbd5ddfb 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -217,10 +217,11 @@ func (s *Authenticator) Authenticate(ctx RPCContext, args structs.RequestWithIde return nil } -// ResolveACL is an authentication wrapper which handles resolving ACL tokens, +// ResolveACL is an authentication wrapper that handles resolving ACL tokens, // Workload Identities, or client secrets into acl.ACL objects. Exclusively // server-to-server or client-to-server requests should be using -// AuthenticateServerOnly or AuthenticateClientOnly and never use this method. +// AuthenticateServerOnly or AuthenticateClientOnly unless they use the +// AuthenticateNodeIdentityGenerator function. func (s *Authenticator) ResolveACL(args structs.RequestWithIdentity) (*acl.ACL, error) { identity := args.GetIdentity() if identity == nil { diff --git a/nomad/client_agent_endpoint_test.go b/nomad/client_agent_endpoint_test.go index 3dcaa9ef7..44e752010 100644 --- a/nomad/client_agent_endpoint_test.go +++ b/nomad/client_agent_endpoint_test.go @@ -854,7 +854,10 @@ func TestAgentHost_Server(t *testing.T) { } c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s2.GetConfig().RPCAddr.String()} + c.Servers = []string{ + s1.GetConfig().RPCAddr.String(), + s2.GetConfig().RPCAddr.String(), + } c.EnableDebug = true }) defer cleanupC() diff --git a/nomad/client_alloc_endpoint_test.go b/nomad/client_alloc_endpoint_test.go index 48a5185bf..20a05679d 100644 --- a/nomad/client_alloc_endpoint_test.go +++ b/nomad/client_alloc_endpoint_test.go @@ -38,6 +38,7 @@ func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.Region()) c, cleanupC := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s.config.RPCAddr.String()} diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index d2d127584..241379f89 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -474,6 +474,7 @@ func setupLocal(t *testing.T) rpc.ClientCodec { t.Cleanup(cleanupS1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) codec := rpcClient(t, s1) mockCSI := newMockClientCSI() diff --git a/nomad/client_stats_endpoint_test.go b/nomad/client_stats_endpoint_test.go index 55c439da5..6b8d01d36 100644 --- a/nomad/client_stats_endpoint_test.go +++ b/nomad/client_stats_endpoint_test.go @@ -29,6 +29,7 @@ func TestClientStats_Stats_Local(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.Region()) c, cleanupC := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s.config.RPCAddr.String()} diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 2d71a0f46..744dd38c2 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -1136,12 +1136,13 @@ func TestCSIVolumeEndpoint_List_PaginationFiltering(t *testing.T) { func TestCSIVolumeEndpoint_Create(t *testing.T) { ci.Parallel(t) var err error - srv, rootToken, shutdown := TestACLServer(t, func(c *Config) { + srv, _, shutdown := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) fake := newMockClientCSI() fake.NextValidateError = nil @@ -1158,6 +1159,7 @@ func TestCSIVolumeEndpoint_Create(t *testing.T) { client, cleanup := client.TestClientWithRPCs(t, func(c *cconfig.Config) { c.Servers = []string{srv.config.RPCAddr.String()} + c.TLSConfig = srv.config.TLSConfig }, map[string]interface{}{"CSI": fake}, ) @@ -1169,8 +1171,11 @@ func TestCSIVolumeEndpoint_Create(t *testing.T) { }).Node req0 := &structs.NodeRegisterRequest{ - Node: node, - WriteRequest: structs.WriteRequest{Region: "global", AuthToken: rootToken.SecretID}, + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: node.SecretID, + }, } var resp0 structs.NodeUpdateResponse err = client.RPC("Node.Register", req0, &resp0) diff --git a/nomad/drainer_int_test.go b/nomad/drainer_int_test.go index 02f4e3142..b4c7d4507 100644 --- a/nomad/drainer_int_test.go +++ b/nomad/drainer_int_test.go @@ -149,6 +149,7 @@ func TestDrainer_Simple_ServiceOnly(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -220,6 +221,7 @@ func TestDrainer_Simple_ServiceOnly_Deadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -277,6 +279,7 @@ func TestDrainer_DrainEmptyNode(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create an empty node @@ -312,6 +315,7 @@ func TestDrainer_AllTypes_Deadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -420,6 +424,7 @@ func TestDrainer_AllTypes_NoDeadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create two nodes, registering the second later @@ -551,6 +556,7 @@ func TestDrainer_AllTypes_Deadline_GarbageCollectedNode(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -668,6 +674,7 @@ func TestDrainer_MultipleNSes_ServiceOnly(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -762,6 +769,7 @@ func TestDrainer_Batch_TransitionToForce(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node diff --git a/nomad/encrypter.go b/nomad/encrypter.go index ab580031b..fa1330c5b 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -303,11 +303,12 @@ func (e *Encrypter) Decrypt(ciphertext []byte, keyID string) ([]byte, error) { // header name. const keyIDHeader = "kid" -// SignClaims signs the identity claim for the task and returns an encoded JWT -// (including both the claim and its signature) and the key ID of the key used -// to sign it, or an error. +// SignClaims signs the identity claim and returns an encoded JWT (including +// both the claim and its signature) and the key ID of the key used to sign it, +// or an error. // -// SignClaims adds the Issuer claim prior to signing. +// SignClaims adds the Issuer claim prior to signing if it is unset by the +// caller. func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, error) { if claims == nil { @@ -324,7 +325,7 @@ func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, claims.Issuer = e.issuer } - opts := (&jose.SignerOptions{}).WithHeader("kid", cs.rootKey.Meta.KeyID).WithType("JWT") + opts := (&jose.SignerOptions{}).WithHeader(keyIDHeader, cs.rootKey.Meta.KeyID).WithType("JWT") var sig jose.Signer if cs.rsaPrivateKey != nil { diff --git a/nomad/eval_broker_test.go b/nomad/eval_broker_test.go index d44a189b9..df117310d 100644 --- a/nomad/eval_broker_test.go +++ b/nomad/eval_broker_test.go @@ -1535,6 +1535,7 @@ func TestEvalBroker_IntegrationTest(t *testing.T) { defer cleanupS1() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) codec := rpcClient(t, srv) store := srv.fsm.State() diff --git a/nomad/host_volume_endpoint_test.go b/nomad/host_volume_endpoint_test.go index 6a46c83b9..7433ff5aa 100644 --- a/nomad/host_volume_endpoint_test.go +++ b/nomad/host_volume_endpoint_test.go @@ -38,6 +38,7 @@ func TestHostVolumeEndpoint_CreateRegisterGetDelete(t *testing.T) { }) t.Cleanup(cleanupSrv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) store := srv.fsm.State() c1, node1 := newMockHostVolumeClient(t, srv, "prod") @@ -434,6 +435,7 @@ func TestHostVolumeEndpoint_List(t *testing.T) { }) t.Cleanup(cleanupSrv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) store := srv.fsm.State() codec := rpcClient(t, srv) @@ -809,6 +811,7 @@ func TestHostVolumeEndpoint_concurrency(t *testing.T) { srv, cleanup := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) t.Cleanup(cleanup) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) c, node := newMockHostVolumeClient(t, srv, "default") diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index e0743379e..c04b1db00 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -4,12 +4,14 @@ package nomad import ( + "errors" "fmt" "net/http" "reflect" "sync" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" metrics "github.com/hashicorp/go-metrics/compat" @@ -91,9 +93,10 @@ func NewNodeEndpoint(srv *Server, ctx *RPCContext) *Node { // Register is used to upsert a client that is available for scheduling func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUpdateResponse) error { - // note that we trust-on-first use and the identity will be anonymous for - // that initial request; we lean on mTLS for handling that safely - authErr := n.srv.Authenticate(n.ctx, args) + + // The node register RPC is responsible for generating node identities, so + // we use the custom authentication method shared with UpdateStatus. + authErr := n.srv.AuthenticateNodeIdentityGenerator(n.ctx, args) isForwarded := args.IsForwarded() if done, err := n.srv.forward("Node.Register", args, args, reply); done { @@ -108,7 +111,15 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp return err } n.srv.MeasureRPCRate("node", structs.RateMetricWrite, args) - if authErr != nil { + + // The authentication error can be because the identity is expired. If we + // stopped the handler execution here, the node would never be able to + // register after being disconnected. + // + // Further within the RPC we check the supplied SecretID against the stored + // value in state. This acts as a secondary check and can be seen as a + // refresh token, in the event the identity is expired. + if authErr != nil && !errors.Is(authErr, jwt.ErrExpired) { return structs.ErrPermissionDenied } @@ -161,8 +172,13 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp args.Node.NodePool = structs.NodePoolDefault } + // The current time is used at a number of places in the registration + // workflow. Generating it once avoids multiple calls to time.Now() and also + // means the same time is used across all checks and sets. + timeNow := time.Now() + // Set the timestamp when the node is registered - args.Node.StatusUpdatedAt = time.Now().Unix() + args.Node.StatusUpdatedAt = timeNow.Unix() // Compute the node class if err := args.Node.ComputeClass(); err != nil { @@ -214,6 +230,40 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp if n.srv.Region() == n.srv.config.AuthoritativeRegion { args.CreateNodePool = true } + + // Track the TTL that will be used for the node identity. + var identityTTL time.Duration + + // The identity TTL is determined by the node pool the node is registered + // in. In the event the node registration is triggering creation of a new + // node pool, it will be created with the default TTL, so we use this for + // the identity. + nodePool, err := snap.NodePoolByName(ws, args.Node.NodePool) + if err != nil { + return fmt.Errorf("failed to query node pool: %v", err) + } + if nodePool == nil { + identityTTL = structs.DefaultNodePoolNodeIdentityTTL + } else { + identityTTL = nodePool.NodeIdentityTTL + } + + // Check if we need to generate a node identity. This must happen before we + // send the Raft message, as the signing key ID is set on the node if we + // generate one. + if args.ShouldGenerateNodeIdentity(authErr, timeNow.UTC(), identityTTL) { + + claims := structs.GenerateNodeIdentityClaims(args.Node, n.srv.Region(), identityTTL) + + signedJWT, signingKeyID, err := n.srv.encrypter.SignClaims(claims) + if err != nil { + return fmt.Errorf("failed to sign node identity claims: %v", err) + } + + reply.SignedIdentity = &signedJWT + args.Node.IdentitySigningKeyID = signingKeyID + } + _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) if err != nil { n.logger.Error("register failed", "error", err) @@ -509,9 +559,13 @@ func (n *Node) deregister(args *structs.NodeBatchDeregisterRequest, // │ │ // └──── ready ─────┘ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *structs.NodeUpdateResponse) error { - // UpdateStatus receives requests from client and servers that mark failed - // heartbeats, so we can't use AuthenticateClientOnly - authErr := n.srv.Authenticate(n.ctx, args) + + // The node update status RPC is responsible for generating node identities, + // so we use the custom authentication method shared with Register. + // + // Note; UpdateStatus receives requests from clients and servers that mark + // failed heartbeats. + authErr := n.srv.AuthenticateNodeIdentityGenerator(n.ctx, args) isForwarded := args.IsForwarded() if done, err := n.srv.forward("Node.UpdateStatus", args, args, reply); done { @@ -573,14 +627,62 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // to track SecretIDs. // Update the timestamp of when the node status was updated - args.UpdatedAt = time.Now().Unix() + timeNow := time.Now() + args.UpdatedAt = timeNow.Unix() + + // Track the TTL that will be used for the node identity. + var identityTTL time.Duration + + // The identity TTL is determined by the node pool the node is registered + // in. The pool should already exist, as the node is already registered. If + // it does not, we use the default TTL as we have no better value to use. + // + // Once the node pool is created, the node's identity will have the TTL set + // by the node pool on its renewal. + nodePool, err := snap.NodePoolByName(ws, node.NodePool) + if err != nil { + return fmt.Errorf("failed to query node pool: %v", err) + } + if nodePool == nil { + identityTTL = structs.DefaultNodePoolNodeIdentityTTL + } else { + identityTTL = nodePool.NodeIdentityTTL + } + + // Check and generate a node identity if needed. + if args.ShouldGenerateNodeIdentity(timeNow.UTC(), identityTTL) { + + claims := structs.GenerateNodeIdentityClaims(node, n.srv.Region(), identityTTL) + + // Sign the claims with the encrypter and conditionally handle the + // error. The IdentitySigningErrorTerminal method has a description of + // why we do this. + signedJWT, signingKeyID, err := n.srv.encrypter.SignClaims(claims) + if err != nil { + if args.IdentitySigningErrorIsTerminal(timeNow) { + return fmt.Errorf("failed to sign node identity claims: %v", err) + } else { + n.logger.Warn( + "failed to sign node identity claims, will retry on next heartbeat", + "error", err, "node_id", node.ID) + } + } + + reply.SignedIdentity = &signedJWT + args.IdentitySigningKeyID = signingKeyID + } else { + // Ensure the IdentitySigningKeyID is cleared if we are not generating a + // new identity. This is important to ensure that we do not cause Raft + // updates unless we need to. + args.IdentitySigningKeyID = "" + } // Compute next status. switch node.Status { case structs.NodeStatusInit: if args.Status == structs.NodeStatusReady { - // Keep node in the initializing status if it has allocations but - // they are not updated. + // Keep the node in the initializing status if it has allocations, + // but they are not updated. allocs, err := snap.AllocsByNodeTerminal(ws, args.NodeID, false) if err != nil { return fmt.Errorf("failed to query node allocs: %v", err) @@ -592,13 +694,9 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct args.Status = structs.NodeStatusInit } - // Keep node in the initialing status if it's in a node pool that - // doesn't exist. - pool, err := snap.NodePoolByName(ws, node.NodePool) - if err != nil { - return fmt.Errorf("failed to query node pool: %v", err) - } - if pool == nil { + // Keep the node in the initialing status if it's in a node pool + // that doesn't exist. + if nodePool == nil { n.logger.Debug(fmt.Sprintf("marking node as %s due to missing node pool", structs.NodeStatusInit)) args.Status = structs.NodeStatusInit if !node.HasEvent(NodeWaitingForNodePool) { @@ -617,7 +715,19 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // Commit this update via Raft var index uint64 - if node.Status != args.Status || args.NodeEvent != nil { + + // Only perform a Raft apply if we really have to, so we avoid unnecessary + // cluster traffic and CPU load. + // + // We must update state if: + // - The node informed us of a new status. + // - The node informed us of a new event. + // - We have generated an identity which has been signed with a different + // key ID compared to the last identity generated for the node. + if node.Status != args.Status || + args.NodeEvent != nil || + node.IdentitySigningKeyID != args.IdentitySigningKeyID && args.IdentitySigningKeyID != "" { + // Attach an event if we are updating the node status to ready when it // is down via a heartbeat if node.Status == structs.NodeStatusDown && args.NodeEvent == nil { diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 0fae446de..92ad8b6a6 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/acl" @@ -37,6 +38,7 @@ func TestClientEndpoint_Register(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -89,6 +91,267 @@ func TestClientEndpoint_Register(t *testing.T) { }) } +func TestNode_Register_Identity(t *testing.T) { + ci.Parallel(t) + + // This helper function verifies the identity token generated by the server + // in the Node.Register RPC call. + verifyIdentityFn := func( + t *testing.T, + testServer *Server, + token string, + node *structs.Node, + ttl time.Duration, + ) { + t.Helper() + + identityClaims, err := testServer.encrypter.VerifyClaim(token) + must.NoError(t, err) + + must.Eq(t, ttl, identityClaims.Expiry.Time().Sub(identityClaims.NotBefore.Time())) + must.True(t, identityClaims.IsNode()) + must.Eq(t, identityClaims.NodeIdentityClaims, &structs.NodeIdentityClaims{ + NodeID: node.ID, + NodeDatacenter: node.Datacenter, + NodeClass: node.NodeClass, + NodePool: node.NodePool, + }) + + // Identify the active encrypter key ID, which would have been used to + // sign the identity token. + _, keyID, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Perform a lookup of the node in state. The IdentitySigningKeyID field + // should be populated with the active encrypter key ID. + stateNodeResp, err := testServer.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNodeResp) + must.Eq(t, keyID, stateNodeResp.IdentitySigningKeyID) + } + + testCases := []struct { + name string + testFn func(t *testing.T, srv *Server, codec rpc.ClientCodec) + }{ + { + // Test the initial registration flow, where a node will not include + // an authentication token in the request. + // + // A later registration will not generate a new identity, as the + // included identity is still valid. + name: "identity generation and node reregister", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + + req.WriteRequest.AuthToken = *resp.SignedIdentity + var resp2 structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp2)) + must.Nil(t, resp2.SignedIdentity) + }, + }, + { + // A node can register with a node pool that does not exist, and the + // server will create it on FSM write. In this case, the server + // should generate an identity with the default node pool identity + // TTL. + name: "create on register node pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + node.NodePool = "custom-pool" + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // A node can register with a node pool that exists, and the server + // will generate an identity with the node pool's identity TTL. + name: "non-default identity ttl", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools(structs.MsgTypeTestSetup, 1000, []*structs.NodePool{nodePool})) + + node := mock.Node() + node.NodePool = nodePool.Name + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, nodePool.NodeIdentityTTL) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration. + name: "identity close to expiration", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + timeNow := time.Now().UTC().Add(-20 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // A node could disconnect from the cluster for long enough for the + // identity to expire. When it reconnects and performs its + // reregistration, the server should generate a new identity. + name: "identity expired", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(-1 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure that if the node's SecretID is tampered with, the server + // rejects any attempt to register. This test is to gate against a + // potential regressions in how we handle identities within this + // RPC. + name: "identity expired secret ID tampered", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node.Copy())) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(-1 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + node.SecretID = uuid.Generate() + + req := structs.NodeRegisterRequest{ + Node: node.Copy(), + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.ErrorContains( + t, + msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp), + "node secret ID does not match", + ) + }, + }, + } + + // ACL enabled server test run. + testACLServer, _, aclServerCleanup := TestACLServer(t, func(c *Config) {}) + defer aclServerCleanup() + testACLCodec := rpcClient(t, testACLServer) + + testutil.WaitForLeader(t, testACLServer.RPC) + testutil.WaitForKeyring(t, testACLServer.RPC, testACLServer.config.Region) + + // ACL disabled server test run. + testServer, serverCleanup := TestServer(t, func(c *Config) {}) + defer serverCleanup() + testCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + for _, tc := range testCases { + t.Run("ACL_enabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testACLServer, testACLCodec) + }) + t.Run("ACL_disabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testServer, testCodec) + }) + } +} + // This test asserts that we only track node connections if they are not from // forwarded RPCs. This is essential otherwise we will think a Yamux session to // a Nomad server is actually the session to the node. @@ -106,8 +369,8 @@ func TestClientEndpoint_Register_NodeConn_Forwarded(t *testing.T) { }) defer cleanupS2() TestJoin(t, s1, s2) - testutil.WaitForLeader(t, s1.RPC) - testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForLeaders(t, s1.RPC, s2.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Determine the non-leader server var leader, nonLeader *Server @@ -190,6 +453,7 @@ func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -219,6 +483,7 @@ func TestClientEndpoint_Register_NodePool(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.config.Region) testCases := []struct { name string @@ -328,6 +593,7 @@ func TestClientEndpoint_Register_NodePool_Multiregion(t *testing.T) { defer cleanupS1() codec1 := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) s2, _, cleanupS2 := TestACLServer(t, func(c *Config) { c.Region = "region-2" @@ -340,6 +606,7 @@ func TestClientEndpoint_Register_NodePool_Multiregion(t *testing.T) { defer cleanupS2() codec2 := rpcClient(t, s2) testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForKeyring(t, s2.RPC, s2.config.Region) // Verify that registering a node with a new node pool in the authoritative // region creates the node pool. @@ -504,6 +771,7 @@ func TestClientEndpoint_DeregisterOne(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -617,6 +885,7 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -721,6 +990,7 @@ func TestClientEndpoint_UpdateStatus_Reconnect(t *testing.T) { codec := rpcClient(t, s) defer cleanupS() testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.config.Region) // Register node. node := mock.Node() @@ -914,6 +1184,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatRecovery(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -964,6 +1235,7 @@ func TestClientEndpoint_Register_GetEvals(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register a system job. job := mock.SystemJob() @@ -1055,6 +1327,7 @@ func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register a system job. job := mock.SystemJob() @@ -1163,6 +1436,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly(t *testing.T) { codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1224,6 +1498,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly_Advertise(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1255,6 +1530,7 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { testServer, serverCleanup := TestServer(t, nil) defer serverCleanup() testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) // Create a node and upsert this into state. node := mock.Node() @@ -1304,6 +1580,256 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { must.NoError(t, nodeEndpoint.UpdateStatus(&args, &reply)) } +func TestNode_UpdateStatus_Identity(t *testing.T) { + ci.Parallel(t) + + // This helper function verifies the identity token generated by the server + // in the Node.UpdateStatus RPC call. + verifyIdentityFn := func( + t *testing.T, + testServer *Server, + token string, + node *structs.Node, + ttl time.Duration, + ) { + t.Helper() + + identityClaims, err := testServer.encrypter.VerifyClaim(token) + must.NoError(t, err) + + must.Eq(t, ttl, identityClaims.Expiry.Time().Sub(identityClaims.NotBefore.Time())) + must.True(t, identityClaims.IsNode()) + must.Eq(t, identityClaims.NodeIdentityClaims, &structs.NodeIdentityClaims{ + NodeID: node.ID, + NodeDatacenter: node.Datacenter, + NodeClass: node.NodeClass, + NodePool: node.NodePool, + }) + + // Identify the active encrypter key ID, which would have been used to + // sign the identity token. + _, keyID, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Perform a lookup of the node in state. The IdentitySigningKeyID field + // should be populated with the active encrypter key ID. + stateNodeResp, err := testServer.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNodeResp) + must.Eq(t, keyID, stateNodeResp.IdentitySigningKeyID) + } + + testCases := []struct { + name string + testFn func(t *testing.T, srv *Server, codec rpc.ClientCodec) + }{ + { + // Ensure that the Node.UpdateStatus RPC generates a new identity + // for a client authenticating using its secret ID. + name: "node secret ID authenticated default pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.Eq(t, "", node.IdentitySigningKeyID) + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: node.SecretID, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure that the Node.UpdateStatus RPC generates a new identity + // for a client authenticating using its secret ID which belongs to + // a non-default node pool. + name: "node secret ID authenticated non-default pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools( + structs.MsgTypeTestSetup, + srv.raft.LastIndex(), + []*structs.NodePool{nodePool}, + )) + + node := mock.Node() + node.NodePool = nodePool.Name + + must.Eq(t, "", node.IdentitySigningKeyID) + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: node.SecretID, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, nodePool.NodeIdentityTTL) + }, + }, + { + // Nomad servers often call the Node.UpdateStatus RPC to notify that + // a node has missed its heartbeat. In this case, we should write + // the update to state, but not generate an identity token. + name: "leader acl token authenticated", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusDown, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: srv.getLeaderAcl(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.Nil(t, resp.SignedIdentity) + + stateNode, err := srv.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNode) + must.Eq(t, structs.NodeStatusDown, stateNode.Status) + must.Greater(t, stateNode.CreateIndex, stateNode.ModifyIndex) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration. + name: "identity close to expiration", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + timeNow := time.Now().UTC().Add(-20 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration and the new identity has a TTL set + // by its custom node pool configuration. + name: "identity close to expiration custom pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools( + structs.MsgTypeTestSetup, + srv.raft.LastIndex(), + []*structs.NodePool{nodePool}, + )) + + timeNow := time.Now().UTC().Add(-135 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + node.NodePool = nodePool.Name + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, nodePool.NodeIdentityTTL) + }, + }, + } + + // ACL enabled server test run. + testACLServer, _, aclServerCleanup := TestACLServer(t, func(c *Config) {}) + defer aclServerCleanup() + testACLCodec := rpcClient(t, testACLServer) + + testutil.WaitForLeader(t, testACLServer.RPC) + testutil.WaitForKeyring(t, testACLServer.RPC, testACLServer.config.Region) + + // ACL disabled server test run. + testServer, serverCleanup := TestServer(t, func(c *Config) {}) + defer serverCleanup() + testCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + for _, tc := range testCases { + t.Run("ACL_enabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testACLServer, testACLCodec) + }) + t.Run("ACL_disabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testServer, testCodec) + }) + } +} + // TestClientEndpoint_UpdateDrain asserts the ability to initiate drain // against a node and cancel that drain. It also asserts: // * an evaluation is created when the node becomes eligible @@ -1316,6 +1842,7 @@ func TestClientEndpoint_UpdateDrain(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Disable drainer to prevent drain from completing during test s1.nodeDrainer.SetEnabled(false, nil) @@ -1435,6 +1962,7 @@ func TestClientEndpoint_UpdatedDrainAndCompleted(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) state := s1.fsm.State() // Disable drainer for now @@ -1545,6 +2073,7 @@ func TestClientEndpoint_UpdatedDrainNoop(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) state := s1.fsm.State() // Create the register request @@ -1688,6 +2217,7 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) require := require.New(t) // Register a node @@ -1820,6 +2350,7 @@ func TestClientEndpoint_UpdateEligibility(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1933,6 +2464,7 @@ func TestClientEndpoint_GetNode(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1966,10 +2498,14 @@ func TestClientEndpoint_GetNode(t *testing.T) { t.Fatalf("bad ComputedClass: %#v", resp2.Node) } + _, keyID, err := s1.encrypter.GetActiveKey() + must.NoError(t, err) + // Update the status updated at value node.StatusUpdatedAt = resp2.Node.StatusUpdatedAt node.SecretID = "" node.Events = resp2.Node.Events + node.IdentitySigningKeyID = keyID must.Eq(t, node, resp2.Node) // assert that the node register event was set correctly @@ -2167,6 +2703,7 @@ func TestClientEndpoint_GetAllocs(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2497,6 +3034,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2621,6 +3159,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking_GC(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2699,6 +3238,7 @@ func TestClientEndpoint_GetClientAllocs_WithoutMigrateTokens(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2754,6 +3294,7 @@ func TestClientEndpoint_GetAllocs_Blocking(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2853,6 +3394,7 @@ func TestNode_UpdateAlloc(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2933,6 +3475,7 @@ func TestNode_UpdateAlloc_NodeNotReady(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register node. node := mock.Node() @@ -3109,6 +3652,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3522,6 +4066,7 @@ func TestClientEndpoint_ListNodes(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3594,6 +4139,7 @@ func TestClientEndpoint_ListNodes_Fields(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3961,6 +4507,7 @@ func TestClientEndpoint_UpdateAlloc_Evals_ByTrigger(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 6c0446561..5cb00ec6c 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -254,6 +254,7 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { s1, cleanupS1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "node1") + c.Region = "regionFoo" c.TLSConfig = &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, @@ -264,18 +265,19 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { } }) defer cleanupS1() + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) - codec := rpcClient(t, s1) + tlsCodec := rpcClientWithTLS(t, s1, s1.config.TLSConfig) // Create the register request node := mock.Node() req := &structs.NodeRegisterRequest{ Node: node, - WriteRequest: structs.WriteRequest{Region: "global"}, + WriteRequest: structs.WriteRequest{Region: s1.Region()}, } var resp structs.GenericResponse - err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) + err := msgpackrpc.CallWithCodec(tlsCodec, "Node.Register", req, &resp) assert.Nil(err) // Check that heartbeatTimers has the heartbeat ID diff --git a/nomad/server_test.go b/nomad/server_test.go index d7175af67..ea490c6e4 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -375,6 +375,7 @@ func TestServer_Reload_TLSConnections_TLSToPlaintext_OnlyRPC(t *testing.T) { } }) defer cleanupS1() + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) newTLSConfig := &config.TLSConfig{ EnableHTTP: true, diff --git a/nomad/structs/identity.go b/nomad/structs/identity.go index 41e43f99d..4c17a9b03 100644 --- a/nomad/structs/identity.go +++ b/nomad/structs/identity.go @@ -51,7 +51,17 @@ func (i *IdentityClaims) IsExpiring(now time.Time, ttl time.Duration) bool { // relative to the current time. threshold := now.Add(ttl / 3) - return i.Expiry.Time().Before(threshold) + return i.Expiry.Time().UTC().Before(threshold) +} + +// IsExpiringInThreshold checks if the identity JWT is expired or close to +// expiring. It uses a passed threshold to determine "close to expiring" which +// is not manipulated, unlike TTL in the IsExpiring method. +func (i *IdentityClaims) IsExpiringInThreshold(threshold time.Time) bool { + if i != nil && i.Expiry != nil { + return threshold.After(i.Expiry.Time()) + } + return false } // setExpiry sets the "expiry" or "exp" claim for the identity JWT. It is the diff --git a/nomad/structs/identity_test.go b/nomad/structs/identity_test.go index b35690f0c..8e09a7061 100644 --- a/nomad/structs/identity_test.go +++ b/nomad/structs/identity_test.go @@ -174,6 +174,57 @@ func TestIdentityClaims_IsExpiring(t *testing.T) { } } +func TestIdentityClaims_IsExpiringWithTTL(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputIdentityClaims *IdentityClaims + inputThreshold time.Time + expectedResult bool + }{ + { + name: "nil identity", + inputIdentityClaims: nil, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "no expiry", + inputIdentityClaims: &IdentityClaims{}, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "not close to expiring", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + }, + }, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "close to expiring", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now()), + }, + }, + inputThreshold: time.Now().Add(1 * time.Minute), + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputIdentityClaims.IsExpiringInThreshold(tc.inputThreshold) + must.Eq(t, tc.expectedResult, actualOutput) + }) + } +} + func TestIdentityClaimsNg_setExpiry(t *testing.T) { ci.Parallel(t) diff --git a/nomad/structs/node.go b/nomad/structs/node.go index bcc0fec39..a5a308e3f 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -537,3 +537,180 @@ func GenerateNodeIdentityClaims(node *Node, region string, ttl time.Duration) *I return claims } + +// NodeRegisterRequest is used by the Node.Register RPC endpoint to register a +// node as being a schedulable entity. +type NodeRegisterRequest struct { + Node *Node + NodeEvent *NodeEvent + + // CreateNodePool is used to indicate that the node's node pool should be + // created along with the node registration if it doesn't exist. + CreateNodePool bool + + WriteRequest +} + +// ShouldGenerateNodeIdentity compliments the functionality within +// AuthenticateNodeIdentityGenerator to determine whether a new node identity +// should be generated within the RPC handler. +func (n *NodeRegisterRequest) ShouldGenerateNodeIdentity( + authErr error, + now time.Time, + ttl time.Duration, +) bool { + + // In the event the error is because the node identity is expired, we should + // generate a new identity. Without this, a disconnected node would never be + // able to re-register. Any other error is not a reason to generate a new + // identity. + if authErr != nil { + return errors.Is(authErr, jwt.ErrExpired) + } + + // If an ACL token or client ID is set, a node is attempting to register for + // the first time, or is re-registering using its secret ID. In either case, + // we should generate a new identity. + if n.identity.ACLToken != nil || n.identity.ClientID != "" { + return true + } + + // If we have reached this point, we can assume that the request is using a + // node identity. + claims := n.GetIdentity().GetClaims() + + // It is possible that the node has been restarted and had its configuration + // updated. In this case, we should generate a new identity for the node, so + // it reflects its new claims. + if n.Node.NodePool != claims.NodeIdentityClaims.NodePool || + n.Node.NodeClass != claims.NodeIdentityClaims.NodeClass || + n.Node.Datacenter != claims.NodeIdentityClaims.NodeDatacenter { + return true + } + + // The final check is to see if the node identity is expiring. + return claims.IsExpiring(now, ttl) +} + +// NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint +// to update the status of a node. +type NodeUpdateStatusRequest struct { + NodeID string + Status string + + // IdentitySigningKeyID is the ID of the root key used to sign the node's + // identity. This is not provided by the client, but is set by the server, + // so that the value can be propagated through Raft. + IdentitySigningKeyID string + + // ForceIdentityRenewal is used to force the Nomad server to generate a new + // identity for the node. + ForceIdentityRenewal bool + + NodeEvent *NodeEvent + UpdatedAt int64 + WriteRequest +} + +// ShouldGenerateNodeIdentity determines whether the handler should generate a +// new node identity based on the caller identity information. +func (n *NodeUpdateStatusRequest) ShouldGenerateNodeIdentity( + now time.Time, + ttl time.Duration, +) bool { + + identity := n.GetIdentity() + + // If the client ID is set, we should generate a new identity as the node + // has authenticated using its secret ID. + if identity.ClientID != "" { + return true + } + + // Confirm we have a node identity and then check for forced renewal or + // expiration. + if identity.GetClaims().IsNode() { + if n.ForceIdentityRenewal { + return true + } + return n.GetIdentity().GetClaims().IsExpiring(now, ttl) + } + + // No other conditions should generate a new identity. In the case of the + // update status endpoint, this will likely be a Nomad server propagating + // that a node has missed its heartbeat. + return false +} + +// IdentitySigningErrorIsTerminal determines if the RPC handler should return an +// error because it failed to sign a newly generated node identity. +// +// This is because a client might be connected to a follower at the point the +// root keyring is rotated. If the client heartbeats right at that moment and +// before the follower decrypts the key (e.g., network latency to external KMS), +// we will mark the node as down. This is despite identity being valid and the +// likelihood it will get a new identity signed on the next heartbeat. +func (n *NodeUpdateStatusRequest) IdentitySigningErrorIsTerminal(now time.Time) bool { + + identity := n.GetIdentity() + + // If the client has authenticated using a secret ID, we can continue to let + // it do that, until we successfully generate a new identity. + if identity.ClientID != "" { + return false + } + + // If the identity is a node identity, we can check if it is expiring. This + // check is used to determine if the RPC handler should return an error, so + // we use a short threshold of 10 minutes. This is to ensure we don't return + // errors unless we absolutely have to. + // + // A threshold of 10 minutes more than covers another heartbeat on the + // largest Nomad clusters, which can reach ~5 minutes. + if identity.GetClaims().IsNode() { + return n.GetIdentity().GetClaims().IsExpiringInThreshold(now.Add(10 * time.Minute)) + } + + // No other condition should result in the RPC handler returning an error + // because we failed to sign the node identity. No caller should be able to + // reach this point, as identity generation should be gated by + // ShouldGenerateNodeIdentity. + return false +} + +// NodeUpdateResponse is used to respond to a node update. The object is a +// shared response used by the Node.Register, Node.Deregister, +// Node.BatchDeregister, Node.UpdateStatus, and Node.Evaluate RPCs. +type NodeUpdateResponse struct { + HeartbeatTTL time.Duration + EvalIDs []string + EvalCreateIndex uint64 + NodeModifyIndex uint64 + + // Features informs clients what enterprise features are allowed + Features uint64 + + // LeaderRPCAddr is the RPC address of the current Raft Leader. If + // empty, the current Nomad Server is in the minority of a partition. + LeaderRPCAddr string + + // NumNodes is the number of Nomad nodes attached to this quorum of + // Nomad Servers at the time of the response. This value can + // fluctuate based on the health of the cluster between heartbeats. + NumNodes int32 + + // Servers is the full list of known Nomad servers in the local + // region. + Servers []*NodeServerInfo + + // SchedulingEligibility is used to inform clients what the server-side + // has for their scheduling status during heartbeats. + SchedulingEligibility string + + // SignedIdentity is the newly signed node identity that the server has + // generated. The node should check if this is set, and if so, update its + // state with the new identity. + SignedIdentity *string + + QueryMeta +} diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index 95970aaed..57e3912ca 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/nomad/ci" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" @@ -279,3 +280,372 @@ func TestGenerateNodeIdentityClaims(t *testing.T) { must.NotNil(t, claims.NotBefore) must.NotNil(t, claims.Expiry) } + +func TestNodeRegisterRequest_ShouldGenerateNodeIdentity(t *testing.T) { + ci.Parallel(t) + + // Generate a stable mock node for testing. + mockNode := MockNode() + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeRegisterRequest + inputAuthErr error + inputTime time.Time + inputTTL time.Duration + expectedOutput bool + }{ + { + name: "expired node identity", + inputNodeRegisterRequest: &NodeRegisterRequest{}, + inputAuthErr: jwt.ErrExpired, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "first time node registration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: AnonymousACLToken, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "registration using node secret ID", + inputNodeRegisterRequest: &NodeRegisterRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id-1", + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "modified node node pool configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: "new-pool", + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "modified node class configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: "new-class", + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "modified node datacenter configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: "new-datacenter", + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "expiring node identity", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(5 * time.Minute)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "no generation", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.ShouldGenerateNodeIdentity( + tc.inputAuthErr, + tc.inputTime, + tc.inputTTL, + ) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} + +func TestNodeUpdateStatusRequest_ShouldGenerateNodeIdentity(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeUpdateStatusRequest + inputTime time.Time + inputTTL time.Duration + expectedOutput bool + }{ + { + name: "authenticated by node secret ID", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id-1", + }, + }, + }, + inputTime: time.Now(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "expiring node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(1 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "not expiring node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + { + name: "not expiring forced renewal node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + ForceIdentityRenewal: true, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "server authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: LeaderACLToken, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.ShouldGenerateNodeIdentity( + tc.inputTime, + tc.inputTTL, + ) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} +func TestNodeUpdateStatusRequest_IdentitySigningErrorIsTerminal(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeUpdateStatusRequest + inputTime time.Time + expectedOutput bool + }{ + { + name: "not close to expiring", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour).UTC()), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + { + name: "very close to expiring", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC()), + }, + }, + }, + }, + }, + inputTime: time.Now().Add(1 * time.Minute).UTC(), + expectedOutput: true, + }, + { + name: "server authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: LeaderACLToken, + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + { + name: "client secret ID authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id", + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.IdentitySigningErrorIsTerminal(tc.inputTime) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index b88c20109..b67c85b61 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -597,19 +597,6 @@ type WriteMeta struct { Index uint64 } -// NodeRegisterRequest is used for Node.Register endpoint -// to register a node as being a schedulable entity. -type NodeRegisterRequest struct { - Node *Node - NodeEvent *NodeEvent - - // CreateNodePool is used to indicate that the node's node pool should be - // create along with the node registration if it doesn't exist. - CreateNodePool bool - - WriteRequest -} - // NodeDeregisterRequest is used for Node.Deregister endpoint // to deregister a node as being a schedulable entity. type NodeDeregisterRequest struct { @@ -643,26 +630,6 @@ type NodeServerInfo struct { Datacenter string } -// NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint -// to update the status of a node. -type NodeUpdateStatusRequest struct { - NodeID string - Status string - - // IdentitySigningKeyID is the ID of the root key used to sign the node's - // identity. This is not provided by the client, but is set by the server, - // so that the value can be propagated through Raft. - IdentitySigningKeyID string - - // ForceIdentityRenewal is used to force the Nomad server to generate a new - // identity for the node. - ForceIdentityRenewal bool - - NodeEvent *NodeEvent - UpdatedAt int64 - WriteRequest -} - // NodeUpdateDrainRequest is used for updating the drain strategy type NodeUpdateDrainRequest struct { NodeID string @@ -1506,36 +1473,6 @@ type JobValidateResponse struct { Warnings string } -// NodeUpdateResponse is used to respond to a node update -type NodeUpdateResponse struct { - HeartbeatTTL time.Duration - EvalIDs []string - EvalCreateIndex uint64 - NodeModifyIndex uint64 - - // Features informs clients what enterprise features are allowed - Features uint64 - - // LeaderRPCAddr is the RPC address of the current Raft Leader. If - // empty, the current Nomad Server is in the minority of a partition. - LeaderRPCAddr string - - // NumNodes is the number of Nomad nodes attached to this quorum of - // Nomad Servers at the time of the response. This value can - // fluctuate based on the health of the cluster between heartbeats. - NumNodes int32 - - // Servers is the full list of known Nomad servers in the local - // region. - Servers []*NodeServerInfo - - // SchedulingEligibility is used to inform clients what the server-side - // has for their scheduling status during heartbeats. - SchedulingEligibility string - - QueryMeta -} - // NodeDrainUpdateResponse is used to respond to a node drain update type NodeDrainUpdateResponse struct { NodeModifyIndex uint64 diff --git a/nomad/worker_test.go b/nomad/worker_test.go index 4bd18c7ea..eb00f5806 100644 --- a/nomad/worker_test.go +++ b/nomad/worker_test.go @@ -522,6 +522,7 @@ func TestWorker_SubmitPlanNormalizedAllocations(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -574,6 +575,7 @@ func TestWorker_SubmitPlan_MissingNodeRefresh(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -648,6 +650,7 @@ func TestWorker_UpdateEval(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -699,6 +702,7 @@ func TestWorker_CreateEval(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node()