diff --git a/command/agent/node_identity_endpoint.go b/command/agent/node_identity_endpoint.go index 7c460067b..3d5e599de 100644 --- a/command/agent/node_identity_endpoint.go +++ b/command/agent/node_identity_endpoint.go @@ -47,8 +47,20 @@ func (s *HTTPServer) NodeIdentityGetRequest(resp http.ResponseWriter, req *http. } func (s *HTTPServer) NodeIdentityRenewRequest(resp http.ResponseWriter, req *http.Request) (any, error) { - // Build the request by parsing all common parameters and node id + + // Only allow POST and PUT methods. + if !(req.Method == http.MethodPut || req.Method == http.MethodPost) { + return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod) + } + + // Build the request by decoding the request body which will contain the + // node ID and the common parameters. args := structs.NodeIdentityRenewReq{} + + if err := decodeBody(req, &args); err != nil { + return nil, CodedError(http.StatusBadRequest, err.Error()) + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) parseNode(req, &args.NodeID) diff --git a/command/agent/node_identity_endpoint_test.go b/command/agent/node_identity_endpoint_test.go index c2c8270e5..438c193a8 100644 --- a/command/agent/node_identity_endpoint_test.go +++ b/command/agent/node_identity_endpoint_test.go @@ -9,7 +9,9 @@ import ( "testing" "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" "github.com/shoenig/test/must" ) @@ -83,3 +85,79 @@ func TestHTTPServer_NodeIdentityGetRequest(t *testing.T) { }) }) } + +func TestHTTPServer_NodeIdentityRenewRequest(t *testing.T) { + ci.Parallel(t) + + t.Run("405 invalid method", func(t *testing.T) { + httpTest(t, nil, func(s *TestAgent) { + respW := httptest.NewRecorder() + + badMethods := []string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodTrace, + } + + for _, method := range badMethods { + req, err := http.NewRequest(method, "/v1/client/identity/renew", nil) + must.NoError(t, err) + + _, err = s.Server.NodeIdentityRenewRequest(respW, req) + must.ErrorContains(t, err, "Invalid method") + + codedErr, ok := err.(HTTPCodedError) + must.True(t, ok) + must.Eq(t, http.StatusMethodNotAllowed, codedErr.Code()) + must.Eq(t, ErrInvalidMethod, codedErr.Error()) + } + }) + }) + + t.Run("400 no node", func(t *testing.T) { + httpTest(t, nil, func(s *TestAgent) { + + reqObj := structs.NodeIdentityRenewReq{NodeID: uuid.Generate()} + + buf := encodeReq(reqObj) + + respW := httptest.NewRecorder() + + req, err := http.NewRequest(http.MethodPost, "/v1/client/identity/renew", buf) + must.NoError(t, err) + + _, err = s.Server.NodeIdentityRenewRequest(respW, req) + must.ErrorContains(t, err, "Unknown node") + }) + }) + + t.Run("200 ok", func(t *testing.T) { + + // Enable the client, so we have something to renew. + configFn := func(c *Config) { c.Client.Enabled = true } + + httpTest(t, configFn, func(s *TestAgent) { + + testutil.WaitForClient(t, s.RPC, s.client.NodeID(), s.config().Region) + + reqObj := structs.NodeIdentityRenewReq{NodeID: s.client.NodeID()} + + buf := encodeReq(reqObj) + + respW := httptest.NewRecorder() + + req, err := http.NewRequest(http.MethodPost, "/v1/client/identity/renew", buf) + must.NoError(t, err) + + obj, err := s.Server.NodeIdentityRenewRequest(respW, req) + must.NoError(t, err) + + _, ok := obj.(structs.NodeIdentityRenewResp) + must.True(t, ok) + }) + }) +}