diff --git a/drivers/exec/driver_test.go b/drivers/exec/driver_test.go index 59a3ac606..02061ef2c 100644 --- a/drivers/exec/driver_test.go +++ b/drivers/exec/driver_test.go @@ -910,10 +910,7 @@ func TestExecDriver_OOMKilled(t *testing.T) { ci.Parallel(t) ctestutils.ExecCompatible(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - d := newExecDriverTest(t, ctx) + d := newExecDriverTest(t, t.Context()) harness := dtestutil.NewDriverHarness(t, d) allocID := uuid.Generate() name := "oom-killed" @@ -923,12 +920,11 @@ func TestExecDriver_OOMKilled(t *testing.T) { Name: name, Resources: testResources(allocID, name), } - task.Resources.LinuxResources.MemoryLimitBytes = 10 * 1024 * 1024 - task.Resources.NomadResources.Memory.MemoryMB = 10 tc := &TaskConfig{ Command: "/bin/tail", Args: []string{"/dev/zero"}, + ModePID: "private", } must.NoError(t, task.EncodeConcreteDriverConfig(&tc)) @@ -938,7 +934,7 @@ func TestExecDriver_OOMKilled(t *testing.T) { handle, _, err := harness.StartTask(task) must.NoError(t, err) - ch, err := harness.WaitTask(context.Background(), handle.Config.ID) + ch, err := harness.WaitTask(t.Context(), handle.Config.ID) must.NoError(t, err) result := <-ch must.False(t, result.Successful(), must.Sprint("container should OOM")) diff --git a/drivers/shared/executor/executor_linux.go b/drivers/shared/executor/executor_linux.go index 1e90aa8cf..662d51571 100644 --- a/drivers/shared/executor/executor_linux.go +++ b/drivers/shared/executor/executor_linux.go @@ -276,19 +276,17 @@ func (l *LibcontainerExecutor) wait() { // Best effort detection of OOMs. It's possible for us to miss OOM notifications in // the event that the wait returns before we read from the OOM notification channel var oomKilled atomic.Bool - go func() { - oomCh, err := l.container.NotifyOOM() - if err != nil { - l.logger.Error("failed to get OOM notification channel for container(%s): %v", l.id, err) - return - } - - for range oomCh { - oomKilled.Store(true) - // We can terminate this goroutine as soon as we've seen the first OOM - return - } - }() + oomCh, err := l.container.NotifyOOM() + if err != nil { + l.logger.Error("failed to get OOM notification channel for container(%s): %v", l.id, err) + } else { + go func() { + for range oomCh { + oomKilled.Store(true) + return // Exit goroutine on first OOM + } + }() + } ps, err := l.userProc.Wait() if err != nil { diff --git a/drivers/shared/executor/executor_linux_test.go b/drivers/shared/executor/executor_linux_test.go index a203975e2..4899894c8 100644 --- a/drivers/shared/executor/executor_linux_test.go +++ b/drivers/shared/executor/executor_linux_test.go @@ -287,8 +287,6 @@ func TestExecutor_OOMKilled(t *testing.T) { execCmd.ResourceLimits = true execCmd.ModePID = "private" execCmd.ModeIPC = "private" - execCmd.Resources.LinuxResources.MemoryLimitBytes = 10 * 1024 * 1024 - execCmd.Resources.NomadResources.Memory.MemoryMB = 10 executor := NewExecutorWithIsolation(testlog.HCLogger(t), compute) defer executor.Shutdown("SIGKILL", 0)