diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index 391e44fda..eb4715293 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -530,3 +530,22 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (* }, }, nil } + +var _ drivers.ExecTaskStreamingRawDriver = (*Driver)(nil) + +func (d *Driver) ExecTaskStreamingRaw(ctx context.Context, + taskID string, + command []string, + tty bool, + stream drivers.ExecTaskStream) error { + + if len(command) == 0 { + return fmt.Errorf("error cmd must have atleast one value") + } + handle, ok := d.tasks.Get(taskID) + if !ok { + return drivers.ErrTaskNotFound + } + + return handle.exec.ExecStreaming(ctx, command, tty, stream) +} diff --git a/drivers/exec/driver_unix_test.go b/drivers/exec/driver_unix_test.go index 3d2844ba7..342993a8f 100644 --- a/drivers/exec/driver_unix_test.go +++ b/drivers/exec/driver_unix_test.go @@ -77,3 +77,33 @@ func TestExecDriver_StartWaitStop(t *testing.T) { require.NoError(harness.DestroyTask(task.ID, true)) } + +func TestExec_ExecTaskStreaming(t *testing.T) { + t.Parallel() + require := require.New(t) + + d := NewExecDriver(testlog.HCLogger(t)) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + task := &drivers.TaskConfig{ + ID: uuid.Generate(), + Name: "sleep", + } + + cleanup := harness.MkAllocDir(task, false) + defer cleanup() + + tc := &TaskConfig{ + Command: "/bin/sleep", + Args: []string{"9000"}, + } + require.NoError(task.EncodeConcreteDriverConfig(&tc)) + + _, _, err := harness.StartTask(task) + require.NoError(err) + defer d.DestroyTask(task.ID, true) + + dtestutil.ExecTaskStreamingConformanceTests(t, harness, task.ID) + +} diff --git a/drivers/java/driver.go b/drivers/java/driver.go index 3a4a808d9..554875e9d 100644 --- a/drivers/java/driver.go +++ b/drivers/java/driver.go @@ -554,6 +554,25 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (* }, nil } +var _ drivers.ExecTaskStreamingRawDriver = (*Driver)(nil) + +func (d *Driver) ExecTaskStreamingRaw(ctx context.Context, + taskID string, + command []string, + tty bool, + stream drivers.ExecTaskStream) error { + + if len(command) == 0 { + return fmt.Errorf("error cmd must have atleast one value") + } + handle, ok := d.tasks.Get(taskID) + if !ok { + return drivers.ErrTaskNotFound + } + + return handle.exec.ExecStreaming(ctx, command, tty, stream) +} + // GetAbsolutePath returns the absolute path of the passed binary by resolving // it in the path and following symlinks. func GetAbsolutePath(bin string) (string, error) { diff --git a/drivers/java/driver_test.go b/drivers/java/driver_test.go index 4d431e29f..b4b3d0010 100644 --- a/drivers/java/driver_test.go +++ b/drivers/java/driver_test.go @@ -243,6 +243,35 @@ func TestJavaCmdArgs(t *testing.T) { } } +func TestJavaDriver_ExecTaskStreaming(t *testing.T) { + javaCompatible(t) + if !testutil.IsCI() { + t.Parallel() + } + + require := require.New(t) + d := NewDriver(testlog.HCLogger(t)) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + tc := &TaskConfig{ + Class: "Hello", + Args: []string{"900"}, + } + task := basicTask(t, "demo-app", tc) + + cleanup := harness.MkAllocDir(task, true) + defer cleanup() + + copyFile("./test-resources/Hello.class", filepath.Join(task.TaskDir().Dir, "Hello.class"), t) + + _, _, err := harness.StartTask(task) + require.NoError(err) + defer d.DestroyTask(task.ID, true) + + dtestutil.ExecTaskStreamingConformanceTests(t, harness, task.ID) + +} func basicTask(t *testing.T, name string, taskConfig *TaskConfig) *drivers.TaskConfig { t.Helper() diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index bde086963..a79164c85 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -521,3 +521,22 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (* }, }, nil } + +var _ drivers.ExecTaskStreamingRawDriver = (*Driver)(nil) + +func (d *Driver) ExecTaskStreamingRaw(ctx context.Context, + taskID string, + command []string, + tty bool, + stream drivers.ExecTaskStream) error { + + if len(command) == 0 { + return fmt.Errorf("error cmd must have at least one value") + } + handle, ok := d.tasks.Get(taskID) + if !ok { + return drivers.ErrTaskNotFound + } + + return handle.exec.ExecStreaming(ctx, command, tty, stream) +} diff --git a/drivers/rawexec/driver_unix_test.go b/drivers/rawexec/driver_unix_test.go index 07491574f..d921e53b3 100644 --- a/drivers/rawexec/driver_unix_test.go +++ b/drivers/rawexec/driver_unix_test.go @@ -196,3 +196,37 @@ func TestRawExecDriver_StartWaitStop(t *testing.T) { require.NoError(harness.DestroyTask(task.ID, true)) } + +func TestRawExec_ExecTaskStreaming(t *testing.T) { + t.Parallel() + if runtime.GOOS == "darwin" { + t.Skip("skip running exec tasks on darwin as darwin has restrictions on starting tty shells") + } + require := require.New(t) + + d := NewRawExecDriver(testlog.HCLogger(t)) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + task := &drivers.TaskConfig{ + ID: uuid.Generate(), + Name: "sleep", + } + + cleanup := harness.MkAllocDir(task, false) + defer cleanup() + + tc := &TaskConfig{ + Command: testtask.Path(), + Args: []string{"sleep", "9000s"}, + } + require.NoError(task.EncodeConcreteDriverConfig(&tc)) + testtask.SetTaskConfigEnv(task) + + _, _, err := harness.StartTask(task) + require.NoError(err) + defer d.DestroyTask(task.ID, true) + + dtestutil.ExecTaskStreamingConformanceTests(t, harness, task.ID) + +}