diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index af68a08ea..d55722f49 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -534,16 +534,18 @@ func (s *Authenticator) VerifyClaim(token string) (*structs.IdentityClaims, erro return claims, nil } + // If the claims are for a node identity, we can return them directly once + // we have verified the claim. In the happy path, we could read the node out + // of state and verify that it is found. However, it is possible the node + // has been garbage collected, and if we failed on that check, the node + // would not be able to register again without manual intervention. if claims.IsNode() { - if err := s.verifyNodeIdentityClaim(claims); err != nil { - return nil, err - } return claims, nil } - // Node introduction claims are a special case where we don't verify them - // against the state store, since they are used to introduce a node that - // does not yet exist. + // Node introduction claims are a case where we don't verify them against + // the state store, since they are used to introduce a node that does not + // yet exist. if claims.IsNodeIntroduction() { return claims, nil } @@ -572,23 +574,6 @@ func (s *Authenticator) verifyWorkloadIdentityClaim(claims *structs.IdentityClai return nil } -func (s *Authenticator) verifyNodeIdentityClaim(claims *structs.IdentityClaims) error { - - snap, err := s.getState().Snapshot() - if err != nil { - return err - } - node, err := snap.NodeByID(nil, claims.NodeIdentityClaims.NodeID) - if err != nil { - return err - } - if node == nil { - return errors.New("node does not exist") - } - - return nil -} - func (s *Authenticator) resolveClaims(claims *structs.IdentityClaims) (*acl.ACL, error) { // Nomad node identity claims currently map to a client ACL. If we open this diff --git a/nomad/auth/auth_test.go b/nomad/auth/auth_test.go index ff958771b..140f4696e 100644 --- a/nomad/auth/auth_test.go +++ b/nomad/auth/auth_test.go @@ -385,7 +385,7 @@ func TestAuthenticateDefault(t *testing.T) { }, }, { - name: "mTLS and ACLs with invalid node identity", + name: "mTLS and ACLs with node identity no state", testFn: func(t *testing.T, store *state.StateStore) { node := mock.Node() @@ -400,7 +400,7 @@ func TestAuthenticateDefault(t *testing.T) { args.AuthToken = token var ctx *testContext - must.ErrorContains(t, auth.Authenticate(ctx, args), "node does not exist") + must.NoError(t, auth.Authenticate(ctx, args)) }, }, } @@ -981,7 +981,7 @@ func TestAuthenticateClientOnly(t *testing.T) { }, }, { - name: "with mTLS and ACLs with server cert and invalid node identity", + name: "with mTLS and ACLs with server cert and node identity no state", testFn: func(t *testing.T, store *state.StateStore, node *structs.Node) { ctx := newTestContext(t, "server.global.nomad", "192.168.1.1") @@ -999,8 +999,8 @@ func TestAuthenticateClientOnly(t *testing.T) { args.AuthToken = token aclObj, err := auth.AuthenticateClientOnly(ctx, args) - must.Error(t, err) - must.Nil(t, aclObj) + must.NoError(t, err) + must.NotNil(t, aclObj) }, }, } @@ -1566,45 +1566,6 @@ func TestResolveClaims(t *testing.T) { } -func TestAuthenticator_verifyNodeIdentityClaim(t *testing.T) { - ci.Parallel(t) - - // Create our base test objects including a node that can be used in the - // tests. - testAuthenticator := testDefaultAuthenticator(t) - - mockNode := mock.Node() - must.NoError(t, testAuthenticator.getState().UpsertNode(structs.MsgTypeTestSetup, 100, mockNode)) - - testCases := []struct { - name string - inputClaims *structs.IdentityClaims - expectedOutput error - }{ - { - name: "node does not exist", - inputClaims: structs.GenerateNodeIdentityClaims(mock.Node(), "global", 1*time.Hour), - expectedOutput: errors.New("node does not exist"), - }, - { - name: "verified node claims", - inputClaims: structs.GenerateNodeIdentityClaims(mockNode, "global", 1*time.Hour), - expectedOutput: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - actualOutput := testAuthenticator.verifyNodeIdentityClaim(tc.inputClaims) - if tc.expectedOutput == nil { - must.NoError(t, actualOutput) - } else { - must.EqError(t, actualOutput, tc.expectedOutput.Error()) - } - }) - } -} - func testStateStore(t *testing.T) *state.StateStore { sconfig := &state.StateStoreConfig{ Logger: testlog.HCLogger(t), diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 2b9ee503c..b99dc6ffb 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -327,6 +327,19 @@ func (n *Node) newRegistrationAllowed( claims.NodeIntroductionIdentityClaims.NodeName == args.Node.Name) } + // In a less happy path, a node could be making a registration request after + // its state object has been removed via garbage collection. In this case, + // it will be using its existing node identity, and we can perform a check + // on the claim here. + // + // It's possible while it was down, the nodes configuration changed, so we + // only check the node ID in this case. Later in the RPC handler, we will + // check if a new identity needs to be generated based on change + // configuration. + if claims.IsNode() { + claimsMatch = claims.NodeIdentityClaims.NodeID == args.Node.ID + } + // If there was no authentication error and the identity claims match the // node's claims, the registration is allowed to proceed. if authErr == nil && claimsMatch { @@ -349,10 +362,12 @@ func (n *Node) newRegistrationAllowed( "node_name", args.Node.Name, } - // If the node used a node introduction identity, add the claims for - // comparison to the logging pairs. + // If the node used a node introduction identity or node identity, add the + // claims for comparison to the logging pairs. if claims.IsNodeIntroduction() { loggingPairs = append(loggingPairs, claims.NodeIntroductionIdentityClaims.LoggingPairs()...) + } else if claims.IsNode() { + loggingPairs = append(loggingPairs, claims.NodeIdentityClaims.LoggingPairs()...) } // Make some effort to log a message that indicates why the node is failing diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 1e36877ed..9f1f42fed 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -5088,6 +5088,58 @@ func TestNode_newRegistrationAllowed(t *testing.T) { testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), )) }) + + t.Run("enforcement warn node identity claims match", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + claims := structs.GenerateNodeIdentityClaims( + mockNode, + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.True(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement strict node identity claims match", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + claims := structs.GenerateNodeIdentityClaims( + mockNode, + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.True(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) } // TestNode_List_PaginationFiltering asserts that API pagination and filtering diff --git a/nomad/structs/node.go b/nomad/structs/node.go index 383a5aacf..e3dd6af21 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -539,6 +539,17 @@ func GenerateNodeIdentityClaims(node *Node, region string, ttl time.Duration) *I return claims } +// LoggingPairs returns a set of key-value pairs that can be used for logging +// purposes. +func (n *NodeIdentityClaims) LoggingPairs() []any { + return []any{ + "claim_node_id", n.NodeID, + "claim_node_pool", n.NodePool, + "claim_node_class", n.NodeClass, + "claim_node_datacenter", n.NodeDatacenter, + } +} + // NodeRegisterRequest is used by the Node.Register RPC endpoint to register a // node as being a schedulable entity. type NodeRegisterRequest struct { diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index d8138a61e..0cf7d4b2a 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -282,6 +282,31 @@ func TestGenerateNodeIdentityClaims(t *testing.T) { must.NotNil(t, claims.Expiry) } +func TestNodeIdentityClaims_LoggingPairs(t *testing.T) { + ci.Parallel(t) + + claims := GenerateNodeIdentityClaims( + &Node{ + ID: "node-id-1", + NodePool: "custom-pool", + NodeClass: "custom-class", + Datacenter: "euw2", + }, + "euw", + 10*time.Minute, + ) + must.Eq( + t, + []any{ + "claim_node_id", "node-id-1", + "claim_node_pool", "custom-pool", + "claim_node_class", "custom-class", + "claim_node_datacenter", "euw2", + }, + claims.NodeIdentityClaims.LoggingPairs(), + ) +} + func TestNodeRegisterRequest_Validate(t *testing.T) { ci.Parallel(t)