diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 54037dcbe..0ab5be2ff 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -265,14 +265,9 @@ func TestTaskRunner_DevicePropogation(t *testing.T) { } // Get the mock driver plugin - driverPlugin, err := conf.PluginSingletonLoader.Dispense( - mockdriver.PluginID.Name, - mockdriver.PluginID.PluginType, - nil, - conf.Logger, - ) + driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name) require.NoError(err) - mockDriver := driverPlugin.Plugin().(*mockdriver.Driver) + mockDriver := driverPlugin.(*mockdriver.Driver) // Assert its config has been properly interpolated driverCfg, _ := mockDriver.GetTaskConfig() diff --git a/client/client.go b/client/client.go index 15acbb31b..538835f07 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "errors" "fmt" "io/ioutil" @@ -387,12 +388,19 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic return nil, fmt.Errorf("failed to setup vault client: %v", err) } - // Wait for plugin manangers to initialize - pluginReadyCh, err := c.pluginManagers.Ready(pluginmanager.DefaultManagerReadyTimeout) + // Wait for plugin manangers to initialize. + // Plugins must be initialized before restore is called otherwise restoring + // tasks that use uninitialized plugins will fail. + ctx, cancel := context.WithTimeout(context.Background(), pluginmanager.DefaultManagerReadyTimeout) + defer cancel() + pluginReadyCh, err := c.pluginManagers.Ready(ctx) if err != nil { return nil, err } - <-pluginReadyCh + select { + case <-pluginReadyCh: + case <-ctx.Done(): + } // Restore the state if err := c.restoreState(); err != nil { diff --git a/client/devicemanager/manager.go b/client/devicemanager/manager.go index 8fb0abc79..94e075b6a 100644 --- a/client/devicemanager/manager.go +++ b/client/devicemanager/manager.go @@ -195,16 +195,14 @@ func (m *manager) Shutdown() { } func (m *manager) Ready() <-chan struct{} { - ret := make(chan struct{}) + ctx, cancel := context.WithTimeout(m.ctx, 5*time.Second) go func() { - ctx, cancel := context.WithTimeout(m.ctx, 5*time.Second) for _, i := range m.instances { i.WaitForFirstFingerprint(ctx) } cancel() - close(ret) }() - return ret + return ctx.Done() } // Reserve reserves the given allocated device. If the device is unknown, an diff --git a/client/pluginmanager/drivermanager/manager.go b/client/pluginmanager/drivermanager/manager.go index 0dabc47df..30f3cfc1c 100644 --- a/client/pluginmanager/drivermanager/manager.go +++ b/client/pluginmanager/drivermanager/manager.go @@ -191,11 +191,7 @@ func (m *manager) Run() { } // signal ready - select { - case <-m.ctx.Done(): - return - case m.readyCh <- struct{}{}: - } + close(m.readyCh) // wait for shutdown <-m.ctx.Done() @@ -216,25 +212,27 @@ func (m *manager) Shutdown() { } func (m *manager) Ready() <-chan struct{} { - ret := make(chan struct{}) + ctx, cancel := context.WithTimeout(m.ctx, time.Second*10) go func() { + defer cancel() // We don't want to start initial fingerprint wait until Run loop has // finished - <-m.readyCh + select { + case <-m.readyCh: + case <-m.ctx.Done(): + return + } var availDrivers []string - ctx, cancel := context.WithTimeout(m.ctx, time.Second*10) for name, instance := range m.instances { instance.WaitForFirstFingerprint(ctx) if instance.lastHealthState != drivers.HealthStateUndetected { availDrivers = append(availDrivers, name) } } - cancel() m.logger.Debug("detected drivers", "drivers", availDrivers) - close(ret) }() - return ret + return ctx.Done() } func (m *manager) loadReattachConfigs() error { @@ -299,13 +297,17 @@ func (m *manager) fetchPluginReattachConfig(id loader.PluginID) (*plugin.Reattac func (m *manager) RegisterEventHandler(driver, taskID string, handler EventHandler) { m.instancesMu.Lock() - m.instances[driver].registerEventHandler(taskID, handler) + if d, ok := m.instances[driver]; ok { + d.registerEventHandler(taskID, handler) + } m.instancesMu.Unlock() } func (m *manager) DeregisterEventHandler(driver, taskID string) { m.instancesMu.Lock() - m.instances[driver].deregisterEventHandler(taskID) + if d, ok := m.instances[driver]; ok { + d.deregisterEventHandler(taskID) + } m.instancesMu.Unlock() } diff --git a/client/pluginmanager/drivermanager/manager_test.go b/client/pluginmanager/drivermanager/manager_test.go index d3e6e6812..80c5d3cbd 100644 --- a/client/pluginmanager/drivermanager/manager_test.go +++ b/client/pluginmanager/drivermanager/manager_test.go @@ -15,8 +15,10 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" + dtu "github.com/hashicorp/nomad/plugins/drivers/testutils" "github.com/hashicorp/nomad/plugins/shared/loader" "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,7 +45,7 @@ func testSetup(t *testing.T) (chan *drivers.Fingerprint, chan *drivers.TaskEvent } func mockDriver(fpChan chan *drivers.Fingerprint, evChan chan *drivers.TaskEvent) drivers.DriverPlugin { - return &drivers.MockDriver{ + return &dtu.MockDriver{ FingerprintF: func(ctx context.Context) (<-chan *drivers.Fingerprint, error) { return fpChan, nil }, @@ -111,10 +113,7 @@ func TestMananger_Fingerprint(t *testing.T) { testutil.WaitForResult(func() (bool, error) { mgr.instancesMu.Lock() defer mgr.instancesMu.Unlock() - if len(mgr.instances) != 1 { - return false, fmt.Errorf("mananger should have registered an instance") - } - return true, nil + return len(mgr.instances) == 1, fmt.Errorf("mananger should have registered 1 instance") }, func(err error) { require.NoError(err) }) @@ -177,10 +176,7 @@ func TestMananger_TaskEvents(t *testing.T) { testutil.WaitForResult(func() (bool, error) { mgr.instancesMu.Lock() defer mgr.instancesMu.Unlock() - if len(mgr.instances) != 1 { - return false, fmt.Errorf("mananger should have registered 1 instance") - } - return true, nil + return len(mgr.instances) == 1, fmt.Errorf("mananger should have registered 1 instance") }, func(err error) { require.NoError(err) }) @@ -190,7 +186,7 @@ func TestMananger_TaskEvents(t *testing.T) { wg.Add(1) mgr.RegisterEventHandler("mock", "abc1", func(ev *drivers.TaskEvent) { defer wg.Done() - require.Exactly(event1, ev) + assert.Exactly(t, event1, ev) }) evChan <- event1 @@ -211,10 +207,7 @@ func TestManager_Run_AllowedDrivers(t *testing.T) { testutil.AssertUntil(200*time.Millisecond, func() (bool, error) { mgr.instancesMu.Lock() defer mgr.instancesMu.Unlock() - if len(mgr.instances) > 0 { - return false, fmt.Errorf("mananger should have no registered instances") - } - return true, nil + return len(mgr.instances) == 0, fmt.Errorf("mananger should have no registered instances") }, func(err error) { require.NoError(err) }) @@ -234,10 +227,7 @@ func TestManager_Run_BlockedDrivers(t *testing.T) { testutil.AssertUntil(200*time.Millisecond, func() (bool, error) { mgr.instancesMu.Lock() defer mgr.instancesMu.Unlock() - if len(mgr.instances) > 0 { - return false, fmt.Errorf("mananger should have no registered instances") - } - return true, nil + return len(mgr.instances) == 0, fmt.Errorf("mananger should have no registered instances") }, func(err error) { require.NoError(err) }) @@ -287,13 +277,10 @@ func TestManager_Run_AllowedBlockedDrivers_Combined(t *testing.T) { }(d) } - testutil.AssertUntil(200*time.Millisecond, func() (bool, error) { + testutil.AssertUntil(250*time.Millisecond, func() (bool, error) { mgr.instancesMu.Lock() defer mgr.instancesMu.Unlock() - if len(mgr.instances) > 1 { - return false, fmt.Errorf("mananger should have 1 registered instance") - } - return true, nil + return len(mgr.instances) < 2, fmt.Errorf("mananger should have 1 registered instance, %v", len(mgr.instances)) }, func(err error) { require.NoError(err) }) diff --git a/client/pluginmanager/group.go b/client/pluginmanager/group.go index 3f06c1d25..243a051eb 100644 --- a/client/pluginmanager/group.go +++ b/client/pluginmanager/group.go @@ -1,6 +1,7 @@ package pluginmanager import ( + "context" "fmt" "sync" "time" @@ -42,7 +43,6 @@ func (m *PluginGroup) RegisterAndRun(manager PluginManager) error { m.mLock.Lock() defer m.mLock.Unlock() if m.shutdown { - m.mLock.Unlock() return fmt.Errorf("plugin group already shutdown") } m.managers = append(m.managers, manager) @@ -57,7 +57,7 @@ func (m *PluginGroup) RegisterAndRun(manager PluginManager) error { // Ready returns a channel which will be closed once all plugin manangers are ready. // A timeout for waiting on each manager is given -func (m *PluginGroup) Ready(timeout time.Duration) (<-chan struct{}, error) { +func (m *PluginGroup) Ready(ctx context.Context) (<-chan struct{}, error) { m.mLock.Lock() defer m.mLock.Unlock() if m.shutdown { @@ -72,7 +72,7 @@ func (m *PluginGroup) Ready(timeout time.Duration) (<-chan struct{}, error) { defer wg.Done() select { case <-manager.Ready(): - case <-time.After(timeout): + case <-ctx.Done(): m.logger.Warn("timeout waiting for plugin manager to be ready", "plugin-type", manager.PluginType()) } diff --git a/client/pluginmanager/testing.go b/client/pluginmanager/testing.go index acb378414..93504cdb5 100644 --- a/client/pluginmanager/testing.go +++ b/client/pluginmanager/testing.go @@ -8,3 +8,8 @@ type MockPluginManager struct { func (m *MockPluginManager) Run() { m.RunF() } func (m *MockPluginManager) Shutdown() { m.ShutdownF() } func (m *MockPluginManager) PluginType() string { return "mock" } +func (m *MockPluginManager) Ready() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +}