This commit is contained in:
Alex Dadgar
2018-10-31 18:00:30 -07:00
parent 57f40c7e3e
commit a8e95502fe
12 changed files with 699 additions and 37 deletions

View File

@@ -10,6 +10,7 @@ import (
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/nomad/client/devicemanager/state"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device"
@@ -40,11 +41,11 @@ type Manager interface {
type StateStorage interface {
// GetDevicePluginState is used to retrieve the device manager's plugin
// state.
GetDevicePluginState() (*PluginState, error)
GetDevicePluginState() (*state.PluginState, error)
// PutDevicePluginState is used to store the device manager's plugin
// state.
PutDevicePluginState(state *PluginState) error
PutDevicePluginState(state *state.PluginState) error
}
// UpdateNodeDevices is a callback for updating the set of devices on a node.
@@ -136,6 +137,12 @@ func (m *manager) Run() {
// Get device plugins
devices := m.loader.Catalog()[base.PluginTypeDevice]
if len(devices) == 0 {
m.logger.Debug("exiting since there are no device plugins")
m.cancel()
return
}
for _, d := range devices {
id := loader.PluginInfoID(d)
storeFn := func(c *plugin.ReattachConfig) error {
@@ -290,7 +297,7 @@ func (m *manager) storePluginReattachConfig(id loader.PluginID, c *plugin.Reatta
m.reattachConfigs[id] = shared.ReattachConfigFromGoPlugin(c)
// Persist the state
s := &PluginState{
s := &state.PluginState{
ReattachConfigs: make(map[string]*shared.ReattachConfig, len(m.reattachConfigs)),
}

View File

@@ -0,0 +1,501 @@
package devicemanager
import (
"context"
"fmt"
"strings"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device"
"github.com/hashicorp/nomad/plugins/shared/loader"
psstructs "github.com/hashicorp/nomad/plugins/shared/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/kr/pretty"
"github.com/stretchr/testify/require"
)
var (
nvidiaDevice0ID = uuid.Generate()
nvidiaDevice1ID = uuid.Generate()
nvidiaDeviceGroup = &device.DeviceGroup{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
Devices: []*device.Device{
{
ID: nvidiaDevice0ID,
Healthy: true,
},
{
ID: nvidiaDevice1ID,
Healthy: true,
},
},
Attributes: map[string]*psstructs.Attribute{
"memory": {
Int: helper.Int64ToPtr(4),
Unit: "GB",
},
},
}
intelDeviceID = uuid.Generate()
intelDeviceGroup = &device.DeviceGroup{
Vendor: "intel",
Type: "gpu",
Name: "640GT",
Devices: []*device.Device{
{
ID: intelDeviceID,
Healthy: true,
},
},
Attributes: map[string]*psstructs.Attribute{
"memory": {
Int: helper.Int64ToPtr(2),
Unit: "GB",
},
},
}
nvidiaDeviceGroupStats = &device.DeviceGroupStats{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
InstanceStats: map[string]*device.DeviceStats{
nvidiaDevice0ID: &device.DeviceStats{
Summary: &device.StatValue{
IntNumeratorVal: 212,
Unit: "F",
Desc: "Temperature",
},
},
nvidiaDevice1ID: &device.DeviceStats{
Summary: &device.StatValue{
IntNumeratorVal: 218,
Unit: "F",
Desc: "Temperature",
},
},
},
}
intelDeviceGroupStats = &device.DeviceGroupStats{
Vendor: "intel",
Type: "gpu",
Name: "640GT",
InstanceStats: map[string]*device.DeviceStats{
intelDeviceID: &device.DeviceStats{
Summary: &device.StatValue{
IntNumeratorVal: 220,
Unit: "F",
Desc: "Temperature",
},
},
},
}
)
func baseTestConfig(t *testing.T) (
config *Config,
deviceUpdateCh chan []*structs.NodeDeviceResource,
catalog *loader.MockCatalog) {
// Create an update handler
deviceUpdates := make(chan []*structs.NodeDeviceResource, 1)
updateFn := func(devices []*structs.NodeDeviceResource) {
deviceUpdates <- devices
}
// Create a mock plugin catalog
mc := &loader.MockCatalog{}
// Create the config
config = &Config{
Logger: testlog.HCLogger(t),
PluginConfig: &base.ClientAgentConfig{},
StatsInterval: 100 * time.Millisecond,
State: state.NewMemDB(),
Updater: updateFn,
Loader: mc,
}
return config, deviceUpdates, mc
}
func configureCatalogWith(catalog *loader.MockCatalog, plugins map[*base.PluginInfoResponse]loader.PluginInstance) {
catalog.DispenseF = func(name, _ string, _ *base.ClientAgentConfig, _ log.Logger) (loader.PluginInstance, error) {
for info, v := range plugins {
if info.Name == name {
return v, nil
}
}
return nil, fmt.Errorf("no matching plugin")
}
catalog.ReattachF = func(name, _ string, _ *plugin.ReattachConfig) (loader.PluginInstance, error) {
for info, v := range plugins {
if info.Name == name {
return v, nil
}
}
return nil, fmt.Errorf("no matching plugin")
}
catalog.CatalogF = func() map[string][]*base.PluginInfoResponse {
devices := make([]*base.PluginInfoResponse, 0, len(plugins))
for k := range plugins {
devices = append(devices, k)
}
out := map[string][]*base.PluginInfoResponse{
base.PluginTypeDevice: devices,
}
return out
}
}
func pluginInfoResponse(name string) *base.PluginInfoResponse {
return &base.PluginInfoResponse{
Type: base.PluginTypeDevice,
PluginApiVersion: "v0.0.1",
PluginVersion: "v0.0.1",
Name: name,
}
}
// drainNodeDeviceUpdates drains all updates to the node device fingerprint channel
func drainNodeDeviceUpdates(ctx context.Context, in chan []*structs.NodeDeviceResource) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-in:
}
}
}()
}
func deviceReserveFn(ids []string) (*device.ContainerReservation, error) {
return &device.ContainerReservation{
Envs: map[string]string{
"DEVICES": strings.Join(ids, ","),
},
}, nil
}
// nvidiaAndIntelDefaultPlugins adds an nvidia and intel mock plugin to the
// catalog
func nvidiaAndIntelDefaultPlugins(catalog *loader.MockCatalog) {
pluginInfoNvidia := pluginInfoResponse("nvidia")
deviceNvidia := &device.MockDevicePlugin{
MockPlugin: &base.MockPlugin{
PluginInfoF: base.StaticInfo(pluginInfoNvidia),
ConfigSchemaF: base.TestConfigSchema(),
SetConfigF: base.NoopSetConfig(),
},
FingerprintF: device.StaticFingerprinter([]*device.DeviceGroup{nvidiaDeviceGroup}),
ReserveF: deviceReserveFn,
StatsF: device.StaticStats([]*device.DeviceGroupStats{nvidiaDeviceGroupStats}),
}
pluginNvidia := loader.MockBasicExternalPlugin(deviceNvidia)
pluginInfoIntel := pluginInfoResponse("intel")
deviceIntel := &device.MockDevicePlugin{
MockPlugin: &base.MockPlugin{
PluginInfoF: base.StaticInfo(pluginInfoIntel),
ConfigSchemaF: base.TestConfigSchema(),
SetConfigF: base.NoopSetConfig(),
},
FingerprintF: device.StaticFingerprinter([]*device.DeviceGroup{intelDeviceGroup}),
ReserveF: deviceReserveFn,
StatsF: device.StaticStats([]*device.DeviceGroupStats{intelDeviceGroupStats}),
}
pluginIntel := loader.MockBasicExternalPlugin(deviceIntel)
// Configure the catalog with two plugins
configureCatalogWith(catalog, map[*base.PluginInfoResponse]loader.PluginInstance{
pluginInfoNvidia: pluginNvidia,
pluginInfoIntel: pluginIntel,
})
}
// Test collecting statistics from all devices
func TestManager_AllStats(t *testing.T) {
t.Parallel()
require := require.New(t)
config, updateCh, catalog := baseTestConfig(t)
nvidiaAndIntelDefaultPlugins(catalog)
m := New(config)
go m.Run()
defer m.Shutdown()
// Wait till we get a fingerprint result
select {
case <-time.After(5 * time.Second):
t.Fatal("timeout")
case devices := <-updateCh:
require.Len(devices, 2)
}
// Now collect all the stats
var stats []*device.DeviceGroupStats
testutil.WaitForResult(func() (bool, error) {
stats = m.AllStats()
l := len(stats)
if l == 2 {
return true, nil
}
return false, fmt.Errorf("expected count 2; got %d", l)
}, func(err error) {
t.Fatal(err)
})
// Check we got stats from both the devices
var nstats, istats bool
for _, stat := range stats {
switch stat.Vendor {
case "intel":
istats = true
case "nvidia":
nstats = true
default:
t.Fatalf("unexpected vendor %q", stat.Vendor)
}
}
require.True(nstats)
require.True(istats)
}
// Test collecting statistics from a particular device
func TestManager_DeviceStats(t *testing.T) {
t.Parallel()
require := require.New(t)
config, updateCh, catalog := baseTestConfig(t)
nvidiaAndIntelDefaultPlugins(catalog)
m := New(config)
go m.Run()
defer m.Shutdown()
// Wait till we get a fingerprint result
select {
case <-time.After(5 * time.Second):
t.Fatal("timeout")
case devices := <-updateCh:
require.Len(devices, 2)
}
testutil.WaitForResult(func() (bool, error) {
stats := m.AllStats()
l := len(stats)
if l == 2 {
t.Logf("% #v", pretty.Formatter(stats))
return true, nil
}
return false, fmt.Errorf("expected count 2; got %d", l)
}, func(err error) {
t.Fatal(err)
})
// Now collect the stats for one nvidia device
stat, err := m.DeviceStats(&structs.AllocatedDeviceResource{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
DeviceIDs: []string{nvidiaDevice1ID},
})
require.NoError(err)
require.NotNil(stat)
require.Len(stat.InstanceStats, 1)
require.Contains(stat.InstanceStats, nvidiaDevice1ID)
istat := stat.InstanceStats[nvidiaDevice1ID]
require.EqualValues(218, istat.Summary.IntNumeratorVal)
}
// Test reserving a particular device
func TestManager_Reserve(t *testing.T) {
t.Parallel()
r := require.New(t)
config, updateCh, catalog := baseTestConfig(t)
nvidiaAndIntelDefaultPlugins(catalog)
m := New(config)
go m.Run()
defer m.Shutdown()
// Wait till we get a fingerprint result
select {
case <-time.After(5 * time.Second):
t.Fatal("timeout")
case devices := <-updateCh:
r.Len(devices, 2)
}
cases := []struct {
in *structs.AllocatedDeviceResource
expected string
err bool
}{
{
in: &structs.AllocatedDeviceResource{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
DeviceIDs: []string{nvidiaDevice1ID},
},
expected: nvidiaDevice1ID,
},
{
in: &structs.AllocatedDeviceResource{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
DeviceIDs: []string{nvidiaDevice0ID},
},
expected: nvidiaDevice0ID,
},
{
in: &structs.AllocatedDeviceResource{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
DeviceIDs: []string{nvidiaDevice0ID, nvidiaDevice1ID},
},
expected: fmt.Sprintf("%s,%s", nvidiaDevice0ID, nvidiaDevice1ID),
},
{
in: &structs.AllocatedDeviceResource{
Vendor: "nvidia",
Type: "gpu",
Name: "1080ti",
DeviceIDs: []string{nvidiaDevice0ID, nvidiaDevice1ID, "foo"},
},
err: true,
},
{
in: &structs.AllocatedDeviceResource{
Vendor: "intel",
Type: "gpu",
Name: "640GT",
DeviceIDs: []string{intelDeviceID},
},
expected: intelDeviceID,
},
{
in: &structs.AllocatedDeviceResource{
Vendor: "intel",
Type: "gpu",
Name: "foo",
DeviceIDs: []string{intelDeviceID},
},
err: true,
},
}
for i, c := range cases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
r = require.New(t)
// Reserve a particular device
res, err := m.Reserve(c.in)
if !c.err {
r.NoError(err)
r.NotNil(res)
r.Len(res.Envs, 1)
r.Equal(res.Envs["DEVICES"], c.expected)
} else {
r.Error(err)
}
})
}
}
// Test that shutdown shutsdown the plugins
func TestManager_Shutdown(t *testing.T) {
t.Parallel()
require := require.New(t)
config, updateCh, catalog := baseTestConfig(t)
nvidiaAndIntelDefaultPlugins(catalog)
m := New(config)
go m.Run()
defer m.Shutdown()
// Wait till we get a fingerprint result
select {
case <-time.After(5 * time.Second):
t.Fatal("timeout")
case devices := <-updateCh:
require.Len(devices, 2)
}
// Call shutdown and assert that we killed the plugins
m.Shutdown()
for _, resp := range catalog.Catalog()[base.PluginTypeDevice] {
pinst, _ := catalog.Dispense(resp.Name, resp.Type, &base.ClientAgentConfig{}, config.Logger)
require.True(pinst.Exited())
}
}
// Test that startup shutsdown previously launched plugins
func TestManager_Run_ShutdownOld(t *testing.T) {
t.Parallel()
require := require.New(t)
config, updateCh, catalog := baseTestConfig(t)
nvidiaAndIntelDefaultPlugins(catalog)
m := New(config)
go m.Run()
defer m.Shutdown()
// Wait till we get a fingerprint result
select {
case <-time.After(5 * time.Second):
t.Fatal("timeout")
case devices := <-updateCh:
require.Len(devices, 2)
}
// Create a new manager with the same config so that it reads the old state
m2 := New(config)
go m2.Run()
defer m2.Shutdown()
testutil.WaitForResult(func() (bool, error) {
for _, resp := range catalog.Catalog()[base.PluginTypeDevice] {
pinst, _ := catalog.Dispense(resp.Name, resp.Type, &base.ClientAgentConfig{}, config.Logger)
if !pinst.Exited() {
return false, fmt.Errorf("plugin %q not shutdown", resp.Name)
}
}
return true, nil
}, func(err error) {
t.Fatal(err)
})
}

