diff --git a/api/node_identity.go b/api/node_identity.go index 497ebcd23..d30a890d9 100644 --- a/api/node_identity.go +++ b/api/node_identity.go @@ -3,6 +3,18 @@ package api +// NodeIdentityGetRequest represents the request to retrieve the node identity +// claims for a specific node. +type NodeIdentityGetRequest struct { + NodeID string +} + +// NodeIdentityGetResponse represents the response containing the node identity +// claims. +type NodeIdentityGetResponse struct { + Claims map[string]any +} + type NodeIdentityRenewRequest struct { NodeID string } @@ -17,6 +29,34 @@ func (n *Nodes) Identity() *NodeIdentity { return &NodeIdentity{client: n.client} } +// Get retrieves the node identity claims for the node specified within the +// request object. +// +// The request uses query options to control the forwarding behavior of the +// request only. Parameters such as Filter, WaitTime, and WaitIndex are not used +// and ignored. +func (n *NodeIdentity) Get(req *NodeIdentityGetRequest, qo *QueryOptions) (*NodeIdentityGetResponse, error) { + + if qo == nil { + qo = &QueryOptions{} + } + + if qo.Params == nil { + qo.Params = make(map[string]string) + } + + if req.NodeID != "" { + qo.Params["node_id"] = req.NodeID + } + + var out NodeIdentityGetResponse + + if _, err := n.client.query("/v1/client/identity", &out, qo); err != nil { + return nil, err + } + return &out, nil +} + // Renew instructs the node to request a new identity from the server at its // next heartbeat. // diff --git a/api/node_identity_test.go b/api/node_identity_test.go index 56f682f15..1887177cf 100644 --- a/api/node_identity_test.go +++ b/api/node_identity_test.go @@ -10,6 +10,25 @@ import ( "github.com/shoenig/test/must" ) +func TestNodeIdentity_Get(t *testing.T) { + testutil.Parallel(t) + + configCallback := func(c *testutil.TestServerConfig) { c.DevMode = true } + testClient, testServer := makeClient(t, nil, configCallback) + defer testServer.Stop() + + nodeID := oneNodeFromNodeList(t, testClient.Nodes()).ID + + req := NodeIdentityGetRequest{ + NodeID: nodeID, + } + + resp, err := testClient.Nodes().Identity().Get(&req, nil) + must.NoError(t, err) + must.NotNil(t, resp) + must.MapLen(t, 9, resp.Claims) +} + func TestNodeIdentity_Renew(t *testing.T) { testutil.Parallel(t) diff --git a/client/node_identity_endpoint.go b/client/node_identity_endpoint.go index 8d3eb9289..fc498ae11 100644 --- a/client/node_identity_endpoint.go +++ b/client/node_identity_endpoint.go @@ -4,6 +4,9 @@ package client import ( + "fmt" + + "github.com/go-jose/go-jose/v3/jwt" "github.com/hashicorp/nomad/nomad/structs" ) @@ -16,6 +19,34 @@ func newNodeIdentityEndpoint(c *Client) *NodeIdentity { return n } +func (n *NodeIdentity) Get(args *structs.NodeIdentityGetReq, resp *structs.NodeIdentityGetResp) error { + + // Check for node read permissions. + if aclObj, err := n.c.ResolveToken(args.AuthToken); err != nil { + return err + } else if !aclObj.AllowNodeRead() { + return structs.ErrPermissionDenied + } + + // Parse the signed JWT token from the node identity and extract the claims + // into a map. This is done to avoid exposing the key material of the signed + // JWT token, but still results in all the claims which is perfect for + // debugging and introspection purposes. + parsedJWT, err := jwt.ParseSigned(n.c.nodeIdentityToken()) + if err != nil { + return fmt.Errorf("failed to parsed signed token: %w", err) + } + + claims := make(map[string]any) + + if err := parsedJWT.UnsafeClaimsWithoutVerification(&claims); err != nil { + return fmt.Errorf("failed to extract claims from token: %w", err) + } + + resp.Claims = claims + return nil +} + func (n *NodeIdentity) Renew(args *structs.NodeIdentityRenewReq, _ *structs.NodeIdentityRenewResp) error { // Check node write permissions. diff --git a/client/node_identity_endpoint_test.go b/client/node_identity_endpoint_test.go index cdbbd06e6..647570cbd 100644 --- a/client/node_identity_endpoint_test.go +++ b/client/node_identity_endpoint_test.go @@ -16,6 +16,139 @@ import ( "github.com/shoenig/test/must" ) +func TestNodeIdentity_Get(t *testing.T) { + ci.Parallel(t) + + // Create a test ACL server and client and perform our node identity get + // tests against it. + testACLServer, testServerToken, testACLServerCleanup := nomad.TestACLServer(t, nil) + t.Cleanup(func() { testACLServerCleanup() }) + testutil.WaitForLeader(t, testACLServer.RPC) + + testACLClient, testACLClientCleanup := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{testACLServer.GetConfig().RPCAddr.String()} + }) + t.Cleanup(func() { _ = testACLClientCleanup() }) + testutil.WaitForClientStatusWithToken( + t, testACLServer.RPC, testACLClient.NodeID(), testACLClient.Region(), + structs.NodeStatusReady, testServerToken.SecretID, + ) + + t.Run("acl_denied", func(t *testing.T) { + must.ErrorContains( + t, + testACLClient.ClientRPC( + structs.NodeIdentityGetRPCMethod, + &structs.NodeIdentityGetReq{}, + &structs.NodeIdentityGetResp{}, + ), + structs.ErrPermissionDenied.Error(), + ) + }) + + t.Run("acl_valid", func(t *testing.T) { + + aclPolicy := mock.NodePolicy(acl.PolicyRead) + aclToken := mock.CreatePolicyAndToken(t, testACLServer.State(), 10, t.Name(), aclPolicy) + + req := structs.NodeIdentityGetReq{ + NodeID: testACLClient.NodeID(), + QueryOptions: structs.QueryOptions{ + AuthToken: aclToken.SecretID, + }, + } + + var resp structs.NodeIdentityGetResp + + must.NoError( + t, + testACLClient.ClientRPC( + structs.NodeIdentityGetRPCMethod, + &req, + &resp, + ), + ) + + must.MapLen(t, 10, resp.Claims) + + must.MapContainsKeys(t, resp.Claims, []string{ + "aud", + "exp", + "jti", + "nbf", + "sub", + "iat", + "nomad_node_class", + "nomad_node_datacenter", + "nomad_node_id", + "nomad_node_pool", + }) + + must.MapContainsValues(t, resp.Claims, []any{ + "nomadproject.io", + testACLClient.NodeID(), + testACLClient.Datacenter(), + testACLClient.Node().NodeClass, + testACLClient.Node().NodePool, + }) + }) + + // Create a test non-ACL server and client and perform our node identity get + // tests against it. + testServer, testServerCleanup := nomad.TestServer(t, nil) + t.Cleanup(func() { testServerCleanup() }) + testutil.WaitForLeader(t, testServer.RPC) + + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.Servers = []string{testServer.GetConfig().RPCAddr.String()} + }) + t.Cleanup(func() { _ = testClientCleanup() }) + testutil.WaitForClient(t, testServer.RPC, testClient.NodeID(), testClient.Region()) + + t.Run("non_acl_valid", func(t *testing.T) { + + req := structs.NodeIdentityGetReq{ + NodeID: testACLClient.NodeID(), + QueryOptions: structs.QueryOptions{}, + } + + var resp structs.NodeIdentityGetResp + + must.NoError( + t, + testClient.ClientRPC( + structs.NodeIdentityGetRPCMethod, + &req, + &resp, + ), + ) + + must.MapLen(t, 10, resp.Claims) + + must.MapContainsKeys(t, resp.Claims, []string{ + "aud", + "exp", + "jti", + "nbf", + "sub", + "iat", + "nomad_node_class", + "nomad_node_datacenter", + "nomad_node_id", + "nomad_node_pool", + }) + + must.MapContainsValues(t, resp.Claims, []any{ + "nomadproject.io", + testClient.NodeID(), + testClient.Datacenter(), + testClient.Node().NodeClass, + testClient.Node().NodePool, + }) + }) +} + func TestNodeIdentity_Renew(t *testing.T) { ci.Parallel(t) diff --git a/command/agent/http.go b/command/agent/http.go index 40857a18f..a66ea8385 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -451,6 +451,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.Handle("/v1/client/stats", wrapCORS(s.wrap(s.ClientStatsRequest))) s.mux.Handle("/v1/client/allocation/", wrapCORS(s.wrap(s.ClientAllocRequest))) s.mux.Handle("/v1/client/metadata", wrapCORS(s.wrap(s.NodeMetaRequest))) + s.mux.Handle("/v1/client/identity", wrapCORS(s.wrap(s.NodeIdentityGetRequest))) s.mux.Handle("/v1/client/identity/renew", wrapCORS(s.wrap(s.NodeIdentityRenewRequest))) s.mux.HandleFunc("/v1/agent/self", s.wrap(s.AgentSelfRequest)) diff --git a/command/agent/node_identity_endpoint.go b/command/agent/node_identity_endpoint.go index 4109c98e5..7c460067b 100644 --- a/command/agent/node_identity_endpoint.go +++ b/command/agent/node_identity_endpoint.go @@ -9,7 +9,44 @@ import ( "github.com/hashicorp/nomad/nomad/structs" ) -func (s *HTTPServer) NodeIdentityRenewRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { +func (s *HTTPServer) NodeIdentityGetRequest(resp http.ResponseWriter, req *http.Request) (any, error) { + + if req.Method != http.MethodGet { + return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod) + } + + // Build the request by parsing all common parameters and node id + args := structs.NodeIdentityGetReq{} + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + parseNode(req, &args.NodeID) + + // Determine the handler to use + useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(args.NodeID) + + // Make the RPC + var reply structs.NodeIdentityGetResp + var rpcErr error + if useLocalClient { + rpcErr = s.agent.Client().ClientRPC(structs.NodeIdentityGetRPCMethod, &args, &reply) + } else if useClientRPC { + rpcErr = s.agent.Client().RPC(structs.NodeIdentityGetRPCMethod, &args, &reply) + } else if useServerRPC { + rpcErr = s.agent.Server().RPC(structs.NodeIdentityGetRPCMethod, &args, &reply) + } else { + rpcErr = CodedError(http.StatusBadRequest, "no local Node and node_id not provided") + } + + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) { + rpcErr = CodedError(http.StatusNotFound, rpcErr.Error()) + } + return nil, rpcErr + } + + return reply, nil +} + +func (s *HTTPServer) NodeIdentityRenewRequest(resp http.ResponseWriter, req *http.Request) (any, error) { // Build the request by parsing all common parameters and node id args := structs.NodeIdentityRenewReq{} s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) @@ -28,12 +65,12 @@ func (s *HTTPServer) NodeIdentityRenewRequest(resp http.ResponseWriter, req *htt } else if useServerRPC { rpcErr = s.agent.Server().RPC(structs.NodeIdentityRenewRPCMethod, &args, &reply) } else { - rpcErr = CodedError(400, "no local Node and node_id not provided") + rpcErr = CodedError(http.StatusBadRequest, "no local Node and node_id not provided") } if rpcErr != nil { if structs.IsErrNoNodeConn(rpcErr) { - rpcErr = CodedError(404, rpcErr.Error()) + rpcErr = CodedError(http.StatusNotFound, rpcErr.Error()) } return nil, rpcErr diff --git a/command/agent/node_identity_endpoint_test.go b/command/agent/node_identity_endpoint_test.go new file mode 100644 index 000000000..c2c8270e5 --- /dev/null +++ b/command/agent/node_identity_endpoint_test.go @@ -0,0 +1,85 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package agent + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" +) + +func TestHTTPServer_NodeIdentityGetRequest(t *testing.T) { + ci.Parallel(t) + + t.Run("200 ok", func(t *testing.T) { + httpTest(t, cb, func(s *TestAgent) { + respW := httptest.NewRecorder() + + req, err := http.NewRequest(http.MethodGet, "/v1/client/identity", nil) + must.NoError(t, err) + + obj, err := s.Server.NodeIdentityGetRequest(respW, req) + must.NoError(t, err) + must.Eq(t, http.StatusOK, respW.Code) + + resp, ok := obj.(structs.NodeIdentityGetResp) + must.True(t, ok) + + must.MapLen(t, 9, resp.Claims) + + must.MapContainsKeys(t, resp.Claims, []string{ + "aud", + "exp", + "jti", + "nbf", + "sub", + "iat", + "nomad_node_datacenter", + "nomad_node_id", + "nomad_node_pool", + }) + + must.MapContainsValues(t, resp.Claims, []any{ + "nomadproject.io", + s.client.NodeID(), + s.client.Datacenter(), + s.client.Node().NodePool, + }) + }) + }) + + t.Run("405 invalid method", func(t *testing.T) { + httpTest(t, cb, func(s *TestAgent) { + respW := httptest.NewRecorder() + + badMethods := []string{ + http.MethodConnect, + http.MethodDelete, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + http.MethodPut, + http.MethodTrace, + } + + for _, method := range badMethods { + req, err := http.NewRequest(method, "/v1/client/identity", nil) + must.NoError(t, err) + + _, err = s.Server.NodeIdentityGetRequest(respW, req) + must.ErrorContains(t, err, "Invalid method") + + codedErr, ok := err.(HTTPCodedError) + must.True(t, ok) + must.Eq(t, http.StatusMethodNotAllowed, codedErr.Code()) + must.Eq(t, ErrInvalidMethod, codedErr.Error()) + } + }) + }) +} diff --git a/command/commands.go b/command/commands.go index 586fecbfb..14ec45b67 100644 --- a/command/commands.go +++ b/command/commands.go @@ -644,6 +644,11 @@ func Commands(metaPtr *Meta, agentUi cli.Ui) map[string]cli.CommandFactory { Meta: meta, }, nil }, + "node identity get": func() (cli.Command, error) { + return &NodeIdentityGetCommand{ + Meta: meta, + }, nil + }, "node identity renew": func() (cli.Command, error) { return &NodeIdentityRenewCommand{ Meta: meta, diff --git a/command/node_identity_get.go b/command/node_identity_get.go new file mode 100644 index 000000000..8657d509c --- /dev/null +++ b/command/node_identity_get.go @@ -0,0 +1,161 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "fmt" + "sort" + "strings" + "time" + + "github.com/hashicorp/nomad/api" + "github.com/posener/complete" +) + +type NodeIdentityGetCommand struct { + Meta + + // Command flags are stored below for use across the command. + json bool + tmpl string +} + +func (n *NodeIdentityGetCommand) Help() string { + helpText := ` +Usage: nomad node identity get [options] + + Get the identity claims for a node. This command only applies to client + agents. + + If ACLs are enabled, this command requires a token with the 'node:read' + capability. + +General Options: + + ` + generalOptionsUsage(usageOptsDefault|usageOptsNoNamespace) + ` + +Get Options: + + -json + Output the node identity claims in a JSON format. + + -t + Format and display the node identity claims using a Go template. +` + + return strings.TrimSpace(helpText) +} + +func (n *NodeIdentityGetCommand) Synopsis() string { return "Detail a node's identity claims" } + +func (n *NodeIdentityGetCommand) Name() string { return "node identity get" } + +func (n *NodeIdentityGetCommand) Run(args []string) int { + + flags := n.Meta.FlagSet(n.Name(), FlagSetClient) + flags.BoolVar(&n.json, "json", false, "") + flags.StringVar(&n.tmpl, "t", "", "") + flags.Usage = func() { n.Ui.Output(n.Help()) } + + if err := flags.Parse(args); err != nil { + return 1 + } + args = flags.Args() + + if len(args) != 1 { + n.Ui.Error("This command takes one argument: ") + n.Ui.Error(commandErrorText(n)) + return 1 + } + + // Get the HTTP client + client, err := n.Meta.Client() + if err != nil { + n.Ui.Error(fmt.Sprintf("Error initializing client: %s", err)) + return 1 + } + + nodeID, err := lookupNodeID(client.Nodes(), args[0]) + if err != nil { + n.Ui.Error(err.Error()) + return 1 + } + + req := api.NodeIdentityGetRequest{NodeID: nodeID} + + resp, err := client.Nodes().Identity().Get(&req, nil) + if err != nil { + n.Ui.Error(fmt.Sprintf("Error requesting node identity: %s", err)) + return 1 + } + + return n.ouputClaims(resp.Claims) +} + +func (n *NodeIdentityGetCommand) ouputClaims(claims map[string]any) int { + + // If the user has requested JSON output or a template, format the claims + // accordingly. + if n.json || len(n.tmpl) > 0 { + out, err := Format(n.json, n.tmpl, claims) + if err != nil { + n.Ui.Error(err.Error()) + return 1 + } + + n.Ui.Output(out) + return 0 + } + + var genericClaims, nomadClaims []string + + // Iterate through the claims and separate the generic and Nomad-specific + // claims. This will allow us to group them in the output. + for key := range claims { + if strings.HasPrefix(key, "nomad") { + nomadClaims = append(nomadClaims, key) + } else { + genericClaims = append(genericClaims, key) + } + } + + // Sort the claims alphabetically for consistent output. + sort.Strings(genericClaims) + sort.Strings(nomadClaims) + + output := make([]string, len(genericClaims)+len(nomadClaims)+1) + output[0] = "Claim Key|Claim Value" + + for i, key := range genericClaims { + + // The generic claims currently include timestamps which come to the CLI + // as float64 values. We need to correctly convert these into a + // human-readable format. All other claims are string values. + switch valT := claims[key].(type) { + case float64: + output[i+1] = fmt.Sprintf("%s | %v", key, formatTime(time.Unix(int64(valT), 0))) + default: + output[i+1] = fmt.Sprintf("%s | %s", key, valT) + } + } + + for i, key := range nomadClaims { + output[i+1+len(genericClaims)] = fmt.Sprintf("%s | %s", key, claims[key]) + } + + n.Ui.Output(formatList(output)) + return 0 +} + +func (n *NodeIdentityGetCommand) AutocompleteFlags() complete.Flags { + return mergeAutocompleteFlags(n.Meta.AutocompleteFlags(FlagSetClient), + complete.Flags{ + "-json": complete.PredictNothing, + "-t": complete.PredictAnything, + }) +} + +func (n *NodeIdentityGetCommand) AutocompleteArgs() complete.Predictor { + return nodePredictor(n.Client, nil) +} diff --git a/command/node_identity_get_test.go b/command/node_identity_get_test.go new file mode 100644 index 000000000..b31b2ce55 --- /dev/null +++ b/command/node_identity_get_test.go @@ -0,0 +1,82 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "encoding/json" + "testing" + + "github.com/hashicorp/cli" + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test/must" +) + +func TestNodeIdentityGetCommand_Implements(t *testing.T) { + ci.Parallel(t) + var _ cli.Command = &NodeIntroCreateCommand{} +} + +func TestNodeIdentityGetCommand_Run(t *testing.T) { + ci.Parallel(t) + + srv, _, url := testServer(t, true, nil) + defer srv.Shutdown() + + // Wait until our test node is ready. + testutil.WaitForClient( + t, + srv.Agent.Client().RPC, + srv.Agent.Client().NodeID(), + srv.Agent.Client().Region(), + ) + + ui := cli.NewMockUi() + + cmd := &NodeIdentityGetCommand{ + Meta: Meta{ + Ui: ui, + flagAddress: url, + }, + } + + t.Run("with no command argument", func(t *testing.T) { + t.Cleanup(func() { resetUI(ui) }) + + must.One(t, cmd.Run([]string{})) + must.StrContains(t, ui.ErrorWriter.String(), "This command takes one argument") + }) + + t.Run("node not found", func(t *testing.T) { + t.Cleanup(func() { resetUI(ui) }) + + must.One(t, cmd.Run([]string{"--address=" + url, "f4b2f0a1-7898-ad4e-de19-d9fc9a773961"})) + must.StrContains(t, ui.ErrorWriter.String(), "No node(s) with prefix or id") + }) + + t.Run("standard output", func(t *testing.T) { + t.Cleanup(func() { resetUI(ui) }) + + must.Zero(t, cmd.Run([]string{"--address=" + url, srv.Agent.Client().NodeID()})) + must.StrContains(t, ui.OutputWriter.String(), "Claim Key") + must.StrContains(t, ui.OutputWriter.String(), "Claim Value") + }) + + t.Run("json output", func(t *testing.T) { + t.Cleanup(func() { resetUI(ui) }) + + must.Zero(t, cmd.Run([]string{"--address=" + url, "-json", srv.Agent.Client().NodeID()})) + + var resp map[string]any + must.NoError(t, json.Unmarshal(ui.OutputWriter.Bytes(), &resp)) + must.MapContainsKey(t, resp, "nomad_node_id") + }) + + t.Run("template output", func(t *testing.T) { + t.Cleanup(func() { resetUI(ui) }) + + must.Zero(t, cmd.Run([]string{"--address=" + url, "-t", "{{ .nomad_node_id }}", srv.Agent.Client().NodeID()})) + must.StrContains(t, ui.OutputWriter.String(), srv.Agent.Client().NodeID()) + }) +} diff --git a/nomad/client_identity_endpoint.go b/nomad/client_identity_endpoint.go index 78235d546..bcc271d1e 100644 --- a/nomad/client_identity_endpoint.go +++ b/nomad/client_identity_endpoint.go @@ -20,6 +20,32 @@ func newNodeIdentityEndpoint(srv *Server) *NodeIdentity { } } +func (n *NodeIdentity) Get(args *structs.NodeIdentityGetReq, reply *structs.NodeIdentityGetResp) error { + + // Prevent infinite loop between the leader and the follower with the target + // node connection. + args.QueryOptions.AllowStale = true + + authErr := n.srv.Authenticate(nil, args) + if done, err := n.srv.forward(structs.NodeIdentityGetRPCMethod, args, args, reply); done { + return err + } + n.srv.MeasureRPCRate("client_identity", structs.RateMetricRead, args) + if authErr != nil { + return structs.ErrPermissionDenied + } + defer metrics.MeasureSince([]string{"nomad", "client_identity", "get"}, time.Now()) + + // Check node read permissions + if aclObj, err := n.srv.ResolveACL(args); err != nil { + return err + } else if !aclObj.AllowNodeRead() { + return structs.ErrPermissionDenied + } + + return n.srv.forwardClientRPC(structs.NodeIdentityGetRPCMethod, args.NodeID, args, reply) +} + func (n *NodeIdentity) Renew(args *structs.NodeIdentityRenewReq, reply *structs.NodeIdentityRenewResp) error { // Prevent infinite loop between the leader and the follower with the target diff --git a/nomad/client_identity_endpoint_test.go b/nomad/client_identity_endpoint_test.go index a40289b88..35327efdf 100644 --- a/nomad/client_identity_endpoint_test.go +++ b/nomad/client_identity_endpoint_test.go @@ -14,6 +14,76 @@ import ( "github.com/shoenig/test/must" ) +func TestNodeIdentity_Get_Forward(t *testing.T) { + ci.Parallel(t) + + servers := []*Server{} + for range 3 { + s, cleanup := TestServer(t, func(c *Config) { + c.BootstrapExpect = 3 + c.NumSchedulers = 0 + }) + t.Cleanup(cleanup) + servers = append(servers, s) + } + + TestJoin(t, servers...) + leader := testutil.WaitForLeaders(t, servers[0].RPC, servers[1].RPC, servers[2].RPC) + + followers := []string{} + for _, s := range servers { + if addr := s.config.RPCAddr.String(); addr != leader { + followers = append(followers, addr) + } + } + t.Logf("leader=%s followers=%q", leader, followers) + + clients := make([]*client.Client, 4) + + for i := range 4 { + c, cleanup := client.TestClient(t, func(c *config.Config) { + c.Servers = followers + }) + t.Cleanup(func() { _ = cleanup() }) + clients[i] = c + } + for _, c := range clients { + testutil.WaitForClient(t, servers[0].RPC, c.NodeID(), c.Region()) + } + + agentRPCs := []func(string, any, any) error{} + nodeIDs := make([]string, 0, len(clients)) + + // Build list of agents and node IDs + for _, s := range servers { + agentRPCs = append(agentRPCs, s.RPC) + } + + for _, c := range clients { + agentRPCs = append(agentRPCs, c.RPC) + nodeIDs = append(nodeIDs, c.NodeID()) + } + + // Iterate through all the agent RPCs to ensure that the renew RPC will + // succeed, no matter which agent we connect to. + for _, agentRPC := range agentRPCs { + for _, nodeID := range nodeIDs { + args := &structs.NodeIdentityGetReq{ + NodeID: nodeID, + QueryOptions: structs.QueryOptions{ + Region: clients[0].Region(), + }, + } + must.NoError(t, + agentRPC(structs.NodeIdentityGetRPCMethod, + args, + &structs.NodeIdentityGetResp{}, + ), + ) + } + } +} + func TestNodeIdentity_Renew_Forward(t *testing.T) { ci.Parallel(t) diff --git a/nomad/structs/node.go b/nomad/structs/node.go index d8d33eff2..e491c9867 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -768,6 +768,13 @@ type NodeUpdateResponse struct { } const ( + // NodeIdentityGetRPCMethod is the RPC method for retrieving a client's + // currently stored node identity. + // + // Args: NodeIdentityGetReq + // Reply: NodeIdentityGetResp + NodeIdentityGetRPCMethod = "NodeIdentity.Get" + // NodeIdentityRenewRPCMethod is the RPC method for instructing a client to // forcibly request a renewal of its node identity at the next heartbeat. // @@ -776,6 +783,21 @@ const ( NodeIdentityRenewRPCMethod = "NodeIdentity.Renew" ) +type NodeIdentityGetReq struct { + NodeID string + + // This is a client RPC, so we must use query options which allow us to set + // AllowStale=true. + QueryOptions +} + +type NodeIdentityGetResp struct { + + // Claims contains the node identity claims that are currently being + // utilized by the client. + Claims map[string]any +} + // NodeIdentityRenewReq is used to instruct the Nomad server to renew the client // identity at its next heartbeat regardless of whether it is close to // expiration. diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index f1f2b777e..8d4ae9bc3 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -825,6 +825,13 @@ func TestNodeUpdateStatusRequest_IdentitySigningErrorIsTerminal(t *testing.T) { } } +func TestNodeIdentityGetReq_QueryOptions(t *testing.T) { + ci.Parallel(t) + + req := &NodeIdentityGetReq{} + must.True(t, req.IsRead()) +} + func Test_DefaultNodeIntroductionConfig(t *testing.T) { ci.Parallel(t)