Handle time.Duration in mock

Mock driver config uses `time.Duration` fields but we initialize them
inconsistently, as time.Duration sometimes and as duration strings other
times.  Previously, `mapstructure` handles it and does the right thing.

This is no longer the case with MsgPack.  I could not find a good way to
bring back old behavior without too much complexity.  `MsgPack` extended
types weren't ideal here as we lose type information (e.g. int64 vs
string), and the input is a generic map and not a MsgPack serialization
of duration.

As such, I went with the simple solution of declaring the config field
as duration string, and panicing if the test doesn't pass a valid
string.

I found this to cause the smallest change in tests, but we can
alternatively force all to be int64 instead.
This commit is contained in:
Mahmood Ali
2018-11-04 21:22:29 -08:00
parent cc52606a07
commit 416b5240f4
2 changed files with 36 additions and 13 deletions

View File

@@ -64,9 +64,9 @@ var (
taskConfigSpec = hclspec.NewObject(map[string]*hclspec.Spec{
"start_error": hclspec.NewAttr("start_error", "string", false),
"start_error_recoverable": hclspec.NewAttr("start_error_recoverable", "bool", false),
"start_block_for": hclspec.NewAttr("start_block_for", "number", false),
"kill_after": hclspec.NewAttr("kill_after", "number", false),
"run_for": hclspec.NewAttr("run_for", "number", false),
"start_block_for": hclspec.NewAttr("start_block_for", "string", false),
"kill_after": hclspec.NewAttr("kill_after", "string", false),
"run_for": hclspec.NewAttr("run_for", "string", false),
"exit_code": hclspec.NewAttr("exit_code", "number", false),
"exit_signal": hclspec.NewAttr("exit_signal", "number", false),
"exit_err_msg": hclspec.NewAttr("exit_err_msg", "string", false),
@@ -76,7 +76,7 @@ var (
"driver_port_map": hclspec.NewAttr("driver_port_map", "string", false),
"stdout_string": hclspec.NewAttr("stdout_string", "string", false),
"stdout_repeat": hclspec.NewAttr("stdout_repeat", "number", false),
"stdout_repeat_duration": hclspec.NewAttr("stdout_repeat_duration", "number", false),
"stdout_repeat_duration": hclspec.NewAttr("stdout_repeat_duration", "string", false),
})
// capabilities is returned by the Capabilities RPC and indicates what
@@ -152,16 +152,16 @@ type TaskConfig struct {
StartErrRecoverable bool `codec:"start_error_recoverable"`
// StartBlockFor specifies a duration in which to block before returning
StartBlockFor time.Duration `codec:"start_block_for"`
StartBlockFor string `codec:"start_block_for"`
// KillAfter is the duration after which the mock driver indicates the task
// has exited after getting the initial SIGINT signal
KillAfter time.Duration `codec:"kill_after"`
KillAfter string `codec:"kill_after"`
// RunFor is the duration for which the fake task runs for. After this
// period the MockDriver responds to the task running indicating that the
// task has terminated
RunFor time.Duration `codec:"run_for"`
RunFor string `codec:"run_for"`
// ExitCode is the exit code with which the MockDriver indicates the task
// has exited
@@ -195,7 +195,7 @@ type TaskConfig struct {
StdoutRepeat int `codec:"stdout_repeat"`
// StdoutRepeatDur is the duration between repeated outputs.
StdoutRepeatDur time.Duration `codec:"stdout_repeat_duration"`
StdoutRepeatDur string `codec:"stdout_repeat_duration"`
}
type MockTaskState struct {
@@ -298,8 +298,8 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
return nil, nil, err
}
if driverConfig.StartBlockFor != 0 {
time.Sleep(driverConfig.StartBlockFor)
if driverConfig.StartBlockFor != "" {
time.Sleep(parseDuration(driverConfig.StartBlockFor))
}
if driverConfig.StartErr != "" {
@@ -327,13 +327,13 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
h := &taskHandle{
taskConfig: cfg,
runFor: driverConfig.RunFor,
killAfter: driverConfig.KillAfter,
runFor: parseDuration(driverConfig.RunFor),
killAfter: parseDuration(driverConfig.KillAfter),
exitCode: driverConfig.ExitCode,
exitSignal: driverConfig.ExitSignal,
stdoutString: driverConfig.StdoutString,
stdoutRepeat: driverConfig.StdoutRepeat,
stdoutRepeatDur: driverConfig.StdoutRepeatDur,
stdoutRepeatDur: parseDuration(driverConfig.StdoutRepeatDur),
logger: d.logger.With("task_name", cfg.Name),
waitCh: make(chan struct{}),
killCh: killCtx.Done(),
@@ -455,3 +455,21 @@ func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (*
}
return &res, nil
}
func parseDuration(s string) time.Duration {
if s == "" {
return time.Duration(0)
}
// check if it's an int64
if v, err := strconv.ParseInt(s, 10, 64); err == nil {
return time.Duration(v)
}
// try to parse it as duration
if v, err := time.ParseDuration(s); err == nil {
return v
}
panic(fmt.Errorf("value is not a duration: %v", s))
}

View File

@@ -114,6 +114,11 @@ func (h *taskHandle) handleLogging(errCh chan<- error) {
errCh <- err
return
}
if _, err := io.WriteString(stdout, h.stdoutString); err != nil {
h.logger.Error("failed to write to stdout", "error", err)
errCh <- err
return
}
for i := 0; i < h.stdoutRepeat; i++ {
select {