View File

@@ -0,0 +1,11 @@
package state
import "github.com/hashicorp/nomad/plugins/shared"
// PluginState is used to store the device managers state across restarts of the
// agent
type PluginState struct {
// ReattachConfigs are the set of reattach configs for plugin's launched by
// the device manager
ReattachConfigs map[string]*shared.ReattachConfig
}

View File

@@ -6,18 +6,9 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/device"
"github.com/hashicorp/nomad/plugins/shared"
psstructs "github.com/hashicorp/nomad/plugins/shared/structs"
)
// PluginState is used to store the device managers state across restarts of the
// agent
type PluginState struct {
// ReattachConfigs are the set of reattach configs for plugin's launched by
// the device manager
ReattachConfigs map[string]*shared.ReattachConfig
}
// UnknownDeviceError is returned when an operation is attempted on an unknown
// device.
type UnknownDeviceError struct {

View File

@@ -7,7 +7,7 @@ import (
"testing"
trstate "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
"github.com/hashicorp/nomad/client/devicemanager"
dmstate "github.com/hashicorp/nomad/client/devicemanager/state"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/kr/pretty"
@@ -206,7 +206,7 @@ func TestStateDB_DeviceManager(t *testing.T) {
require.Nil(ps)
// Putting PluginState should work
state := &devicemanager.PluginState{}
state := &dmstate.PluginState{}
require.NoError(db.PutDevicePluginState(state))
// Getting should return the available state

View File

@@ -2,7 +2,7 @@ package state
import (
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
"github.com/hashicorp/nomad/client/devicemanager"
dmstate "github.com/hashicorp/nomad/client/devicemanager/state"
"github.com/hashicorp/nomad/nomad/structs"
)
@@ -44,11 +44,11 @@ type StateDB interface {
// GetDevicePluginState is used to retrieve the device manager's plugin
// state.
GetDevicePluginState() (*devicemanager.PluginState, error)
GetDevicePluginState() (*dmstate.PluginState, error)
// PutDevicePluginState is used to store the device manager's plugin
// state.
PutDevicePluginState(state *devicemanager.PluginState) error
PutDevicePluginState(state *dmstate.PluginState) error
// Close the database. Unsafe for further use after calling regardless
// of return value.

View File

@@ -4,7 +4,7 @@ import (
"sync"
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
"github.com/hashicorp/nomad/client/devicemanager"
dmstate "github.com/hashicorp/nomad/client/devicemanager/state"
"github.com/hashicorp/nomad/nomad/structs"
)
@@ -19,7 +19,7 @@ type MemDB struct {
taskState map[string]map[string]*structs.TaskState
// devicemanager -> plugin-state
devManagerPs *devicemanager.PluginState
devManagerPs *dmstate.PluginState
mu sync.RWMutex
}
@@ -135,7 +135,7 @@ func (m *MemDB) DeleteAllocationBucket(allocID string) error {
return nil
}
func (m *MemDB) PutDevicePluginState(ps *devicemanager.PluginState) error {
func (m *MemDB) PutDevicePluginState(ps *dmstate.PluginState) error {
m.mu.Lock()
defer m.mu.Unlock()
m.devManagerPs = ps
@@ -144,7 +144,7 @@ func (m *MemDB) PutDevicePluginState(ps *devicemanager.PluginState) error {
// GetDevicePluginState stores the device manager's plugin state or returns an
// error.
func (m *MemDB) GetDevicePluginState() (*devicemanager.PluginState, error) {
func (m *MemDB) GetDevicePluginState() (*dmstate.PluginState, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.devManagerPs, nil

View File

@@ -2,7 +2,7 @@ package state
import (
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
"github.com/hashicorp/nomad/client/devicemanager"
dmstate "github.com/hashicorp/nomad/client/devicemanager/state"
"github.com/hashicorp/nomad/nomad/structs"
)
@@ -41,11 +41,11 @@ func (n NoopDB) DeleteAllocationBucket(allocID string) error {
return nil
}
func (n NoopDB) PutDevicePluginState(ps *devicemanager.PluginState) error {
func (n NoopDB) PutDevicePluginState(ps *dmstate.PluginState) error {
return nil
}
func (n NoopDB) GetDevicePluginState() (*devicemanager.PluginState, error) {
func (n NoopDB) GetDevicePluginState() (*dmstate.PluginState, error) {
return nil, nil
}

View File

@@ -5,7 +5,7 @@ import (
"path/filepath"
trstate "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
"github.com/hashicorp/nomad/client/devicemanager"
dmstate "github.com/hashicorp/nomad/client/devicemanager/state"
"github.com/hashicorp/nomad/helper/boltdd"
"github.com/hashicorp/nomad/nomad/structs"
)
@@ -21,7 +21,7 @@ allocations/ (bucket)
|--> task_runner persisted objects (k/v)
devicemanager/
|--> plugin-state -> *devicemanager.PluginState
|--> plugin-state -> *dmstate.PluginState
*/
var (
@@ -369,7 +369,7 @@ func getTaskBucket(tx *boltdd.Tx, allocID, taskName string) (*boltdd.Bucket, err
// PutDevicePluginState stores the device manager's plugin state or returns an
// error.
func (s *BoltStateDB) PutDevicePluginState(ps *devicemanager.PluginState) error {
func (s *BoltStateDB) PutDevicePluginState(ps *dmstate.PluginState) error {
return s.db.Update(func(tx *boltdd.Tx) error {
// Retrieve the root device manager bucket
devBkt, err := tx.CreateBucketIfNotExists(devManagerBucket)
@@ -383,8 +383,8 @@ func (s *BoltStateDB) PutDevicePluginState(ps *devicemanager.PluginState) error
// GetDevicePluginState stores the device manager's plugin state or returns an
// error.
func (s *BoltStateDB) GetDevicePluginState() (*devicemanager.PluginState, error) {
var ps *devicemanager.PluginState
func (s *BoltStateDB) GetDevicePluginState() (*dmstate.PluginState, error) {
var ps *dmstate.PluginState
err := s.db.View(func(tx *boltdd.Tx) error {
devBkt := tx.Bucket(devManagerBucket)
@@ -394,7 +394,7 @@ func (s *BoltStateDB) GetDevicePluginState() (*devicemanager.PluginState, error)
}
// Restore Plugin State if it exists
ps = &devicemanager.PluginState{}
ps = &dmstate.PluginState{}
if err := devBkt.Get(devManagerPluginStateKey, ps); err != nil {
if !boltdd.IsErrNotFound(err) {
return fmt.Errorf("failed to read device manager plugin state: %v", err)

View File

@@ -22,7 +22,7 @@ var (
Block: &hclspec.Spec_Attr{
Attr: &hclspec.Attr{
Type: "number",
Required: true,
Required: false,
},
},
},
@@ -46,13 +46,17 @@ type TestConfig struct {
Baz bool `cty:"baz" codec:"baz"`
}
type PluginInfoFn func() (*PluginInfoResponse, error)
type ConfigSchemaFn func() (*hclspec.Spec, error)
type SetConfigFn func([]byte, *ClientAgentConfig) error
// MockPlugin is used for testing.
// Each function can be set as a closure to make assertions about how data
// is passed through the base plugin layer.
type MockPlugin struct {
PluginInfoF func() (*PluginInfoResponse, error)
ConfigSchemaF func() (*hclspec.Spec, error)
SetConfigF func([]byte, *ClientAgentConfig) error
PluginInfoF PluginInfoFn
ConfigSchemaF ConfigSchemaFn
SetConfigF SetConfigFn
}
func (p *MockPlugin) PluginInfo() (*PluginInfoResponse, error) { return p.PluginInfoF() }
@@ -60,3 +64,30 @@ func (p *MockPlugin) ConfigSchema() (*hclspec.Spec, error) { return p.Config
func (p *MockPlugin) SetConfig(data []byte, cfg *ClientAgentConfig) error {
return p.SetConfigF(data, cfg)
}
// Below are static implementations of the base plugin functions
// StaticInfo returns the passed PluginInfoResponse with no error
func StaticInfo(out *PluginInfoResponse) PluginInfoFn {
return func() (*PluginInfoResponse, error) {
return out, nil
}
}
// StaticConfigSchema returns the passed Spec with no error
func StaticConfigSchema(out *hclspec.Spec) ConfigSchemaFn {
return func() (*hclspec.Spec, error) {
return out, nil
}
}
// TestConfigSchema returns a ConfigSchemaFn that statically returns the
// TestSpec
func TestConfigSchema() ConfigSchemaFn {
return StaticConfigSchema(TestSpec)
}
// NoopSetConfig is a noop implementation of set config
func NoopSetConfig() SetConfigFn {
return func(_ []byte, _ *ClientAgentConfig) error { return nil }
}

View File

@@ -7,14 +7,18 @@ import (
"github.com/hashicorp/nomad/plugins/base"
)
type FingerprintFn func(context.Context) (<-chan *FingerprintResponse, error)
type ReserveFn func([]string) (*ContainerReservation, error)
type StatsFn func(context.Context, time.Duration) (<-chan *StatsResponse, error)
// MockDevicePlugin is used for testing.
// Each function can be set as a closure to make assertions about how data
// is passed through the base plugin layer.
type MockDevicePlugin struct {
*base.MockPlugin
FingerprintF func(context.Context) (<-chan *FingerprintResponse, error)
ReserveF func([]string) (*ContainerReservation, error)
StatsF func(context.Context, time.Duration) (<-chan *StatsResponse, error)
FingerprintF FingerprintFn
ReserveF ReserveFn
StatsF StatsFn
}
func (p *MockDevicePlugin) Fingerprint(ctx context.Context) (<-chan *FingerprintResponse, error) {
@@ -28,3 +32,84 @@ func (p *MockDevicePlugin) Reserve(devices []string) (*ContainerReservation, err
func (p *MockDevicePlugin) Stats(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
return p.StatsF(ctx, interval)
}
// Below are static implementations of the device functions
// StaticFingerprinter fingerprints the passed devices just once
func StaticFingerprinter(devices []*DeviceGroup) FingerprintFn {
return func(_ context.Context) (<-chan *FingerprintResponse, error) {
outCh := make(chan *FingerprintResponse, 1)
outCh <- &FingerprintResponse{
Devices: devices,
}
return outCh, nil
}
}
// ErrorChFingerprinter returns an error fingerprinting over the channel
func ErrorChFingerprinter(err error) FingerprintFn {
return func(_ context.Context) (<-chan *FingerprintResponse, error) {
outCh := make(chan *FingerprintResponse, 1)
outCh <- &FingerprintResponse{
Error: err,
}
return outCh, nil
}
}
// StaticReserve returns the passed container reservation
func StaticReserve(out *ContainerReservation) ReserveFn {
return func(_ []string) (*ContainerReservation, error) {
return out, nil
}
}
// ErrorReserve returns the passed error
func ErrorReserve(err error) ReserveFn {
return func(_ []string) (*ContainerReservation, error) {
return nil, err
}
}
// StaticStats returns the passed statistics only updating the timestamp
func StaticStats(out []*DeviceGroupStats) StatsFn {
return func(ctx context.Context, intv time.Duration) (<-chan *StatsResponse, error) {
outCh := make(chan *StatsResponse, 1)
go func() {
ticker := time.NewTimer(0)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ticker.Reset(intv)
}
now := time.Now()
for _, g := range out {
for _, i := range g.InstanceStats {
i.Timestamp = now
}
}
outCh <- &StatsResponse{
Groups: out,
}
}
}()
return outCh, nil
}
}
// ErrorChStats returns an error collecting stats over the channel
func ErrorChStats(err error) StatsFn {
return func(_ context.Context, _ time.Duration) (<-chan *StatsResponse, error) {
outCh := make(chan *StatsResponse, 1)
outCh <- &StatsResponse{
Error: err,
}
return outCh, nil
}
}

View File

@@ -1,8 +1,11 @@
package loader
import (
"net"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/plugins/base"
)
@@ -39,3 +42,36 @@ func (m *MockInstance) Kill() { m.KillF
func (m *MockInstance) ReattachConfig() (*plugin.ReattachConfig, bool) { return m.ReattachConfigF() }
func (m *MockInstance) Plugin() interface{} { return m.PluginF() }
func (m *MockInstance) Exited() bool { return m.ExitedF() }
// MockBasicExternalPlugin returns a MockInstance that simulates an external
// plugin returning it has been exited after kill is called. It returns the
// passed inst as the plugin
func MockBasicExternalPlugin(inst interface{}) *MockInstance {
killed := helper.BoolToPtr(false)
return &MockInstance{
InternalPlugin: false,
KillF: func() {
*killed = true
},
ReattachConfigF: func() (*plugin.ReattachConfig, bool) {
return &plugin.ReattachConfig{
Protocol: "tcp",
Addr: &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 3200,
Zone: "",
},
Pid: 1000,
}, true
},
PluginF: func() interface{} {
return inst
},
ExitedF: func() bool {
return *killed
},
}
}