diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index 34dc7d6a5..382ec2d05 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -427,7 +427,9 @@ func TestClientCSI_NodeForControllerPlugin(t *testing.T) { plugin, err := state.CSIPluginByID(ws, "minnie") require.NoError(t, err) - nodeIDs, err := srv.staticEndpoints.ClientCSI.clientIDsForController(plugin.ID) + + clientCSI := NewClientCSIEndpoint(srv) + nodeIDs, err := clientCSI.clientIDsForController(plugin.ID) require.NoError(t, err) require.Equal(t, 1, len(nodeIDs)) // only node1 has both the controller and a recent Nomad version diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 4a4decf86..60b7368a0 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -1950,7 +1950,7 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) { require.NoError(t, err) // has controller - c := srv.staticEndpoints.CSIVolume + c := NewCSIVolumeEndpoint(srv, nil) plugin, vol, err := c.volAndPluginLookup(structs.DefaultNamespace, id0) require.NotNil(t, plugin) require.NotNil(t, vol) diff --git a/nomad/heartbeat.go b/nomad/heartbeat.go index e76b39541..a7c5c7f7c 100644 --- a/nomad/heartbeat.go +++ b/nomad/heartbeat.go @@ -168,7 +168,8 @@ func (h *nodeHeartbeater) invalidateHeartbeat(id string) { req.Status = structs.NodeStatusDisconnected } var resp structs.NodeUpdateResponse - if err := h.staticEndpoints.Node.UpdateStatus(&req, &resp); err != nil { + + if err := h.RPC("Node.UpdateStatus", &req, &resp); err != nil { h.logger.Error("update node status failed", "error", err) } } diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 082773ec3..32aa966b5 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -911,7 +911,9 @@ func TestNode_UpdateStatus_ServiceRegistrations(t *testing.T) { } var reply structs.NodeUpdateResponse - require.NoError(t, testServer.staticEndpoints.Node.UpdateStatus(&args, &reply)) + + nodeEndpoint := NewNodeEndpoint(testServer, nil) + require.NoError(t, nodeEndpoint.UpdateStatus(&args, &reply)) // Query our state, to ensure the node service registrations have been // removed. @@ -2643,7 +2645,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { // Call to do the batch update bf := structs.NewBatchFuture() - endpoint := s1.staticEndpoints.Node + endpoint := NewNodeEndpoint(s1, nil) endpoint.batchUpdate(bf, []*structs.Allocation{clientAlloc}, nil) if err := bf.Wait(); err != nil { t.Fatalf("err: %v", err) @@ -2780,7 +2782,8 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) { idx++ // Create some evaluations - ids, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1) + nodeEndpoint := NewNodeEndpoint(s1, nil) + ids, index, err := nodeEndpoint.createNodeEvals(node, 1) if err != nil { t.Fatalf("err: %v", err) } @@ -2877,7 +2880,8 @@ func TestClientEndpoint_CreateNodeEvals_MultipleNSes(t *testing.T) { idx++ // Create some evaluations - evalIDs, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1) + nodeEndpoint := NewNodeEndpoint(s1, nil) + evalIDs, index, err := nodeEndpoint.createNodeEvals(node, 1) require.NoError(t, err) require.NotZero(t, index) require.Len(t, evalIDs, 2) @@ -2937,7 +2941,8 @@ func TestClientEndpoint_CreateNodeEvals_MultipleDCes(t *testing.T) { idx++ // Create evaluations - evalIDs, index, err := s1.staticEndpoints.Node.createNodeEvals(node, 1) + nodeEndpoint := NewNodeEndpoint(s1, nil) + evalIDs, index, err := nodeEndpoint.createNodeEvals(node, 1) require.NoError(t, err) require.NotZero(t, index) require.Len(t, evalIDs, 1) diff --git a/nomad/server.go b/nomad/server.go index bf897080a..d9f8679fc 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -153,10 +153,6 @@ type Server struct { rpcTLS *tls.Config rpcCancel context.CancelFunc - // staticEndpoints is the set of static endpoints that can be reused across - // all RPC connections - staticEndpoints endpoints - // streamingRpcs is the registry holding our streaming RPC handlers. streamingRpcs *structs.StreamingRpcRegistry @@ -1082,8 +1078,8 @@ func (s *Server) setupDeploymentWatcher() error { s.deploymentWatcher = deploymentwatcher.NewDeploymentsWatcher( s.logger, raftShim, - s.staticEndpoints.Deployment, - s.staticEndpoints.Job, + NewDeploymentEndpoint(s, nil), + NewJobEndpoints(s, nil), s.config.DeploymentQueryRateLimit, deploymentwatcher.CrossDeploymentUpdateBatchDuration, ) @@ -1094,7 +1090,7 @@ func (s *Server) setupDeploymentWatcher() error { // setupVolumeWatcher creates a volume watcher that sends CSI RPCs func (s *Server) setupVolumeWatcher() error { s.volumeWatcher = volumewatcher.NewVolumesWatcher( - s.logger, s.staticEndpoints.CSIVolume, s.getLeaderAcl()) + s.logger, NewCSIVolumeEndpoint(s, nil), s.getLeaderAcl()) return nil } @@ -1197,69 +1193,44 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { return nil } -// setupRpcServer is used to populate an RPC server with endpoints +// setupRpcServer is used to populate an RPC server with endpoints. This gets +// called at startup but also once for every new RPC connection so that RPC +// handlers can have per-connection context. func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) error { - // Add the static endpoints to the RPC server. These are the RPC handlers - // that get used when component on the server is making an internal RPC - // call, so we only need them to be initialized once and they have no RPC - // context. - if s.staticEndpoints.Status == nil { - // note: Alloc, Plan have only dynamic endpoints - s.staticEndpoints.ACL = NewACLEndpoint(s, nil) - s.staticEndpoints.CSIVolume = NewCSIVolumeEndpoint(s, nil) - s.staticEndpoints.CSIPlugin = NewCSIPluginEndpoint(s, nil) - s.staticEndpoints.Deployment = NewDeploymentEndpoint(s, nil) - s.staticEndpoints.Job = NewJobEndpoints(s, nil) - s.staticEndpoints.Keyring = NewKeyringEndpoint(s, nil, s.encrypter) - s.staticEndpoints.Namespace = NewNamespaceEndpoint(s, nil) - s.staticEndpoints.Node = NewNodeEndpoint(s, nil) - s.staticEndpoints.Operator = NewOperatorEndpoint(s, nil) - s.staticEndpoints.Operator.register() // register the streaming RPCs - s.staticEndpoints.Periodic = NewPeriodicEndpoint(s, nil) - s.staticEndpoints.Region = NewRegionEndpoint(s, nil) - s.staticEndpoints.Scaling = NewScalingEndpoint(s, nil) - s.staticEndpoints.Search = NewSearchEndpoint(s, nil) - s.staticEndpoints.ServiceRegistration = NewServiceRegistrationEndpoint(s, nil) - s.staticEndpoints.Status = NewStatusEndpoint(s, nil) - s.staticEndpoints.System = NewSystemEndpoint(s, nil) - s.staticEndpoints.Variables = NewVariablesEndpoint(s, nil, s.encrypter) + // The endpoints are client RPCs and don't include a connection + // context. They also need to be registered as streaming endpoints in their + // register() methods. - s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s, nil) + clientAllocs := NewClientAllocationsEndpoint(s) + clientAllocs.register() + _ = server.Register(clientAllocs) - // These endpoints don't have a dynamic counterpart, so they'll need to - // be re-registered per connection as well (see below) + fsEndpoint := NewFileSystemEndpoint(s) + fsEndpoint.register() + _ = server.Register(fsEndpoint) - // Client endpoints - s.staticEndpoints.ClientStats = NewClientStatsEndpoint(s) - s.staticEndpoints.ClientAllocations = NewClientAllocationsEndpoint(s) - s.staticEndpoints.ClientAllocations.register() // register the streaming RPCs - s.staticEndpoints.ClientCSI = NewClientCSIEndpoint(s) + agentEndpoint := NewAgentEndpoint(s) + agentEndpoint.register() + _ = server.Register(agentEndpoint) - // Streaming endpoints - s.staticEndpoints.FileSystem = NewFileSystemEndpoint(s) - s.staticEndpoints.FileSystem.register() + // Event is a streaming-only endpoint so we don't want to register it as a + // normal RPC + eventEndpoint := NewEventEndpoint(s) + eventEndpoint.register() - s.staticEndpoints.Agent = NewAgentEndpoint(s) - s.staticEndpoints.Agent.register() + // Operator takes a RPC context but also has a streaming RPC that needs to + // be registered + operatorEndpoint := NewOperatorEndpoint(s, ctx) + operatorEndpoint.register() + _ = server.Register(NewOperatorEndpoint(s, ctx)) - s.staticEndpoints.Event = NewEventEndpoint(s) - s.staticEndpoints.Event.register() - } + // These endpoints are client RPCs and don't include a connection context + _ = server.Register(NewClientCSIEndpoint(s)) + _ = server.Register(NewClientStatsEndpoint(s)) - // If an endpoint has any non-streaming RPCs doesn't have an RPC context, - // we'll register the static handler here instead of creating a new dynamic - // endpoint on each connection. - - server.Register(s.staticEndpoints.ClientStats) - server.Register(s.staticEndpoints.ClientAllocations) - server.Register(s.staticEndpoints.ClientCSI) - server.Register(s.staticEndpoints.FileSystem) - server.Register(s.staticEndpoints.Agent) - - // Dynamic endpoints are endpoints that include the connection context and - // are created on each connection. Register all the dynamic endpoints with - // the RPC server. + // All other endpoints include the connection context and don't need to be + // registered as streaming endpoints _ = server.Register(NewACLEndpoint(s, ctx)) _ = server.Register(NewAllocEndpoint(s, ctx)) @@ -1271,7 +1242,6 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) error { _ = server.Register(NewKeyringEndpoint(s, ctx, s.encrypter)) _ = server.Register(NewNamespaceEndpoint(s, ctx)) _ = server.Register(NewNodeEndpoint(s, ctx)) - _ = server.Register(NewOperatorEndpoint(s, ctx)) _ = server.Register(NewPeriodicEndpoint(s, ctx)) _ = server.Register(NewPlanEndpoint(s, ctx)) _ = server.Register(NewRegionEndpoint(s, ctx)) diff --git a/nomad/variables_endpoint_test.go b/nomad/variables_endpoint_test.go index d23898a61..ff0191230 100644 --- a/nomad/variables_endpoint_test.go +++ b/nomad/variables_endpoint_test.go @@ -106,8 +106,10 @@ func TestVariablesEndpoint_auth(t *testing.T) { err = store.UpsertACLTokens(structs.MsgTypeTestSetup, 1150, []*structs.ACLToken{aclToken}) must.NoError(t, err) + variablesRPC := NewVariablesEndpoint(srv, nil, srv.encrypter) + t.Run("terminal alloc should be denied", func(t *testing.T) { - _, _, err = srv.staticEndpoints.Variables.handleMixedAuthEndpoint( + _, _, err := variablesRPC.handleMixedAuthEndpoint( structs.QueryOptions{AuthToken: idToken, Namespace: ns}, acl.PolicyList, fmt.Sprintf("nomad/jobs/%s/web/web", jobID)) must.EqError(t, err, structs.ErrPermissionDenied.Error()) @@ -119,7 +121,7 @@ func TestVariablesEndpoint_auth(t *testing.T) { structs.MsgTypeTestSetup, 1200, []*structs.Allocation{alloc1})) t.Run("wrong namespace should be denied", func(t *testing.T) { - _, _, err = srv.staticEndpoints.Variables.handleMixedAuthEndpoint( + _, _, err := variablesRPC.handleMixedAuthEndpoint( structs.QueryOptions{AuthToken: idToken, Namespace: structs.DefaultNamespace}, acl.PolicyList, fmt.Sprintf("nomad/jobs/%s/web/web", jobID)) must.EqError(t, err, structs.ErrPermissionDenied.Error()) @@ -349,7 +351,7 @@ func TestVariablesEndpoint_auth(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - _, _, err := srv.staticEndpoints.Variables.handleMixedAuthEndpoint( + _, _, err := variablesRPC.handleMixedAuthEndpoint( structs.QueryOptions{AuthToken: tc.token, Namespace: ns}, tc.cap, tc.path) if tc.expectedErr == nil { must.NoError(t, err)