From 089bce5ab4ac482ba5f032d8d44ef0333e4fbc33 Mon Sep 17 00:00:00 2001 From: Michael Schurter Date: Fri, 12 Oct 2018 16:56:13 -0700 Subject: [PATCH] drivers/mock: complete plugin impl --- drivers/mock/driver.go | 99 ++++++++++++++++++++++++++++++++++++------ drivers/mock/handle.go | 61 ++++++++++++++------------ 2 files changed, 119 insertions(+), 41 deletions(-) diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 141049865..83baef225 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -15,6 +15,8 @@ import ( "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/loader" + netctx "golang.org/x/net/context" ) const ( @@ -22,10 +24,22 @@ const ( pluginName = "mock" // fingerprintPeriod is the interval at which the driver will send fingerprint responses - fingerprintPeriod = 30 * time.Second + fingerprintPeriod = 500 * time.Millisecond ) var ( + // When the package is loaded the driver is registered as an internal plugin + // with the plugin catalog + PluginID = loader.PluginID{ + Name: pluginName, + PluginType: base.PluginTypeDriver, + } + + PluginConfig = &loader.InternalPluginConfig{ + Config: map[string]interface{}{}, + Factory: func(l hclog.Logger) interface{} { return NewMockDriver(l) }, + } + // pluginInfo is the response returned for the PluginInfo RPC pluginInfo = &base.PluginInfoResponse{ Type: base.PluginTypeDriver, @@ -66,7 +80,7 @@ var ( // capabilities is returned by the Capabilities RPC and indicates what // optional features this driver supports capabilities = &drivers.Capabilities{ - SendSignals: true, + SendSignals: false, Exec: true, FSIsolation: cstructs.FSIsolationNone, } @@ -99,6 +113,20 @@ type Driver struct { logger hclog.Logger } +// NewMockDriver returns a new DriverPlugin implementation +func NewMockDriver(logger hclog.Logger) drivers.DriverPlugin { + ctx, cancel := context.WithCancel(context.Background()) + logger = logger.Named(pluginName) + return &Driver{ + eventer: eventer.NewEventer(ctx, logger), + config: &Config{}, + tasks: newTaskStore(), + ctx: ctx, + signalShutdown: cancel, + logger: logger, + } +} + // Config is the configuration for the driver that applies to all tasks type Config struct { // ShutdownPeriodicAfter is a toggle that can be used during tests to @@ -203,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) { return capabilities, nil } -func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) { +func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) { ch := make(chan *drivers.Fingerprint) go d.handleFingerprint(ctx, ch) return ch, nil @@ -245,7 +273,8 @@ func (d *Driver) buildFingerprint() *drivers.Fingerprint { } func (d *Driver) RecoverTask(*drivers.TaskHandle) error { - panic("not implemented") + //TODO is there anything to do here? + return nil } func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstructs.DriverNetwork, error) { @@ -279,6 +308,8 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru net.PortMap = map[string]int{parts[0]: port} } + killCtx, killCancel := context.WithCancel(context.Background()) + h := &mockTaskHandle{ task: cfg, runFor: driverConfig.RunFor, @@ -288,8 +319,10 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru stdoutString: driverConfig.StdoutString, stdoutRepeat: driverConfig.StdoutRepeat, stdoutRepeatDur: driverConfig.StdoutRepeatDur, - logger: d.logger, - doneCh: make(chan struct{}), + logger: d.logger.With("task_name", cfg.Name), + waitCh: make(chan struct{}), + killCh: killCtx.Done(), + kill: killCancel, } if driverConfig.ExitErrMsg != "" { h.exitErr = errors.New(driverConfig.ExitErrMsg) @@ -317,7 +350,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru } -func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) { +func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) { handle, ok := d.tasks.Get(taskID) if !ok { return nil, drivers.ErrTaskNotFound @@ -337,16 +370,39 @@ func (d *Driver) handleWait(ctx context.Context, handle *mockTaskHandle, ch chan return case <-d.ctx.Done(): return - case <-handle.doneCh: + case <-handle.waitCh: ch <- handle.exitResult } } func (d *Driver) StopTask(taskID string, timeout time.Duration, signal string) error { - panic("not implemented") + h, ok := d.tasks.Get(taskID) + if !ok { + return drivers.ErrTaskNotFound + } + + d.logger.Debug("killing task", + "task_name", h.task.Name, + "kill_after", h.killAfter, + "kill_timeout", h.killTimeout, + ) + + select { + case <-h.waitCh: + d.logger.Debug("not killing task: already exited", "task_name", h.task.Name) + case <-time.After(h.killAfter): + d.logger.Debug("killing task due to kill_after", "task_name", h.task.Name) + h.kill() + case <-time.After(h.killTimeout): + d.logger.Debug("killing task after kill_timeout", "task_name", h.task.Name) + h.kill() + } + return nil } func (d *Driver) DestroyTask(taskID string, force bool) error { - panic("not implemented") + //TODO is there anything else to do here? + d.tasks.Delete(taskID) + return nil } func (d *Driver) InspectTask(taskID string) (*drivers.TaskStatus, error) { @@ -354,17 +410,32 @@ func (d *Driver) InspectTask(taskID string) (*drivers.TaskStatus, error) { } func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) { - panic("not implemented") + //TODO return an error? + return nil, nil } -func (d *Driver) TaskEvents(context.Context) (<-chan *drivers.TaskEvent, error) { +func (d *Driver) TaskEvents(netctx.Context) (<-chan *drivers.TaskEvent, error) { panic("not implemented") } func (d *Driver) SignalTask(taskID string, signal string) error { - panic("not implemented") + h, ok := d.tasks.Get(taskID) + if !ok { + return drivers.ErrTaskNotFound + } + + return h.signalErr } func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (*drivers.ExecTaskResult, error) { - panic("not implemented") + h, ok := d.tasks.Get(taskID) + if !ok { + return nil, drivers.ErrTaskNotFound + } + + res := drivers.ExecTaskResult{ + Stdout: []byte(fmt.Sprintf("Exec(%q, %q)", h.task.Name, cmd)), + ExitResult: &drivers.ExitResult{}, + } + return &res, nil } diff --git a/drivers/mock/handle.go b/drivers/mock/handle.go index 2296610a1..214ecba1d 100644 --- a/drivers/mock/handle.go +++ b/drivers/mock/handle.go @@ -1,6 +1,7 @@ package mock import ( + "context" "io" "time" @@ -16,6 +17,7 @@ type mockTaskHandle struct { runFor time.Duration killAfter time.Duration killTimeout time.Duration + waitCh chan struct{} exitCode int exitSignal int exitErr error @@ -24,63 +26,68 @@ type mockTaskHandle struct { stdoutRepeat int stdoutRepeatDur time.Duration - doneCh chan struct{} - task *drivers.TaskConfig procState drivers.TaskState startedAt time.Time completedAt time.Time exitResult *drivers.ExitResult + + // Calling kill closes killCh if it is not already closed + kill context.CancelFunc + killCh <-chan struct{} } func (h *mockTaskHandle) run() { + defer close(h.waitCh) + + errCh := make(chan error, 1) // Setup logging output if h.stdoutString != "" { - go h.handleLogging() + go h.handleLogging(errCh) } timer := time.NewTimer(h.runFor) defer timer.Stop() - for { - select { - case <-timer.C: - select { - case <-h.doneCh: - // already closed - default: - close(h.doneCh) - } - case <-h.doneCh: - h.logger.Debug("finished running task", "name", h.task.Name) - h.exitResult = &drivers.ExitResult{ - ExitCode: h.exitCode, - Signal: h.exitSignal, - Err: h.exitErr, - } - return + + select { + case <-timer.C: + h.logger.Debug("run_for time elapsed; exiting", "run_for", h.runFor) + case <-h.killCh: + h.logger.Debug("killed; exiting") + case err := <-errCh: + h.logger.Error("error running mock task; exiting", "error", err) + h.exitResult = &drivers.ExitResult{ + Err: err, } + return } + + h.exitResult = &drivers.ExitResult{ + ExitCode: h.exitCode, + Signal: h.exitSignal, + Err: h.exitErr, + } + return } -func (h *mockTaskHandle) handleLogging() { +func (h *mockTaskHandle) handleLogging(errCh chan<- error) { stdout, err := fifo.Open(h.task.StdoutPath) if err != nil { - h.exitErr = err - close(h.doneCh) - h.logger.Error("failed to write to stdout: %v", err) + h.logger.Error("failed to write to stdout", "error", err) + errCh <- err return } for i := 0; i < h.stdoutRepeat; i++ { select { - case <-h.doneCh: + case <-h.waitCh: + h.logger.Warn("exiting before done writing output", "i", i, "total", h.stdoutRepeat) return case <-time.After(h.stdoutRepeatDur): if _, err := io.WriteString(stdout, h.stdoutString); err != nil { - h.exitErr = err - close(h.doneCh) h.logger.Error("failed to write to stdout", "error", err) + errCh <- err return } }