From e79ce1f9d00d806045ab0338d70cd9042a4e3ee5 Mon Sep 17 00:00:00 2001 From: Mahmood Ali Date: Sun, 28 Apr 2019 17:18:52 -0400 Subject: [PATCH] drivers/mock: extract command related operations Extract command parsing and execution mocking into a separate struct. Also, allow mocking of different fs_isolation for testing. --- drivers/mock/command.go | 93 +++++++++++++++++++++ drivers/mock/driver.go | 180 +++++++++++++++++++++++++--------------- drivers/mock/handle.go | 83 ++---------------- 3 files changed, 214 insertions(+), 142 deletions(-) create mode 100644 drivers/mock/command.go diff --git a/drivers/mock/command.go b/drivers/mock/command.go new file mode 100644 index 000000000..02abfc519 --- /dev/null +++ b/drivers/mock/command.go @@ -0,0 +1,93 @@ +package mock + +import ( + "errors" + "io" + "sync" + "time" + + hclog "github.com/hashicorp/go-hclog" + bstructs "github.com/hashicorp/nomad/plugins/base/structs" + "github.com/hashicorp/nomad/plugins/drivers" +) + +func runCommand(c Command, stdout, stderr io.WriteCloser, cancelCh <-chan struct{}, pluginExitTimer <-chan time.Time, logger hclog.Logger) *drivers.ExitResult { + errCh := make(chan error, 1) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + runCommandOutput(stdout, c.StdoutString, c.StdoutRepeat, c.stdoutRepeatDuration, cancelCh, logger, errCh) + }() + + wg.Add(1) + go func() { + defer wg.Done() + runCommandOutput(stderr, c.StderrString, c.StderrRepeat, c.stderrRepeatDuration, cancelCh, logger, errCh) + }() + + timer := time.NewTimer(c.runForDuration) + defer timer.Stop() + + select { + case <-timer.C: + logger.Debug("run_for time elapsed; exiting", "run_for", c.RunFor) + case <-cancelCh: + logger.Debug("killed; exiting") + case <-pluginExitTimer: + logger.Debug("exiting plugin") + return &drivers.ExitResult{ + Err: bstructs.ErrPluginShutdown, + } + case err := <-errCh: + logger.Error("error running mock task; exiting", "error", err) + return &drivers.ExitResult{ + Err: err, + } + } + + wg.Wait() + + var exitErr error + if c.ExitErrMsg != "" { + exitErr = errors.New(c.ExitErrMsg) + } + + return &drivers.ExitResult{ + ExitCode: c.ExitCode, + Signal: c.ExitSignal, + Err: exitErr, + } +} + +func runCommandOutput(writer io.WriteCloser, + output string, outputRepeat int, repeatDuration time.Duration, + cancelCh <-chan struct{}, logger hclog.Logger, errCh chan error) { + + defer writer.Close() + + if output == "" { + return + } + + if _, err := io.WriteString(writer, output); err != nil { + logger.Error("failed to write to stdout", "error", err) + errCh <- err + return + } + + for i := 0; i < outputRepeat; i++ { + select { + case <-cancelCh: + logger.Warn("exiting before done writing output", "i", i, "total", outputRepeat) + return + case <-time.After(repeatDuration): + if _, err := io.WriteString(writer, output); err != nil { + logger.Error("failed to write to stdout", "error", err) + errCh <- err + return + } + } + } +} diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 0f63d7302..5dced6c85 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -57,6 +57,10 @@ var ( // configSpec is the hcl specification returned by the ConfigSchema RPC configSpec = hclspec.NewObject(map[string]*hclspec.Spec{ + "fs_isolation": hclspec.NewDefault( + hclspec.NewAttr("fs_isolation", "string", false), + hclspec.NewLiteral(fmt.Sprintf("%q", drivers.FSIsolationNone)), + ), "shutdown_periodic_after": hclspec.NewDefault( hclspec.NewAttr("shutdown_periodic_after", "bool", false), hclspec.NewLiteral("false"), @@ -72,26 +76,22 @@ var ( "start_block_for": hclspec.NewAttr("start_block_for", "string", false), "kill_after": hclspec.NewAttr("kill_after", "string", false), "plugin_exit_after": hclspec.NewAttr("plugin_exit_after", "string", false), - "run_for": hclspec.NewAttr("run_for", "string", false), - "exit_code": hclspec.NewAttr("exit_code", "number", false), - "exit_signal": hclspec.NewAttr("exit_signal", "number", false), - "exit_err_msg": hclspec.NewAttr("exit_err_msg", "string", false), - "signal_error": hclspec.NewAttr("signal_error", "string", false), "driver_ip": hclspec.NewAttr("driver_ip", "string", false), "driver_advertise": hclspec.NewAttr("driver_advertise", "bool", false), "driver_port_map": hclspec.NewAttr("driver_port_map", "string", false), - "stdout_string": hclspec.NewAttr("stdout_string", "string", false), - "stdout_repeat": hclspec.NewAttr("stdout_repeat", "number", false), - "stdout_repeat_duration": hclspec.NewAttr("stdout_repeat_duration", "string", false), - }) - // capabilities is returned by the Capabilities RPC and indicates what - // optional features this driver supports - capabilities = &drivers.Capabilities{ - SendSignals: true, - Exec: true, - FSIsolation: drivers.FSIsolationNone, - } + "run_for": hclspec.NewAttr("run_for", "string", false), + "exit_code": hclspec.NewAttr("exit_code", "number", false), + "exit_signal": hclspec.NewAttr("exit_signal", "number", false), + "exit_err_msg": hclspec.NewAttr("exit_err_msg", "string", false), + "signal_error": hclspec.NewAttr("signal_error", "string", false), + "stdout_string": hclspec.NewAttr("stdout_string", "string", false), + "stdout_repeat": hclspec.NewAttr("stdout_repeat", "number", false), + "stdout_repeat_duration": hclspec.NewAttr("stdout_repeat_duration", "string", false), + "stderr_string": hclspec.NewAttr("stderr_string", "string", false), + "stderr_repeat": hclspec.NewAttr("stderr_repeat", "number", false), + "stderr_repeat_duration": hclspec.NewAttr("stderr_repeat_duration", "string", false), + }) ) // Driver is a mock DriverPlugin implementation @@ -100,6 +100,10 @@ type Driver struct { // event can be broadcast to all callers eventer *eventer.Eventer + // capabilities is returned by the Capabilities RPC and indicates what + // optional features this driver supports + capabilities *drivers.Capabilities + // config is the driver configuration set by the SetConfig RPC config *Config @@ -133,8 +137,16 @@ type Driver struct { func NewMockDriver(logger hclog.Logger) drivers.DriverPlugin { ctx, cancel := context.WithCancel(context.Background()) logger = logger.Named(pluginName) + + capabilities := &drivers.Capabilities{ + SendSignals: true, + Exec: true, + FSIsolation: drivers.FSIsolationNone, + } + return &Driver{ eventer: eventer.NewEventer(ctx, logger), + capabilities: capabilities, config: &Config{}, tasks: newTaskStore(), ctx: ctx, @@ -145,6 +157,8 @@ func NewMockDriver(logger hclog.Logger) drivers.DriverPlugin { // Config is the configuration for the driver that applies to all tasks type Config struct { + FSIsolation string `codec:"fs_isolation"` + // ShutdownPeriodicAfter is a toggle that can be used during tests to // "stop" a previously-functioning driver, allowing for testing of periodic // drivers and fingerprinters @@ -156,8 +170,52 @@ type Config struct { ShutdownPeriodicDuration time.Duration `codec:"shutdown_periodic_duration"` } +type Command struct { + // RunFor is the duration for which the fake task runs for. After this + // period the MockDriver responds to the task running indicating that the + // task has terminated + RunFor string `codec:"run_for"` + runForDuration time.Duration + + // ExitCode is the exit code with which the MockDriver indicates the task + // has exited + ExitCode int `codec:"exit_code"` + + // ExitSignal is the signal with which the MockDriver indicates the task has + // been killed + ExitSignal int `codec:"exit_signal"` + + // ExitErrMsg is the error message that the task returns while exiting + ExitErrMsg string `codec:"exit_err_msg"` + + // SignalErr is the error message that the task returns if signalled + SignalErr string `codec:"signal_error"` + + // StdoutString is the string that should be sent to stdout + StdoutString string `codec:"stdout_string"` + + // StdoutRepeat is the number of times the output should be sent. + StdoutRepeat int `codec:"stdout_repeat"` + + // StdoutRepeatDur is the duration between repeated outputs. + StdoutRepeatDur string `codec:"stdout_repeat_duration"` + stdoutRepeatDuration time.Duration + + // StderrString is the string that should be sent to stderr + StderrString string `codec:"stderr_string"` + + // StderrRepeat is the number of times the errput should be sent. + StderrRepeat int `codec:"stderr_repeat"` + + // StderrRepeatDur is the duration between repeated errputs. + StderrRepeatDur string `codec:"stderr_repeat_duration"` + stderrRepeatDuration time.Duration +} + // TaskConfig is the driver configuration of a task within a job type TaskConfig struct { + Command + // PluginExitAfter is the duration after which the mock driver indicates the // plugin has exited via the WaitTask call. PluginExitAfter string `codec:"plugin_exit_after"` @@ -179,26 +237,6 @@ type TaskConfig struct { KillAfter string `codec:"kill_after"` killAfterDuration time.Duration - // RunFor is the duration for which the fake task runs for. After this - // period the MockDriver responds to the task running indicating that the - // task has terminated - RunFor string `codec:"run_for"` - runForDuration time.Duration - - // ExitCode is the exit code with which the MockDriver indicates the task - // has exited - ExitCode int `codec:"exit_code"` - - // ExitSignal is the signal with which the MockDriver indicates the task has - // been killed - ExitSignal int `codec:"exit_signal"` - - // ExitErrMsg is the error message that the task returns while exiting - ExitErrMsg string `codec:"exit_err_msg"` - - // SignalErr is the error message that the task returns if signalled - SignalErr string `codec:"signal_error"` - // DriverIP will be returned as the DriverNetwork.IP from Start() DriverIP string `codec:"driver_ip"` @@ -209,16 +247,6 @@ type TaskConfig struct { // DriverPortMap will parse a label:number pair and return it in // DriverNetwork.PortMap from Start(). DriverPortMap string `codec:"driver_port_map"` - - // StdoutString is the string that should be sent to stdout - StdoutString string `codec:"stdout_string"` - - // StdoutRepeat is the number of times the output should be sent. - StdoutRepeat int `codec:"stdout_repeat"` - - // StdoutRepeatDur is the duration between repeated outputs. - StdoutRepeatDur string `codec:"stdout_repeat_duration"` - stdoutRepeatDuration time.Duration } type MockTaskState struct { @@ -245,6 +273,12 @@ func (d *Driver) SetConfig(cfg *base.Config) error { if d.config.ShutdownPeriodicAfter { d.shutdownFingerprintTime = time.Now().Add(d.config.ShutdownPeriodicDuration) } + + isolation := config.FSIsolation + if isolation != "" { + d.capabilities.FSIsolation = drivers.FSIsolation(isolation) + } + return nil } @@ -253,7 +287,7 @@ func (d *Driver) TaskConfigSchema() (*hclspec.Spec, error) { } func (d *Driver) Capabilities() (*drivers.Capabilities, error) { - return capabilities, nil + return d.capabilities, nil } func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) { @@ -329,6 +363,23 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { return nil } +func (c *Command) parseDurations() error { + var err error + if c.runForDuration, err = parseDuration(c.RunFor); err != nil { + return fmt.Errorf("run_for %v not a valid duration: %v", c.RunFor, err) + } + + if c.stdoutRepeatDuration, err = parseDuration(c.StdoutRepeatDur); err != nil { + return fmt.Errorf("stdout_repeat_duration %v not a valid duration: %v", c.stdoutRepeatDuration, err) + } + + if c.stderrRepeatDuration, err = parseDuration(c.StderrRepeatDur); err != nil { + return fmt.Errorf("stderr_repeat_duration %v not a valid duration: %v", c.stderrRepeatDuration, err) + } + + return nil +} + func parseDriverConfig(cfg *drivers.TaskConfig) (*TaskConfig, error) { var driverConfig TaskConfig if err := cfg.DecodeDriverConfig(&driverConfig); err != nil { @@ -340,16 +391,18 @@ func parseDriverConfig(cfg *drivers.TaskConfig) (*TaskConfig, error) { return nil, fmt.Errorf("start_block_for %v not a valid duration: %v", driverConfig.StartBlockFor, err) } - if driverConfig.runForDuration, err = parseDuration(driverConfig.RunFor); err != nil { - return nil, fmt.Errorf("run_for %v not a valid duration: %v", driverConfig.RunFor, err) - } - if driverConfig.pluginExitAfterDuration, err = parseDuration(driverConfig.PluginExitAfter); err != nil { return nil, fmt.Errorf("plugin_exit_after %v not a valid duration: %v", driverConfig.PluginExitAfter, err) } - if driverConfig.stdoutRepeatDuration, err = parseDuration(driverConfig.StdoutRepeatDur); err != nil { - return nil, fmt.Errorf("stdout_repeat_duration %v not a valid duration: %v", driverConfig.stdoutRepeatDuration, err) + if err = driverConfig.parseDurations(); err != nil { + return nil, err + } + + if driverConfig.ExecCommand != nil { + if err = driverConfig.ExecCommand.parseDurations(); err != nil { + return nil, err + } } return &driverConfig, nil @@ -359,26 +412,15 @@ func newTaskHandle(cfg *drivers.TaskConfig, driverConfig *TaskConfig, logger hcl killCtx, killCancel := context.WithCancel(context.Background()) h := &taskHandle{ taskConfig: cfg, - runFor: driverConfig.runForDuration, + command: driverConfig.Command, pluginExitAfter: driverConfig.pluginExitAfterDuration, killAfter: driverConfig.killAfterDuration, - exitCode: driverConfig.ExitCode, - exitSignal: driverConfig.ExitSignal, - stdoutString: driverConfig.StdoutString, - stdoutRepeat: driverConfig.StdoutRepeat, - stdoutRepeatDur: driverConfig.stdoutRepeatDuration, logger: logger.With("task_name", cfg.Name), - waitCh: make(chan struct{}), + waitCh: make(chan interface{}), killCh: killCtx.Done(), kill: killCancel, startedAt: time.Now(), } - if driverConfig.ExitErrMsg != "" { - h.exitErr = errors.New(driverConfig.ExitErrMsg) - } - if driverConfig.SignalErr != "" { - h.signalErr = fmt.Errorf(driverConfig.SignalErr) - } return h } @@ -541,7 +583,11 @@ func (d *Driver) SignalTask(taskID string, signal string) error { return drivers.ErrTaskNotFound } - return h.signalErr + if h.command.SignalErr == "" { + return nil + } + + return errors.New(h.command.SignalErr) } func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (*drivers.ExecTaskResult, error) { diff --git a/drivers/mock/handle.go b/drivers/mock/handle.go index 7940a6283..69b2f56dc 100644 --- a/drivers/mock/handle.go +++ b/drivers/mock/handle.go @@ -2,13 +2,11 @@ package mock import ( "context" - "io" "sync" "time" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/lib/fifo" - bstructs "github.com/hashicorp/nomad/plugins/base/structs" "github.com/hashicorp/nomad/plugins/drivers" ) @@ -16,19 +14,13 @@ import ( type taskHandle struct { logger hclog.Logger - runFor time.Duration pluginExitAfter time.Duration killAfter time.Duration - waitCh chan struct{} - exitCode int - exitSignal int - exitErr error - signalErr error - stdoutString string - stdoutRepeat int - stdoutRepeatDur time.Duration + waitCh chan interface{} - taskConfig *drivers.TaskConfig + taskConfig *drivers.TaskConfig + command Command + execCommand *Command // stateLock guards the procState field stateLock sync.RWMutex @@ -81,14 +73,6 @@ func (h *taskHandle) run() { h.procState = drivers.TaskStateRunning h.stateLock.Unlock() - errCh := make(chan error, 1) - - // Setup logging output - go h.handleLogging(errCh) - - timer := time.NewTimer(h.runFor) - defer timer.Stop() - var pluginExitTimer <-chan time.Time if h.pluginExitAfter != 0 { timer := time.NewTimer(h.pluginExitAfter) @@ -96,70 +80,19 @@ func (h *taskHandle) run() { pluginExitTimer = timer.C } - select { - case <-timer.C: - h.logger.Debug("run_for time elapsed; exiting", "run_for", h.runFor) - case <-h.killCh: - h.logger.Debug("killed; exiting") - case <-pluginExitTimer: - h.logger.Debug("exiting plugin") - h.exitResult = &drivers.ExitResult{ - Err: bstructs.ErrPluginShutdown, - } - - return - case err := <-errCh: - h.logger.Error("error running mock task; exiting", "error", err) - h.exitResult = &drivers.ExitResult{ - Err: err, - } - return - } - - h.exitResult = &drivers.ExitResult{ - ExitCode: h.exitCode, - Signal: h.exitSignal, - Err: h.exitErr, - } - return -} - -func (h *taskHandle) handleLogging(errCh chan<- error) { stdout, err := fifo.OpenWriter(h.taskConfig.StdoutPath) if err != nil { h.logger.Error("failed to write to stdout", "error", err) - errCh <- err + h.exitResult = &drivers.ExitResult{Err: err} return } stderr, err := fifo.OpenWriter(h.taskConfig.StderrPath) if err != nil { h.logger.Error("failed to write to stderr", "error", err) - errCh <- err - return - } - defer stderr.Close() - - if h.stdoutString == "" { + h.exitResult = &drivers.ExitResult{Err: err} return } - if _, err := io.WriteString(stdout, h.stdoutString); err != nil { - h.logger.Error("failed to write to stdout", "error", err) - errCh <- err - return - } - - for i := 0; i < h.stdoutRepeat; i++ { - select { - case <-h.waitCh: - h.logger.Warn("exiting before done writing output", "i", i, "total", h.stdoutRepeat) - return - case <-time.After(h.stdoutRepeatDur): - if _, err := io.WriteString(stdout, h.stdoutString); err != nil { - h.logger.Error("failed to write to stdout", "error", err) - errCh <- err - return - } - } - } + h.exitResult = runCommand(h.command, stdout, stderr, h.killCh, pluginExitTimer, h.logger) + return }