mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
remove most static RPC handlers (#15451)
Nomad server components that aren't in the `nomad` package like the deployment watcher and volume watcher need to make RPC calls but can't import the Server struct to do so because it creates a circular reference. These components have a "shim" object that gets populated to pass a "static" handler that has no RPC context. Most RPC handlers are never used in this way, but during server setup we were constructing a set of static handlers for most RPC endpoints anyways. This is slightly wasteful but also confusing to developers who end up being encouraged to just copy what was being done for previous RPCs. This changeset includes the following refactorings: * Remove the static handlers field on the server * Instead construct just the specific static handlers we need to pass into the deployment watcher and volume watcher. * Remove the unnecessary static handler from heartbeater * Update various tests to avoid needing the static endpoints and have them use a endpoint constructed on the spot. Follow-up work will examine whether we can remove the RPCs from deployment watcher and volume watcher entirely, falling back to raft applies like node drainer does currently.
This commit is contained in:
@@ -427,7 +427,9 @@ func TestClientCSI_NodeForControllerPlugin(t *testing.T) {
|
||||
|
||||
plugin, err := state.CSIPluginByID(ws, "minnie")
|
||||
require.NoError(t, err)
|
||||
nodeIDs, err := srv.staticEndpoints.ClientCSI.clientIDsForController(plugin.ID)
|
||||
|
||||
clientCSI := NewClientCSIEndpoint(srv)
|
||||
nodeIDs, err := clientCSI.clientIDsForController(plugin.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(nodeIDs))
|
||||
// only node1 has both the controller and a recent Nomad version
|
||||
|
||||
@@ -1950,7 +1950,7 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// has controller
|
||||
c := srv.staticEndpoints.CSIVolume
|
||||
c := NewCSIVolumeEndpoint(srv, nil)
|
||||
plugin, vol, err := c.volAndPluginLookup(structs.DefaultNamespace, id0)
|
||||
require.NotNil(t, plugin)
|
||||
require.NotNil(t, vol)
|
||||
|
||||
@@ -168,7 +168,8 @@ func (h *nodeHeartbeater) invalidateHeartbeat(id string) {
|
||||
req.Status = structs.NodeStatusDisconnected
|
||||
}
|
||||
var resp structs.NodeUpdateResponse
|
||||
if err := h.staticEndpoints.Node.UpdateStatus(&req, &resp); err != nil {
|
||||
|
||||
if err := h.RPC("Node.UpdateStatus", &req, &resp); err != nil {
|
||||
h.logger.Error("update node status failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -911,7 +911,9 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) {
|
||||
}
|
||||
|
||||
var reply structs.NodeUpdateResponse
|
||||
require.NoError(t, testServer.staticEndpoints.Node.UpdateStatus(&args, &reply))
|
||||
|
||||
nodeEndpoint := NewNodeEndpoint(testServer, nil)
|
||||
require.NoError(t, nodeEndpoint.UpdateStatus(&args, &reply))
|
||||
|
||||
// Query our state, to ensure the node service registrations have been
|
||||
// removed.
|
||||
@@ -2643,7 +2645,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) {
|
||||
|
||||
// Call to do the batch update
|
||||
bf := structs.NewBatchFuture()
|
||||
endpoint := s1.staticEndpoints.Node
|
||||
endpoint := NewNodeEndpoint(s1, nil)
|
||||
endpoint.batchUpdate(bf, []*structs.Allocation{clientAlloc}, nil)
|
||||
if err := bf.Wait(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
@@ -2780,7 +2782,8 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) {
|
||||
idx++
|
||||
|
||||
// Create some evaluations
|
||||
ids, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1)
|
||||
nodeEndpoint := NewNodeEndpoint(s1, nil)
|
||||
ids, index, err := nodeEndpoint.createNodeEvals(node, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@@ -2877,7 +2880,8 @@ func TestClientEndpoint_CreateNodeEvals_MultipleNSes(t *testing.T) {
|
||||
idx++
|
||||
|
||||
// Create some evaluations
|
||||
evalIDs, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1)
|
||||
nodeEndpoint := NewNodeEndpoint(s1, nil)
|
||||
evalIDs, index, err := nodeEndpoint.createNodeEvals(node, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, index)
|
||||
require.Len(t, evalIDs, 2)
|
||||
@@ -2937,7 +2941,8 @@ func TestClientEndpoint_CreateNodeEvals_MultipleDCes(t *testing.T) {
|
||||
idx++
|
||||
|
||||
// Create evaluations
|
||||
evalIDs, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1)
|
||||
nodeEndpoint := NewNodeEndpoint(s1, nil)
|
||||
evalIDs, index, err := nodeEndpoint.createNodeEvals(node, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, index)
|
||||
require.Len(t, evalIDs, 1)
|
||||
|
||||
@@ -153,10 +153,6 @@ type Server struct {
|
||||
rpcTLS *tls.Config
|
||||
rpcCancel context.CancelFunc
|
||||
|
||||
// staticEndpoints is the set of static endpoints that can be reused across
|
||||
// all RPC connections
|
||||
staticEndpoints endpoints
|
||||
|
||||
// streamingRpcs is the registry holding our streaming RPC handlers.
|
||||
streamingRpcs *structs.StreamingRpcRegistry
|
||||
|
||||
@@ -1082,8 +1078,8 @@ func (s *Server) setupDeploymentWatcher() error {
|
||||
s.deploymentWatcher = deploymentwatcher.NewDeploymentsWatcher(
|
||||
s.logger,
|
||||
raftShim,
|
||||
s.staticEndpoints.Deployment,
|
||||
s.staticEndpoints.Job,
|
||||
NewDeploymentEndpoint(s, nil),
|
||||
NewJobEndpoints(s, nil),
|
||||
s.config.DeploymentQueryRateLimit,
|
||||
deploymentwatcher.CrossDeploymentUpdateBatchDuration,
|
||||
)
|
||||
@@ -1094,7 +1090,7 @@ func (s *Server) setupDeploymentWatcher() error {
|
||||
// setupVolumeWatcher creates a volume watcher that sends CSI RPCs
|
||||
func (s *Server) setupVolumeWatcher() error {
|
||||
s.volumeWatcher = volumewatcher.NewVolumesWatcher(
|
||||
s.logger, s.staticEndpoints.CSIVolume, s.getLeaderAcl())
|
||||
s.logger, NewCSIVolumeEndpoint(s, nil), s.getLeaderAcl())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1197,69 +1193,44 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupRpcServer is used to populate an RPC server with endpoints
|
||||
// setupRpcServer is used to populate an RPC server with endpoints. This gets
|
||||
// called at startup but also once for every new RPC connection so that RPC
|
||||
// handlers can have per-connection context.
|
||||
func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) error {
|
||||
|
||||
// Add the static endpoints to the RPC server. These are the RPC handlers
|
||||
// that get used when component on the server is making an internal RPC
|
||||
// call, so we only need them to be initialized once and they have no RPC
|
||||
// context.
|
||||
if s.staticEndpoints.Status == nil {
|
||||
// note: Alloc, Plan have only dynamic endpoints
|
||||
s.staticEndpoints.ACL = NewACLEndpoint(s, nil)
|
||||
s.staticEndpoints.CSIVolume = NewCSIVolumeEndpoint(s, nil)
|
||||
s.staticEndpoints.CSIPlugin = NewCSIPluginEndpoint(s, nil)
|
||||
s.staticEndpoints.Deployment = NewDeploymentEndpoint(s, nil)
|
||||
s.staticEndpoints.Job = NewJobEndpoints(s, nil)
|
||||
s.staticEndpoints.Keyring = NewKeyringEndpoint(s, nil, s.encrypter)
|
||||
s.staticEndpoints.Namespace = NewNamespaceEndpoint(s, nil)
|
||||
s.staticEndpoints.Node = NewNodeEndpoint(s, nil)
|
||||
s.staticEndpoints.Operator = NewOperatorEndpoint(s, nil)
|
||||
s.staticEndpoints.Operator.register() // register the streaming RPCs
|
||||
s.staticEndpoints.Periodic = NewPeriodicEndpoint(s, nil)
|
||||
s.staticEndpoints.Region = NewRegionEndpoint(s, nil)
|
||||
s.staticEndpoints.Scaling = NewScalingEndpoint(s, nil)
|
||||
s.staticEndpoints.Search = NewSearchEndpoint(s, nil)
|
||||
s.staticEndpoints.ServiceRegistration = NewServiceRegistrationEndpoint(s, nil)
|
||||
s.staticEndpoints.Status = NewStatusEndpoint(s, nil)
|
||||
s.staticEndpoints.System = NewSystemEndpoint(s, nil)
|
||||
s.staticEndpoints.Variables = NewVariablesEndpoint(s, nil, s.encrypter)
|
||||
// The endpoints are client RPCs and don't include a connection
|
||||
// context. They also need to be registered as streaming endpoints in their
|
||||
// register() methods.
|
||||
|
||||
s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s, nil)
|
||||
clientAllocs := NewClientAllocationsEndpoint(s)
|
||||
clientAllocs.register()
|
||||
_ = server.Register(clientAllocs)
|
||||
|
||||
// These endpoints don't have a dynamic counterpart, so they'll need to
|
||||
// be re-registered per connection as well (see below)
|
||||
fsEndpoint := NewFileSystemEndpoint(s)
|
||||
fsEndpoint.register()
|
||||
_ = server.Register(fsEndpoint)
|
||||
|
||||
// Client endpoints
|
||||
s.staticEndpoints.ClientStats = NewClientStatsEndpoint(s)
|
||||
s.staticEndpoints.ClientAllocations = NewClientAllocationsEndpoint(s)
|
||||
s.staticEndpoints.ClientAllocations.register() // register the streaming RPCs
|
||||
s.staticEndpoints.ClientCSI = NewClientCSIEndpoint(s)
|
||||
agentEndpoint := NewAgentEndpoint(s)
|
||||
agentEndpoint.register()
|
||||
_ = server.Register(agentEndpoint)
|
||||
|
||||
// Streaming endpoints
|
||||
s.staticEndpoints.FileSystem = NewFileSystemEndpoint(s)
|
||||
s.staticEndpoints.FileSystem.register()
|
||||
// Event is a streaming-only endpoint so we don't want to register it as a
|
||||
// normal RPC
|
||||
eventEndpoint := NewEventEndpoint(s)
|
||||
eventEndpoint.register()
|
||||
|
||||
s.staticEndpoints.Agent = NewAgentEndpoint(s)
|
||||
s.staticEndpoints.Agent.register()
|
||||
// Operator takes a RPC context but also has a streaming RPC that needs to
|
||||
// be registered
|
||||
operatorEndpoint := NewOperatorEndpoint(s, ctx)
|
||||
operatorEndpoint.register()
|
||||
_ = server.Register(NewOperatorEndpoint(s, ctx))
|
||||
|
||||
s.staticEndpoints.Event = NewEventEndpoint(s)
|
||||
s.staticEndpoints.Event.register()
|
||||
}
|
||||
// These endpoints are client RPCs and don't include a connection context
|
||||
_ = server.Register(NewClientCSIEndpoint(s))
|
||||
_ = server.Register(NewClientStatsEndpoint(s))
|
||||
|
||||
// If an endpoint has any non-streaming RPCs doesn't have an RPC context,
|
||||
// we'll register the static handler here instead of creating a new dynamic
|
||||
// endpoint on each connection.
|
||||
|
||||
server.Register(s.staticEndpoints.ClientStats)
|
||||
server.Register(s.staticEndpoints.ClientAllocations)
|
||||
server.Register(s.staticEndpoints.ClientCSI)
|
||||
server.Register(s.staticEndpoints.FileSystem)
|
||||
server.Register(s.staticEndpoints.Agent)
|
||||
|
||||
// Dynamic endpoints are endpoints that include the connection context and
|
||||
// are created on each connection. Register all the dynamic endpoints with
|
||||
// the RPC server.
|
||||
// All other endpoints include the connection context and don't need to be
|
||||
// registered as streaming endpoints
|
||||
|
||||
_ = server.Register(NewACLEndpoint(s, ctx))
|
||||
_ = server.Register(NewAllocEndpoint(s, ctx))
|
||||
@@ -1271,7 +1242,6 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) error {
|
||||
_ = server.Register(NewKeyringEndpoint(s, ctx, s.encrypter))
|
||||
_ = server.Register(NewNamespaceEndpoint(s, ctx))
|
||||
_ = server.Register(NewNodeEndpoint(s, ctx))
|
||||
_ = server.Register(NewOperatorEndpoint(s, ctx))
|
||||
_ = server.Register(NewPeriodicEndpoint(s, ctx))
|
||||
_ = server.Register(NewPlanEndpoint(s, ctx))
|
||||
_ = server.Register(NewRegionEndpoint(s, ctx))
|
||||
|
||||
@@ -106,8 +106,10 @@ func TestVariablesEndpoint_auth(t *testing.T) {
|
||||
err = store.UpsertACLTokens(structs.MsgTypeTestSetup, 1150, []*structs.ACLToken{aclToken})
|
||||
must.NoError(t, err)
|
||||
|
||||
variablesRPC := NewVariablesEndpoint(srv, nil, srv.encrypter)
|
||||
|
||||
t.Run("terminal alloc should be denied", func(t *testing.T) {
|
||||
_, _, err = srv.staticEndpoints.Variables.handleMixedAuthEndpoint(
|
||||
_, _, err := variablesRPC.handleMixedAuthEndpoint(
|
||||
structs.QueryOptions{AuthToken: idToken, Namespace: ns}, acl.PolicyList,
|
||||
fmt.Sprintf("nomad/jobs/%s/web/web", jobID))
|
||||
must.EqError(t, err, structs.ErrPermissionDenied.Error())
|
||||
@@ -119,7 +121,7 @@ func TestVariablesEndpoint_auth(t *testing.T) {
|
||||
structs.MsgTypeTestSetup, 1200, []*structs.Allocation{alloc1}))
|
||||
|
||||
t.Run("wrong namespace should be denied", func(t *testing.T) {
|
||||
_, _, err = srv.staticEndpoints.Variables.handleMixedAuthEndpoint(
|
||||
_, _, err := variablesRPC.handleMixedAuthEndpoint(
|
||||
structs.QueryOptions{AuthToken: idToken, Namespace: structs.DefaultNamespace}, acl.PolicyList,
|
||||
fmt.Sprintf("nomad/jobs/%s/web/web", jobID))
|
||||
must.EqError(t, err, structs.ErrPermissionDenied.Error())
|
||||
@@ -349,7 +351,7 @@ func TestVariablesEndpoint_auth(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, _, err := srv.staticEndpoints.Variables.handleMixedAuthEndpoint(
|
||||
_, _, err := variablesRPC.handleMixedAuthEndpoint(
|
||||
structs.QueryOptions{AuthToken: tc.token, Namespace: ns}, tc.cap, tc.path)
|
||||
if tc.expectedErr == nil {
|
||||
must.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user