diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index cf1aa6b94..5d9f09096 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -448,7 +448,7 @@ func (r *RawExecDriver) SignalTask(taskID string, signal string) error { sig := os.Interrupt if s, ok := signals.SignalLookup[signal]; ok { - r.logger.Warn("signal to send to task unknown, using SIGINT", "signal", signal) + r.logger.Warn("signal to send to task unknown, using SIGINT", "signal", signal, "task_id", handle.task.ID) sig = s } return handle.exec.Signal(sig) diff --git a/drivers/rawexec/handle.go b/drivers/rawexec/handle.go new file mode 100644 index 000000000..09fe04e81 --- /dev/null +++ b/drivers/rawexec/handle.go @@ -0,0 +1,58 @@ +package rawexec + +import ( + "sync" + "time" + + hclog "github.com/hashicorp/go-hclog" + plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/nomad/client/driver/executor" + "github.com/hashicorp/nomad/plugins/drivers" +) + +type rawExecTaskHandle struct { + exec executor.Executor + pid int + pluginClient *plugin.Client + logger hclog.Logger + + // stateLock syncs access to all fields below + stateLock sync.RWMutex + + task *drivers.TaskConfig + procState drivers.TaskState + startedAt time.Time + completedAt time.Time + exitResult *drivers.ExitResult +} + +func (h *rawExecTaskHandle) IsRunning() bool { + return h.procState == drivers.TaskStateRunning +} + +func (h *rawExecTaskHandle) run() { + + // since run is called immediatly after the handle is created this + // ensures the exitResult is initialized so we avoid a nil pointer + // thus it does not need to be included in the lock + if h.exitResult == nil { + h.exitResult = &drivers.ExitResult{} + } + + ps, err := h.exec.Wait() + h.stateLock.Lock() + defer h.stateLock.Unlock() + + if err != nil { + h.exitResult.Err = err + h.procState = drivers.TaskStateUnknown + h.completedAt = time.Now() + return + } + h.procState = drivers.TaskStateExited + h.exitResult.ExitCode = ps.ExitCode + h.exitResult.Signal = ps.Signal + h.completedAt = ps.Time + + // TODO: detect if the task OOMed +} diff --git a/drivers/rawexec/state.go b/drivers/rawexec/state.go index ecaefb87b..3a8e28e79 100644 --- a/drivers/rawexec/state.go +++ b/drivers/rawexec/state.go @@ -2,12 +2,6 @@ package rawexec import ( "sync" - "time" - - hclog "github.com/hashicorp/go-hclog" - plugin "github.com/hashicorp/go-plugin" - "github.com/hashicorp/nomad/client/driver/executor" - "github.com/hashicorp/nomad/plugins/drivers" ) type taskStore struct { @@ -37,50 +31,3 @@ func (ts *taskStore) Delete(id string) { defer ts.lock.Unlock() delete(ts.store, id) } - -type rawExecTaskHandle struct { - exec executor.Executor - pid int - pluginClient *plugin.Client - logger hclog.Logger - - // stateLock syncs access to all fields below - stateLock sync.RWMutex - - task *drivers.TaskConfig - procState drivers.TaskState - startedAt time.Time - completedAt time.Time - exitResult *drivers.ExitResult -} - -func (h *rawExecTaskHandle) IsRunning() bool { - return h.procState == drivers.TaskStateRunning -} - -func (h *rawExecTaskHandle) run() { - - // since run is called immediatly after the handle is created this - // ensures the exitResult is initialized so we avoid a nil pointer - // thus it does not need to be included in the lock - if h.exitResult == nil { - h.exitResult = &drivers.ExitResult{} - } - - ps, err := h.exec.Wait() - h.stateLock.Lock() - defer h.stateLock.Unlock() - - if err != nil { - h.exitResult.Err = err - h.procState = drivers.TaskStateUnknown - h.completedAt = time.Now() - return - } - h.procState = drivers.TaskStateExited - h.exitResult.ExitCode = ps.ExitCode - h.exitResult.Signal = ps.Signal - h.completedAt = ps.Time - - // TODO: detect if the task OOMed -} diff --git a/plugins/drivers/utils/eventer.go b/plugins/drivers/utils/eventer.go index 45a6aa5fe..737b7f608 100644 --- a/plugins/drivers/utils/eventer.go +++ b/plugins/drivers/utils/eventer.go @@ -1,7 +1,6 @@ package utils import ( - "fmt" "sync" "time" @@ -13,20 +12,27 @@ import ( var ( // DefaultSendEventTimeout is the timeout used when publishing events to consumers DefaultSendEventTimeout = 2 * time.Second + + // ConsumerGCInterval is the interval at which garbage collection of consumers + // occures + ConsumerGCInterval = time.Minute ) // Eventer is a utility to control broadcast of TaskEvents to multiple consumers. // It also implements the TaskEvents func in the DriverPlugin interface so that // it can be embedded in a implementing driver struct. type Eventer struct { - consumersLock sync.RWMutex // events is a channel were events to be broadcasted are sent + // This channel is never closed, because it's lifetime is tied to the + // life of the driver and closing creates some subtile race conditions + // between closing it and emitting events. events chan *drivers.TaskEvent // consumers is a slice of eventConsumers to broadcast events to. // access is gaurded by consumersLock RWMutex - consumers []*eventConsumer + consumers []*eventConsumer + consumersLock sync.RWMutex // ctx to allow control of event loop shutdown ctx context.Context @@ -34,6 +40,13 @@ type Eventer struct { logger hclog.Logger } +type eventConsumer struct { + timeout time.Duration + ctx context.Context + ch chan *drivers.TaskEvent + logger hclog.Logger +} + // NewEventer returns an Eventer with a running event loop that can be stopped // by closing the given stop channel func NewEventer(ctx context.Context, logger hclog.Logger) *Eventer { @@ -52,32 +65,48 @@ func (e *Eventer) eventLoop() { for { select { case <-e.ctx.Done(): - close(e.events) + e.logger.Debug("task event loop shutdown") return case event := <-e.events: - e.consumersLock.RLock() - for _, consumer := range e.consumers { - consumer.send(event) - } - e.consumersLock.RUnlock() + e.iterateConsumers(event) + case <-time.After(ConsumerGCInterval): + e.gcConsumers() } } } -type eventConsumer struct { - timeout time.Duration - ctx context.Context - ch chan *drivers.TaskEvent - logger hclog.Logger +func (e *Eventer) iterateConsumers(event *drivers.TaskEvent) { + e.consumersLock.Lock() + filtered := e.consumers[:0] + for _, consumer := range e.consumers { + select { + case <-time.After(consumer.timeout): + filtered = append(filtered, consumer) + e.logger.Warn("timeout sending event", "task_id", event.TaskID, "message", event.Message) + case <-consumer.ctx.Done(): + // consumer context finished, filtering it out of loop + close(consumer.ch) + case consumer.ch <- event: + filtered = append(filtered, consumer) + } + } + e.consumers = filtered + e.consumersLock.Unlock() } -func (c *eventConsumer) send(event *drivers.TaskEvent) { - select { - case <-time.After(c.timeout): - c.logger.Warn("timeout sending event", "task_id", event.TaskID, "message", event.Message) - case <-c.ctx.Done(): - case c.ch <- event: +func (e *Eventer) gcConsumers() { + e.consumersLock.Lock() + filtered := e.consumers[:0] + for _, consumer := range e.consumers { + select { + case <-consumer.ctx.Done(): + // consumer context finished, filtering it out of loop + default: + filtered = append(filtered, consumer) + } } + e.consumers = filtered + e.consumersLock.Unlock() } func (e *Eventer) newConsumer(ctx context.Context) *eventConsumer { @@ -98,38 +127,18 @@ func (e *Eventer) newConsumer(ctx context.Context) *eventConsumer { // TaskEvents is an implementation of the DriverPlugin.TaskEvents function func (e *Eventer) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) { consumer := e.newConsumer(ctx) - go e.handleConsumer(consumer) return consumer.ch, nil } -func (e *Eventer) handleConsumer(consumer *eventConsumer) { - // wait for consumer or eventer ctx to finish - select { - case <-consumer.ctx.Done(): - case <-e.ctx.Done(): - } - e.consumersLock.Lock() - defer e.consumersLock.Unlock() - defer close(consumer.ch) - - filtered := e.consumers[:0] - for _, c := range e.consumers { - if c != consumer { - filtered = append(filtered, c) - } - } - e.consumers = filtered -} - // EmitEvent can be used to broadcast a new event func (e *Eventer) EmitEvent(event *drivers.TaskEvent) error { select { case <-e.ctx.Done(): - return fmt.Errorf("error sending event, context canceled") + return e.ctx.Err() case e.events <- event: if e.logger.IsTrace() { - e.logger.Trace("emitting event", "event", *event) + e.logger.Trace("emitting event", "event", event) } } return nil diff --git a/plugins/drivers/utils/eventer_test.go b/plugins/drivers/utils/eventer_test.go index 6c7972a9a..763772920 100644 --- a/plugins/drivers/utils/eventer_test.go +++ b/plugins/drivers/utils/eventer_test.go @@ -15,7 +15,7 @@ func TestEventer(t *testing.T) { t.Parallel() require := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, _ := context.WithCancel(context.Background()) e := NewEventer(ctx, testlog.HCLogger(t)) events := []*drivers.TaskEvent{ @@ -33,7 +33,7 @@ func TestEventer(t *testing.T) { }, } - ctx1, cancel1 := context.WithCancel(context.Background()) + ctx1, _ := context.WithCancel(context.Background()) consumer1, err := e.TaskEvents(ctx1) require.NoError(err) ctx2 := (context.Background()) @@ -49,8 +49,8 @@ func TestEventer(t *testing.T) { for event := range consumer1 { i++ buffer1 = append(buffer1, event) - if i == 3 { - break + if i == len(events) { + return } } }() @@ -60,8 +60,8 @@ func TestEventer(t *testing.T) { for event := range consumer2 { i++ buffer2 = append(buffer2, event) - if i == 3 { - break + if i == len(events) { + return } } }() @@ -73,20 +73,45 @@ func TestEventer(t *testing.T) { wg.Wait() require.Exactly(events, buffer1) require.Exactly(events, buffer2) - cancel1() - time.Sleep(100 * time.Millisecond) +} + +func TestEventer_iterateConsumers(t *testing.T) { + t.Parallel() + require := require.New(t) + + e := &Eventer{ + events: make(chan *drivers.TaskEvent), + ctx: context.Background(), + logger: testlog.HCLogger(t), + } + + ev := &drivers.TaskEvent{ + TaskID: "a", + Timestamp: time.Now(), + } + + ctx1, cancel1 := context.WithCancel(context.Background()) + consumer, err := e.TaskEvents(ctx1) + require.NoError(err) require.Equal(1, len(e.consumers)) - require.NoError(e.EmitEvent(&drivers.TaskEvent{})) - ev, ok := <-consumer1 - require.Nil(ev) - require.False(ok) - ev, ok = <-consumer2 - require.NotNil(ev) - require.True(ok) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ev1, ok := <-consumer + require.Exactly(ev, ev1) + require.True(ok) + }() + e.iterateConsumers(ev) + wg.Wait() - cancel() - time.Sleep(100 * time.Millisecond) - require.Zero(len(e.consumers)) - require.Error(e.EmitEvent(&drivers.TaskEvent{})) + go func() { + cancel1() + e.iterateConsumers(ev) + }() + ev1, ok := <-consumer + require.False(ok) + require.Nil(ev1) + require.Equal(0, len(e.consumers)) }