From 2abd72d433b819da8905d79b1f2c44031f131f58 Mon Sep 17 00:00:00 2001 From: James Rasell Date: Tue, 16 Sep 2025 15:15:39 +0100 Subject: [PATCH] http: Fix client identity renew call when node ID is in URI. (#26773) When calling the client identity renew API, it is possible the target node ID is provided by either the URI or within the request body. This change fixes a bug where all calls using a node_id query parameter would be reject as it failed to decode the empty request body. Co-authored-by: Tim Gross --- command/agent/node_identity_endpoint.go | 13 +- command/agent/node_identity_endpoint_test.go | 144 ++++++++++++++----- 2 files changed, 117 insertions(+), 40 deletions(-) diff --git a/command/agent/node_identity_endpoint.go b/command/agent/node_identity_endpoint.go index 3d5e599de..7c5d275a5 100644 --- a/command/agent/node_identity_endpoint.go +++ b/command/agent/node_identity_endpoint.go @@ -56,14 +56,17 @@ func (s *HTTPServer) NodeIdentityRenewRequest(resp http.ResponseWriter, req *htt // Build the request by decoding the request body which will contain the // node ID and the common parameters. args := structs.NodeIdentityRenewReq{} - - if err := decodeBody(req, &args); err != nil { - return nil, CodedError(http.StatusBadRequest, err.Error()) - } - s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) parseNode(req, &args.NodeID) + // If the request body is not empty, it is likely the caller is using this + // to indicate the node ID. Decode it. + if req.Body != nil && req.Body != http.NoBody { + if err := decodeBody(req, &args); err != nil { + return nil, CodedError(http.StatusBadRequest, err.Error()) + } + } + // Determine the handler to use useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(args.NodeID) diff --git a/command/agent/node_identity_endpoint_test.go b/command/agent/node_identity_endpoint_test.go index 438c193a8..9d34fa72d 100644 --- a/command/agent/node_identity_endpoint_test.go +++ b/command/agent/node_identity_endpoint_test.go @@ -18,11 +18,67 @@ import ( func TestHTTPServer_NodeIdentityGetRequest(t *testing.T) { ci.Parallel(t) - t.Run("200 ok", func(t *testing.T) { + t.Run("405 invalid method", func(t *testing.T) { httpTest(t, cb, func(s *TestAgent) { respW := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/v1/client/identity", nil) + badMethods := []string{ + http.MethodConnect, + http.MethodDelete, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + http.MethodPut, + http.MethodTrace, + } + + for _, method := range badMethods { + req, err := http.NewRequest(method, "/v1/client/identity", nil) + must.NoError(t, err) + + _, err = s.Server.NodeIdentityGetRequest(respW, req) + must.ErrorContains(t, err, "Invalid method") + + codedErr, ok := err.(HTTPCodedError) + must.True(t, ok) + must.Eq(t, http.StatusMethodNotAllowed, codedErr.Code()) + must.Eq(t, ErrInvalidMethod, codedErr.Error()) + } + }) + }) + + t.Run("400 query param with unknown node", func(t *testing.T) { + httpTest(t, nil, func(s *TestAgent) { + + respW := httptest.NewRecorder() + + req, err := http.NewRequest( + http.MethodGet, + "/v1/client/identity?node_id="+uuid.Generate(), + nil, + ) + must.NoError(t, err) + + _, err = s.Server.NodeIdentityGetRequest(respW, req) + must.ErrorContains(t, err, "Unknown node") + }) + }) + + t.Run("200 ok query param", func(t *testing.T) { + + // Enable the client, so we have something to renew. + configFn := func(c *Config) { c.Client.Enabled = true } + + httpTest(t, configFn, func(s *TestAgent) { + + respW := httptest.NewRecorder() + + req, err := http.NewRequest( + http.MethodGet, + "/v1/client/identity?node_id="+s.client.NodeID(), + nil, + ) must.NoError(t, err) obj, err := s.Server.NodeIdentityGetRequest(respW, req) @@ -54,36 +110,6 @@ func TestHTTPServer_NodeIdentityGetRequest(t *testing.T) { }) }) }) - - t.Run("405 invalid method", func(t *testing.T) { - httpTest(t, cb, func(s *TestAgent) { - respW := httptest.NewRecorder() - - badMethods := []string{ - http.MethodConnect, - http.MethodDelete, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - http.MethodPut, - http.MethodTrace, - } - - for _, method := range badMethods { - req, err := http.NewRequest(method, "/v1/client/identity", nil) - must.NoError(t, err) - - _, err = s.Server.NodeIdentityGetRequest(respW, req) - must.ErrorContains(t, err, "Invalid method") - - codedErr, ok := err.(HTTPCodedError) - must.True(t, ok) - must.Eq(t, http.StatusMethodNotAllowed, codedErr.Code()) - must.Eq(t, ErrInvalidMethod, codedErr.Error()) - } - }) - }) } func TestHTTPServer_NodeIdentityRenewRequest(t *testing.T) { @@ -118,10 +144,15 @@ func TestHTTPServer_NodeIdentityRenewRequest(t *testing.T) { }) }) - t.Run("400 no node", func(t *testing.T) { + t.Run("400 body with unknown node", func(t *testing.T) { httpTest(t, nil, func(s *TestAgent) { - reqObj := structs.NodeIdentityRenewReq{NodeID: uuid.Generate()} + reqObj := structs.NodeIdentityRenewReq{ + NodeID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + Region: s.config().Region, + }, + } buf := encodeReq(reqObj) @@ -135,7 +166,24 @@ func TestHTTPServer_NodeIdentityRenewRequest(t *testing.T) { }) }) - t.Run("200 ok", func(t *testing.T) { + t.Run("400 query param with unknown node", func(t *testing.T) { + httpTest(t, nil, func(s *TestAgent) { + + respW := httptest.NewRecorder() + + req, err := http.NewRequest( + http.MethodPost, + "/v1/client/identity/renew?node_id="+uuid.Generate(), + nil, + ) + must.NoError(t, err) + + _, err = s.Server.NodeIdentityRenewRequest(respW, req) + must.ErrorContains(t, err, "Unknown node") + }) + }) + + t.Run("200 ok body", func(t *testing.T) { // Enable the client, so we have something to renew. configFn := func(c *Config) { c.Client.Enabled = true } @@ -160,4 +208,30 @@ func TestHTTPServer_NodeIdentityRenewRequest(t *testing.T) { must.True(t, ok) }) }) + + t.Run("200 ok query param", func(t *testing.T) { + + // Enable the client, so we have something to renew. + configFn := func(c *Config) { c.Client.Enabled = true } + + httpTest(t, configFn, func(s *TestAgent) { + + testutil.WaitForClient(t, s.RPC, s.client.NodeID(), s.config().Region) + + respW := httptest.NewRecorder() + + req, err := http.NewRequest( + http.MethodPost, + "/v1/client/identity/renew?node_id="+s.client.NodeID(), + nil, + ) + must.NoError(t, err) + + obj, err := s.Server.NodeIdentityRenewRequest(respW, req) + must.NoError(t, err) + + _, ok := obj.(structs.NodeIdentityRenewResp) + must.True(t, ok) + }) + }) }