From 6b8ddff1fa36e6bd034e170f2576e9da872678e3 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Wed, 16 Oct 2024 09:20:26 -0400 Subject: [PATCH] windows: set job object for executor and children (#24214) On Windows, if the `raw_exec` driver's executor exits, the child processes are not also killed. Create a Windows "job object" (not to be confused with a Nomad job) and add the executor to it. Child processes of the executor will inherit the job automatically. When the handle to the job object is freed (on executor exit), the job itself is destroyed and this causes all processes in that job to exit. Fixes: https://github.com/hashicorp/nomad/issues/23668 Ref: https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects --- .changelog/24214.txt | 3 + .github/workflows/test-windows.yml | 2 + drivers/rawexec/driver_test.go | 98 ------------------ drivers/rawexec/driver_unix_test.go | 99 +++++++++++++++++++ drivers/rawexec/driver_windows_test.go | 96 ++++++++++++++++++ drivers/shared/executor/executor_test.go | 45 +-------- drivers/shared/executor/executor_windows.go | 33 ++++++- .../shared/executor/executor_windows_test.go | 88 +++++++++++++++++ drivers/shared/executor/utils_test.go | 47 +++++++++ 9 files changed, 369 insertions(+), 142 deletions(-) create mode 100644 .changelog/24214.txt create mode 100644 drivers/rawexec/driver_windows_test.go create mode 100644 drivers/shared/executor/executor_windows_test.go diff --git a/.changelog/24214.txt b/.changelog/24214.txt new file mode 100644 index 000000000..d0e59532d --- /dev/null +++ b/.changelog/24214.txt @@ -0,0 +1,3 @@ +```release-note:bug +windows: Fixed a bug where a crashed executor would orphan task processes +``` diff --git a/.github/workflows/test-windows.yml b/.github/workflows/test-windows.yml index dc5d961a5..6a8b3536c 100644 --- a/.github/workflows/test-windows.yml +++ b/.github/workflows/test-windows.yml @@ -87,6 +87,8 @@ jobs: gotestsum --format=short-verbose \ --junitfile results.xml \ github.com/hashicorp/nomad/drivers/docker \ + github.com/hashicorp/nomad/drivers/rawexec \ + github.com/hashicorp/nomad/drivers/shared/executor \ github.com/hashicorp/nomad/client/lib/fifo \ github.com/hashicorp/nomad/client/logmon \ github.com/hashicorp/nomad/client/allocrunner/taskrunner/template \ diff --git a/drivers/rawexec/driver_test.go b/drivers/rawexec/driver_test.go index 35a60fc2b..df360f5eb 100644 --- a/drivers/rawexec/driver_test.go +++ b/drivers/rawexec/driver_test.go @@ -12,7 +12,6 @@ import ( "path/filepath" "runtime" "strconv" - "sync" "syscall" "testing" "time" @@ -237,103 +236,6 @@ func TestRawExecDriver_StartWait(t *testing.T) { require.NoError(harness.DestroyTask(task.ID, true)) } -func TestRawExecDriver_StartWaitRecoverWaitStop(t *testing.T) { - ci.Parallel(t) - require := require.New(t) - - d := newEnabledRawExecDriver(t) - harness := dtestutil.NewDriverHarness(t, d) - defer harness.Kill() - - config := &Config{Enabled: true} - var data []byte - require.NoError(basePlug.MsgPackEncode(&data, config)) - bconfig := &basePlug.Config{ - PluginConfig: data, - AgentConfig: &base.AgentConfig{ - Driver: &base.ClientDriverConfig{ - Topology: d.nomadConfig.Topology, - }, - }, - } - require.NoError(harness.SetConfig(bconfig)) - - allocID := uuid.Generate() - taskName := "sleep" - task := &drivers.TaskConfig{ - AllocID: allocID, - ID: uuid.Generate(), - Name: taskName, - Env: defaultEnv(), - Resources: testResources(allocID, taskName), - } - tc := &TaskConfig{ - Command: testtask.Path(), - Args: []string{"sleep", "100s"}, - } - require.NoError(task.EncodeConcreteDriverConfig(&tc)) - - testtask.SetTaskConfigEnv(task) - - cleanup := harness.MkAllocDir(task, false) - defer cleanup() - - harness.MakeTaskCgroup(allocID, taskName) - - handle, _, err := harness.StartTask(task) - require.NoError(err) - - ch, err := harness.WaitTask(context.Background(), task.ID) - require.NoError(err) - - var waitDone bool - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - result := <-ch - require.Error(result.Err) - waitDone = true - }() - - originalStatus, err := d.InspectTask(task.ID) - require.NoError(err) - - d.tasks.Delete(task.ID) - - wg.Wait() - require.True(waitDone) - _, err = d.InspectTask(task.ID) - require.Equal(drivers.ErrTaskNotFound, err) - - err = d.RecoverTask(handle) - require.NoError(err) - - status, err := d.InspectTask(task.ID) - require.NoError(err) - require.Exactly(originalStatus, status) - - ch, err = harness.WaitTask(context.Background(), task.ID) - require.NoError(err) - - wg.Add(1) - waitDone = false - go func() { - defer wg.Done() - result := <-ch - require.NoError(result.Err) - require.NotZero(result.ExitCode) - require.Equal(9, result.Signal) - waitDone = true - }() - - time.Sleep(300 * time.Millisecond) - require.NoError(d.StopTask(task.ID, 0, "SIGKILL")) - wg.Wait() - require.NoError(d.DestroyTask(task.ID, false)) - require.True(waitDone) -} - func TestRawExecDriver_Start_Wait_AllocDir(t *testing.T) { ci.Parallel(t) require := require.New(t) diff --git a/drivers/rawexec/driver_unix_test.go b/drivers/rawexec/driver_unix_test.go index 4a620856a..c09e3e0eb 100644 --- a/drivers/rawexec/driver_unix_test.go +++ b/drivers/rawexec/driver_unix_test.go @@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "sync" "syscall" "testing" "time" @@ -23,6 +24,7 @@ import ( "github.com/hashicorp/nomad/helper/testtask" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/plugins/base" + basePlug "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" "github.com/hashicorp/nomad/testutil" @@ -443,3 +445,100 @@ func TestRawExec_ExecTaskStreaming_User(t *testing.T) { require.Empty(t, stderr) require.Contains(t, stdout, "nobody") } + +func TestRawExecDriver_StartWaitRecoverWaitStop(t *testing.T) { + ci.Parallel(t) + require := require.New(t) + + d := newEnabledRawExecDriver(t) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + config := &Config{Enabled: true} + var data []byte + require.NoError(basePlug.MsgPackEncode(&data, config)) + bconfig := &basePlug.Config{ + PluginConfig: data, + AgentConfig: &base.AgentConfig{ + Driver: &base.ClientDriverConfig{ + Topology: d.nomadConfig.Topology, + }, + }, + } + require.NoError(harness.SetConfig(bconfig)) + + allocID := uuid.Generate() + taskName := "sleep" + task := &drivers.TaskConfig{ + AllocID: allocID, + ID: uuid.Generate(), + Name: taskName, + Env: defaultEnv(), + Resources: testResources(allocID, taskName), + } + tc := &TaskConfig{ + Command: testtask.Path(), + Args: []string{"sleep", "100s"}, + } + require.NoError(task.EncodeConcreteDriverConfig(&tc)) + + testtask.SetTaskConfigEnv(task) + + cleanup := harness.MkAllocDir(task, false) + defer cleanup() + + harness.MakeTaskCgroup(allocID, taskName) + + handle, _, err := harness.StartTask(task) + require.NoError(err) + + ch, err := harness.WaitTask(context.Background(), task.ID) + require.NoError(err) + + var waitDone bool + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + result := <-ch + require.Error(result.Err) + waitDone = true + }() + + originalStatus, err := d.InspectTask(task.ID) + require.NoError(err) + + d.tasks.Delete(task.ID) + + wg.Wait() + require.True(waitDone) + _, err = d.InspectTask(task.ID) + require.Equal(drivers.ErrTaskNotFound, err) + + err = d.RecoverTask(handle) + require.NoError(err) + + status, err := d.InspectTask(task.ID) + require.NoError(err) + require.Exactly(originalStatus, status) + + ch, err = harness.WaitTask(context.Background(), task.ID) + require.NoError(err) + + wg.Add(1) + waitDone = false + go func() { + defer wg.Done() + result := <-ch + require.NoError(result.Err) + require.NotZero(result.ExitCode) + require.Equal(9, result.Signal) + waitDone = true + }() + + time.Sleep(300 * time.Millisecond) + require.NoError(d.StopTask(task.ID, 0, "SIGKILL")) + wg.Wait() + require.NoError(d.DestroyTask(task.ID, false)) + require.True(waitDone) +} diff --git a/drivers/rawexec/driver_windows_test.go b/drivers/rawexec/driver_windows_test.go new file mode 100644 index 000000000..68876b037 --- /dev/null +++ b/drivers/rawexec/driver_windows_test.go @@ -0,0 +1,96 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build windows + +package rawexec + +import ( + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/plugins/base" + "github.com/hashicorp/nomad/plugins/drivers" + dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" + "github.com/shoenig/test/must" +) + +// TestRawExecDriver_ExecutorKill verifies that killing the executor will stop +// its child processes +func TestRawExecDriver_ExecutorKill(t *testing.T) { + ci.Parallel(t) + + d := newEnabledRawExecDriver(t) + harness := dtestutil.NewDriverHarness(t, d) + t.Cleanup(harness.Kill) + + config := &Config{Enabled: true} + var data []byte + must.NoError(t, base.MsgPackEncode(&data, config)) + bconfig := &base.Config{ + PluginConfig: data, + AgentConfig: &base.AgentConfig{ + Driver: &base.ClientDriverConfig{ + Topology: d.nomadConfig.Topology, + }, + }, + } + must.NoError(t, harness.SetConfig(bconfig)) + + allocID := uuid.Generate() + taskName := "test" + task := &drivers.TaskConfig{ + AllocID: allocID, + ID: uuid.Generate(), + Name: taskName, + Resources: testResources(allocID, taskName), + } + + taskConfig := map[string]interface{}{} + taskConfig["command"] = "Powershell.exe" + taskConfig["args"] = []string{"sleep", "100s"} + + must.NoError(t, task.EncodeConcreteDriverConfig(&taskConfig)) + + cleanup := harness.MkAllocDir(task, false) + t.Cleanup(cleanup) + + handle, _, err := harness.StartTask(task) + must.NoError(t, err) + + var taskState TaskState + must.NoError(t, handle.GetDriverState(&taskState)) + must.NoError(t, harness.WaitUntilStarted(task.ID, 1*time.Second)) + + // forcibly kill the executor, not the workload + must.NotEq(t, taskState.ReattachConfig.Pid, taskState.Pid) + proc, err := os.FindProcess(taskState.ReattachConfig.Pid) + must.NoError(t, err) + + taskProc, err := os.FindProcess(taskState.Pid) + must.NoError(t, err) + + must.NoError(t, proc.Kill()) + t.Logf("killed %d, waiting on %d to stop", taskState.ReattachConfig.Pid, taskState.Pid) + + t.Cleanup(func() { + if taskProc != nil { + taskProc.Kill() + } + }) + + done := make(chan struct{}) + go func() { + taskProc.Wait() + close(done) + }() + + select { + case <-time.After(5 * time.Second): + t.Fatal("expected child process to exit") + case <-done: + } +} diff --git a/drivers/shared/executor/executor_test.go b/drivers/shared/executor/executor_test.go index 50e415d66..a0e17e666 100644 --- a/drivers/shared/executor/executor_test.go +++ b/drivers/shared/executor/executor_test.go @@ -1,10 +1,11 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 +//go:build !windows + package executor import ( - "bytes" "context" "fmt" "io" @@ -12,7 +13,6 @@ import ( "path/filepath" "runtime" "strings" - "sync" "syscall" "testing" "time" @@ -59,15 +59,6 @@ var ( compute = topology.Compute() ) -type testExecCmd struct { - command *ExecCommand - allocDir *allocdir.AllocDir - - stdout *bytes.Buffer - stderr *bytes.Buffer - outputCopyDone *sync.WaitGroup -} - // testExecutorContext returns an ExecutorContext and AllocDir. // // The caller is responsible for calling AllocDir.Destroy() to cleanup. @@ -123,38 +114,6 @@ func testExecutorCommand(t *testing.T) *testExecCmd { return testCmd } -// configureTLogging configures a test command executor with buffer as Std{out|err} -// but using os.Pipe so it mimics non-test case where cmd is set with files as Std{out|err} -// the buffers can be used to read command output -func configureTLogging(t *testing.T, testcmd *testExecCmd) { - var stdout, stderr bytes.Buffer - var copyDone sync.WaitGroup - - stdoutPr, stdoutPw, err := os.Pipe() - require.NoError(t, err) - - stderrPr, stderrPw, err := os.Pipe() - require.NoError(t, err) - - copyDone.Add(2) - go func() { - defer copyDone.Done() - io.Copy(&stdout, stdoutPr) - }() - go func() { - defer copyDone.Done() - io.Copy(&stderr, stderrPr) - }() - - testcmd.stdout = &stdout - testcmd.stderr = &stderr - testcmd.outputCopyDone = ©Done - - testcmd.command.stdout = stdoutPw - testcmd.command.stderr = stderrPw - return -} - func TestExecutor_Start_Invalid(t *testing.T) { ci.Parallel(t) invalid := "/bin/foobar" diff --git a/drivers/shared/executor/executor_windows.go b/drivers/shared/executor/executor_windows.go index 457f29a6e..25134ece5 100644 --- a/drivers/shared/executor/executor_windows.go +++ b/drivers/shared/executor/executor_windows.go @@ -9,17 +9,48 @@ import ( "fmt" "os" "syscall" + "unsafe" "golang.org/x/sys/windows" ) -// configure new process group for child process +// configure new process group for child process and creates a JobObject for the +// executor. Children of the executor will be created in the same JobObject +// Ref: https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects func (e *UniversalExecutor) setNewProcessGroup() error { // We need to check that as build flags includes windows for this file if e.childCmd.SysProcAttr == nil { e.childCmd.SysProcAttr = &syscall.SysProcAttr{} } e.childCmd.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP + + // note: we don't call CloseHandle on this job handle because we need to + // hold onto it until the executor exits + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return fmt.Errorf("could not create Windows job object for executor: %w", err) + } + + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + _, err = windows.SetInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info))) + if err != nil { + return fmt.Errorf("could not configure Windows job object for executor: %w", err) + } + + handle := windows.CurrentProcess() + err = windows.AssignProcessToJobObject(job, handle) + if err != nil { + return fmt.Errorf("could not assign executor to Windows job object: %w", err) + } + return nil } diff --git a/drivers/shared/executor/executor_windows_test.go b/drivers/shared/executor/executor_windows_test.go new file mode 100644 index 000000000..c54cd4972 --- /dev/null +++ b/drivers/shared/executor/executor_windows_test.go @@ -0,0 +1,88 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build windows + +package executor + +import ( + "context" + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/lib/numalib" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" + "github.com/hashicorp/nomad/plugins/drivers/fsisolation" + "github.com/shoenig/test/must" +) + +// testExecutorCommand sets up a test task environment. +func testExecutorCommand(t *testing.T) *testExecCmd { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + taskEnv := taskenv.NewBuilder(mock.Node(), alloc, task, "global").Build() + + allocDir := allocdir.NewAllocDir(testlog.HCLogger(t), t.TempDir(), t.TempDir(), alloc.ID) + must.NoError(t, allocDir.Build()) + t.Cleanup(func() { allocDir.Destroy() }) + + must.NoError(t, allocDir.NewTaskDir(task).Build(fsisolation.None, nil, task.User)) + td := allocDir.TaskDirs[task.Name] + cmd := &ExecCommand{ + Env: taskEnv.List(), + TaskDir: td.Dir, + Resources: &drivers.Resources{ + NomadResources: &structs.AllocatedTaskResources{ + Cpu: structs.AllocatedCpuResources{ + CpuShares: 500, + }, + Memory: structs.AllocatedMemoryResources{ + MemoryMB: 256, + }, + }, + }, + } + + testCmd := &testExecCmd{ + command: cmd, + allocDir: allocDir, + } + configureTLogging(t, testCmd) + return testCmd +} + +func TestExecutor_ProcessExit(t *testing.T) { + ci.Parallel(t) + + topology := numalib.Scan(numalib.PlatformScanners()) + compute := topology.Compute() + + cmd := testExecutorCommand(t) + cmd.command.Cmd = "Powershell.exe" + cmd.command.Args = []string{"sleep", "30"} + executor := NewExecutor(testlog.HCLogger(t), compute) + + t.Cleanup(func() { executor.Shutdown("SIGKILL", 0) }) + + childPs, err := executor.Launch(cmd.command) + must.NoError(t, err) + must.NonZero(t, childPs.Pid) + + proc, err := os.FindProcess(childPs.Pid) + must.NoError(t, err) + must.NoError(t, proc.Kill()) + + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + t.Cleanup(cancel) + waitPs, err := executor.Wait(ctx) + must.NoError(t, err) + must.Eq(t, 1, waitPs.ExitCode) + must.Eq(t, childPs.Pid, waitPs.Pid) +} diff --git a/drivers/shared/executor/utils_test.go b/drivers/shared/executor/utils_test.go index 24a0598d0..b58a6854e 100644 --- a/drivers/shared/executor/utils_test.go +++ b/drivers/shared/executor/utils_test.go @@ -4,8 +4,13 @@ package executor import ( + "bytes" + "io" + "os" + "sync" "testing" + "github.com/hashicorp/nomad/client/allocdir" "github.com/stretchr/testify/require" ) @@ -29,3 +34,45 @@ func TestUtils_IsolationMode(t *testing.T) { require.Equal(t, tc.exp, result) } } + +type testExecCmd struct { + command *ExecCommand + allocDir *allocdir.AllocDir + + stdout *bytes.Buffer + stderr *bytes.Buffer + outputCopyDone *sync.WaitGroup +} + +// configureTLogging configures a test command executor with buffer as +// Std{out|err} but using os.Pipe so it mimics non-test case where cmd is set +// with files as Std{out|err} the buffers can be used to read command output +func configureTLogging(t *testing.T, testcmd *testExecCmd) { + t.Helper() + var stdout, stderr bytes.Buffer + var copyDone sync.WaitGroup + + stdoutPr, stdoutPw, err := os.Pipe() + require.NoError(t, err) + + stderrPr, stderrPw, err := os.Pipe() + require.NoError(t, err) + + copyDone.Add(2) + go func() { + defer copyDone.Done() + io.Copy(&stdout, stdoutPr) + }() + go func() { + defer copyDone.Done() + io.Copy(&stderr, stderrPr) + }() + + testcmd.stdout = &stdout + testcmd.stderr = &stderr + testcmd.outputCopyDone = ©Done + + testcmd.command.stdout = stdoutPw + testcmd.command.stderr = stderrPw + return +}