CSI: restart task on failing initial probe, instead of killing it (#25307)

When a CSI plugin is launched, we probe it until the csi_plugin.health_timeout
expires (by default 30s). But if the plugin never becomes healthy, we're not
restarting the task as documented.

Update the plugin supervisor to trigger a restart instead. We still exit the
supervisor loop at that point to avoid having the supervisor send probes to a
task that isn't running yet. This requires reworking the poststart hook to allow
the supervisor loop to be restarted when the task restarts.

In doing so, I identified that we weren't respecting the task kill context from
the post start hook, which would leave the supervisor running in the window
between when a task is killed because it failed and its stop hooks were
triggered. Combine the two contexts to make sure we stop the supervisor
whichever context gets closed first.

Fixes: https://github.com/hashicorp/nomad/issues/25293
Ref: https://hashicorp.atlassian.net/browse/NET-12264
This commit is contained in:
Tim Gross
2025-03-07 10:04:59 -05:00
committed by GitHub
parent 768ba78e2d
commit f3d53e3e2b
2 changed files with 55 additions and 32 deletions

3
.changelog/25307.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
csi: Fixed a bug where plugins that failed initial fingerprints would not be restarted
```

View File

@@ -11,6 +11,7 @@ import (
"sync"
"time"
"github.com/LK4D4/joincontext"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
@@ -45,9 +46,10 @@ type csiPluginSupervisorHook struct {
eventEmitter ti.EventEmitter
lifecycle ti.TaskLifecycle
shutdownCtx context.Context
shutdownCancelFn context.CancelFunc
runOnce sync.Once
supervisorIsRunningLock sync.Mutex
supervisorIsRunning bool
shutdownCtx context.Context
shutdownCancelFn context.CancelFunc
// previousHealthstate is used by the supervisor goroutine to track historic
// health states for gating task events.
@@ -120,6 +122,7 @@ func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPlugi
task.CSIPluginConfig.HealthTimeout = 30 * time.Second
}
// this context will be closed only csiPluginSupervisorHookConfig.Stop
shutdownCtx, cancelFn := context.WithCancel(context.Background())
hook := &csiPluginSupervisorHook{
@@ -154,11 +157,11 @@ func (h *csiPluginSupervisorHook) Prestart(ctx context.Context,
// Create the mount directory that the container will access if it doesn't
// already exist. Default to only nomad user access.
if err := os.MkdirAll(h.mountPoint, 0700); err != nil && !os.IsExist(err) {
return fmt.Errorf("failed to create mount point: %v", err)
return fmt.Errorf("failed to create mount point: %w", err)
}
if err := os.MkdirAll(h.socketMountPoint, 0700); err != nil && !os.IsExist(err) {
return fmt.Errorf("failed to create socket mount point: %v", err)
return fmt.Errorf("failed to create socket mount point: %w", err)
}
// where the socket will be mounted
@@ -230,19 +233,23 @@ func (h *csiPluginSupervisorHook) setSocketHook() {
h.socketPath = filepath.Join(h.socketMountPoint, structs.CSISocketName)
}
// Poststart is called after the task has started. Poststart is not
// called if the allocation is terminal.
// Poststart is called after the task has started (or restarted). Poststart is
// not called if the allocation is terminal.
//
// The context is cancelled if the task is killed.
func (h *csiPluginSupervisorHook) Poststart(_ context.Context, _ *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error {
func (h *csiPluginSupervisorHook) Poststart(ctx context.Context, _ *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error {
// If we're already running the supervisor routine, then we don't need to try
// and restart it here as it only terminates on `Stop` hooks.
h.runOnce.Do(func() {
h.setSocketHook()
go h.ensureSupervisorLoop(h.shutdownCtx)
})
// If we're already running the supervisor routine, then we don't need to
// try and restart it here as it only terminates on `Stop` hooks and health
// timeouts (which restart the task)
h.supervisorIsRunningLock.Lock()
defer h.supervisorIsRunningLock.Unlock()
if h.supervisorIsRunning {
return nil
}
h.setSocketHook()
go h.ensureSupervisorLoop(ctx)
return nil
}
@@ -263,15 +270,27 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) {
client := csi.NewClient(h.socketPath, h.logger.Named("csi_client").With(
"plugin.name", h.task.CSIPluginConfig.ID,
"plugin.type", h.task.CSIPluginConfig.Type))
defer client.Close()
// this context joins the context we get from the Poststart hook (closed by
// task failure) and the one closed by the Stop hook (triggered by task
// stop)
supervisorCtx, supervisorCtxCancel := joincontext.Join(ctx, h.shutdownCtx)
// this context is used for the health timeout. If we can't connect within
// this deadline, assume the plugin is broken so we can restart the task
startCtx, startCancelFn := context.WithTimeout(ctx, h.task.CSIPluginConfig.HealthTimeout)
defer func() {
h.supervisorIsRunningLock.Lock()
h.supervisorIsRunning = false
client.Close()
supervisorCtxCancel()
startCancelFn()
h.supervisorIsRunningLock.Unlock()
}()
t := time.NewTimer(0)
// We're in Poststart at this point, so if we can't connect within
// this deadline, assume it's broken so we can restart the task
startCtx, startCancelFn := context.WithTimeout(ctx, h.task.CSIPluginConfig.HealthTimeout)
defer startCancelFn()
var err error
var pluginHealthy bool
@@ -280,7 +299,9 @@ WAITFORREADY:
for {
select {
case <-startCtx.Done():
h.kill(ctx, fmt.Errorf("CSI plugin failed probe: %v", err))
h.restartTask(ctx, fmt.Errorf("CSI plugin failed probe: %w", err))
return
case <-supervisorCtx.Done():
return
case <-t.C:
pluginHealthy, err = h.supervisorLoopOnce(startCtx, client)
@@ -306,7 +327,7 @@ WAITFORREADY:
// Step 2: Register the plugin with the catalog.
deregisterPluginFn, err := h.registerPlugin(client, h.socketPath)
if err != nil {
h.kill(ctx, fmt.Errorf("CSI plugin failed to register: %v", err))
h.restartTask(ctx, fmt.Errorf("CSI plugin failed to register: %w", err))
return
}
// De-register plugins on task shutdown
@@ -317,10 +338,10 @@ WAITFORREADY:
t.Reset(0)
for {
select {
case <-ctx.Done():
case <-supervisorCtx.Done():
return
case <-t.C:
pluginHealthy, err := h.supervisorLoopOnce(ctx, client)
pluginHealthy, err := h.supervisorLoopOnce(supervisorCtx, client)
if err != nil {
h.logger.Error("CSI plugin fingerprinting failed", "error", err)
}
@@ -357,7 +378,7 @@ func (h *csiPluginSupervisorHook) registerPlugin(client csi.CSIPlugin, socketPat
// to get its vendor name and version
info, err := client.PluginInfo()
if err != nil {
return nil, fmt.Errorf("failed to probe plugin: %v", err)
return nil, fmt.Errorf("failed to probe plugin: %w", err)
}
mkInfoFn := func(pluginType string) *dynamicplugins.PluginInfo {
@@ -448,18 +469,17 @@ func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskSt
return nil
}
func (h *csiPluginSupervisorHook) kill(ctx context.Context, reason error) {
h.logger.Error("killing task because plugin failed", "error", reason)
func (h *csiPluginSupervisorHook) restartTask(ctx context.Context, reason error) {
h.logger.Error("restarting task because plugin failed", "error", reason)
event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
event.SetMessage(fmt.Sprintf("Error: %v", reason.Error()))
h.eventEmitter.EmitEvent(event)
if err := h.lifecycle.Kill(ctx,
structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
if err := h.lifecycle.Restart(ctx,
structs.NewTaskEvent(structs.TaskRestarting).
SetDisplayMessage(fmt.Sprintf("CSI plugin did not become healthy before configured %v health timeout", h.task.CSIPluginConfig.HealthTimeout.String())),
); err != nil {
h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
true); err != nil {
h.logger.Error("failed to restart task", "restart_reason", reason, "error", err)
}
}