From b027d8f771e68f3bf4aba57ce58a8aded04214eb Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 17 Oct 2023 15:15:00 -0500 Subject: [PATCH] do not embed *Server (#18786) these structs embedding Server, then Server _also embedding them_, confused my IDE, isn't necessary, and just feels wrong! --- nomad/heartbeat.go | 26 ++++++------- nomad/plan_apply.go | 30 +++++++-------- nomad/rpc.go | 92 ++++++++++++++++++++++----------------------- 3 files changed, 73 insertions(+), 75 deletions(-) diff --git a/nomad/heartbeat.go b/nomad/heartbeat.go index 25ba5348a..2b207b6ed 100644 --- a/nomad/heartbeat.go +++ b/nomad/heartbeat.go @@ -35,7 +35,7 @@ var ( // nodeHeartbeater is used to track expiration times of node heartbeats. If it // detects an expired node, the node status is updated to be 'down'. type nodeHeartbeater struct { - *Server + srv *Server logger log.Logger // heartbeatTimers track the expiration time of each heartbeat that has @@ -48,7 +48,7 @@ type nodeHeartbeater struct { // failed node heartbeats. func newNodeHeartbeater(s *Server) *nodeHeartbeater { return &nodeHeartbeater{ - Server: s, + srv: s, logger: s.logger.Named("heartbeat"), } } @@ -58,7 +58,7 @@ func newNodeHeartbeater(s *Server) *nodeHeartbeater { // the previously known set of timers. func (h *nodeHeartbeater) initializeHeartbeatTimers() error { // Scan all nodes and reset their timer - snap, err := h.fsm.State().Snapshot() + snap, err := h.srv.fsm.State().Snapshot() if err != nil { return err } @@ -83,7 +83,7 @@ func (h *nodeHeartbeater) initializeHeartbeatTimers() error { if node.TerminalStatus() { continue } - h.resetHeartbeatTimerLocked(node.ID, h.config.FailoverHeartbeatTTL) + h.resetHeartbeatTimerLocked(node.ID, h.srv.config.FailoverHeartbeatTTL) } return nil } @@ -97,18 +97,18 @@ func (h *nodeHeartbeater) resetHeartbeatTimer(id string) (time.Duration, error) // Do not create a timer for the node since we are not the leader. This // check avoids the race in which leadership is lost but a timer is created // on this server since it was servicing an RPC during a leadership loss. - if !h.IsLeader() { + if !h.srv.IsLeader() { h.logger.Debug("ignoring resetting node TTL since this server is not the leader", "node_id", id) return 0, heartbeatNotLeaderErr } // Compute the target TTL value n := len(h.heartbeatTimers) - ttl := helper.RateScaledInterval(h.config.MaxHeartbeatsPerSecond, h.config.MinHeartbeatTTL, n) + ttl := helper.RateScaledInterval(h.srv.config.MaxHeartbeatsPerSecond, h.srv.config.MinHeartbeatTTL, n) ttl += helper.RandomStagger(ttl) // Reset the TTL - h.resetHeartbeatTimerLocked(id, ttl+h.config.HeartbeatGrace) + h.resetHeartbeatTimerLocked(id, ttl+h.srv.config.HeartbeatGrace) return ttl, nil } @@ -148,7 +148,7 @@ func (h *nodeHeartbeater) invalidateHeartbeat(id string) { // Do not invalidate the node since we are not the leader. This check avoids // the race in which leadership is lost but a timer is created on this // server since it was servicing an RPC during a leadership loss. - if !h.IsLeader() { + if !h.srv.IsLeader() { h.logger.Debug("ignoring node TTL since this server is not the leader", "node_id", id) return } @@ -163,7 +163,7 @@ func (h *nodeHeartbeater) invalidateHeartbeat(id string) { Status: structs.NodeStatusDown, NodeEvent: structs.NewNodeEvent().SetSubsystem(structs.NodeEventSubsystemCluster).SetMessage(NodeHeartbeatEventMissed), WriteRequest: structs.WriteRequest{ - Region: h.config.Region, + Region: h.srv.config.Region, }, } @@ -172,13 +172,13 @@ func (h *nodeHeartbeater) invalidateHeartbeat(id string) { } var resp structs.NodeUpdateResponse - if err := h.RPC("Node.UpdateStatus", &req, &resp); err != nil { + if err := h.srv.RPC("Node.UpdateStatus", &req, &resp); err != nil { h.logger.Error("update node status failed", "error", err) } } func (h *nodeHeartbeater) disconnectState(id string) (bool, bool) { - node, err := h.State().NodeByID(nil, id) + node, err := h.srv.State().NodeByID(nil, id) if err != nil { h.logger.Error("error retrieving node by id", "error", err) return false, false @@ -189,7 +189,7 @@ func (h *nodeHeartbeater) disconnectState(id string) (bool, bool) { return false, false } - allocs, err := h.State().AllocsByNode(nil, id) + allocs, err := h.srv.State().AllocsByNode(nil, id) if err != nil { h.logger.Error("error retrieving allocs by node", "error", err) return false, false @@ -257,7 +257,7 @@ func (h *nodeHeartbeater) heartbeatStats() { h.heartbeatTimersLock.Unlock() metrics.SetGauge([]string{"nomad", "heartbeat", "active"}, float32(num)) - case <-h.shutdownCh: + case <-h.srv.shutdownCh: return } } diff --git a/nomad/plan_apply.go b/nomad/plan_apply.go index 27b72f15c..ecf4d4ec5 100644 --- a/nomad/plan_apply.go +++ b/nomad/plan_apply.go @@ -22,8 +22,7 @@ import ( // planner is used to manage the submitted allocation plans that are waiting // to be accessed by the leader type planner struct { - *Server - log log.Logger + srv *Server // planQueue is used to manage the submitted allocation // plans that are waiting to be assessed by the leader @@ -63,8 +62,7 @@ func newPlanner(s *Server) (*planner, error) { } return &planner{ - Server: s, - log: log, + srv: s, planQueue: planQueue, badNodeTracker: badNodeTracker, }, nil @@ -157,16 +155,16 @@ func (p *planner) planApply() { if planIndexCh == nil || snap == nil { snap, err = p.snapshotMinIndex(prevPlanResultIndex, pending.plan.SnapshotIndex) if err != nil { - p.logger.Error("failed to snapshot state", "error", err) + p.srv.logger.Error("failed to snapshot state", "error", err) pending.respond(nil, err) continue } } // Evaluate the plan - result, err := evaluatePlan(pool, snap, pending.plan, p.logger) + result, err := evaluatePlan(pool, snap, pending.plan, p.srv.logger) if err != nil { - p.logger.Error("failed to evaluate plan", "error", err) + p.srv.logger.Error("failed to evaluate plan", "error", err) pending.respond(nil, err) continue } @@ -192,7 +190,7 @@ func (p *planner) planApply() { prevPlanResultIndex = max(prevPlanResultIndex, idx) snap, err = p.snapshotMinIndex(prevPlanResultIndex, pending.plan.SnapshotIndex) if err != nil { - p.logger.Error("failed to update snapshot state", "error", err) + p.srv.logger.Error("failed to update snapshot state", "error", err) pending.respond(nil, err) continue } @@ -201,7 +199,7 @@ func (p *planner) planApply() { // Dispatch the Raft transaction for the plan future, err := p.applyPlan(pending.plan, result, snap) if err != nil { - p.logger.Error("failed to submit plan", "error", err) + p.srv.logger.Error("failed to submit plan", "error", err) pending.respond(nil, err) continue } @@ -229,7 +227,7 @@ func (p *planner) snapshotMinIndex(prevPlanResultIndex, planSnapshotIndex uint64 // because schedulers won't dequeue more work while waiting. const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(context.Background(), timeout) - snap, err := p.fsm.State().SnapshotMinIndex(ctx, minIndex) + snap, err := p.srv.fsm.State().SnapshotMinIndex(ctx, minIndex) cancel() if err == context.DeadlineExceeded { return nil, fmt.Errorf("timed out after %s waiting for index=%d (previous plan result index=%d; plan snapshot index=%d)", @@ -258,7 +256,7 @@ func (p *planner) applyPlan(plan *structs.Plan, result *structs.PlanResult, snap preemptedJobIDs := make(map[structs.NamespacedID]struct{}) - if ServersMeetMinimumVersion(p.Members(), p.Region(), MinVersionPlanNormalization, true) { + if ServersMeetMinimumVersion(p.srv.Members(), p.srv.Region(), MinVersionPlanNormalization, true) { // Initialize the allocs request using the new optimized log entry format. // Determine the minimum number of updates, could be more if there // are multiple updates per node @@ -280,7 +278,7 @@ func (p *planner) applyPlan(plan *structs.Plan, result *structs.PlanResult, snap // to approximate the scheduling time. updateAllocTimestamps(req.AllocsUpdated, unixNow) - err := signAllocIdentities(p.Server.encrypter, plan.Job, req.AllocsUpdated, now) + err := signAllocIdentities(p.srv.encrypter, plan.Job, req.AllocsUpdated, now) if err != nil { return nil, err } @@ -331,7 +329,7 @@ func (p *planner) applyPlan(plan *structs.Plan, result *structs.PlanResult, snap var evals []*structs.Evaluation for preemptedJobID := range preemptedJobIDs { - job, _ := p.State().JobByID(nil, preemptedJobID.Namespace, preemptedJobID.ID) + job, _ := p.srv.State().JobByID(nil, preemptedJobID.Namespace, preemptedJobID.ID) if job != nil { eval := &structs.Evaluation{ ID: uuid.Generate(), @@ -350,14 +348,14 @@ func (p *planner) applyPlan(plan *structs.Plan, result *structs.PlanResult, snap req.PreemptionEvals = evals // Dispatch the Raft transaction - future, err := p.raftApplyFuture(structs.ApplyPlanResultsRequestType, &req) + future, err := p.srv.raftApplyFuture(structs.ApplyPlanResultsRequestType, &req) if err != nil { return nil, err } // Optimistically apply to our state view if snap != nil { - nextIdx := p.raft.AppliedIndex() + 1 + nextIdx := p.srv.raft.AppliedIndex() + 1 if err := snap.UpsertPlanResults(structs.ApplyPlanResultsRequestType, nextIdx, &req); err != nil { return future, err } @@ -444,7 +442,7 @@ func (p *planner) asyncPlanWait(indexCh chan<- uint64, future raft.ApplyFuture, // Wait for the plan to apply if err := future.Error(); err != nil { - p.logger.Error("failed to apply plan", "error", err) + p.srv.logger.Error("failed to apply plan", "error", err) pending.respond(nil, err) return } diff --git a/nomad/rpc.go b/nomad/rpc.go index d5154c92d..f8dcfc0a4 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -44,7 +44,7 @@ const ( ) type rpcHandler struct { - *Server + srv *Server // connLimiter is used to limit the number of RPC connections per // remote address. It is distinct from the HTTP connection limit. @@ -68,7 +68,7 @@ func newRpcHandler(s *Server) *rpcHandler { logger := s.logger.NamedIntercept("rpc") r := rpcHandler{ - Server: s, + srv: s, connLimit: s.config.RPCMaxConnsPerClient, logger: logger, gologger: logger.StandardLoggerIntercept(&log.StandardLoggerOptions{InferLevels: true}), @@ -177,7 +177,7 @@ func (ctx *RPCContext) GetRemoteIP() (net.IP, error) { // listen is used to listen for incoming RPC connections func (r *rpcHandler) listen(ctx context.Context) { - defer close(r.listenerCh) + defer close(r.srv.listenerCh) var acceptLoopDelay time.Duration for { @@ -189,9 +189,9 @@ func (r *rpcHandler) listen(ctx context.Context) { } // Accept a connection - conn, err := r.rpcListener.Accept() + conn, err := r.srv.rpcListener.Accept() if err != nil { - if r.shutdown { + if r.srv.shutdown { return } r.handleAcceptErr(ctx, err, &acceptLoopDelay) @@ -263,8 +263,8 @@ func (r *rpcHandler) handleAcceptErr(ctx context.Context, err error, loopDelay * func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { // Limit how long an unauthenticated client can hold the connection // open before they send the magic byte. - if !rpcCtx.TLS && r.config.RPCHandshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(r.config.RPCHandshakeTimeout)) + if !rpcCtx.TLS && r.srv.config.RPCHandshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(r.srv.config.RPCHandshakeTimeout)) } // Read a single byte @@ -279,13 +279,13 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC // Reset the deadline as we aren't sure what is expected next - it depends on // the protocol. - if !rpcCtx.TLS && r.config.RPCHandshakeTimeout > 0 { + if !rpcCtx.TLS && r.srv.config.RPCHandshakeTimeout > 0 { conn.SetDeadline(time.Time{}) } // Enforce TLS if EnableRPC is set - if r.config.TLSConfig.EnableRPC && !rpcCtx.TLS && pool.RPCType(buf[0]) != pool.RpcTLS { - if !r.config.TLSConfig.RPCUpgradeMode { + if r.srv.config.TLSConfig.EnableRPC && !rpcCtx.TLS && pool.RPCType(buf[0]) != pool.RpcTLS { + if !r.srv.config.TLSConfig.RPCUpgradeMode { r.logger.Warn("non-TLS connection attempted with RequireTLS set", "remote_addr", conn.RemoteAddr()) conn.Close() return @@ -297,12 +297,12 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC case pool.RpcNomad: // Create an RPC Server and handle the request server := rpc.NewServer() - r.setupRpcServer(server, rpcCtx) + r.srv.setupRpcServer(server, rpcCtx) r.handleNomadConn(ctx, conn, server) // Remove any potential mapping between a NodeID to this connection and // close the underlying connection. - r.removeNodeConn(rpcCtx) + r.srv.removeNodeConn(rpcCtx) case pool.RpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) @@ -311,13 +311,13 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC conn.Close() return } - r.raftLayer.Handoff(ctx, conn) + r.srv.raftLayer.Handoff(ctx, conn) case pool.RpcMultiplex: r.handleMultiplex(ctx, conn, rpcCtx) case pool.RpcTLS: - if r.rpcTLS == nil { + if r.srv.rpcTLS == nil { r.logger.Warn("TLS connection attempted, server not configured for TLS") conn.Close() return @@ -330,7 +330,7 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC return } - conn = tls.Server(conn, r.rpcTLS) + conn = tls.Server(conn, r.srv.rpcTLS) // Force a handshake so we can get information about the TLS connection // state. @@ -344,8 +344,8 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC // Enforce handshake timeout during TLS handshake to prevent // unauthenticated users from holding connections open // indefinitely. - if r.config.RPCHandshakeTimeout > 0 { - tlsConn.SetDeadline(time.Now().Add(r.config.RPCHandshakeTimeout)) + if r.srv.config.RPCHandshakeTimeout > 0 { + tlsConn.SetDeadline(time.Now().Add(r.srv.config.RPCHandshakeTimeout)) } if err := tlsConn.Handshake(); err != nil { @@ -355,7 +355,7 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC } // Reset the deadline as unauthenticated users have now been rejected. - if r.config.RPCHandshakeTimeout > 0 { + if r.srv.config.RPCHandshakeTimeout > 0 { tlsConn.SetDeadline(time.Time{}) } @@ -402,7 +402,7 @@ func (r *rpcHandler) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx defer func() { // Remove any potential mapping between a NodeID to this connection and // close the underlying connection. - r.removeNodeConn(rpcCtx) + r.srv.removeNodeConn(rpcCtx) conn.Close() }() @@ -420,7 +420,7 @@ func (r *rpcHandler) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx // Create the RPC server for this connection rpcServer := rpc.NewServer() - r.setupRpcServer(rpcServer, rpcCtx) + r.srv.setupRpcServer(rpcServer, rpcCtx) for { // stop handling connections if context was cancelled @@ -448,7 +448,7 @@ func (r *rpcHandler) handleNomadConn(ctx context.Context, conn net.Conn, server case <-ctx.Done(): r.logger.Info("closing server RPC connection") return - case <-r.shutdownCh: + case <-r.srv.shutdownCh: return default: } @@ -481,7 +481,7 @@ func (r *rpcHandler) handleStreamingConn(conn net.Conn) { } ack := structs.StreamingRpcAck{} - handler, err := r.streamingRpcs.GetHandler(header.Method) + handler, err := r.srv.streamingRpcs.GetHandler(header.Method) if err != nil { r.logger.Error("streaming RPC error", "error", err, "connection", conn) metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) @@ -511,7 +511,7 @@ func (r *rpcHandler) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCt defer func() { // Remove any potential mapping between a NodeID to this connection and // close the underlying connection. - r.removeNodeConn(rpcCtx) + r.srv.removeNodeConn(rpcCtx) conn.Close() }() @@ -529,7 +529,7 @@ func (r *rpcHandler) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCt // Create the RPC server for this connection rpcServer := rpc.NewServer() - r.setupRpcServer(rpcServer, rpcCtx) + r.srv.setupRpcServer(rpcServer, rpcCtx) for { // stop handling connections if context was cancelled @@ -579,7 +579,7 @@ func (r *rpcHandler) forward(method string, info structs.RPCInfo, args interface } // Handle region forwarding - if region != r.config.Region { + if region != r.srv.config.Region { // Mark that we are forwarding the RPC info.SetForwarded() err := r.forwardRegion(region, method, args, reply) @@ -616,10 +616,10 @@ func (r *rpcHandler) getLeaderForRPC() (*serverParts, error) { CHECK_LEADER: // Find the leader - isLeader, remoteServer := r.getLeader() + isLeader, remoteServer := r.srv.getLeader() // Handle the case we are the leader - if isLeader && r.Server.isReadyForConsistentReads() { + if isLeader && r.srv.isReadyForConsistentReads() { return nil, nil } @@ -632,12 +632,12 @@ CHECK_LEADER: if firstCheck.IsZero() { firstCheck = time.Now() } - if time.Since(firstCheck) < r.config.RPCHoldTimeout { - jitter := helper.RandomStagger(r.config.RPCHoldTimeout / structs.JitterFraction) + if time.Since(firstCheck) < r.srv.config.RPCHoldTimeout { + jitter := helper.RandomStagger(r.srv.config.RPCHoldTimeout / structs.JitterFraction) select { case <-time.After(jitter): goto CHECK_LEADER - case <-r.shutdownCh: + case <-r.srv.shutdownCh: } } @@ -680,7 +680,7 @@ func (r *rpcHandler) forwardLeader(server *serverParts, method string, args inte if server == nil { return structs.ErrNoLeader } - return r.connPool.RPC(r.config.Region, server.Addr, method, args, reply) + return r.srv.connPool.RPC(r.srv.config.Region, server.Addr, method, args, reply) } // forwardServer is used to forward an RPC call to a particular server @@ -689,14 +689,14 @@ func (r *rpcHandler) forwardServer(server *serverParts, method string, args inte if server == nil { return errors.New("must be given a valid server address") } - return r.connPool.RPC(r.config.Region, server.Addr, method, args, reply) + return r.srv.connPool.RPC(r.srv.config.Region, server.Addr, method, args, reply) } func (r *rpcHandler) findRegionServer(region string) (*serverParts, error) { - r.peerLock.RLock() - defer r.peerLock.RUnlock() + r.srv.peerLock.RLock() + defer r.srv.peerLock.RUnlock() - servers := r.peers[region] + servers := r.srv.peers[region] if len(servers) == 0 { r.logger.Warn("no path found to region", "region", region) return nil, structs.ErrNoRegionPath @@ -716,15 +716,15 @@ func (r *rpcHandler) forwardRegion(region, method string, args interface{}, repl // Forward to remote Nomad metrics.IncrCounter([]string{"nomad", "rpc", "cross-region", region}, 1) - return r.connPool.RPC(region, server.Addr, method, args, reply) + return r.srv.connPool.RPC(region, server.Addr, method, args, reply) } func (r *rpcHandler) getServer(region, serverID string) (*serverParts, error) { // Bail if we can't find any servers - r.peerLock.RLock() - defer r.peerLock.RUnlock() + r.srv.peerLock.RLock() + defer r.srv.peerLock.RUnlock() - servers := r.peers[region] + servers := r.srv.peers[region] if len(servers) == 0 { r.logger.Warn("no path found to region", "region", region) return nil, structs.ErrNoRegionPath @@ -744,7 +744,7 @@ func (r *rpcHandler) getServer(region, serverID string) (*serverParts, error) { // initial handshake, returning the connection or an error. It is the callers // responsibility to close the connection if there is no returned error. func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn, error) { - c, err := r.connPool.StreamingRPC(r.config.Region, server.Addr) + c, err := r.srv.connPool.StreamingRPC(r.srv.config.Region, server.Addr) if err != nil { return nil, err } @@ -823,12 +823,12 @@ func (s *Server) raftApply(t structs.MessageType, msg any) (any, uint64, error) // setQueryMeta is used to populate the QueryMeta data for an RPC call func (r *rpcHandler) setQueryMeta(m *structs.QueryMeta) { - if r.IsLeader() { + if r.srv.IsLeader() { m.LastContact = 0 m.KnownLeader = true } else { - m.LastContact = time.Since(r.raft.LastContact()) - leaderAddr, _ := r.raft.LeaderWithID() + m.LastContact = time.Since(r.srv.raft.LastContact()) + leaderAddr, _ := r.srv.raft.LeaderWithID() m.KnownLeader = (leaderAddr != "") } } @@ -878,7 +878,7 @@ RUN_QUERY: // We capture the state store and its abandon channel but pass a snapshot to // the blocking query function. We operate on the snapshot to allow separate // calls to the state store not all wrapped within the same transaction. - state = r.fsm.State() + state = r.srv.fsm.State() abandonCh := state.AbandonCh() snap, _ := state.Snapshot() stateSnap := &snap.StateStore @@ -907,13 +907,13 @@ RUN_QUERY: func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error { // TLS is not configured or not to be enforced - tlsConf := r.config.TLSConfig + tlsConf := r.srv.config.TLSConfig if !tlsConf.EnableRPC || !tlsConf.VerifyServerHostname || tlsConf.RPCUpgradeMode { return nil } // check that `server..nomad` is present in cert - expected := "server." + r.Region() + ".nomad" + expected := "server." + r.srv.Region() + ".nomad" err := rpcCtx.ValidateCertificateForName(expected) if err != nil { cert := rpcCtx.Certificate()