remove most static RPC handlers (#15451)

Nomad server components that aren't in the `nomad` package like the deployment
watcher and volume watcher need to make RPC calls but can't import the Server
struct to do so because it creates a circular reference. These components have a
"shim" object that gets populated to pass a "static" handler that has no RPC
context.

Most RPC handlers are never used in this way, but during server setup we were
constructing a set of static handlers for most RPC endpoints anyways. This is
slightly wasteful but also confusing to developers who end up being encouraged
to just copy what was being done for previous RPCs.

This changeset includes the following refactorings:
* Remove the static handlers field on the server
* Instead construct just the specific static handlers we need to pass into the
  deployment watcher and volume watcher.
* Remove the unnecessary static handler from heartbeater
* Update various tests to avoid needing the static endpoints and have them use a
  endpoint constructed on the spot.

Follow-up work will examine whether we can remove the RPCs from deployment
watcher and volume watcher entirely, falling back to raft applies like node
drainer does currently.
This commit is contained in:
Tim Gross
2022-12-02 10:12:05 -05:00
committed by GitHub
parent bfcb93c434
commit b352225eed
6 changed files with 53 additions and 73 deletions

View File

@@ -427,7 +427,9 @@ func TestClientCSI_NodeForControllerPlugin(t *testing.T) {
plugin, err := state.CSIPluginByID(ws, "minnie")
require.NoError(t, err)
nodeIDs, err := srv.staticEndpoints.ClientCSI.clientIDsForController(plugin.ID)
clientCSI := NewClientCSIEndpoint(srv)
nodeIDs, err := clientCSI.clientIDsForController(plugin.ID)
require.NoError(t, err)
require.Equal(t, 1, len(nodeIDs))
// only node1 has both the controller and a recent Nomad version

View File

@@ -1950,7 +1950,7 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
require.NoError(t, err)
// has controller
c := srv.staticEndpoints.CSIVolume
c := NewCSIVolumeEndpoint(srv, nil)
plugin, vol, err := c.volAndPluginLookup(structs.DefaultNamespace, id0)
require.NotNil(t, plugin)
require.NotNil(t, vol)

View File

@@ -168,7 +168,8 @@ func (h *nodeHeartbeater) invalidateHeartbeat(id string) {
req.Status = structs.NodeStatusDisconnected
}
var resp structs.NodeUpdateResponse
if err := h.staticEndpoints.Node.UpdateStatus(&req, &resp); err != nil {
if err := h.RPC("Node.UpdateStatus", &req, &resp); err != nil {
h.logger.Error("update node status failed", "error", err)
}
}

View File

@@ -911,7 +911,9 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) {
}
var reply structs.NodeUpdateResponse
require.NoError(t, testServer.staticEndpoints.Node.UpdateStatus(&args, &reply))
nodeEndpoint := NewNodeEndpoint(testServer, nil)
require.NoError(t, nodeEndpoint.UpdateStatus(&args, &reply))
// Query our state, to ensure the node service registrations have been
// removed.
@@ -2643,7 +2645,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) {
// Call to do the batch update
bf := structs.NewBatchFuture()
endpoint := s1.staticEndpoints.Node
endpoint := NewNodeEndpoint(s1, nil)
endpoint.batchUpdate(bf, []*structs.Allocation{clientAlloc}, nil)
if err := bf.Wait(); err != nil {
t.Fatalf("err: %v", err)
@@ -2780,7 +2782,8 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) {
idx++
// Create some evaluations
ids, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1)
nodeEndpoint := NewNodeEndpoint(s1, nil)
ids, index, err := nodeEndpoint.createNodeEvals(node, 1)
if err != nil {
t.Fatalf("err: %v", err)
}
@@ -2877,7 +2880,8 @@ func TestClientEndpoint_CreateNodeEvals_MultipleNSes(t *testing.T) {
idx++
// Create some evaluations
evalIDs, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1)
nodeEndpoint := NewNodeEndpoint(s1, nil)
evalIDs, index, err := nodeEndpoint.createNodeEvals(node, 1)
require.NoError(t, err)
require.NotZero(t, index)
require.Len(t, evalIDs, 2)
@@ -2937,7 +2941,8 @@ func TestClientEndpoint_CreateNodeEvals_MultipleDCes(t *testing.T) {
idx++
// Create evaluations
evalIDs, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1)
nodeEndpoint := NewNodeEndpoint(s1, nil)
evalIDs, index, err := nodeEndpoint.createNodeEvals(node, 1)
require.NoError(t, err)
require.NotZero(t, index)
require.Len(t, evalIDs, 1)

View File

@@ -153,10 +153,6 @@ type Server struct {
rpcTLS *tls.Config
rpcCancel context.CancelFunc
// staticEndpoints is the set of static endpoints that can be reused across
// all RPC connections
staticEndpoints endpoints
// streamingRpcs is the registry holding our streaming RPC handlers.
streamingRpcs *structs.StreamingRpcRegistry
@@ -1082,8 +1078,8 @@ func (s *Server) setupDeploymentWatcher() error {
s.deploymentWatcher = deploymentwatcher.NewDeploymentsWatcher(
s.logger,
raftShim,
s.staticEndpoints.Deployment,
s.staticEndpoints.Job,
NewDeploymentEndpoint(s, nil),
NewJobEndpoints(s, nil),
s.config.DeploymentQueryRateLimit,
deploymentwatcher.CrossDeploymentUpdateBatchDuration,
)
@@ -1094,7 +1090,7 @@ func (s *Server) setupDeploymentWatcher() error {
// setupVolumeWatcher creates a volume watcher that sends CSI RPCs
func (s *Server) setupVolumeWatcher() error {
s.volumeWatcher = volumewatcher.NewVolumesWatcher(
s.logger, s.staticEndpoints.CSIVolume, s.getLeaderAcl())
s.logger, NewCSIVolumeEndpoint(s, nil), s.getLeaderAcl())
return nil
}
@@ -1197,69 +1193,44 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error {
return nil
}
// setupRpcServer is used to populate an RPC server with endpoints
// setupRpcServer is used to populate an RPC server with endpoints. This gets
// called at startup but also once for every new RPC connection so that RPC
// handlers can have per-connection context.
func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) error {
// Add the static endpoints to the RPC server. These are the RPC handlers
// that get used when component on the server is making an internal RPC
// call, so we only need them to be initialized once and they have no RPC
// context.
if s.staticEndpoints.Status == nil {
// note: Alloc, Plan have only dynamic endpoints
s.staticEndpoints.ACL = NewACLEndpoint(s, nil)
s.staticEndpoints.CSIVolume = NewCSIVolumeEndpoint(s, nil)
s.staticEndpoints.CSIPlugin = NewCSIPluginEndpoint(s, nil)
s.staticEndpoints.Deployment = NewDeploymentEndpoint(s, nil)
s.staticEndpoints.Job = NewJobEndpoints(s, nil)
s.staticEndpoints.Keyring = NewKeyringEndpoint(s, nil, s.encrypter)
s.staticEndpoints.Namespace = NewNamespaceEndpoint(s, nil)
s.staticEndpoints.Node = NewNodeEndpoint(s, nil)
s.staticEndpoints.Operator = NewOperatorEndpoint(s, nil)
s.staticEndpoints.Operator.register() // register the streaming RPCs
s.staticEndpoints.Periodic = NewPeriodicEndpoint(s, nil)
s.staticEndpoints.Region = NewRegionEndpoint(s, nil)
s.staticEndpoints.Scaling = NewScalingEndpoint(s, nil)
s.staticEndpoints.Search = NewSearchEndpoint(s, nil)
s.staticEndpoints.ServiceRegistration = NewServiceRegistrationEndpoint(s, nil)
s.staticEndpoints.Status = NewStatusEndpoint(s, nil)
s.staticEndpoints.System = NewSystemEndpoint(s, nil)
s.staticEndpoints.Variables = NewVariablesEndpoint(s, nil, s.encrypter)
// The endpoints are client RPCs and don't include a connection
// context. They also need to be registered as streaming endpoints in their
// register() methods.
s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s, nil)
clientAllocs := NewClientAllocationsEndpoint(s)
clientAllocs.register()
_ = server.Register(clientAllocs)
// These endpoints don't have a dynamic counterpart, so they'll need to
// be re-registered per connection as well (see below)
fsEndpoint := NewFileSystemEndpoint(s)
fsEndpoint.register()
_ = server.Register(fsEndpoint)
// Client endpoints
s.staticEndpoints.ClientStats = NewClientStatsEndpoint(s)
s.staticEndpoints.ClientAllocations = NewClientAllocationsEndpoint(s)
s.staticEndpoints.ClientAllocations.register() // register the streaming RPCs
s.staticEndpoints.ClientCSI = NewClientCSIEndpoint(s)
agentEndpoint := NewAgentEndpoint(s)
agentEndpoint.register()
_ = server.Register(agentEndpoint)
// Streaming endpoints
s.staticEndpoints.FileSystem = NewFileSystemEndpoint(s)
s.staticEndpoints.FileSystem.register()
// Event is a streaming-only endpoint so we don't want to register it as a
// normal RPC
eventEndpoint := NewEventEndpoint(s)
eventEndpoint.register()
s.staticEndpoints.Agent = NewAgentEndpoint(s)
s.staticEndpoints.Agent.register()
// Operator takes a RPC context but also has a streaming RPC that needs to
// be registered
operatorEndpoint := NewOperatorEndpoint(s, ctx)
operatorEndpoint.register()
_ = server.Register(NewOperatorEndpoint(s, ctx))
s.staticEndpoints.Event = NewEventEndpoint(s)
s.staticEndpoints.Event.register()
}
// These endpoints are client RPCs and don't include a connection context
_ = server.Register(NewClientCSIEndpoint(s))
_ = server.Register(NewClientStatsEndpoint(s))
// If an endpoint has any non-streaming RPCs doesn't have an RPC context,
// we'll register the static handler here instead of creating a new dynamic
// endpoint on each connection.
server.Register(s.staticEndpoints.ClientStats)
server.Register(s.staticEndpoints.ClientAllocations)
server.Register(s.staticEndpoints.ClientCSI)
server.Register(s.staticEndpoints.FileSystem)
server.Register(s.staticEndpoints.Agent)
// Dynamic endpoints are endpoints that include the connection context and
// are created on each connection. Register all the dynamic endpoints with
// the RPC server.
// All other endpoints include the connection context and don't need to be
// registered as streaming endpoints
_ = server.Register(NewACLEndpoint(s, ctx))
_ = server.Register(NewAllocEndpoint(s, ctx))
@@ -1271,7 +1242,6 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) error {
_ = server.Register(NewKeyringEndpoint(s, ctx, s.encrypter))
_ = server.Register(NewNamespaceEndpoint(s, ctx))
_ = server.Register(NewNodeEndpoint(s, ctx))
_ = server.Register(NewOperatorEndpoint(s, ctx))
_ = server.Register(NewPeriodicEndpoint(s, ctx))
_ = server.Register(NewPlanEndpoint(s, ctx))
_ = server.Register(NewRegionEndpoint(s, ctx))

View File

@@ -106,8 +106,10 @@ func TestVariablesEndpoint_auth(t *testing.T) {
err = store.UpsertACLTokens(structs.MsgTypeTestSetup, 1150, []*structs.ACLToken{aclToken})
must.NoError(t, err)
variablesRPC := NewVariablesEndpoint(srv, nil, srv.encrypter)
t.Run("terminal alloc should be denied", func(t *testing.T) {
_, _, err = srv.staticEndpoints.Variables.handleMixedAuthEndpoint(
_, _, err := variablesRPC.handleMixedAuthEndpoint(
structs.QueryOptions{AuthToken: idToken, Namespace: ns}, acl.PolicyList,
fmt.Sprintf("nomad/jobs/%s/web/web", jobID))
must.EqError(t, err, structs.ErrPermissionDenied.Error())
@@ -119,7 +121,7 @@ func TestVariablesEndpoint_auth(t *testing.T) {
structs.MsgTypeTestSetup, 1200, []*structs.Allocation{alloc1}))
t.Run("wrong namespace should be denied", func(t *testing.T) {
_, _, err = srv.staticEndpoints.Variables.handleMixedAuthEndpoint(
_, _, err := variablesRPC.handleMixedAuthEndpoint(
structs.QueryOptions{AuthToken: idToken, Namespace: structs.DefaultNamespace}, acl.PolicyList,
fmt.Sprintf("nomad/jobs/%s/web/web", jobID))
must.EqError(t, err, structs.ErrPermissionDenied.Error())
@@ -349,7 +351,7 @@ func TestVariablesEndpoint_auth(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
_, _, err := srv.staticEndpoints.Variables.handleMixedAuthEndpoint(
_, _, err := variablesRPC.handleMixedAuthEndpoint(
structs.QueryOptions{AuthToken: tc.token, Namespace: ns}, tc.cap, tc.path)
if tc.expectedErr == nil {
must.NoError(t, err)