drivers: defer executor cleanup func to fix executor leak (#24495)

This commit is contained in:
Michael Smithhisler
2024-12-02 12:25:32 -05:00
committed by GitHub
parent e963d55ea0
commit 11ae64acb0
4 changed files with 96 additions and 23 deletions

3
.changelog/24495.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
drivers: fix executor leak when drivers error starting tasks
```

View File

@@ -456,7 +456,7 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return nil return nil
} }
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) { func (d *Driver) StartTask(cfg *drivers.TaskConfig) (handle *drivers.TaskHandle, network *drivers.DriverNetwork, err error) {
if _, ok := d.tasks.Get(cfg.ID); ok { if _, ok := d.tasks.Get(cfg.ID); ok {
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID) return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
} }
@@ -481,7 +481,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
} }
d.logger.Info("starting task", "driver_cfg", hclog.Fmt("%+v", driverConfig)) d.logger.Info("starting task", "driver_cfg", hclog.Fmt("%+v", driverConfig))
handle := drivers.NewTaskHandle(taskHandleVersion) handle = drivers.NewTaskHandle(taskHandleVersion)
handle.Config = cfg handle.Config = cfg
pluginLogFile := filepath.Join(cfg.TaskDir().Dir, "executor.out") pluginLogFile := filepath.Join(cfg.TaskDir().Dir, "executor.out")
@@ -492,13 +492,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
Compute: d.compute, Compute: d.compute,
} }
exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}
user := cfg.User user := cfg.User
if cfg.DNS != nil { if cfg.DNS != nil {
dnsMount, err := resolvconf.GenerateDNSMount(cfg.TaskDir().Dir, cfg.DNS) dnsMount, err := resolvconf.GenerateDNSMount(cfg.TaskDir().Dir, cfg.DNS)
@@ -516,6 +509,19 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
} }
d.logger.Debug("task capabilities", "capabilities", caps) d.logger.Debug("task capabilities", "capabilities", caps)
exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}
// prevent leaking executor in error scenarios
defer func() {
if err != nil {
pluginClient.Kill()
}
}()
execCmd := &executor.ExecCommand{ execCmd := &executor.ExecCommand{
Cmd: driverConfig.Command, Cmd: driverConfig.Command,
Args: driverConfig.Args, Args: driverConfig.Args,
@@ -538,7 +544,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
ps, err := exec.Launch(execCmd) ps, err := exec.Launch(execCmd)
if err != nil { if err != nil {
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to launch command with executor: %v", err) return nil, nil, fmt.Errorf("failed to launch command with executor: %v", err)
} }
@@ -562,7 +567,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
if err := handle.SetDriverState(&driverState); err != nil { if err := handle.SetDriverState(&driverState); err != nil {
d.logger.Error("failed to start task, error setting driver state", "error", err) d.logger.Error("failed to start task, error setting driver state", "error", err)
_ = exec.Shutdown("", 0) _ = exec.Shutdown("", 0)
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to set driver state: %v", err) return nil, nil, fmt.Errorf("failed to set driver state: %v", err)
} }

View File

@@ -35,6 +35,7 @@ import (
"github.com/hashicorp/nomad/testutil" "github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test/must" "github.com/shoenig/test/must"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
) )
type mockIDValidator struct{} type mockIDValidator struct{}
@@ -347,9 +348,70 @@ func TestExecDriver_StartWaitRecover(t *testing.T) {
require.NoError(t, harness.DestroyTask(task.ID, true)) require.NoError(t, harness.DestroyTask(task.ID, true))
} }
func TestExecDriver_NoOrphanedExecutor(t *testing.T) {
ci.Parallel(t)
ctestutils.ExecCompatible(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
d := newExecDriverTest(t, ctx)
harness := dtestutil.NewDriverHarness(t, d)
defer harness.Kill()
config := &Config{
NoPivotRoot: false,
DefaultModePID: executor.IsolationModePrivate,
DefaultModeIPC: executor.IsolationModePrivate,
}
var data []byte
must.NoError(t, base.MsgPackEncode(&data, config))
baseConfig := &base.Config{
PluginConfig: data,
AgentConfig: &base.AgentConfig{
Driver: &base.ClientDriverConfig{
Topology: d.(*Driver).nomadConfig.Topology,
},
},
}
must.NoError(t, harness.SetConfig(baseConfig))
allocID := uuid.Generate()
taskName := "test"
task := &drivers.TaskConfig{
AllocID: allocID,
ID: uuid.Generate(),
Name: taskName,
Resources: testResources(allocID, taskName),
}
cleanup := harness.MkAllocDir(task, true)
defer cleanup()
taskConfig := map[string]interface{}{}
taskConfig["command"] = "force-an-error"
must.NoError(t, task.EncodeConcreteDriverConfig(&taskConfig))
_, _, err := harness.StartTask(task)
must.Error(t, err)
defer harness.DestroyTask(task.ID, true)
testPid := unix.Getpid()
tids, err := os.ReadDir(fmt.Sprintf("/proc/%d/task", testPid))
must.NoError(t, err)
for _, tid := range tids {
children, err := os.ReadFile(fmt.Sprintf("/proc/%d/task/%s/children", testPid, tid.Name()))
must.NoError(t, err)
pids := strings.Fields(string(children))
must.Eq(t, 0, len(pids))
}
}
// TestExecDriver_NoOrphans asserts that when the main // TestExecDriver_NoOrphans asserts that when the main
// task dies, the orphans in the PID namespaces are killed by the kernel // task dies, the orphans in the PID namespaces are killed by the kernel
func TestExecDriver_NoOrphans(t *testing.T) { func TestExecDriver_NoOrphanedTasks(t *testing.T) {
ci.Parallel(t) ci.Parallel(t)
ctestutils.ExecCompatible(t) ctestutils.ExecCompatible(t)

View File

@@ -429,7 +429,7 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return nil return nil
} }
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) { func (d *Driver) StartTask(cfg *drivers.TaskConfig) (handle *drivers.TaskHandle, network *drivers.DriverNetwork, err error) {
if _, ok := d.tasks.Get(cfg.ID); ok { if _, ok := d.tasks.Get(cfg.ID); ok {
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID) return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
} }
@@ -456,7 +456,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
d.logger.Info("starting java task", "driver_cfg", hclog.Fmt("%+v", driverConfig), "args", args) d.logger.Info("starting java task", "driver_cfg", hclog.Fmt("%+v", driverConfig), "args", args)
handle := drivers.NewTaskHandle(taskHandleVersion) handle = drivers.NewTaskHandle(taskHandleVersion)
handle.Config = cfg handle.Config = cfg
pluginLogFile := filepath.Join(cfg.TaskDir().Dir, "executor.out") pluginLogFile := filepath.Join(cfg.TaskDir().Dir, "executor.out")
@@ -467,13 +467,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
Compute: d.nomadConfig.Topology.Compute(), Compute: d.nomadConfig.Topology.Compute(),
} }
exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}
user := cfg.User user := cfg.User
if user == "" { if user == "" {
user = "nobody" user = "nobody"
@@ -495,6 +488,19 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
} }
d.logger.Debug("task capabilities", "capabilities", caps) d.logger.Debug("task capabilities", "capabilities", caps)
exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}
// prevent leaking executor in error scenarios
defer func() {
if err != nil {
pluginClient.Kill()
}
}()
execCmd := &executor.ExecCommand{ execCmd := &executor.ExecCommand{
Cmd: absPath, Cmd: absPath,
Args: args, Args: args,
@@ -516,7 +522,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
ps, err := exec.Launch(execCmd) ps, err := exec.Launch(execCmd)
if err != nil { if err != nil {
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to launch command with executor: %v", err) return nil, nil, fmt.Errorf("failed to launch command with executor: %v", err)
} }
@@ -540,7 +545,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
if err := handle.SetDriverState(&driverState); err != nil { if err := handle.SetDriverState(&driverState); err != nil {
d.logger.Error("failed to start task, error setting driver state", "error", err) d.logger.Error("failed to start task, error setting driver state", "error", err)
exec.Shutdown("", 0) exec.Shutdown("", 0)
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to set driver state: %v", err) return nil, nil, fmt.Errorf("failed to set driver state: %v", err)
} }