diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 734ae61f8..68bfb0335 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -64,9 +64,9 @@ var ( taskConfigSpec = hclspec.NewObject(map[string]*hclspec.Spec{ "start_error": hclspec.NewAttr("start_error", "string", false), "start_error_recoverable": hclspec.NewAttr("start_error_recoverable", "bool", false), - "start_block_for": hclspec.NewAttr("start_block_for", "number", false), - "kill_after": hclspec.NewAttr("kill_after", "number", false), - "run_for": hclspec.NewAttr("run_for", "number", false), + "start_block_for": hclspec.NewAttr("start_block_for", "string", false), + "kill_after": hclspec.NewAttr("kill_after", "string", false), + "run_for": hclspec.NewAttr("run_for", "string", false), "exit_code": hclspec.NewAttr("exit_code", "number", false), "exit_signal": hclspec.NewAttr("exit_signal", "number", false), "exit_err_msg": hclspec.NewAttr("exit_err_msg", "string", false), @@ -76,7 +76,7 @@ var ( "driver_port_map": hclspec.NewAttr("driver_port_map", "string", false), "stdout_string": hclspec.NewAttr("stdout_string", "string", false), "stdout_repeat": hclspec.NewAttr("stdout_repeat", "number", false), - "stdout_repeat_duration": hclspec.NewAttr("stdout_repeat_duration", "number", false), + "stdout_repeat_duration": hclspec.NewAttr("stdout_repeat_duration", "string", false), }) // capabilities is returned by the Capabilities RPC and indicates what @@ -152,16 +152,16 @@ type TaskConfig struct { StartErrRecoverable bool `codec:"start_error_recoverable"` // StartBlockFor specifies a duration in which to block before returning - StartBlockFor time.Duration `codec:"start_block_for"` + StartBlockFor string `codec:"start_block_for"` // KillAfter is the duration after which the mock driver indicates the task // has exited after getting the initial SIGINT signal - KillAfter time.Duration `codec:"kill_after"` + KillAfter string `codec:"kill_after"` // RunFor is the duration for which the fake task runs for. After this // period the MockDriver responds to the task running indicating that the // task has terminated - RunFor time.Duration `codec:"run_for"` + RunFor string `codec:"run_for"` // ExitCode is the exit code with which the MockDriver indicates the task // has exited @@ -195,7 +195,7 @@ type TaskConfig struct { StdoutRepeat int `codec:"stdout_repeat"` // StdoutRepeatDur is the duration between repeated outputs. - StdoutRepeatDur time.Duration `codec:"stdout_repeat_duration"` + StdoutRepeatDur string `codec:"stdout_repeat_duration"` } type MockTaskState struct { @@ -298,8 +298,8 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru return nil, nil, err } - if driverConfig.StartBlockFor != 0 { - time.Sleep(driverConfig.StartBlockFor) + if driverConfig.StartBlockFor != "" { + time.Sleep(parseDuration(driverConfig.StartBlockFor)) } if driverConfig.StartErr != "" { @@ -327,13 +327,13 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru h := &taskHandle{ taskConfig: cfg, - runFor: driverConfig.RunFor, - killAfter: driverConfig.KillAfter, + runFor: parseDuration(driverConfig.RunFor), + killAfter: parseDuration(driverConfig.KillAfter), exitCode: driverConfig.ExitCode, exitSignal: driverConfig.ExitSignal, stdoutString: driverConfig.StdoutString, stdoutRepeat: driverConfig.StdoutRepeat, - stdoutRepeatDur: driverConfig.StdoutRepeatDur, + stdoutRepeatDur: parseDuration(driverConfig.StdoutRepeatDur), logger: d.logger.With("task_name", cfg.Name), waitCh: make(chan struct{}), killCh: killCtx.Done(), @@ -455,3 +455,21 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (* } return &res, nil } + +func parseDuration(s string) time.Duration { + if s == "" { + return time.Duration(0) + } + + // check if it's an int64 + if v, err := strconv.ParseInt(s, 10, 64); err == nil { + return time.Duration(v) + } + + // try to parse it as duration + if v, err := time.ParseDuration(s); err == nil { + return v + } + + panic(fmt.Errorf("value is not a duration: %v", s)) +} diff --git a/drivers/mock/handle.go b/drivers/mock/handle.go index 67297c8f0..44ad4932e 100644 --- a/drivers/mock/handle.go +++ b/drivers/mock/handle.go @@ -114,6 +114,11 @@ func (h *taskHandle) handleLogging(errCh chan<- error) { errCh <- err return } + if _, err := io.WriteString(stdout, h.stdoutString); err != nil { + h.logger.Error("failed to write to stdout", "error", err) + errCh <- err + return + } for i := 0; i < h.stdoutRepeat; i++ { select {