From f3d53e3e2b4c96f5c3e87f5bcb9a1a48ec9e956b Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Fri, 7 Mar 2025 10:04:59 -0500 Subject: [PATCH] CSI: restart task on failing initial probe, instead of killing it (#25307) When a CSI plugin is launched, we probe it until the csi_plugin.health_timeout expires (by default 30s). But if the plugin never becomes healthy, we're not restarting the task as documented. Update the plugin supervisor to trigger a restart instead. We still exit the supervisor loop at that point to avoid having the supervisor send probes to a task that isn't running yet. This requires reworking the poststart hook to allow the supervisor loop to be restarted when the task restarts. In doing so, I identified that we weren't respecting the task kill context from the post start hook, which would leave the supervisor running in the window between when a task is killed because it failed and its stop hooks were triggered. Combine the two contexts to make sure we stop the supervisor whichever context gets closed first. Fixes: https://github.com/hashicorp/nomad/issues/25293 Ref: https://hashicorp.atlassian.net/browse/NET-12264 --- .changelog/25307.txt | 3 + .../taskrunner/plugin_supervisor_hook.go | 84 ++++++++++++------- 2 files changed, 55 insertions(+), 32 deletions(-) create mode 100644 .changelog/25307.txt diff --git a/.changelog/25307.txt b/.changelog/25307.txt new file mode 100644 index 000000000..e0bafec12 --- /dev/null +++ b/.changelog/25307.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where plugins that failed initial fingerprints would not be restarted +``` diff --git a/client/allocrunner/taskrunner/plugin_supervisor_hook.go b/client/allocrunner/taskrunner/plugin_supervisor_hook.go index e98f7d0da..3972b38ee 100644 --- a/client/allocrunner/taskrunner/plugin_supervisor_hook.go +++ b/client/allocrunner/taskrunner/plugin_supervisor_hook.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/LK4D4/joincontext" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" @@ -45,9 +46,10 @@ type csiPluginSupervisorHook struct { eventEmitter ti.EventEmitter lifecycle ti.TaskLifecycle - shutdownCtx context.Context - shutdownCancelFn context.CancelFunc - runOnce sync.Once + supervisorIsRunningLock sync.Mutex + supervisorIsRunning bool + shutdownCtx context.Context + shutdownCancelFn context.CancelFunc // previousHealthstate is used by the supervisor goroutine to track historic // health states for gating task events. @@ -120,6 +122,7 @@ func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPlugi task.CSIPluginConfig.HealthTimeout = 30 * time.Second } + // this context will be closed only csiPluginSupervisorHookConfig.Stop shutdownCtx, cancelFn := context.WithCancel(context.Background()) hook := &csiPluginSupervisorHook{ @@ -154,11 +157,11 @@ func (h *csiPluginSupervisorHook) Prestart(ctx context.Context, // Create the mount directory that the container will access if it doesn't // already exist. Default to only nomad user access. if err := os.MkdirAll(h.mountPoint, 0700); err != nil && !os.IsExist(err) { - return fmt.Errorf("failed to create mount point: %v", err) + return fmt.Errorf("failed to create mount point: %w", err) } if err := os.MkdirAll(h.socketMountPoint, 0700); err != nil && !os.IsExist(err) { - return fmt.Errorf("failed to create socket mount point: %v", err) + return fmt.Errorf("failed to create socket mount point: %w", err) } // where the socket will be mounted @@ -230,19 +233,23 @@ func (h *csiPluginSupervisorHook) setSocketHook() { h.socketPath = filepath.Join(h.socketMountPoint, structs.CSISocketName) } -// Poststart is called after the task has started. Poststart is not -// called if the allocation is terminal. +// Poststart is called after the task has started (or restarted). Poststart is +// not called if the allocation is terminal. // // The context is cancelled if the task is killed. -func (h *csiPluginSupervisorHook) Poststart(_ context.Context, _ *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error { +func (h *csiPluginSupervisorHook) Poststart(ctx context.Context, _ *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error { - // If we're already running the supervisor routine, then we don't need to try - // and restart it here as it only terminates on `Stop` hooks. - h.runOnce.Do(func() { - h.setSocketHook() - go h.ensureSupervisorLoop(h.shutdownCtx) - }) + // If we're already running the supervisor routine, then we don't need to + // try and restart it here as it only terminates on `Stop` hooks and health + // timeouts (which restart the task) + h.supervisorIsRunningLock.Lock() + defer h.supervisorIsRunningLock.Unlock() + if h.supervisorIsRunning { + return nil + } + h.setSocketHook() + go h.ensureSupervisorLoop(ctx) return nil } @@ -263,15 +270,27 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) { client := csi.NewClient(h.socketPath, h.logger.Named("csi_client").With( "plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type)) - defer client.Close() + + // this context joins the context we get from the Poststart hook (closed by + // task failure) and the one closed by the Stop hook (triggered by task + // stop) + supervisorCtx, supervisorCtxCancel := joincontext.Join(ctx, h.shutdownCtx) + + // this context is used for the health timeout. If we can't connect within + // this deadline, assume the plugin is broken so we can restart the task + startCtx, startCancelFn := context.WithTimeout(ctx, h.task.CSIPluginConfig.HealthTimeout) + + defer func() { + h.supervisorIsRunningLock.Lock() + h.supervisorIsRunning = false + client.Close() + supervisorCtxCancel() + startCancelFn() + h.supervisorIsRunningLock.Unlock() + }() t := time.NewTimer(0) - // We're in Poststart at this point, so if we can't connect within - // this deadline, assume it's broken so we can restart the task - startCtx, startCancelFn := context.WithTimeout(ctx, h.task.CSIPluginConfig.HealthTimeout) - defer startCancelFn() - var err error var pluginHealthy bool @@ -280,7 +299,9 @@ WAITFORREADY: for { select { case <-startCtx.Done(): - h.kill(ctx, fmt.Errorf("CSI plugin failed probe: %v", err)) + h.restartTask(ctx, fmt.Errorf("CSI plugin failed probe: %w", err)) + return + case <-supervisorCtx.Done(): return case <-t.C: pluginHealthy, err = h.supervisorLoopOnce(startCtx, client) @@ -306,7 +327,7 @@ WAITFORREADY: // Step 2: Register the plugin with the catalog. deregisterPluginFn, err := h.registerPlugin(client, h.socketPath) if err != nil { - h.kill(ctx, fmt.Errorf("CSI plugin failed to register: %v", err)) + h.restartTask(ctx, fmt.Errorf("CSI plugin failed to register: %w", err)) return } // De-register plugins on task shutdown @@ -317,10 +338,10 @@ WAITFORREADY: t.Reset(0) for { select { - case <-ctx.Done(): + case <-supervisorCtx.Done(): return case <-t.C: - pluginHealthy, err := h.supervisorLoopOnce(ctx, client) + pluginHealthy, err := h.supervisorLoopOnce(supervisorCtx, client) if err != nil { h.logger.Error("CSI plugin fingerprinting failed", "error", err) } @@ -357,7 +378,7 @@ func (h *csiPluginSupervisorHook) registerPlugin(client csi.CSIPlugin, socketPat // to get its vendor name and version info, err := client.PluginInfo() if err != nil { - return nil, fmt.Errorf("failed to probe plugin: %v", err) + return nil, fmt.Errorf("failed to probe plugin: %w", err) } mkInfoFn := func(pluginType string) *dynamicplugins.PluginInfo { @@ -448,18 +469,17 @@ func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskSt return nil } -func (h *csiPluginSupervisorHook) kill(ctx context.Context, reason error) { - h.logger.Error("killing task because plugin failed", "error", reason) +func (h *csiPluginSupervisorHook) restartTask(ctx context.Context, reason error) { + h.logger.Error("restarting task because plugin failed", "error", reason) event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) event.SetMessage(fmt.Sprintf("Error: %v", reason.Error())) h.eventEmitter.EmitEvent(event) - if err := h.lifecycle.Kill(ctx, - structs.NewTaskEvent(structs.TaskKilling). - SetFailsTask(). + if err := h.lifecycle.Restart(ctx, + structs.NewTaskEvent(structs.TaskRestarting). SetDisplayMessage(fmt.Sprintf("CSI plugin did not become healthy before configured %v health timeout", h.task.CSIPluginConfig.HealthTimeout.String())), - ); err != nil { - h.logger.Error("failed to kill task", "kill_reason", reason, "error", err) + true); err != nil { + h.logger.Error("failed to restart task", "restart_reason", reason, "error", err) } }