diff --git a/client/client_test.go b/client/client_test.go index fc6fb60e8..d3dafd194 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -703,6 +703,7 @@ func TestClient_SaveRestoreState(t *testing.T) { s1, _, cleanupS1 := testServer(t, nil) t.Cleanup(cleanupS1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.DevMode = false diff --git a/client/drain_test.go b/client/drain_test.go index d67a6219a..a995baf4c 100644 --- a/client/drain_test.go +++ b/client/drain_test.go @@ -29,6 +29,7 @@ func TestClient_SelfDrainConfig(t *testing.T) { srv, _, cleanupSRV := testServer(t, nil) defer cleanupSRV() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.RPCHandler = srv @@ -81,6 +82,7 @@ func TestClient_SelfDrain_FailLocal(t *testing.T) { srv, _, cleanupSRV := testServer(t, nil) defer cleanupSRV() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) c1, cleanupC1 := TestClient(t, func(c *config.Config) { c.RPCHandler = srv diff --git a/command/acl_bootstrap_test.go b/command/acl_bootstrap_test.go index 78c5fc566..14355c18d 100644 --- a/command/acl_bootstrap_test.go +++ b/command/acl_bootstrap_test.go @@ -23,7 +23,7 @@ func TestACLBootstrapCommand(t *testing.T) { c.ACL.PolicyTTL = 0 } - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) @@ -101,7 +101,7 @@ func TestACLBootstrapCommand_WithOperatorFileBootstrapToken(t *testing.T) { err := os.WriteFile(file, []byte(mockToken.SecretID), 0700) must.NoError(t, err) - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) @@ -139,7 +139,7 @@ func TestACLBootstrapCommand_WithBadOperatorFileBootstrapToken(t *testing.T) { err := os.WriteFile(file, []byte(invalidToken), 0700) must.NoError(t, err) - srv, _, url := testServer(t, true, config) + srv, _, url := testServer(t, false, config) defer srv.Shutdown() must.Nil(t, srv.RootToken) diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 4b77df344..cbd7a076d 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -1120,7 +1120,7 @@ func TestServer_Reload_TLS_Shared_Keyloader(t *testing.T) { TLSConfig: &config.TLSConfig{ EnableHTTP: true, EnableRPC: true, - VerifyServerHostname: true, + VerifyServerHostname: false, CAFile: foocafile, CertFile: fooclientcert, KeyFile: fooclientkey, diff --git a/nomad/acl.go b/nomad/acl.go index 78cfc052c..1b77dc565 100644 --- a/nomad/acl.go +++ b/nomad/acl.go @@ -16,6 +16,10 @@ func (s *Server) AuthenticateServerOnly(ctx *RPCContext, args structs.RequestWit return s.auth.AuthenticateServerOnly(ctx, args) } +func (s *Server) AuthenticateNodeIdentityGenerator(ctx *RPCContext, args structs.RequestWithIdentity) error { + return s.auth.AuthenticateNodeIdentityGenerator(ctx, args) +} + func (s *Server) AuthenticateClientOnly(ctx *RPCContext, args structs.RequestWithIdentity) (*acl.ACL, error) { return s.auth.AuthenticateClientOnly(ctx, args) } diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index 1e412961c..afbd5ddfb 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -217,10 +217,11 @@ func (s *Authenticator) Authenticate(ctx RPCContext, args structs.RequestWithIde return nil } -// ResolveACL is an authentication wrapper which handles resolving ACL tokens, +// ResolveACL is an authentication wrapper that handles resolving ACL tokens, // Workload Identities, or client secrets into acl.ACL objects. Exclusively // server-to-server or client-to-server requests should be using -// AuthenticateServerOnly or AuthenticateClientOnly and never use this method. +// AuthenticateServerOnly or AuthenticateClientOnly unless they use the +// AuthenticateNodeIdentityGenerator function. func (s *Authenticator) ResolveACL(args structs.RequestWithIdentity) (*acl.ACL, error) { identity := args.GetIdentity() if identity == nil { diff --git a/nomad/client_agent_endpoint_test.go b/nomad/client_agent_endpoint_test.go index 3dcaa9ef7..44e752010 100644 --- a/nomad/client_agent_endpoint_test.go +++ b/nomad/client_agent_endpoint_test.go @@ -854,7 +854,10 @@ func TestAgentHost_Server(t *testing.T) { } c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s2.GetConfig().RPCAddr.String()} + c.Servers = []string{ + s1.GetConfig().RPCAddr.String(), + s2.GetConfig().RPCAddr.String(), + } c.EnableDebug = true }) defer cleanupC() diff --git a/nomad/client_alloc_endpoint_test.go b/nomad/client_alloc_endpoint_test.go index 48a5185bf..20a05679d 100644 --- a/nomad/client_alloc_endpoint_test.go +++ b/nomad/client_alloc_endpoint_test.go @@ -38,6 +38,7 @@ func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.Region()) c, cleanupC := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s.config.RPCAddr.String()} diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index d2d127584..241379f89 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -474,6 +474,7 @@ func setupLocal(t *testing.T) rpc.ClientCodec { t.Cleanup(cleanupS1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) codec := rpcClient(t, s1) mockCSI := newMockClientCSI() diff --git a/nomad/client_stats_endpoint_test.go b/nomad/client_stats_endpoint_test.go index 55c439da5..6b8d01d36 100644 --- a/nomad/client_stats_endpoint_test.go +++ b/nomad/client_stats_endpoint_test.go @@ -29,6 +29,7 @@ func TestClientStats_Stats_Local(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.Region()) c, cleanupC := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s.config.RPCAddr.String()} diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 2d71a0f46..744dd38c2 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -1136,12 +1136,13 @@ func TestCSIVolumeEndpoint_List_PaginationFiltering(t *testing.T) { func TestCSIVolumeEndpoint_Create(t *testing.T) { ci.Parallel(t) var err error - srv, rootToken, shutdown := TestACLServer(t, func(c *Config) { + srv, _, shutdown := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) fake := newMockClientCSI() fake.NextValidateError = nil @@ -1158,6 +1159,7 @@ func TestCSIVolumeEndpoint_Create(t *testing.T) { client, cleanup := client.TestClientWithRPCs(t, func(c *cconfig.Config) { c.Servers = []string{srv.config.RPCAddr.String()} + c.TLSConfig = srv.config.TLSConfig }, map[string]interface{}{"CSI": fake}, ) @@ -1169,8 +1171,11 @@ func TestCSIVolumeEndpoint_Create(t *testing.T) { }).Node req0 := &structs.NodeRegisterRequest{ - Node: node, - WriteRequest: structs.WriteRequest{Region: "global", AuthToken: rootToken.SecretID}, + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: node.SecretID, + }, } var resp0 structs.NodeUpdateResponse err = client.RPC("Node.Register", req0, &resp0) diff --git a/nomad/drainer_int_test.go b/nomad/drainer_int_test.go index 02f4e3142..b4c7d4507 100644 --- a/nomad/drainer_int_test.go +++ b/nomad/drainer_int_test.go @@ -149,6 +149,7 @@ func TestDrainer_Simple_ServiceOnly(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -220,6 +221,7 @@ func TestDrainer_Simple_ServiceOnly_Deadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -277,6 +279,7 @@ func TestDrainer_DrainEmptyNode(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create an empty node @@ -312,6 +315,7 @@ func TestDrainer_AllTypes_Deadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -420,6 +424,7 @@ func TestDrainer_AllTypes_NoDeadline(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create two nodes, registering the second later @@ -551,6 +556,7 @@ func TestDrainer_AllTypes_Deadline_GarbageCollectedNode(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -668,6 +674,7 @@ func TestDrainer_MultipleNSes_ServiceOnly(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node @@ -762,6 +769,7 @@ func TestDrainer_Batch_TransitionToForce(t *testing.T) { defer cleanupSrv() codec := rpcClient(t, srv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) store := srv.State() // Create a node diff --git a/nomad/encrypter.go b/nomad/encrypter.go index ab580031b..fa1330c5b 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -303,11 +303,12 @@ func (e *Encrypter) Decrypt(ciphertext []byte, keyID string) ([]byte, error) { // header name. const keyIDHeader = "kid" -// SignClaims signs the identity claim for the task and returns an encoded JWT -// (including both the claim and its signature) and the key ID of the key used -// to sign it, or an error. +// SignClaims signs the identity claim and returns an encoded JWT (including +// both the claim and its signature) and the key ID of the key used to sign it, +// or an error. // -// SignClaims adds the Issuer claim prior to signing. +// SignClaims adds the Issuer claim prior to signing if it is unset by the +// caller. func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, error) { if claims == nil { @@ -324,7 +325,7 @@ func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, claims.Issuer = e.issuer } - opts := (&jose.SignerOptions{}).WithHeader("kid", cs.rootKey.Meta.KeyID).WithType("JWT") + opts := (&jose.SignerOptions{}).WithHeader(keyIDHeader, cs.rootKey.Meta.KeyID).WithType("JWT") var sig jose.Signer if cs.rsaPrivateKey != nil { diff --git a/nomad/eval_broker_test.go b/nomad/eval_broker_test.go index d44a189b9..df117310d 100644 --- a/nomad/eval_broker_test.go +++ b/nomad/eval_broker_test.go @@ -1535,6 +1535,7 @@ func TestEvalBroker_IntegrationTest(t *testing.T) { defer cleanupS1() testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.Region()) codec := rpcClient(t, srv) store := srv.fsm.State() diff --git a/nomad/host_volume_endpoint_test.go b/nomad/host_volume_endpoint_test.go index 6a46c83b9..7433ff5aa 100644 --- a/nomad/host_volume_endpoint_test.go +++ b/nomad/host_volume_endpoint_test.go @@ -38,6 +38,7 @@ func TestHostVolumeEndpoint_CreateRegisterGetDelete(t *testing.T) { }) t.Cleanup(cleanupSrv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) store := srv.fsm.State() c1, node1 := newMockHostVolumeClient(t, srv, "prod") @@ -434,6 +435,7 @@ func TestHostVolumeEndpoint_List(t *testing.T) { }) t.Cleanup(cleanupSrv) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) store := srv.fsm.State() codec := rpcClient(t, srv) @@ -809,6 +811,7 @@ func TestHostVolumeEndpoint_concurrency(t *testing.T) { srv, cleanup := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) t.Cleanup(cleanup) testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, srv.config.Region) c, node := newMockHostVolumeClient(t, srv, "default") diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index e0743379e..c04b1db00 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -4,12 +4,14 @@ package nomad import ( + "errors" "fmt" "net/http" "reflect" "sync" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" metrics "github.com/hashicorp/go-metrics/compat" @@ -91,9 +93,10 @@ func NewNodeEndpoint(srv *Server, ctx *RPCContext) *Node { // Register is used to upsert a client that is available for scheduling func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUpdateResponse) error { - // note that we trust-on-first use and the identity will be anonymous for - // that initial request; we lean on mTLS for handling that safely - authErr := n.srv.Authenticate(n.ctx, args) + + // The node register RPC is responsible for generating node identities, so + // we use the custom authentication method shared with UpdateStatus. + authErr := n.srv.AuthenticateNodeIdentityGenerator(n.ctx, args) isForwarded := args.IsForwarded() if done, err := n.srv.forward("Node.Register", args, args, reply); done { @@ -108,7 +111,15 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp return err } n.srv.MeasureRPCRate("node", structs.RateMetricWrite, args) - if authErr != nil { + + // The authentication error can be because the identity is expired. If we + // stopped the handler execution here, the node would never be able to + // register after being disconnected. + // + // Further within the RPC we check the supplied SecretID against the stored + // value in state. This acts as a secondary check and can be seen as a + // refresh token, in the event the identity is expired. + if authErr != nil && !errors.Is(authErr, jwt.ErrExpired) { return structs.ErrPermissionDenied } @@ -161,8 +172,13 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp args.Node.NodePool = structs.NodePoolDefault } + // The current time is used at a number of places in the registration + // workflow. Generating it once avoids multiple calls to time.Now() and also + // means the same time is used across all checks and sets. + timeNow := time.Now() + // Set the timestamp when the node is registered - args.Node.StatusUpdatedAt = time.Now().Unix() + args.Node.StatusUpdatedAt = timeNow.Unix() // Compute the node class if err := args.Node.ComputeClass(); err != nil { @@ -214,6 +230,40 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp if n.srv.Region() == n.srv.config.AuthoritativeRegion { args.CreateNodePool = true } + + // Track the TTL that will be used for the node identity. + var identityTTL time.Duration + + // The identity TTL is determined by the node pool the node is registered + // in. In the event the node registration is triggering creation of a new + // node pool, it will be created with the default TTL, so we use this for + // the identity. + nodePool, err := snap.NodePoolByName(ws, args.Node.NodePool) + if err != nil { + return fmt.Errorf("failed to query node pool: %v", err) + } + if nodePool == nil { + identityTTL = structs.DefaultNodePoolNodeIdentityTTL + } else { + identityTTL = nodePool.NodeIdentityTTL + } + + // Check if we need to generate a node identity. This must happen before we + // send the Raft message, as the signing key ID is set on the node if we + // generate one. + if args.ShouldGenerateNodeIdentity(authErr, timeNow.UTC(), identityTTL) { + + claims := structs.GenerateNodeIdentityClaims(args.Node, n.srv.Region(), identityTTL) + + signedJWT, signingKeyID, err := n.srv.encrypter.SignClaims(claims) + if err != nil { + return fmt.Errorf("failed to sign node identity claims: %v", err) + } + + reply.SignedIdentity = &signedJWT + args.Node.IdentitySigningKeyID = signingKeyID + } + _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) if err != nil { n.logger.Error("register failed", "error", err) @@ -509,9 +559,13 @@ func (n *Node) deregister(args *structs.NodeBatchDeregisterRequest, // │ │ // └──── ready ─────┘ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *structs.NodeUpdateResponse) error { - // UpdateStatus receives requests from client and servers that mark failed - // heartbeats, so we can't use AuthenticateClientOnly - authErr := n.srv.Authenticate(n.ctx, args) + + // The node update status RPC is responsible for generating node identities, + // so we use the custom authentication method shared with Register. + // + // Note; UpdateStatus receives requests from clients and servers that mark + // failed heartbeats. + authErr := n.srv.AuthenticateNodeIdentityGenerator(n.ctx, args) isForwarded := args.IsForwarded() if done, err := n.srv.forward("Node.UpdateStatus", args, args, reply); done { @@ -573,14 +627,62 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // to track SecretIDs. // Update the timestamp of when the node status was updated - args.UpdatedAt = time.Now().Unix() + timeNow := time.Now() + args.UpdatedAt = timeNow.Unix() + + // Track the TTL that will be used for the node identity. + var identityTTL time.Duration + + // The identity TTL is determined by the node pool the node is registered + // in. The pool should already exist, as the node is already registered. If + // it does not, we use the default TTL as we have no better value to use. + // + // Once the node pool is created, the node's identity will have the TTL set + // by the node pool on its renewal. + nodePool, err := snap.NodePoolByName(ws, node.NodePool) + if err != nil { + return fmt.Errorf("failed to query node pool: %v", err) + } + if nodePool == nil { + identityTTL = structs.DefaultNodePoolNodeIdentityTTL + } else { + identityTTL = nodePool.NodeIdentityTTL + } + + // Check and generate a node identity if needed. + if args.ShouldGenerateNodeIdentity(timeNow.UTC(), identityTTL) { + + claims := structs.GenerateNodeIdentityClaims(node, n.srv.Region(), identityTTL) + + // Sign the claims with the encrypter and conditionally handle the + // error. The IdentitySigningErrorTerminal method has a description of + // why we do this. + signedJWT, signingKeyID, err := n.srv.encrypter.SignClaims(claims) + if err != nil { + if args.IdentitySigningErrorIsTerminal(timeNow) { + return fmt.Errorf("failed to sign node identity claims: %v", err) + } else { + n.logger.Warn( + "failed to sign node identity claims, will retry on next heartbeat", + "error", err, "node_id", node.ID) + } + } + + reply.SignedIdentity = &signedJWT + args.IdentitySigningKeyID = signingKeyID + } else { + // Ensure the IdentitySigningKeyID is cleared if we are not generating a + // new identity. This is important to ensure that we do not cause Raft + // updates unless we need to. + args.IdentitySigningKeyID = "" + } // Compute next status. switch node.Status { case structs.NodeStatusInit: if args.Status == structs.NodeStatusReady { - // Keep node in the initializing status if it has allocations but - // they are not updated. + // Keep the node in the initializing status if it has allocations, + // but they are not updated. allocs, err := snap.AllocsByNodeTerminal(ws, args.NodeID, false) if err != nil { return fmt.Errorf("failed to query node allocs: %v", err) @@ -592,13 +694,9 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct args.Status = structs.NodeStatusInit } - // Keep node in the initialing status if it's in a node pool that - // doesn't exist. - pool, err := snap.NodePoolByName(ws, node.NodePool) - if err != nil { - return fmt.Errorf("failed to query node pool: %v", err) - } - if pool == nil { + // Keep the node in the initialing status if it's in a node pool + // that doesn't exist. + if nodePool == nil { n.logger.Debug(fmt.Sprintf("marking node as %s due to missing node pool", structs.NodeStatusInit)) args.Status = structs.NodeStatusInit if !node.HasEvent(NodeWaitingForNodePool) { @@ -617,7 +715,19 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // Commit this update via Raft var index uint64 - if node.Status != args.Status || args.NodeEvent != nil { + + // Only perform a Raft apply if we really have to, so we avoid unnecessary + // cluster traffic and CPU load. + // + // We must update state if: + // - The node informed us of a new status. + // - The node informed us of a new event. + // - We have generated an identity which has been signed with a different + // key ID compared to the last identity generated for the node. + if node.Status != args.Status || + args.NodeEvent != nil || + node.IdentitySigningKeyID != args.IdentitySigningKeyID && args.IdentitySigningKeyID != "" { + // Attach an event if we are updating the node status to ready when it // is down via a heartbeat if node.Status == structs.NodeStatusDown && args.NodeEvent == nil { diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 0fae446de..92ad8b6a6 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/acl" @@ -37,6 +38,7 @@ func TestClientEndpoint_Register(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -89,6 +91,267 @@ func TestClientEndpoint_Register(t *testing.T) { }) } +func TestNode_Register_Identity(t *testing.T) { + ci.Parallel(t) + + // This helper function verifies the identity token generated by the server + // in the Node.Register RPC call. + verifyIdentityFn := func( + t *testing.T, + testServer *Server, + token string, + node *structs.Node, + ttl time.Duration, + ) { + t.Helper() + + identityClaims, err := testServer.encrypter.VerifyClaim(token) + must.NoError(t, err) + + must.Eq(t, ttl, identityClaims.Expiry.Time().Sub(identityClaims.NotBefore.Time())) + must.True(t, identityClaims.IsNode()) + must.Eq(t, identityClaims.NodeIdentityClaims, &structs.NodeIdentityClaims{ + NodeID: node.ID, + NodeDatacenter: node.Datacenter, + NodeClass: node.NodeClass, + NodePool: node.NodePool, + }) + + // Identify the active encrypter key ID, which would have been used to + // sign the identity token. + _, keyID, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Perform a lookup of the node in state. The IdentitySigningKeyID field + // should be populated with the active encrypter key ID. + stateNodeResp, err := testServer.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNodeResp) + must.Eq(t, keyID, stateNodeResp.IdentitySigningKeyID) + } + + testCases := []struct { + name string + testFn func(t *testing.T, srv *Server, codec rpc.ClientCodec) + }{ + { + // Test the initial registration flow, where a node will not include + // an authentication token in the request. + // + // A later registration will not generate a new identity, as the + // included identity is still valid. + name: "identity generation and node reregister", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + + req.WriteRequest.AuthToken = *resp.SignedIdentity + var resp2 structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp2)) + must.Nil(t, resp2.SignedIdentity) + }, + }, + { + // A node can register with a node pool that does not exist, and the + // server will create it on FSM write. In this case, the server + // should generate an identity with the default node pool identity + // TTL. + name: "create on register node pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + node.NodePool = "custom-pool" + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // A node can register with a node pool that exists, and the server + // will generate an identity with the node pool's identity TTL. + name: "non-default identity ttl", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools(structs.MsgTypeTestSetup, 1000, []*structs.NodePool{nodePool})) + + node := mock.Node() + node.NodePool = nodePool.Name + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, nodePool.NodeIdentityTTL) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration. + name: "identity close to expiration", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + timeNow := time.Now().UTC().Add(-20 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // A node could disconnect from the cluster for long enough for the + // identity to expire. When it reconnects and performs its + // reregistration, the server should generate a new identity. + name: "identity expired", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(-1 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, req.Node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure that if the node's SecretID is tampered with, the server + // rejects any attempt to register. This test is to gate against a + // potential regressions in how we handle identities within this + // RPC. + name: "identity expired secret ID tampered", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, 1000, node.Copy())) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(-1 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + node.SecretID = uuid.Generate() + + req := structs.NodeRegisterRequest{ + Node: node.Copy(), + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.ErrorContains( + t, + msgpackrpc.CallWithCodec(codec, "Node.Register", &req, &resp), + "node secret ID does not match", + ) + }, + }, + } + + // ACL enabled server test run. + testACLServer, _, aclServerCleanup := TestACLServer(t, func(c *Config) {}) + defer aclServerCleanup() + testACLCodec := rpcClient(t, testACLServer) + + testutil.WaitForLeader(t, testACLServer.RPC) + testutil.WaitForKeyring(t, testACLServer.RPC, testACLServer.config.Region) + + // ACL disabled server test run. + testServer, serverCleanup := TestServer(t, func(c *Config) {}) + defer serverCleanup() + testCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + for _, tc := range testCases { + t.Run("ACL_enabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testACLServer, testACLCodec) + }) + t.Run("ACL_disabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testServer, testCodec) + }) + } +} + // This test asserts that we only track node connections if they are not from // forwarded RPCs. This is essential otherwise we will think a Yamux session to // a Nomad server is actually the session to the node. @@ -106,8 +369,8 @@ func TestClientEndpoint_Register_NodeConn_Forwarded(t *testing.T) { }) defer cleanupS2() TestJoin(t, s1, s2) - testutil.WaitForLeader(t, s1.RPC) - testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForLeaders(t, s1.RPC, s2.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Determine the non-leader server var leader, nonLeader *Server @@ -190,6 +453,7 @@ func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -219,6 +483,7 @@ func TestClientEndpoint_Register_NodePool(t *testing.T) { defer cleanupS() codec := rpcClient(t, s) testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.config.Region) testCases := []struct { name string @@ -328,6 +593,7 @@ func TestClientEndpoint_Register_NodePool_Multiregion(t *testing.T) { defer cleanupS1() codec1 := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) s2, _, cleanupS2 := TestACLServer(t, func(c *Config) { c.Region = "region-2" @@ -340,6 +606,7 @@ func TestClientEndpoint_Register_NodePool_Multiregion(t *testing.T) { defer cleanupS2() codec2 := rpcClient(t, s2) testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForKeyring(t, s2.RPC, s2.config.Region) // Verify that registering a node with a new node pool in the authoritative // region creates the node pool. @@ -504,6 +771,7 @@ func TestClientEndpoint_DeregisterOne(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -617,6 +885,7 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -721,6 +990,7 @@ func TestClientEndpoint_UpdateStatus_Reconnect(t *testing.T) { codec := rpcClient(t, s) defer cleanupS() testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, s.config.Region) // Register node. node := mock.Node() @@ -914,6 +1184,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatRecovery(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Check that we have no client connections require.Empty(s1.connectedNodes()) @@ -964,6 +1235,7 @@ func TestClientEndpoint_Register_GetEvals(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register a system job. job := mock.SystemJob() @@ -1055,6 +1327,7 @@ func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register a system job. job := mock.SystemJob() @@ -1163,6 +1436,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly(t *testing.T) { codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1224,6 +1498,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly_Advertise(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1255,6 +1530,7 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { testServer, serverCleanup := TestServer(t, nil) defer serverCleanup() testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) // Create a node and upsert this into state. node := mock.Node() @@ -1304,6 +1580,256 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { must.NoError(t, nodeEndpoint.UpdateStatus(&args, &reply)) } +func TestNode_UpdateStatus_Identity(t *testing.T) { + ci.Parallel(t) + + // This helper function verifies the identity token generated by the server + // in the Node.UpdateStatus RPC call. + verifyIdentityFn := func( + t *testing.T, + testServer *Server, + token string, + node *structs.Node, + ttl time.Duration, + ) { + t.Helper() + + identityClaims, err := testServer.encrypter.VerifyClaim(token) + must.NoError(t, err) + + must.Eq(t, ttl, identityClaims.Expiry.Time().Sub(identityClaims.NotBefore.Time())) + must.True(t, identityClaims.IsNode()) + must.Eq(t, identityClaims.NodeIdentityClaims, &structs.NodeIdentityClaims{ + NodeID: node.ID, + NodeDatacenter: node.Datacenter, + NodeClass: node.NodeClass, + NodePool: node.NodePool, + }) + + // Identify the active encrypter key ID, which would have been used to + // sign the identity token. + _, keyID, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Perform a lookup of the node in state. The IdentitySigningKeyID field + // should be populated with the active encrypter key ID. + stateNodeResp, err := testServer.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNodeResp) + must.Eq(t, keyID, stateNodeResp.IdentitySigningKeyID) + } + + testCases := []struct { + name string + testFn func(t *testing.T, srv *Server, codec rpc.ClientCodec) + }{ + { + // Ensure that the Node.UpdateStatus RPC generates a new identity + // for a client authenticating using its secret ID. + name: "node secret ID authenticated default pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.Eq(t, "", node.IdentitySigningKeyID) + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: node.SecretID, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure that the Node.UpdateStatus RPC generates a new identity + // for a client authenticating using its secret ID which belongs to + // a non-default node pool. + name: "node secret ID authenticated non-default pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools( + structs.MsgTypeTestSetup, + srv.raft.LastIndex(), + []*structs.NodePool{nodePool}, + )) + + node := mock.Node() + node.NodePool = nodePool.Name + + must.Eq(t, "", node.IdentitySigningKeyID) + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: node.SecretID, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, nodePool.NodeIdentityTTL) + }, + }, + { + // Nomad servers often call the Node.UpdateStatus RPC to notify that + // a node has missed its heartbeat. In this case, we should write + // the update to state, but not generate an identity token. + name: "leader acl token authenticated", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + req := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusDown, + WriteRequest: structs.WriteRequest{ + Region: srv.Region(), + AuthToken: srv.getLeaderAcl(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", req, &resp)) + must.Nil(t, resp.SignedIdentity) + + stateNode, err := srv.State().NodeByID(nil, node.ID) + must.NoError(t, err) + must.NotNil(t, stateNode) + must.Eq(t, structs.NodeStatusDown, stateNode.Status) + must.Greater(t, stateNode.CreateIndex, stateNode.ModifyIndex) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration. + name: "identity close to expiration", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + timeNow := time.Now().UTC().Add(-20 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, structs.DefaultNodePoolNodeIdentityTTL) + }, + }, + { + // Ensure a new identity is generated if the identity within the + // request is close to expiration and the new identity has a TTL set + // by its custom node pool configuration. + name: "identity close to expiration custom pool", + testFn: func(t *testing.T, srv *Server, codec rpc.ClientCodec) { + + nodePool := mock.NodePool() + nodePool.NodeIdentityTTL = 168 * time.Hour + must.NoError(t, srv.State().UpsertNodePools( + structs.MsgTypeTestSetup, + srv.raft.LastIndex(), + []*structs.NodePool{nodePool}, + )) + + timeNow := time.Now().UTC().Add(-135 * time.Hour) + timeJWTNow := jwt.NewNumericDate(timeNow) + + node := mock.Node() + node.NodePool = nodePool.Name + must.NoError(t, srv.State().UpsertNode(structs.MsgTypeTestSetup, srv.raft.LastIndex(), node)) + + claims := structs.GenerateNodeIdentityClaims( + node, + srv.Region(), + structs.DefaultNodePoolNodeIdentityTTL, + ) + claims.IssuedAt = timeJWTNow + claims.NotBefore = timeJWTNow + claims.Expiry = jwt.NewNumericDate(time.Now().UTC().Add(4 * time.Hour)) + + signedToken, _, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + + req := structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusReady, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: signedToken, + }, + } + + var resp structs.NodeUpdateResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", &req, &resp)) + must.NotNil(t, resp.SignedIdentity) + verifyIdentityFn(t, srv, *resp.SignedIdentity, node, nodePool.NodeIdentityTTL) + }, + }, + } + + // ACL enabled server test run. + testACLServer, _, aclServerCleanup := TestACLServer(t, func(c *Config) {}) + defer aclServerCleanup() + testACLCodec := rpcClient(t, testACLServer) + + testutil.WaitForLeader(t, testACLServer.RPC) + testutil.WaitForKeyring(t, testACLServer.RPC, testACLServer.config.Region) + + // ACL disabled server test run. + testServer, serverCleanup := TestServer(t, func(c *Config) {}) + defer serverCleanup() + testCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + for _, tc := range testCases { + t.Run("ACL_enabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testACLServer, testACLCodec) + }) + t.Run("ACL_disabled_"+tc.name, func(t *testing.T) { + tc.testFn(t, testServer, testCodec) + }) + } +} + // TestClientEndpoint_UpdateDrain asserts the ability to initiate drain // against a node and cancel that drain. It also asserts: // * an evaluation is created when the node becomes eligible @@ -1316,6 +1842,7 @@ func TestClientEndpoint_UpdateDrain(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Disable drainer to prevent drain from completing during test s1.nodeDrainer.SetEnabled(false, nil) @@ -1435,6 +1962,7 @@ func TestClientEndpoint_UpdatedDrainAndCompleted(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) state := s1.fsm.State() // Disable drainer for now @@ -1545,6 +2073,7 @@ func TestClientEndpoint_UpdatedDrainNoop(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) state := s1.fsm.State() // Create the register request @@ -1688,6 +2217,7 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) require := require.New(t) // Register a node @@ -1820,6 +2350,7 @@ func TestClientEndpoint_UpdateEligibility(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1933,6 +2464,7 @@ func TestClientEndpoint_GetNode(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -1966,10 +2498,14 @@ func TestClientEndpoint_GetNode(t *testing.T) { t.Fatalf("bad ComputedClass: %#v", resp2.Node) } + _, keyID, err := s1.encrypter.GetActiveKey() + must.NoError(t, err) + // Update the status updated at value node.StatusUpdatedAt = resp2.Node.StatusUpdatedAt node.SecretID = "" node.Events = resp2.Node.Events + node.IdentitySigningKeyID = keyID must.Eq(t, node, resp2.Node) // assert that the node register event was set correctly @@ -2167,6 +2703,7 @@ func TestClientEndpoint_GetAllocs(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2497,6 +3034,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2621,6 +3159,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking_GC(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2699,6 +3238,7 @@ func TestClientEndpoint_GetClientAllocs_WithoutMigrateTokens(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2754,6 +3294,7 @@ func TestClientEndpoint_GetAllocs_Blocking(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2853,6 +3394,7 @@ func TestNode_UpdateAlloc(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -2933,6 +3475,7 @@ func TestNode_UpdateAlloc_NodeNotReady(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Register node. node := mock.Node() @@ -3109,6 +3652,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3522,6 +4066,7 @@ func TestClientEndpoint_ListNodes(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3594,6 +4139,7 @@ func TestClientEndpoint_ListNodes_Fields(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() @@ -3961,6 +4507,7 @@ func TestClientEndpoint_UpdateAlloc_Evals_ByTrigger(t *testing.T) { defer cleanupS1() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) // Create the register request node := mock.Node() diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 6c0446561..5cb00ec6c 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -254,6 +254,7 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { s1, cleanupS1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "node1") + c.Region = "regionFoo" c.TLSConfig = &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, @@ -264,18 +265,19 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { } }) defer cleanupS1() + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) - codec := rpcClient(t, s1) + tlsCodec := rpcClientWithTLS(t, s1, s1.config.TLSConfig) // Create the register request node := mock.Node() req := &structs.NodeRegisterRequest{ Node: node, - WriteRequest: structs.WriteRequest{Region: "global"}, + WriteRequest: structs.WriteRequest{Region: s1.Region()}, } var resp structs.GenericResponse - err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) + err := msgpackrpc.CallWithCodec(tlsCodec, "Node.Register", req, &resp) assert.Nil(err) // Check that heartbeatTimers has the heartbeat ID diff --git a/nomad/server_test.go b/nomad/server_test.go index d7175af67..ea490c6e4 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -375,6 +375,7 @@ func TestServer_Reload_TLSConnections_TLSToPlaintext_OnlyRPC(t *testing.T) { } }) defer cleanupS1() + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) newTLSConfig := &config.TLSConfig{ EnableHTTP: true, diff --git a/nomad/structs/identity.go b/nomad/structs/identity.go index 41e43f99d..4c17a9b03 100644 --- a/nomad/structs/identity.go +++ b/nomad/structs/identity.go @@ -51,7 +51,17 @@ func (i *IdentityClaims) IsExpiring(now time.Time, ttl time.Duration) bool { // relative to the current time. threshold := now.Add(ttl / 3) - return i.Expiry.Time().Before(threshold) + return i.Expiry.Time().UTC().Before(threshold) +} + +// IsExpiringInThreshold checks if the identity JWT is expired or close to +// expiring. It uses a passed threshold to determine "close to expiring" which +// is not manipulated, unlike TTL in the IsExpiring method. +func (i *IdentityClaims) IsExpiringInThreshold(threshold time.Time) bool { + if i != nil && i.Expiry != nil { + return threshold.After(i.Expiry.Time()) + } + return false } // setExpiry sets the "expiry" or "exp" claim for the identity JWT. It is the diff --git a/nomad/structs/identity_test.go b/nomad/structs/identity_test.go index b35690f0c..8e09a7061 100644 --- a/nomad/structs/identity_test.go +++ b/nomad/structs/identity_test.go @@ -174,6 +174,57 @@ func TestIdentityClaims_IsExpiring(t *testing.T) { } } +func TestIdentityClaims_IsExpiringWithTTL(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputIdentityClaims *IdentityClaims + inputThreshold time.Time + expectedResult bool + }{ + { + name: "nil identity", + inputIdentityClaims: nil, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "no expiry", + inputIdentityClaims: &IdentityClaims{}, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "not close to expiring", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + }, + }, + inputThreshold: time.Now(), + expectedResult: false, + }, + { + name: "close to expiring", + inputIdentityClaims: &IdentityClaims{ + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now()), + }, + }, + inputThreshold: time.Now().Add(1 * time.Minute), + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputIdentityClaims.IsExpiringInThreshold(tc.inputThreshold) + must.Eq(t, tc.expectedResult, actualOutput) + }) + } +} + func TestIdentityClaimsNg_setExpiry(t *testing.T) { ci.Parallel(t) diff --git a/nomad/structs/node.go b/nomad/structs/node.go index bcc0fec39..a5a308e3f 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -537,3 +537,180 @@ func GenerateNodeIdentityClaims(node *Node, region string, ttl time.Duration) *I return claims } + +// NodeRegisterRequest is used by the Node.Register RPC endpoint to register a +// node as being a schedulable entity. +type NodeRegisterRequest struct { + Node *Node + NodeEvent *NodeEvent + + // CreateNodePool is used to indicate that the node's node pool should be + // created along with the node registration if it doesn't exist. + CreateNodePool bool + + WriteRequest +} + +// ShouldGenerateNodeIdentity compliments the functionality within +// AuthenticateNodeIdentityGenerator to determine whether a new node identity +// should be generated within the RPC handler. +func (n *NodeRegisterRequest) ShouldGenerateNodeIdentity( + authErr error, + now time.Time, + ttl time.Duration, +) bool { + + // In the event the error is because the node identity is expired, we should + // generate a new identity. Without this, a disconnected node would never be + // able to re-register. Any other error is not a reason to generate a new + // identity. + if authErr != nil { + return errors.Is(authErr, jwt.ErrExpired) + } + + // If an ACL token or client ID is set, a node is attempting to register for + // the first time, or is re-registering using its secret ID. In either case, + // we should generate a new identity. + if n.identity.ACLToken != nil || n.identity.ClientID != "" { + return true + } + + // If we have reached this point, we can assume that the request is using a + // node identity. + claims := n.GetIdentity().GetClaims() + + // It is possible that the node has been restarted and had its configuration + // updated. In this case, we should generate a new identity for the node, so + // it reflects its new claims. + if n.Node.NodePool != claims.NodeIdentityClaims.NodePool || + n.Node.NodeClass != claims.NodeIdentityClaims.NodeClass || + n.Node.Datacenter != claims.NodeIdentityClaims.NodeDatacenter { + return true + } + + // The final check is to see if the node identity is expiring. + return claims.IsExpiring(now, ttl) +} + +// NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint +// to update the status of a node. +type NodeUpdateStatusRequest struct { + NodeID string + Status string + + // IdentitySigningKeyID is the ID of the root key used to sign the node's + // identity. This is not provided by the client, but is set by the server, + // so that the value can be propagated through Raft. + IdentitySigningKeyID string + + // ForceIdentityRenewal is used to force the Nomad server to generate a new + // identity for the node. + ForceIdentityRenewal bool + + NodeEvent *NodeEvent + UpdatedAt int64 + WriteRequest +} + +// ShouldGenerateNodeIdentity determines whether the handler should generate a +// new node identity based on the caller identity information. +func (n *NodeUpdateStatusRequest) ShouldGenerateNodeIdentity( + now time.Time, + ttl time.Duration, +) bool { + + identity := n.GetIdentity() + + // If the client ID is set, we should generate a new identity as the node + // has authenticated using its secret ID. + if identity.ClientID != "" { + return true + } + + // Confirm we have a node identity and then check for forced renewal or + // expiration. + if identity.GetClaims().IsNode() { + if n.ForceIdentityRenewal { + return true + } + return n.GetIdentity().GetClaims().IsExpiring(now, ttl) + } + + // No other conditions should generate a new identity. In the case of the + // update status endpoint, this will likely be a Nomad server propagating + // that a node has missed its heartbeat. + return false +} + +// IdentitySigningErrorIsTerminal determines if the RPC handler should return an +// error because it failed to sign a newly generated node identity. +// +// This is because a client might be connected to a follower at the point the +// root keyring is rotated. If the client heartbeats right at that moment and +// before the follower decrypts the key (e.g., network latency to external KMS), +// we will mark the node as down. This is despite identity being valid and the +// likelihood it will get a new identity signed on the next heartbeat. +func (n *NodeUpdateStatusRequest) IdentitySigningErrorIsTerminal(now time.Time) bool { + + identity := n.GetIdentity() + + // If the client has authenticated using a secret ID, we can continue to let + // it do that, until we successfully generate a new identity. + if identity.ClientID != "" { + return false + } + + // If the identity is a node identity, we can check if it is expiring. This + // check is used to determine if the RPC handler should return an error, so + // we use a short threshold of 10 minutes. This is to ensure we don't return + // errors unless we absolutely have to. + // + // A threshold of 10 minutes more than covers another heartbeat on the + // largest Nomad clusters, which can reach ~5 minutes. + if identity.GetClaims().IsNode() { + return n.GetIdentity().GetClaims().IsExpiringInThreshold(now.Add(10 * time.Minute)) + } + + // No other condition should result in the RPC handler returning an error + // because we failed to sign the node identity. No caller should be able to + // reach this point, as identity generation should be gated by + // ShouldGenerateNodeIdentity. + return false +} + +// NodeUpdateResponse is used to respond to a node update. The object is a +// shared response used by the Node.Register, Node.Deregister, +// Node.BatchDeregister, Node.UpdateStatus, and Node.Evaluate RPCs. +type NodeUpdateResponse struct { + HeartbeatTTL time.Duration + EvalIDs []string + EvalCreateIndex uint64 + NodeModifyIndex uint64 + + // Features informs clients what enterprise features are allowed + Features uint64 + + // LeaderRPCAddr is the RPC address of the current Raft Leader. If + // empty, the current Nomad Server is in the minority of a partition. + LeaderRPCAddr string + + // NumNodes is the number of Nomad nodes attached to this quorum of + // Nomad Servers at the time of the response. This value can + // fluctuate based on the health of the cluster between heartbeats. + NumNodes int32 + + // Servers is the full list of known Nomad servers in the local + // region. + Servers []*NodeServerInfo + + // SchedulingEligibility is used to inform clients what the server-side + // has for their scheduling status during heartbeats. + SchedulingEligibility string + + // SignedIdentity is the newly signed node identity that the server has + // generated. The node should check if this is set, and if so, update its + // state with the new identity. + SignedIdentity *string + + QueryMeta +} diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index 95970aaed..57e3912ca 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/nomad/ci" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" @@ -279,3 +280,372 @@ func TestGenerateNodeIdentityClaims(t *testing.T) { must.NotNil(t, claims.NotBefore) must.NotNil(t, claims.Expiry) } + +func TestNodeRegisterRequest_ShouldGenerateNodeIdentity(t *testing.T) { + ci.Parallel(t) + + // Generate a stable mock node for testing. + mockNode := MockNode() + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeRegisterRequest + inputAuthErr error + inputTime time.Time + inputTTL time.Duration + expectedOutput bool + }{ + { + name: "expired node identity", + inputNodeRegisterRequest: &NodeRegisterRequest{}, + inputAuthErr: jwt.ErrExpired, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "first time node registration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: AnonymousACLToken, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "registration using node secret ID", + inputNodeRegisterRequest: &NodeRegisterRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id-1", + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now(), + inputTTL: 10 * time.Minute, + expectedOutput: true, + }, + { + name: "modified node node pool configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: "new-pool", + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "modified node class configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: "new-class", + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "modified node datacenter configuration", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: "new-datacenter", + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(23 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "expiring node identity", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(5 * time.Minute)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "no generation", + inputNodeRegisterRequest: &NodeRegisterRequest{ + Node: mockNode, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{ + NodeID: mockNode.ID, + NodePool: mockNode.NodePool, + NodeClass: mockNode.NodeClass, + NodeDatacenter: mockNode.Datacenter, + }, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputAuthErr: nil, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.ShouldGenerateNodeIdentity( + tc.inputAuthErr, + tc.inputTime, + tc.inputTTL, + ) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} + +func TestNodeUpdateStatusRequest_ShouldGenerateNodeIdentity(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeUpdateStatusRequest + inputTime time.Time + inputTTL time.Duration + expectedOutput bool + }{ + { + name: "authenticated by node secret ID", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id-1", + }, + }, + }, + inputTime: time.Now(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "expiring node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(1 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "not expiring node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + { + name: "not expiring forced renewal node identity", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + ForceIdentityRenewal: true, + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour)), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: true, + }, + { + name: "server authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: LeaderACLToken, + }, + }, + }, + inputTime: time.Now().UTC(), + inputTTL: 24 * time.Hour, + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.ShouldGenerateNodeIdentity( + tc.inputTime, + tc.inputTTL, + ) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} +func TestNodeUpdateStatusRequest_IdentitySigningErrorIsTerminal(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputNodeRegisterRequest *NodeUpdateStatusRequest + inputTime time.Time + expectedOutput bool + }{ + { + name: "not close to expiring", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC().Add(24 * time.Hour).UTC()), + }, + }, + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + { + name: "very close to expiring", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + Claims: &IdentityClaims{ + NodeIdentityClaims: &NodeIdentityClaims{}, + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().UTC()), + }, + }, + }, + }, + }, + inputTime: time.Now().Add(1 * time.Minute).UTC(), + expectedOutput: true, + }, + { + name: "server authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ACLToken: LeaderACLToken, + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + { + name: "client secret ID authenticated request", + inputNodeRegisterRequest: &NodeUpdateStatusRequest{ + WriteRequest: WriteRequest{ + identity: &AuthenticatedIdentity{ + ClientID: "client-id", + }, + }, + }, + inputTime: time.Now().UTC(), + expectedOutput: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := tc.inputNodeRegisterRequest.IdentitySigningErrorIsTerminal(tc.inputTime) + must.Eq(t, tc.expectedOutput, actualOutput) + }) + } +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index b88c20109..b67c85b61 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -597,19 +597,6 @@ type WriteMeta struct { Index uint64 } -// NodeRegisterRequest is used for Node.Register endpoint -// to register a node as being a schedulable entity. -type NodeRegisterRequest struct { - Node *Node - NodeEvent *NodeEvent - - // CreateNodePool is used to indicate that the node's node pool should be - // create along with the node registration if it doesn't exist. - CreateNodePool bool - - WriteRequest -} - // NodeDeregisterRequest is used for Node.Deregister endpoint // to deregister a node as being a schedulable entity. type NodeDeregisterRequest struct { @@ -643,26 +630,6 @@ type NodeServerInfo struct { Datacenter string } -// NodeUpdateStatusRequest is used for Node.UpdateStatus endpoint -// to update the status of a node. -type NodeUpdateStatusRequest struct { - NodeID string - Status string - - // IdentitySigningKeyID is the ID of the root key used to sign the node's - // identity. This is not provided by the client, but is set by the server, - // so that the value can be propagated through Raft. - IdentitySigningKeyID string - - // ForceIdentityRenewal is used to force the Nomad server to generate a new - // identity for the node. - ForceIdentityRenewal bool - - NodeEvent *NodeEvent - UpdatedAt int64 - WriteRequest -} - // NodeUpdateDrainRequest is used for updating the drain strategy type NodeUpdateDrainRequest struct { NodeID string @@ -1506,36 +1473,6 @@ type JobValidateResponse struct { Warnings string } -// NodeUpdateResponse is used to respond to a node update -type NodeUpdateResponse struct { - HeartbeatTTL time.Duration - EvalIDs []string - EvalCreateIndex uint64 - NodeModifyIndex uint64 - - // Features informs clients what enterprise features are allowed - Features uint64 - - // LeaderRPCAddr is the RPC address of the current Raft Leader. If - // empty, the current Nomad Server is in the minority of a partition. - LeaderRPCAddr string - - // NumNodes is the number of Nomad nodes attached to this quorum of - // Nomad Servers at the time of the response. This value can - // fluctuate based on the health of the cluster between heartbeats. - NumNodes int32 - - // Servers is the full list of known Nomad servers in the local - // region. - Servers []*NodeServerInfo - - // SchedulingEligibility is used to inform clients what the server-side - // has for their scheduling status during heartbeats. - SchedulingEligibility string - - QueryMeta -} - // NodeDrainUpdateResponse is used to respond to a node drain update type NodeDrainUpdateResponse struct { NodeModifyIndex uint64 diff --git a/nomad/worker_test.go b/nomad/worker_test.go index 4bd18c7ea..eb00f5806 100644 --- a/nomad/worker_test.go +++ b/nomad/worker_test.go @@ -522,6 +522,7 @@ func TestWorker_SubmitPlanNormalizedAllocations(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -574,6 +575,7 @@ func TestWorker_SubmitPlan_MissingNodeRefresh(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -648,6 +650,7 @@ func TestWorker_UpdateEval(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node() @@ -699,6 +702,7 @@ func TestWorker_CreateEval(t *testing.T) { }) defer cleanupS1() testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.Region()) // Register node node := mock.Node()