diff --git a/api/tasks.go b/api/tasks.go index 4e05a2cd3..7e68fea25 100644 --- a/api/tasks.go +++ b/api/tasks.go @@ -643,6 +643,7 @@ type Task struct { Templates []*Template DispatchPayload *DispatchPayloadConfig VolumeMounts []*VolumeMount + CSIPluginConfig *TaskCSIPluginConfig `mapstructure:"csi_plugin" json:"csi_plugin,omitempty"` Leader bool ShutdownDelay time.Duration `mapstructure:"shutdown_delay"` KillSignal string `mapstructure:"kill_signal"` @@ -683,6 +684,9 @@ func (t *Task) Canonicalize(tg *TaskGroup, job *Job) { if t.Lifecycle.Empty() { t.Lifecycle = nil } + if t.CSIPluginConfig != nil { + t.CSIPluginConfig.Canonicalize() + } } // TaskArtifact is used to download artifacts before running a task. @@ -909,3 +913,48 @@ type TaskEvent struct { TaskSignal string GenericSource string } + +// CSIPluginType is an enum string that encapsulates the valid options for a +// CSIPlugin stanza's Type. These modes will allow the plugin to be used in +// different ways by the client. +type CSIPluginType string + +const ( + // CSIPluginTypeNode indicates that Nomad should only use the plugin for + // performing Node RPCs against the provided plugin. + CSIPluginTypeNode CSIPluginType = "node" + + // CSIPluginTypeController indicates that Nomad should only use the plugin for + // performing Controller RPCs against the provided plugin. + CSIPluginTypeController CSIPluginType = "controller" + + // CSIPluginTypeMonolith indicates that Nomad can use the provided plugin for + // both controller and node rpcs. + CSIPluginTypeMonolith CSIPluginType = "monolith" +) + +// TaskCSIPluginConfig contains the data that is required to setup a task as a +// CSI plugin. This will be used by the csi_plugin_supervisor_hook to configure +// mounts for the plugin and initiate the connection to the plugin catalog. +type TaskCSIPluginConfig struct { + // ID is the identifier of the plugin. + // Ideally this should be the FQDN of the plugin. + ID string `mapstructure:"id"` + + // CSIPluginType instructs Nomad on how to handle processing a plugin + Type CSIPluginType `mapstructure:"type"` + + // MountDir is the destination that nomad should mount in its CSI + // directory for the plugin. It will then expect a file called CSISocketName + // to be created by the plugin, and will provide references into + // "MountDir/CSIIntermediaryDirname/VolumeName/AllocID for mounts. + // + // Default is /csi. + MountDir string `mapstructure:"mount_dir"` +} + +func (t *TaskCSIPluginConfig) Canonicalize() { + if t.MountDir == "" { + t.MountDir = "/csi" + } +} diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index fdd62ad98..9c8286c2d 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/client/dynamicplugins" cinterfaces "github.com/hashicorp/nomad/client/interfaces" "github.com/hashicorp/nomad/client/pluginmanager/drivermanager" cstate "github.com/hashicorp/nomad/client/state" @@ -134,6 +135,10 @@ type allocRunner struct { // prevAllocMigrator allows the migration of a previous allocations alloc dir. prevAllocMigrator allocwatcher.PrevAllocMigrator + // dynamicRegistry contains all locally registered dynamic plugins (e.g csi + // plugins). + dynamicRegistry dynamicplugins.Registry + // devicemanager is used to mount devices as well as lookup device // statistics devicemanager devicemanager.Manager @@ -178,6 +183,7 @@ func NewAllocRunner(config *Config) (*allocRunner, error) { deviceStatsReporter: config.DeviceStatsReporter, prevAllocWatcher: config.PrevAllocWatcher, prevAllocMigrator: config.PrevAllocMigrator, + dynamicRegistry: config.DynamicRegistry, devicemanager: config.DeviceManager, driverManager: config.DriverManager, serversContactedCh: config.ServersContactedCh, @@ -218,6 +224,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error { Logger: ar.logger, StateDB: ar.stateDB, StateUpdater: ar, + DynamicRegistry: ar.dynamicRegistry, Consul: ar.consulClient, ConsulSI: ar.sidsClient, Vault: ar.vaultClient, diff --git a/client/allocrunner/config.go b/client/allocrunner/config.go index a9240b3a3..4893c9604 100644 --- a/client/allocrunner/config.go +++ b/client/allocrunner/config.go @@ -6,6 +6,7 @@ import ( clientconfig "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/client/dynamicplugins" "github.com/hashicorp/nomad/client/interfaces" "github.com/hashicorp/nomad/client/pluginmanager/drivermanager" cstate "github.com/hashicorp/nomad/client/state" @@ -48,6 +49,10 @@ type Config struct { // PrevAllocMigrator allows the migration of a previous allocations alloc dir PrevAllocMigrator allocwatcher.PrevAllocMigrator + // DynamicRegistry contains all locally registered dynamic plugins (e.g csi + // plugins). + DynamicRegistry dynamicplugins.Registry + // DeviceManager is used to mount devices as well as lookup device // statistics DeviceManager devicemanager.Manager diff --git a/client/allocrunner/taskrunner/plugin_supervisor_hook.go b/client/allocrunner/taskrunner/plugin_supervisor_hook.go new file mode 100644 index 000000000..5774c4548 --- /dev/null +++ b/client/allocrunner/taskrunner/plugin_supervisor_hook.go @@ -0,0 +1,333 @@ +package taskrunner + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" + "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" + "github.com/hashicorp/nomad/plugins/drivers" +) + +// csiPluginSupervisorHook manages supervising plugins that are running as Nomad +// tasks. These plugins will be fingerprinted and it will manage connecting them +// to their requisite plugin manager. +// +// It provides a couple of things to a task running inside Nomad. These are: +// * A mount to the `plugin_mount_dir`, that will then be used by Nomad +// to connect to the nested plugin and handle volume mounts. +// * When the task has started, it starts a loop of attempting to connect to the +// plugin, to perform initial fingerprinting of the plugins capabilities before +// notifying the plugin manager of the plugin. +type csiPluginSupervisorHook struct { + logger hclog.Logger + alloc *structs.Allocation + task *structs.Task + runner *TaskRunner + mountPoint string + + // eventEmitter is used to emit events to the task + eventEmitter ti.EventEmitter + + shutdownCtx context.Context + shutdownCancelFn context.CancelFunc + + running bool + runningLock sync.Mutex + + // previousHealthstate is used by the supervisor goroutine to track historic + // health states for gating task events. + previousHealthState bool +} + +// The plugin supervisor uses the PrestartHook mechanism to setup the requisite +// mount points and configuration for the task that exposes a CSI plugin. +var _ interfaces.TaskPrestartHook = &csiPluginSupervisorHook{} + +// The plugin supervisor uses the PoststartHook mechanism to start polling the +// plugin for readiness and supported functionality before registering the +// plugin with the catalog. +var _ interfaces.TaskPoststartHook = &csiPluginSupervisorHook{} + +// The plugin supervisor uses the StopHook mechanism to deregister the plugin +// with the catalog and to ensure any mounts are cleaned up. +var _ interfaces.TaskStopHook = &csiPluginSupervisorHook{} + +func newCSIPluginSupervisorHook(csiRootDir string, eventEmitter ti.EventEmitter, runner *TaskRunner, logger hclog.Logger) *csiPluginSupervisorHook { + task := runner.Task() + pluginRoot := filepath.Join(csiRootDir, string(task.CSIPluginConfig.Type), task.CSIPluginConfig.ID) + + shutdownCtx, cancelFn := context.WithCancel(context.Background()) + + hook := &csiPluginSupervisorHook{ + alloc: runner.Alloc(), + runner: runner, + logger: logger, + task: task, + mountPoint: pluginRoot, + shutdownCtx: shutdownCtx, + shutdownCancelFn: cancelFn, + eventEmitter: eventEmitter, + } + + return hook +} + +func (*csiPluginSupervisorHook) Name() string { + return "csi_plugin_supervisor" +} + +// Prestart is called before the task is started including after every +// restart. This requires that the mount paths for a plugin be idempotent, +// despite us not knowing the name of the plugin ahead of time. +// Because of this, we use the allocid_taskname as the unique identifier for a +// plugin on the filesystem. +func (h *csiPluginSupervisorHook) Prestart(ctx context.Context, + req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + // Create the mount directory that the container will access if it doesn't + // already exist. Default to only user access. + if err := os.MkdirAll(h.mountPoint, 0700); err != nil && !os.IsExist(err) { + return fmt.Errorf("failed to create mount point: %v", err) + } + + configMount := &drivers.MountConfig{ + TaskPath: h.task.CSIPluginConfig.MountDir, + HostPath: h.mountPoint, + Readonly: false, + PropagationMode: "bidirectional", + } + + mounts := ensureMountpointInserted(h.runner.hookResources.getMounts(), configMount) + h.runner.hookResources.setMounts(mounts) + + resp.Done = true + return nil +} + +// Poststart is called after the task has started. Poststart is not +// called if the allocation is terminal. +// +// The context is cancelled if the task is killed. +func (h *csiPluginSupervisorHook) Poststart(_ context.Context, _ *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error { + // If we're already running the supervisor routine, then we don't need to try + // and restart it here as it only terminates on `Stop` hooks. + h.runningLock.Lock() + if h.running { + h.runningLock.Unlock() + return nil + } + h.runningLock.Unlock() + + go h.ensureSupervisorLoop(h.shutdownCtx) + return nil +} + +// ensureSupervisorLoop should be called in a goroutine. It will terminate when +// the passed in context is terminated. +// +// The supervisor works by: +// - Initially waiting for the plugin to become available. This loop is expensive +// and may do things like create new gRPC Clients on every iteration. +// - After receiving an initial healthy status, it will inform the plugin catalog +// of the plugin, registering it with the plugins fingerprinted capabilities. +// - We then perform a more lightweight check, simply probing the plugin on a less +// frequent interval to ensure it is still alive, emitting task events when this +// status changes. +// +// Deeper fingerprinting of the plugin is implemented by the csimanager. +func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) { + h.runningLock.Lock() + if h.running == true { + h.runningLock.Unlock() + return + } + h.running = true + h.runningLock.Unlock() + + defer func() { + h.runningLock.Lock() + h.running = false + h.runningLock.Unlock() + }() + + socketPath := filepath.Join(h.mountPoint, structs.CSISocketName) + t := time.NewTimer(0) + + // Step 1: Wait for the plugin to initially become available. +WAITFORREADY: + for { + select { + case <-ctx.Done(): + return + case <-t.C: + pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath) + if err != nil || !pluginHealthy { + h.logger.Info("CSI Plugin not ready", "error", err) + + // Plugin is not yet returning healthy, because we want to optimise for + // quickly bringing a plugin online, we use a short timeout here. + // TODO(dani): Test with more plugins and adjust. + t.Reset(5 * time.Second) + continue + } + + // Mark the plugin as healthy in a task event + h.previousHealthState = pluginHealthy + event := structs.NewTaskEvent(structs.TaskPluginHealthy) + event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID)) + h.eventEmitter.EmitEvent(event) + + break WAITFORREADY + } + } + + // Step 2: Register the plugin with the catalog. + deregisterPluginFn, err := h.registerPlugin(socketPath) + if err != nil { + h.logger.Error("CSI Plugin registration failed", "error", err) + event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) + event.SetMessage(fmt.Sprintf("failed to register plugin: %s, reason: %v", h.task.CSIPluginConfig.ID, err)) + h.eventEmitter.EmitEvent(event) + } + + // Step 3: Start the lightweight supervisor loop. + t.Reset(0) + for { + select { + case <-ctx.Done(): + // De-register plugins on task shutdown + deregisterPluginFn() + return + case <-t.C: + pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath) + if err != nil { + h.logger.Error("CSI Plugin fingerprinting failed", "error", err) + } + + // The plugin has transitioned to a healthy state. Emit an event. + if !h.previousHealthState && pluginHealthy { + event := structs.NewTaskEvent(structs.TaskPluginHealthy) + event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID)) + h.eventEmitter.EmitEvent(event) + } + + // The plugin has transitioned to an unhealthy state. Emit an event. + if h.previousHealthState && !pluginHealthy { + event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) + if err != nil { + event.SetMessage(fmt.Sprintf("error: %v", err)) + } else { + event.SetMessage("Unknown Reason") + } + h.eventEmitter.EmitEvent(event) + } + + h.previousHealthState = pluginHealthy + + // This loop is informational and in some plugins this may be expensive to + // validate. We use a longer timeout (30s) to avoid causing undue work. + t.Reset(30 * time.Second) + } + } +} + +func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), error) { + mkInfoFn := func(pluginType string) *dynamicplugins.PluginInfo { + return &dynamicplugins.PluginInfo{ + Type: pluginType, + Name: h.task.CSIPluginConfig.ID, + Version: "1.0.0", + ConnectionInfo: &dynamicplugins.PluginConnectionInfo{ + SocketPath: socketPath, + }, + } + } + + registrations := []*dynamicplugins.PluginInfo{} + + switch h.task.CSIPluginConfig.Type { + case structs.CSIPluginTypeController: + registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSIController)) + case structs.CSIPluginTypeNode: + registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSINode)) + case structs.CSIPluginTypeMonolith: + registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSIController)) + registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSINode)) + } + + deregistrationFns := []func(){} + + for _, reg := range registrations { + if err := h.runner.dynamicRegistry.RegisterPlugin(reg); err != nil { + for _, fn := range deregistrationFns { + fn() + } + return nil, err + } + + deregistrationFns = append(deregistrationFns, func() { + err := h.runner.dynamicRegistry.DeregisterPlugin(reg.Type, reg.Name) + if err != nil { + h.logger.Error("failed to deregister csi plugin", "name", reg.Name, "type", reg.Type, "error", err) + } + }) + } + + return func() { + for _, fn := range deregistrationFns { + fn() + } + }, nil +} + +func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, socketPath string) (bool, error) { + _, err := os.Stat(socketPath) + if err != nil { + return false, fmt.Errorf("failed to stat socket: %v", err) + } + + client, err := csi.NewClient(socketPath) + defer client.Close() + if err != nil { + return false, fmt.Errorf("failed to create csi client: %v", err) + } + + healthy, err := client.PluginProbe(ctx) + if err != nil { + return false, fmt.Errorf("failed to probe plugin: %v", err) + } + + return healthy, nil +} + +// Stop is called after the task has exited and will not be started +// again. It is the only hook guaranteed to be executed whenever +// TaskRunner.Run is called (and not gracefully shutting down). +// Therefore it may be called even when prestart and the other hooks +// have not. +// +// Stop hooks must be idempotent. The context is cancelled prematurely if the +// task is killed. +func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskStopRequest, _ *interfaces.TaskStopResponse) error { + h.shutdownCancelFn() + return nil +} + +func ensureMountpointInserted(mounts []*drivers.MountConfig, mount *drivers.MountConfig) []*drivers.MountConfig { + for _, mnt := range mounts { + if mnt.IsEqual(mount) { + return mounts + } + } + + mounts = append(mounts, mount) + return mounts +} diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index a24b634e5..9982db96b 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/client/dynamicplugins" cinterfaces "github.com/hashicorp/nomad/client/interfaces" "github.com/hashicorp/nomad/client/pluginmanager/drivermanager" cstate "github.com/hashicorp/nomad/client/state" @@ -194,6 +195,9 @@ type TaskRunner struct { // handlers driverManager drivermanager.Manager + // dynamicRegistry is where dynamic plugins should be registered. + dynamicRegistry dynamicplugins.Registry + // maxEvents is the capacity of the TaskEvents on the TaskState. // Defaults to defaultMaxEvents but overrideable for testing. maxEvents int @@ -227,6 +231,9 @@ type Config struct { // ConsulSI is the client to use for managing Consul SI tokens ConsulSI consul.ServiceIdentityAPI + // DynamicRegistry is where dynamic plugins should be registered. + DynamicRegistry dynamicplugins.Registry + // Vault is the client to use to derive and renew Vault tokens Vault vaultclient.VaultClient @@ -285,6 +292,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) { taskName: config.Task.Name, taskLeader: config.Task.Leader, envBuilder: envBuilder, + dynamicRegistry: config.DynamicRegistry, consulClient: config.Consul, siClient: config.ConsulSI, vaultClient: config.Vault, diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 549b8316e..470ecd2db 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -3,6 +3,7 @@ package taskrunner import ( "context" "fmt" + "path/filepath" "sync" "time" @@ -69,6 +70,11 @@ func (tr *TaskRunner) initHooks() { newDeviceHook(tr.devicemanager, hookLogger), } + // If the task has a CSI stanza, add the hook. + if task.CSIPluginConfig != nil { + tr.runnerHooks = append(tr.runnerHooks, newCSIPluginSupervisorHook(filepath.Join(tr.clientConfig.StateDir, "csi"), tr, tr, hookLogger)) + } + // If Vault is enabled, add the hook if task.Vault != nil { tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ diff --git a/client/client.go b/client/client.go index 6996875f4..aa9ecbf97 100644 --- a/client/client.go +++ b/client/client.go @@ -26,8 +26,10 @@ import ( "github.com/hashicorp/nomad/client/config" consulApi "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/client/dynamicplugins" "github.com/hashicorp/nomad/client/fingerprint" "github.com/hashicorp/nomad/client/pluginmanager" + "github.com/hashicorp/nomad/client/pluginmanager/csimanager" "github.com/hashicorp/nomad/client/pluginmanager/drivermanager" "github.com/hashicorp/nomad/client/servers" "github.com/hashicorp/nomad/client/state" @@ -42,6 +44,7 @@ import ( "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/structs" nconfig "github.com/hashicorp/nomad/nomad/structs/config" + "github.com/hashicorp/nomad/plugins/csi" "github.com/hashicorp/nomad/plugins/device" "github.com/hashicorp/nomad/plugins/drivers" vaultapi "github.com/hashicorp/vault/api" @@ -258,6 +261,9 @@ type Client struct { // pluginManagers is the set of PluginManagers registered by the client pluginManagers *pluginmanager.PluginGroup + // csimanager is responsible for managing csi plugins. + csimanager pluginmanager.PluginManager + // devicemanger is responsible for managing device plugins. devicemanager devicemanager.Manager @@ -279,6 +285,10 @@ type Client struct { // successfully run once. serversContactedCh chan struct{} serversContactedOnce sync.Once + + // dynamicRegistry provides access to plugins that are dynamically registered + // with a nomad client. Currently only used for CSI. + dynamicRegistry dynamicplugins.Registry } var ( @@ -331,11 +341,20 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic invalidAllocs: make(map[string]struct{}), serversContactedCh: make(chan struct{}), serversContactedOnce: sync.Once{}, + dynamicRegistry: dynamicplugins.NewRegistry(map[string]dynamicplugins.PluginDispenser{ + dynamicplugins.PluginTypeCSIController: func(info *dynamicplugins.PluginInfo) (interface{}, error) { + return csi.NewClient(info.ConnectionInfo.SocketPath) + }, + dynamicplugins.PluginTypeCSINode: func(info *dynamicplugins.PluginInfo) (interface{}, error) { + return csi.NewClient(info.ConnectionInfo.SocketPath) + }, + }), } c.batchNodeUpdates = newBatchNodeUpdates( c.updateNodeFromDriver, c.updateNodeFromDevices, + c.updateNodeFromCSI, ) // Initialize the server manager @@ -383,6 +402,16 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic allowlistDrivers := cfg.ReadStringListToMap("driver.whitelist") blocklistDrivers := cfg.ReadStringListToMap("driver.blacklist") + // Setup the csi manager + csiConfig := &csimanager.Config{ + Logger: c.logger, + DynamicRegistry: c.dynamicRegistry, + UpdateNodeCSIInfoFunc: c.batchNodeUpdates.updateNodeFromCSI, + } + csiManager := csimanager.New(csiConfig) + c.csimanager = csiManager + c.pluginManagers.RegisterAndRun(csiManager) + // Setup the driver manager driverConfig := &drivermanager.Config{ Logger: c.logger, @@ -1054,6 +1083,7 @@ func (c *Client) restoreState() error { Vault: c.vaultClient, PrevAllocWatcher: prevAllocWatcher, PrevAllocMigrator: prevAllocMigrator, + DynamicRegistry: c.dynamicRegistry, DeviceManager: c.devicemanager, DriverManager: c.drivermanager, ServersContactedCh: c.serversContactedCh, @@ -1279,6 +1309,12 @@ func (c *Client) setupNode() error { if node.Drivers == nil { node.Drivers = make(map[string]*structs.DriverInfo) } + if node.CSIControllerPlugins == nil { + node.CSIControllerPlugins = make(map[string]*structs.CSIInfo) + } + if node.CSINodePlugins == nil { + node.CSINodePlugins = make(map[string]*structs.CSIInfo) + } if node.Meta == nil { node.Meta = make(map[string]string) } @@ -2310,6 +2346,7 @@ func (c *Client) addAlloc(alloc *structs.Allocation, migrateToken string) error DeviceStatsReporter: c, PrevAllocWatcher: prevAllocWatcher, PrevAllocMigrator: prevAllocMigrator, + DynamicRegistry: c.dynamicRegistry, DeviceManager: c.devicemanager, DriverManager: c.drivermanager, } diff --git a/client/dynamicplugins/registry.go b/client/dynamicplugins/registry.go new file mode 100644 index 000000000..b1aa06130 --- /dev/null +++ b/client/dynamicplugins/registry.go @@ -0,0 +1,338 @@ +// dynamicplugins is a package that manages dynamic plugins in Nomad. +// It exposes a registry that allows for plugins to be registered/deregistered +// and also allows subscribers to receive real time updates of these events. +package dynamicplugins + +import ( + "context" + "errors" + "fmt" + "sync" +) + +const ( + PluginTypeCSIController = "csi-controller" + PluginTypeCSINode = "csi-node" +) + +// Registry is an interface that allows for the dynamic registration of plugins +// that are running as Nomad Tasks. +type Registry interface { + RegisterPlugin(info *PluginInfo) error + DeregisterPlugin(ptype, name string) error + + ListPlugins(ptype string) []*PluginInfo + DispensePlugin(ptype, name string) (interface{}, error) + + PluginsUpdatedCh(ctx context.Context, ptype string) <-chan *PluginUpdateEvent + + Shutdown() +} + +type PluginDispenser func(info *PluginInfo) (interface{}, error) + +// NewRegistry takes a map of `plugintype` to PluginDispenser functions +// that should be used to vend clients for plugins to be used. +func NewRegistry(dispensers map[string]PluginDispenser) Registry { + return &dynamicRegistry{ + plugins: make(map[string]map[string]*PluginInfo), + broadcasters: make(map[string]*pluginEventBroadcaster), + dispensers: dispensers, + } +} + +// PluginInfo is the metadata that is stored by the registry for a given plugin. +type PluginInfo struct { + Name string + Type string + Version string + + // ConnectionInfo should only be used externally during `RegisterPlugin` and + // may not be exposed in the future. + ConnectionInfo *PluginConnectionInfo +} + +// PluginConnectionInfo is the data required to connect to the plugin. +// note: We currently only support Unix Domain Sockets, but this may be expanded +// to support other connection modes in the future. +type PluginConnectionInfo struct { + // SocketPath is the path to the plugins api socket. + SocketPath string +} + +// EventType is the enum of events that will be emitted by a Registry's +// PluginsUpdatedCh. +type EventType string + +const ( + // EventTypeRegistered is emitted by the Registry when a new plugin has been + // registered. + EventTypeRegistered EventType = "registered" + // EventTypeDeregistered is emitted by the Registry when a plugin has been + // removed. + EventTypeDeregistered EventType = "deregistered" +) + +// PluginUpdateEvent is a struct that is sent over a PluginsUpdatedCh when +// plugins are added or removed from the registry. +type PluginUpdateEvent struct { + EventType EventType + Info *PluginInfo +} + +type dynamicRegistry struct { + plugins map[string]map[string]*PluginInfo + pluginsLock sync.RWMutex + + broadcasters map[string]*pluginEventBroadcaster + broadcastersLock sync.Mutex + + dispensers map[string]PluginDispenser +} + +func (d *dynamicRegistry) RegisterPlugin(info *PluginInfo) error { + if info.Type == "" { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return errors.New("Plugin.Type must not be empty") + } + + if info.ConnectionInfo == nil { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return errors.New("Plugin.ConnectionInfo must not be nil") + } + + if info.Name == "" { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return errors.New("Plugin.Name must not be empty") + } + + d.pluginsLock.Lock() + defer d.pluginsLock.Unlock() + + pmap, ok := d.plugins[info.Type] + if !ok { + pmap = make(map[string]*PluginInfo, 1) + d.plugins[info.Type] = pmap + } + + pmap[info.Name] = info + + broadcaster := d.broadcasterForPluginType(info.Type) + event := &PluginUpdateEvent{ + EventType: EventTypeRegistered, + Info: info, + } + broadcaster.broadcast(event) + + return nil +} + +func (d *dynamicRegistry) broadcasterForPluginType(ptype string) *pluginEventBroadcaster { + d.broadcastersLock.Lock() + defer d.broadcastersLock.Unlock() + + broadcaster, ok := d.broadcasters[ptype] + if !ok { + broadcaster = newPluginEventBroadcaster() + d.broadcasters[ptype] = broadcaster + } + + return broadcaster +} + +func (d *dynamicRegistry) DeregisterPlugin(ptype, name string) error { + d.pluginsLock.Lock() + defer d.pluginsLock.Unlock() + + if ptype == "" { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return errors.New("must specify plugin type to deregister") + } + if name == "" { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return errors.New("must specify plugin name to deregister") + } + + pmap, ok := d.plugins[ptype] + if !ok { + // If this occurs there's a bug in the registration handler. + return fmt.Errorf("no plugins registered for type: %s", ptype) + } + + info, ok := pmap[name] + if !ok { + // plugin already deregistered, don't send events or try re-deleting. + return nil + } + delete(pmap, name) + + broadcaster := d.broadcasterForPluginType(ptype) + event := &PluginUpdateEvent{ + EventType: EventTypeDeregistered, + Info: info, + } + broadcaster.broadcast(event) + + return nil +} + +func (d *dynamicRegistry) ListPlugins(ptype string) []*PluginInfo { + d.pluginsLock.RLock() + defer d.pluginsLock.RUnlock() + + pmap, ok := d.plugins[ptype] + if !ok { + return nil + } + + plugins := make([]*PluginInfo, 0, len(pmap)) + + for _, info := range pmap { + plugins = append(plugins, info) + } + + return plugins +} + +func (d *dynamicRegistry) DispensePlugin(ptype string, name string) (interface{}, error) { + d.pluginsLock.Lock() + defer d.pluginsLock.Unlock() + + if ptype == "" { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return nil, errors.New("must specify plugin type to deregister") + } + if name == "" { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return nil, errors.New("must specify plugin name to deregister") + } + + dispenseFunc, ok := d.dispensers[ptype] + if !ok { + // This error shouldn't make it to a production cluster and is to aid + // developers during the development of new plugin types. + return nil, fmt.Errorf("no plugin dispenser found for type: %s", ptype) + } + + pmap, ok := d.plugins[ptype] + if !ok { + return nil, fmt.Errorf("no plugins registered for type: %s", ptype) + } + + info, ok := pmap[name] + if !ok { + return nil, fmt.Errorf("plugin %s for type %s not found", name, ptype) + } + + return dispenseFunc(info) +} + +// PluginsUpdatedCh returns a channel over which plugin events for the requested +// plugin type will be emitted. These events are strongly ordered and will never +// be dropped. +// +// The receiving channel _must not_ be closed before the provided context is +// cancelled. +func (d *dynamicRegistry) PluginsUpdatedCh(ctx context.Context, ptype string) <-chan *PluginUpdateEvent { + b := d.broadcasterForPluginType(ptype) + ch := b.subscribe() + go func() { + select { + case <-b.shutdownCh: + return + case <-ctx.Done(): + b.unsubscribe(ch) + } + }() + + return ch +} + +func (d *dynamicRegistry) Shutdown() { + for _, b := range d.broadcasters { + b.shutdown() + } +} + +type pluginEventBroadcaster struct { + stopCh chan struct{} + shutdownCh chan struct{} + publishCh chan *PluginUpdateEvent + + subscriptions map[chan *PluginUpdateEvent]struct{} + subscriptionsLock sync.RWMutex +} + +func newPluginEventBroadcaster() *pluginEventBroadcaster { + b := &pluginEventBroadcaster{ + stopCh: make(chan struct{}), + shutdownCh: make(chan struct{}), + publishCh: make(chan *PluginUpdateEvent, 1), + subscriptions: make(map[chan *PluginUpdateEvent]struct{}), + } + go b.run() + return b +} + +func (p *pluginEventBroadcaster) run() { + for { + select { + case <-p.stopCh: + close(p.shutdownCh) + return + case msg := <-p.publishCh: + p.subscriptionsLock.RLock() + for msgCh := range p.subscriptions { + select { + case msgCh <- msg: + } + } + p.subscriptionsLock.RUnlock() + } + } +} + +func (p *pluginEventBroadcaster) shutdown() { + close(p.stopCh) + + // Wait for loop to exit before closing subscriptions + <-p.shutdownCh + + p.subscriptionsLock.Lock() + for sub := range p.subscriptions { + delete(p.subscriptions, sub) + close(sub) + } + p.subscriptionsLock.Unlock() +} + +func (p *pluginEventBroadcaster) broadcast(e *PluginUpdateEvent) { + p.publishCh <- e +} + +func (p *pluginEventBroadcaster) subscribe() chan *PluginUpdateEvent { + p.subscriptionsLock.Lock() + defer p.subscriptionsLock.Unlock() + + ch := make(chan *PluginUpdateEvent, 1) + p.subscriptions[ch] = struct{}{} + return ch +} + +func (p *pluginEventBroadcaster) unsubscribe(ch chan *PluginUpdateEvent) { + p.subscriptionsLock.Lock() + defer p.subscriptionsLock.Unlock() + + _, ok := p.subscriptions[ch] + if ok { + delete(p.subscriptions, ch) + close(ch) + } +} diff --git a/client/dynamicplugins/registry_test.go b/client/dynamicplugins/registry_test.go new file mode 100644 index 000000000..a3feaaac5 --- /dev/null +++ b/client/dynamicplugins/registry_test.go @@ -0,0 +1,171 @@ +package dynamicplugins + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPluginEventBroadcaster_SendsMessagesToAllClients(t *testing.T) { + t.Parallel() + b := newPluginEventBroadcaster() + defer close(b.stopCh) + var rcv1, rcv2 bool + + ch1 := b.subscribe() + ch2 := b.subscribe() + + listenFunc := func(ch chan *PluginUpdateEvent, updateBool *bool) { + select { + case <-ch: + *updateBool = true + } + } + + go listenFunc(ch1, &rcv1) + go listenFunc(ch2, &rcv2) + + b.broadcast(&PluginUpdateEvent{}) + + require.Eventually(t, func() bool { + return rcv1 == true && rcv2 == true + }, 1*time.Second, 200*time.Millisecond) +} + +func TestPluginEventBroadcaster_UnsubscribeWorks(t *testing.T) { + t.Parallel() + + b := newPluginEventBroadcaster() + defer close(b.stopCh) + var rcv1 bool + + ch1 := b.subscribe() + + listenFunc := func(ch chan *PluginUpdateEvent, updateBool *bool) { + select { + case e := <-ch: + if e == nil { + *updateBool = true + } + } + } + + go listenFunc(ch1, &rcv1) + + b.unsubscribe(ch1) + + b.broadcast(&PluginUpdateEvent{}) + + require.Eventually(t, func() bool { + return rcv1 == true + }, 1*time.Second, 200*time.Millisecond) +} + +func TestDynamicRegistry_RegisterPlugin_SendsUpdateEvents(t *testing.T) { + t.Parallel() + r := NewRegistry(nil) + + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + + ch := r.PluginsUpdatedCh(ctx, "csi") + receivedRegistrationEvent := false + + listenFunc := func(ch <-chan *PluginUpdateEvent, updateBool *bool) { + select { + case e := <-ch: + if e == nil { + return + } + + if e.EventType == EventTypeRegistered { + *updateBool = true + } + } + } + + go listenFunc(ch, &receivedRegistrationEvent) + + err := r.RegisterPlugin(&PluginInfo{ + Type: "csi", + Name: "my-plugin", + ConnectionInfo: &PluginConnectionInfo{}, + }) + + require.NoError(t, err) + + require.Eventually(t, func() bool { + return receivedRegistrationEvent == true + }, 1*time.Second, 200*time.Millisecond) +} + +func TestDynamicRegistry_DeregisterPlugin_SendsUpdateEvents(t *testing.T) { + t.Parallel() + r := NewRegistry(nil) + + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + + ch := r.PluginsUpdatedCh(ctx, "csi") + receivedDeregistrationEvent := false + + listenFunc := func(ch <-chan *PluginUpdateEvent, updateBool *bool) { + for { + select { + case e := <-ch: + if e == nil { + return + } + + if e.EventType == EventTypeDeregistered { + *updateBool = true + } + } + } + } + + go listenFunc(ch, &receivedDeregistrationEvent) + + err := r.RegisterPlugin(&PluginInfo{ + Type: "csi", + Name: "my-plugin", + ConnectionInfo: &PluginConnectionInfo{}, + }) + require.NoError(t, err) + + err = r.DeregisterPlugin("csi", "my-plugin") + require.NoError(t, err) + + require.Eventually(t, func() bool { + return receivedDeregistrationEvent == true + }, 1*time.Second, 200*time.Millisecond) +} + +func TestDynamicRegistry_DispensePlugin_Works(t *testing.T) { + dispenseFn := func(i *PluginInfo) (interface{}, error) { + return struct{}{}, nil + } + + registry := NewRegistry(map[string]PluginDispenser{"csi": dispenseFn}) + + err := registry.RegisterPlugin(&PluginInfo{ + Type: "csi", + Name: "my-plugin", + ConnectionInfo: &PluginConnectionInfo{}, + }) + require.NoError(t, err) + + result, err := registry.DispensePlugin("unknown-type", "unknown-name") + require.Nil(t, result) + require.EqualError(t, err, "no plugin dispenser found for type: unknown-type") + + result, err = registry.DispensePlugin("csi", "unknown-name") + require.Nil(t, result) + require.EqualError(t, err, "plugin unknown-name for type csi not found") + + result, err = registry.DispensePlugin("csi", "my-plugin") + require.NotNil(t, result) + require.NoError(t, err) +} diff --git a/client/node_updater.go b/client/node_updater.go index 702cfe8c2..115150da5 100644 --- a/client/node_updater.go +++ b/client/node_updater.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/nomad/client/devicemanager" + "github.com/hashicorp/nomad/client/pluginmanager/csimanager" "github.com/hashicorp/nomad/client/pluginmanager/drivermanager" "github.com/hashicorp/nomad/nomad/structs" ) @@ -40,6 +41,23 @@ SEND_BATCH: c.configLock.Lock() defer c.configLock.Unlock() + // csi updates + var csiChanged bool + c.batchNodeUpdates.batchCSIUpdates(func(name string, info *structs.CSIInfo) { + if c.updateNodeFromCSIControllerLocked(name, info) { + if c.config.Node.CSIControllerPlugins[name].UpdateTime.IsZero() { + c.config.Node.CSIControllerPlugins[name].UpdateTime = time.Now() + } + csiChanged = true + } + if c.updateNodeFromCSINodeLocked(name, info) { + if c.config.Node.CSINodePlugins[name].UpdateTime.IsZero() { + c.config.Node.CSINodePlugins[name].UpdateTime = time.Now() + } + csiChanged = true + } + }) + // driver node updates var driverChanged bool c.batchNodeUpdates.batchDriverUpdates(func(driver string, info *structs.DriverInfo) { @@ -61,13 +79,128 @@ SEND_BATCH: }) // only update the node if changes occurred - if driverChanged || devicesChanged { + if driverChanged || devicesChanged || csiChanged { c.updateNodeLocked() } close(c.fpInitialized) } +// updateNodeFromCSI receives a CSIInfo struct for the plugin and updates the +// node accordingly +func (c *Client) updateNodeFromCSI(name string, info *structs.CSIInfo) { + c.configLock.Lock() + defer c.configLock.Unlock() + + changed := false + + if c.updateNodeFromCSIControllerLocked(name, info) { + if c.config.Node.CSIControllerPlugins[name].UpdateTime.IsZero() { + c.config.Node.CSIControllerPlugins[name].UpdateTime = time.Now() + } + changed = true + } + + if c.updateNodeFromCSINodeLocked(name, info) { + if c.config.Node.CSINodePlugins[name].UpdateTime.IsZero() { + c.config.Node.CSINodePlugins[name].UpdateTime = time.Now() + } + changed = true + } + + if changed { + c.updateNodeLocked() + } +} + +// updateNodeFromCSIControllerLocked makes the changes to the node from a csi +// update but does not send the update to the server. c.configLock must be held +// before calling this func. +// +// It is safe to call for all CSI Updates, but will only perform changes when +// a ControllerInfo field is present. +func (c *Client) updateNodeFromCSIControllerLocked(name string, info *structs.CSIInfo) bool { + var changed bool + if info.ControllerInfo == nil { + return false + } + i := info.Copy() + i.NodeInfo = nil + + oldController, hadController := c.config.Node.CSIControllerPlugins[name] + if !hadController { + // If the controller info has not yet been set, do that here + changed = true + c.config.Node.CSIControllerPlugins[name] = i + } else { + // The controller info has already been set, fix it up + if !oldController.Equal(i) { + c.config.Node.CSIControllerPlugins[name] = i + changed = true + } + + // If health state has changed, trigger node event + if oldController.Healthy != i.Healthy || oldController.HealthDescription != i.HealthDescription { + changed = true + if i.HealthDescription != "" { + event := &structs.NodeEvent{ + Subsystem: "CSI", + Message: i.HealthDescription, + Timestamp: time.Now(), + Details: map[string]string{"plugin": name, "type": "controller"}, + } + c.triggerNodeEvent(event) + } + } + } + + return changed +} + +// updateNodeFromCSINodeLocked makes the changes to the node from a csi +// update but does not send the update to the server. c.configLock must be hel +// before calling this func. +// +// It is safe to call for all CSI Updates, but will only perform changes when +// a NodeInfo field is present. +func (c *Client) updateNodeFromCSINodeLocked(name string, info *structs.CSIInfo) bool { + var changed bool + if info.NodeInfo == nil { + return false + } + i := info.Copy() + i.ControllerInfo = nil + + oldNode, hadNode := c.config.Node.CSINodePlugins[name] + if !hadNode { + // If the Node info has not yet been set, do that here + changed = true + c.config.Node.CSINodePlugins[name] = i + } else { + // The node info has already been set, fix it up + if !oldNode.Equal(info) { + c.config.Node.CSINodePlugins[name] = i + changed = true + } + + // If health state has changed, trigger node event + if oldNode.Healthy != i.Healthy || oldNode.HealthDescription != i.HealthDescription { + changed = true + if i.HealthDescription != "" { + event := &structs.NodeEvent{ + Subsystem: "CSI", + Message: i.HealthDescription, + Timestamp: time.Now(), + Details: map[string]string{"plugin": name, "type": "node"}, + } + c.triggerNodeEvent(event) + } + } + } + + return changed +} + // updateNodeFromDriver receives a DriverInfo struct for the driver and updates // the node accordingly func (c *Client) updateNodeFromDriver(name string, info *structs.DriverInfo) { @@ -187,20 +320,66 @@ type batchNodeUpdates struct { devicesBatched bool devicesCB devicemanager.UpdateNodeDevicesFn devicesMu sync.Mutex + + // access to csi fields must hold csiMu lock + csiNodePlugins map[string]*structs.CSIInfo + csiControllerPlugins map[string]*structs.CSIInfo + csiBatched bool + csiCB csimanager.UpdateNodeCSIInfoFunc + csiMu sync.Mutex } func newBatchNodeUpdates( driverCB drivermanager.UpdateNodeDriverInfoFn, - devicesCB devicemanager.UpdateNodeDevicesFn) *batchNodeUpdates { + devicesCB devicemanager.UpdateNodeDevicesFn, + csiCB csimanager.UpdateNodeCSIInfoFunc) *batchNodeUpdates { return &batchNodeUpdates{ - drivers: make(map[string]*structs.DriverInfo), - driverCB: driverCB, - devices: []*structs.NodeDeviceResource{}, - devicesCB: devicesCB, + drivers: make(map[string]*structs.DriverInfo), + driverCB: driverCB, + devices: []*structs.NodeDeviceResource{}, + devicesCB: devicesCB, + csiNodePlugins: make(map[string]*structs.CSIInfo), + csiControllerPlugins: make(map[string]*structs.CSIInfo), + csiCB: csiCB, } } +// updateNodeFromCSI implements csimanager.UpdateNodeCSIInfoFunc and is used in +// the csi manager to send csi fingerprints to the server. Currently it registers +// all plugins as both controller and node plugins. +// TODO: separate node and controller plugin handling. +func (b *batchNodeUpdates) updateNodeFromCSI(plugin string, info *structs.CSIInfo) { + b.csiMu.Lock() + defer b.csiMu.Unlock() + if b.csiBatched { + b.csiCB(plugin, info) + return + } + + b.csiNodePlugins[plugin] = info + b.csiControllerPlugins[plugin] = info +} + +// batchCSIUpdates sends all of the batched CSI updates by calling f for each +// plugin batched +func (b *batchNodeUpdates) batchCSIUpdates(f csimanager.UpdateNodeCSIInfoFunc) error { + b.csiMu.Lock() + defer b.csiMu.Unlock() + if b.csiBatched { + return fmt.Errorf("csi updates already batched") + } + + b.csiBatched = true + for plugin, info := range b.csiNodePlugins { + f(plugin, info) + } + for plugin, info := range b.csiControllerPlugins { + f(plugin, info) + } + return nil +} + // updateNodeFromDriver implements drivermanager.UpdateNodeDriverInfoFn and is // used in the driver manager to send driver fingerprints to func (b *batchNodeUpdates) updateNodeFromDriver(driver string, info *structs.DriverInfo) { diff --git a/client/pluginmanager/csimanager/instance.go b/client/pluginmanager/csimanager/instance.go new file mode 100644 index 000000000..9de20ce8b --- /dev/null +++ b/client/pluginmanager/csimanager/instance.go @@ -0,0 +1,203 @@ +package csimanager + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" +) + +const managerFingerprintInterval = 30 * time.Second + +// instanceManager is used to manage the fingerprinting and supervision of a +// single CSI Plugin. +type instanceManager struct { + info *dynamicplugins.PluginInfo + logger hclog.Logger + + updater UpdateNodeCSIInfoFunc + + shutdownCtx context.Context + shutdownCtxCancelFn context.CancelFunc + shutdownCh chan struct{} + + fingerprintNode bool + fingerprintController bool + + client csi.CSIPlugin +} + +func newInstanceManager(logger hclog.Logger, updater UpdateNodeCSIInfoFunc, p *dynamicplugins.PluginInfo) *instanceManager { + ctx, cancelFn := context.WithCancel(context.Background()) + + return &instanceManager{ + logger: logger.Named(p.Name), + info: p, + updater: updater, + + fingerprintNode: p.Type == dynamicplugins.PluginTypeCSINode, + fingerprintController: p.Type == dynamicplugins.PluginTypeCSIController, + + shutdownCtx: ctx, + shutdownCtxCancelFn: cancelFn, + shutdownCh: make(chan struct{}), + } +} + +func (i *instanceManager) run() { + c, err := csi.NewClient(i.info.ConnectionInfo.SocketPath) + if err != nil { + i.logger.Error("failed to setup instance manager client", "error", err) + close(i.shutdownCh) + return + } + i.client = c + + go i.runLoop() +} + +func (i *instanceManager) requestCtxWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(i.shutdownCtx, timeout) +} + +func (i *instanceManager) runLoop() { + // basicInfo holds a cache of data that should not change within a CSI plugin. + // This allows us to minimize the number of requests we make to plugins on each + // run of the fingerprinter, and reduces the chances of performing overly + // expensive actions repeatedly, and improves stability of data through + // transient failures. + var basicInfo *structs.CSIInfo + + timer := time.NewTimer(0) + for { + select { + case <-i.shutdownCtx.Done(): + if i.client != nil { + i.client.Close() + i.client = nil + } + close(i.shutdownCh) + return + case <-timer.C: + ctx, cancelFn := i.requestCtxWithTimeout(managerFingerprintInterval) + + if basicInfo == nil { + info, err := i.buildBasicFingerprint(ctx) + if err != nil { + // If we receive a fingerprinting error, update the stats with as much + // info as possible and wait for the next fingerprint interval. + info.HealthDescription = fmt.Sprintf("failed initial fingerprint with err: %v", err) + cancelFn() + i.updater(i.info.Name, basicInfo) + timer.Reset(managerFingerprintInterval) + continue + } + + // If fingerprinting succeeded, we don't need to repopulate the basic + // info and we can stop here. + basicInfo = info + } + + info := basicInfo.Copy() + var fp *structs.CSIInfo + var err error + + if i.fingerprintNode { + fp, err = i.buildNodeFingerprint(ctx, info) + } else if i.fingerprintController { + fp, err = i.buildControllerFingerprint(ctx, info) + } + + if err != nil { + info.Healthy = false + info.HealthDescription = fmt.Sprintf("failed fingerprinting with error: %v", err) + } else { + info = fp + } + + cancelFn() + i.updater(i.info.Name, info) + timer.Reset(managerFingerprintInterval) + } + } +} + +func (i *instanceManager) buildControllerFingerprint(ctx context.Context, base *structs.CSIInfo) (*structs.CSIInfo, error) { + fp := base.Copy() + + healthy, err := i.client.PluginProbe(ctx) + if err != nil { + return nil, err + } + fp.SetHealthy(healthy) + + return fp, nil +} + +func (i *instanceManager) buildNodeFingerprint(ctx context.Context, base *structs.CSIInfo) (*structs.CSIInfo, error) { + fp := base.Copy() + + healthy, err := i.client.PluginProbe(ctx) + if err != nil { + return nil, err + } + fp.SetHealthy(healthy) + + return fp, nil +} + +func structCSITopologyFromCSITopology(a *csi.Topology) *structs.CSITopology { + if a == nil { + return nil + } + + return &structs.CSITopology{ + Segments: helper.CopyMapStringString(a.Segments), + } +} + +func (i *instanceManager) buildBasicFingerprint(ctx context.Context) (*structs.CSIInfo, error) { + info := &structs.CSIInfo{ + PluginID: i.info.Name, + Healthy: false, + HealthDescription: "initial fingerprint not completed", + } + + if i.fingerprintNode { + info.NodeInfo = &structs.CSINodeInfo{} + } + if i.fingerprintController { + info.ControllerInfo = &structs.CSIControllerInfo{} + } + + capabilities, err := i.client.PluginGetCapabilities(ctx) + if err != nil { + return info, err + } + + info.RequiresControllerPlugin = capabilities.HasControllerService() + info.RequiresTopologies = capabilities.HasToplogies() + + if i.fingerprintNode { + nodeInfo, err := i.client.NodeGetInfo(ctx) + if err != nil { + return info, err + } + + info.NodeInfo.ID = nodeInfo.NodeID + info.NodeInfo.MaxVolumes = nodeInfo.MaxVolumes + info.NodeInfo.AccessibleTopology = structCSITopologyFromCSITopology(nodeInfo.AccessibleTopology) + } + + return info, nil +} + +func (i *instanceManager) shutdown() { + i.shutdownCtxCancelFn() + <-i.shutdownCh +} diff --git a/client/pluginmanager/csimanager/instance_test.go b/client/pluginmanager/csimanager/instance_test.go new file mode 100644 index 000000000..ca30a321e --- /dev/null +++ b/client/pluginmanager/csimanager/instance_test.go @@ -0,0 +1,159 @@ +package csimanager + +import ( + "context" + "errors" + "testing" + + "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" + "github.com/hashicorp/nomad/plugins/csi/fake" + "github.com/stretchr/testify/require" +) + +func setupTestNodeInstanceManager(t *testing.T) (*fake.Client, *instanceManager) { + tp := &fake.Client{} + + logger := testlog.HCLogger(t) + pinfo := &dynamicplugins.PluginInfo{ + Name: "test-plugin", + } + + return tp, &instanceManager{ + logger: logger, + info: pinfo, + client: tp, + fingerprintNode: true, + } +} + +func TestBuildBasicFingerprint_Node(t *testing.T) { + tt := []struct { + Name string + + Capabilities *csi.PluginCapabilitySet + CapabilitiesErr error + CapabilitiesCallCount int64 + + NodeInfo *csi.NodeGetInfoResponse + NodeInfoErr error + NodeInfoCallCount int64 + + ExpectedCSIInfo *structs.CSIInfo + ExpectedErr error + }{ + { + Name: "Minimal successful response", + + Capabilities: &csi.PluginCapabilitySet{}, + CapabilitiesCallCount: 1, + + NodeInfo: &csi.NodeGetInfoResponse{ + NodeID: "foobar", + MaxVolumes: 5, + AccessibleTopology: nil, + }, + NodeInfoCallCount: 1, + + ExpectedCSIInfo: &structs.CSIInfo{ + PluginID: "test-plugin", + Healthy: false, + HealthDescription: "initial fingerprint not completed", + NodeInfo: &structs.CSINodeInfo{ + ID: "foobar", + MaxVolumes: 5, + }, + }, + }, + { + Name: "Successful response with capabilities and topologies", + + Capabilities: csi.NewTestPluginCapabilitySet(true, false), + CapabilitiesCallCount: 1, + + NodeInfo: &csi.NodeGetInfoResponse{ + NodeID: "foobar", + MaxVolumes: 5, + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + "com.hashicorp.nomad/node-id": "foobar", + }, + }, + }, + NodeInfoCallCount: 1, + + ExpectedCSIInfo: &structs.CSIInfo{ + PluginID: "test-plugin", + Healthy: false, + HealthDescription: "initial fingerprint not completed", + + RequiresTopologies: true, + + NodeInfo: &structs.CSINodeInfo{ + ID: "foobar", + MaxVolumes: 5, + AccessibleTopology: &structs.CSITopology{ + Segments: map[string]string{ + "com.hashicorp.nomad/node-id": "foobar", + }, + }, + }, + }, + }, + { + Name: "PluginGetCapabilities Failed", + + CapabilitiesErr: errors.New("request failed"), + CapabilitiesCallCount: 1, + + NodeInfoCallCount: 0, + + ExpectedCSIInfo: &structs.CSIInfo{ + PluginID: "test-plugin", + Healthy: false, + HealthDescription: "initial fingerprint not completed", + NodeInfo: &structs.CSINodeInfo{}, + }, + ExpectedErr: errors.New("request failed"), + }, + { + Name: "NodeGetInfo Failed", + + Capabilities: &csi.PluginCapabilitySet{}, + CapabilitiesCallCount: 1, + + NodeInfoErr: errors.New("request failed"), + NodeInfoCallCount: 1, + + ExpectedCSIInfo: &structs.CSIInfo{ + PluginID: "test-plugin", + Healthy: false, + HealthDescription: "initial fingerprint not completed", + NodeInfo: &structs.CSINodeInfo{}, + }, + ExpectedErr: errors.New("request failed"), + }, + } + + for _, test := range tt { + t.Run(test.Name, func(t *testing.T) { + client, im := setupTestNodeInstanceManager(t) + + client.NextPluginGetCapabilitiesResponse = test.Capabilities + client.NextPluginGetCapabilitiesErr = test.CapabilitiesErr + + client.NextNodeGetInfoResponse = test.NodeInfo + client.NextNodeGetInfoErr = test.NodeInfoErr + + info, err := im.buildBasicFingerprint(context.TODO()) + + require.Equal(t, test.ExpectedCSIInfo, info) + require.Equal(t, test.ExpectedErr, err) + + require.Equal(t, test.CapabilitiesCallCount, client.PluginGetCapabilitiesCallCount) + require.Equal(t, test.NodeInfoCallCount, client.NodeGetInfoCallCount) + }) + } +} diff --git a/client/pluginmanager/csimanager/manager.go b/client/pluginmanager/csimanager/manager.go new file mode 100644 index 000000000..ebcbcc89f --- /dev/null +++ b/client/pluginmanager/csimanager/manager.go @@ -0,0 +1,153 @@ +package csimanager + +import ( + "context" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/client/pluginmanager" + "github.com/hashicorp/nomad/nomad/structs" +) + +// defaultPluginResyncPeriod is the time interval used to do a full resync +// against the dynamicplugins, to account for missed updates. +const defaultPluginResyncPeriod = 30 * time.Second + +// UpdateNodeCSIInfoFunc is the callback used to update the node from +// fingerprinting +type UpdateNodeCSIInfoFunc func(string, *structs.CSIInfo) + +type Config struct { + Logger hclog.Logger + DynamicRegistry dynamicplugins.Registry + UpdateNodeCSIInfoFunc UpdateNodeCSIInfoFunc + PluginResyncPeriod time.Duration +} + +// New returns a new PluginManager that will handle managing CSI plugins from +// the dynamicRegistry from the provided Config. +func New(config *Config) pluginmanager.PluginManager { + // Use a dedicated internal context for managing plugin shutdown. + ctx, cancelFn := context.WithCancel(context.Background()) + + if config.PluginResyncPeriod == 0 { + config.PluginResyncPeriod = defaultPluginResyncPeriod + } + + return &csiManager{ + logger: config.Logger, + registry: config.DynamicRegistry, + instances: make(map[string]map[string]*instanceManager), + + updateNodeCSIInfoFunc: config.UpdateNodeCSIInfoFunc, + pluginResyncPeriod: config.PluginResyncPeriod, + + shutdownCtx: ctx, + shutdownCtxCancelFn: cancelFn, + shutdownCh: make(chan struct{}), + } +} + +type csiManager struct { + // instances should only be accessed from the run() goroutine and the shutdown + // fn. It is a map of PluginType : [PluginName : instanceManager] + instances map[string]map[string]*instanceManager + + registry dynamicplugins.Registry + logger hclog.Logger + pluginResyncPeriod time.Duration + + updateNodeCSIInfoFunc UpdateNodeCSIInfoFunc + + shutdownCtx context.Context + shutdownCtxCancelFn context.CancelFunc + shutdownCh chan struct{} +} + +// Run starts a plugin manager and should return early +func (c *csiManager) Run() { + go c.runLoop() +} + +func (c *csiManager) runLoop() { + // TODO: Subscribe to the events channel from the registry to receive dynamic + // updates without a full resync + timer := time.NewTimer(0) + for { + select { + case <-c.shutdownCtx.Done(): + close(c.shutdownCh) + return + case <-timer.C: + c.resyncPluginsFromRegistry("csi-controller") + c.resyncPluginsFromRegistry("csi-node") + timer.Reset(c.pluginResyncPeriod) + } + } +} + +// resyncPluginsFromRegistry does a full sync of the running instance managers +// against those in the registry. Eventually we should primarily be using +// update events from the registry, but this is an ok fallback for now. +func (c *csiManager) resyncPluginsFromRegistry(ptype string) { + plugins := c.registry.ListPlugins(ptype) + seen := make(map[string]struct{}, len(plugins)) + + pluginMap, ok := c.instances[ptype] + if !ok { + pluginMap = make(map[string]*instanceManager) + c.instances[ptype] = pluginMap + } + + // For every plugin in the registry, ensure that we have an existing plugin + // running. Also build the map of valid plugin names. + for _, plugin := range plugins { + seen[plugin.Name] = struct{}{} + if _, ok := pluginMap[plugin.Name]; !ok { + c.logger.Debug("detected new CSI plugin", "name", plugin.Name, "type", ptype) + mgr := newInstanceManager(c.logger, c.updateNodeCSIInfoFunc, plugin) + pluginMap[plugin.Name] = mgr + mgr.run() + } + } + + // For every instance manager, if we did not find it during the plugin + // iterator, shut it down and remove it from the table. + for name, mgr := range pluginMap { + if _, ok := seen[name]; !ok { + c.logger.Info("shutting down CSI plugin", "name", name, "type", ptype) + mgr.shutdown() + delete(pluginMap, name) + } + } +} + +// Shutdown should gracefully shutdown all plugins managed by the manager. +// It must block until shutdown is complete +func (c *csiManager) Shutdown() { + // Shut down the run loop + c.shutdownCtxCancelFn() + + // Wait for plugin manager shutdown to complete + <-c.shutdownCh + + // Shutdown all the instance managers in parallel + var wg sync.WaitGroup + for _, pluginMap := range c.instances { + for _, mgr := range pluginMap { + wg.Add(1) + go func(mgr *instanceManager) { + mgr.shutdown() + wg.Done() + }(mgr) + } + } + wg.Wait() +} + +// PluginType is the type of plugin which the manager manages +func (c *csiManager) PluginType() string { + return "csi" +} diff --git a/client/pluginmanager/csimanager/manager_test.go b/client/pluginmanager/csimanager/manager_test.go new file mode 100644 index 000000000..408168ca2 --- /dev/null +++ b/client/pluginmanager/csimanager/manager_test.go @@ -0,0 +1,111 @@ +package csimanager + +import ( + "testing" + "time" + + "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/client/pluginmanager" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +var _ pluginmanager.PluginManager = (*csiManager)(nil) + +var fakePlugin = &dynamicplugins.PluginInfo{ + Name: "my-plugin", + Type: "csi-controller", + ConnectionInfo: &dynamicplugins.PluginConnectionInfo{}, +} + +func setupRegistry() dynamicplugins.Registry { + return dynamicplugins.NewRegistry( + map[string]dynamicplugins.PluginDispenser{ + "csi-controller": func(*dynamicplugins.PluginInfo) (interface{}, error) { + return nil, nil + }, + }) +} + +func TestCSIManager_Setup_Shutdown(t *testing.T) { + r := setupRegistry() + defer r.Shutdown() + + cfg := &Config{ + Logger: testlog.HCLogger(t), + DynamicRegistry: r, + UpdateNodeCSIInfoFunc: func(string, *structs.CSIInfo) {}, + } + pm := New(cfg).(*csiManager) + pm.Run() + pm.Shutdown() +} + +func TestCSIManager_RegisterPlugin(t *testing.T) { + registry := setupRegistry() + defer registry.Shutdown() + + require.NotNil(t, registry) + + cfg := &Config{ + Logger: testlog.HCLogger(t), + DynamicRegistry: registry, + UpdateNodeCSIInfoFunc: func(string, *structs.CSIInfo) {}, + } + pm := New(cfg).(*csiManager) + defer pm.Shutdown() + + require.NotNil(t, pm.registry) + + err := registry.RegisterPlugin(fakePlugin) + require.Nil(t, err) + + pm.Run() + + require.Eventually(t, func() bool { + pmap, ok := pm.instances[fakePlugin.Type] + if !ok { + return false + } + + _, ok = pmap[fakePlugin.Name] + return ok + }, 5*time.Second, 10*time.Millisecond) +} + +func TestCSIManager_DeregisterPlugin(t *testing.T) { + registry := setupRegistry() + defer registry.Shutdown() + + require.NotNil(t, registry) + + cfg := &Config{ + Logger: testlog.HCLogger(t), + DynamicRegistry: registry, + UpdateNodeCSIInfoFunc: func(string, *structs.CSIInfo) {}, + PluginResyncPeriod: 500 * time.Millisecond, + } + pm := New(cfg).(*csiManager) + defer pm.Shutdown() + + require.NotNil(t, pm.registry) + + err := registry.RegisterPlugin(fakePlugin) + require.Nil(t, err) + + pm.Run() + + require.Eventually(t, func() bool { + _, ok := pm.instances[fakePlugin.Type][fakePlugin.Name] + return ok + }, 5*time.Second, 10*time.Millisecond) + + err = registry.DeregisterPlugin(fakePlugin.Type, fakePlugin.Name) + require.Nil(t, err) + + require.Eventually(t, func() bool { + _, ok := pm.instances[fakePlugin.Type][fakePlugin.Name] + return !ok + }, 5*time.Second, 10*time.Millisecond) +} diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index b394ed357..0483c018f 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -812,6 +812,7 @@ func ApiTaskToStructsTask(apiTask *api.Task, structsTask *structs.Task) { structsTask.Kind = structs.TaskKind(apiTask.Kind) structsTask.Constraints = ApiConstraintsToStructs(apiTask.Constraints) structsTask.Affinities = ApiAffinitiesToStructs(apiTask.Affinities) + structsTask.CSIPluginConfig = ApiCSIPluginConfigToStructsCSIPluginConfig(apiTask.CSIPluginConfig) if l := len(apiTask.VolumeMounts); l != 0 { structsTask.VolumeMounts = make([]*structs.VolumeMount, l) @@ -933,6 +934,18 @@ func ApiTaskToStructsTask(apiTask *api.Task, structsTask *structs.Task) { } } +func ApiCSIPluginConfigToStructsCSIPluginConfig(apiConfig *api.TaskCSIPluginConfig) *structs.TaskCSIPluginConfig { + if apiConfig == nil { + return nil + } + + sc := &structs.TaskCSIPluginConfig{} + sc.ID = apiConfig.ID + sc.Type = structs.CSIPluginType(apiConfig.Type) + sc.MountDir = apiConfig.MountDir + return sc +} + func ApiResourcesToStructs(in *api.Resources) *structs.Resources { if in == nil { return nil diff --git a/jobspec/parse_task.go b/jobspec/parse_task.go index dbd20abdd..a59c88331 100644 --- a/jobspec/parse_task.go +++ b/jobspec/parse_task.go @@ -74,6 +74,7 @@ func parseTask(item *ast.ObjectItem) (*api.Task, error) { "kill_signal", "kind", "volume_mount", + "csi_plugin", } if err := helper.CheckHCLKeys(listVal, valid); err != nil { return nil, err @@ -97,6 +98,7 @@ func parseTask(item *ast.ObjectItem) (*api.Task, error) { delete(m, "template") delete(m, "vault") delete(m, "volume_mount") + delete(m, "csi_plugin") // Build the task var t api.Task @@ -135,6 +137,25 @@ func parseTask(item *ast.ObjectItem) (*api.Task, error) { t.Services = services } + if o := listVal.Filter("csi_plugin"); len(o.Items) > 0 { + if len(o.Items) != 1 { + return nil, fmt.Errorf("csi_plugin -> Expected single stanza, got %d", len(o.Items)) + } + i := o.Elem().Items[0] + + var m map[string]interface{} + if err := hcl.DecodeObject(&m, i.Val); err != nil { + return nil, err + } + + var cfg api.TaskCSIPluginConfig + if err := mapstructure.WeakDecode(m, &cfg); err != nil { + return nil, err + } + + t.CSIPluginConfig = &cfg + } + // If we have config, then parse that if o := listVal.Filter("config"); len(o.Items) > 0 { for _, o := range o.Elem().Items { diff --git a/jobspec/parse_test.go b/jobspec/parse_test.go index 13639a1b2..ed66f05ae 100644 --- a/jobspec/parse_test.go +++ b/jobspec/parse_test.go @@ -569,6 +569,30 @@ func TestParse(t *testing.T) { }, false, }, + { + "csi-plugin.hcl", + &api.Job{ + ID: helper.StringToPtr("binstore-storagelocker"), + Name: helper.StringToPtr("binstore-storagelocker"), + TaskGroups: []*api.TaskGroup{ + { + Name: helper.StringToPtr("binsl"), + Tasks: []*api.Task{ + { + Name: "binstore", + Driver: "docker", + CSIPluginConfig: &api.TaskCSIPluginConfig{ + ID: "org.hashicorp.csi", + Type: api.CSIPluginTypeMonolith, + MountDir: "/csi/test", + }, + }, + }, + }, + }, + }, + false, + }, { "service-check-initial-status.hcl", &api.Job{ diff --git a/jobspec/test-fixtures/csi-plugin.hcl b/jobspec/test-fixtures/csi-plugin.hcl new file mode 100644 index 000000000..b879da184 --- /dev/null +++ b/jobspec/test-fixtures/csi-plugin.hcl @@ -0,0 +1,13 @@ +job "binstore-storagelocker" { + group "binsl" { + task "binstore" { + driver = "docker" + + csi_plugin { + id = "org.hashicorp.csi" + type = "monolith" + mount_dir = "/csi/test" + } + } + } +} diff --git a/nomad/structs/csi.go b/nomad/structs/csi.go new file mode 100644 index 000000000..b00b2fa72 --- /dev/null +++ b/nomad/structs/csi.go @@ -0,0 +1,68 @@ +package structs + +// CSISocketName is the filename that Nomad expects plugins to create inside the +// PluginMountDir. +const CSISocketName = "csi.sock" + +// CSIIntermediaryDirname is the name of the directory inside the PluginMountDir +// where Nomad will expect plugins to create intermediary mounts for volumes. +const CSIIntermediaryDirname = "volumes" + +// CSIPluginType is an enum string that encapsulates the valid options for a +// CSIPlugin stanza's Type. These modes will allow the plugin to be used in +// different ways by the client. +type CSIPluginType string + +const ( + // CSIPluginTypeNode indicates that Nomad should only use the plugin for + // performing Node RPCs against the provided plugin. + CSIPluginTypeNode CSIPluginType = "node" + + // CSIPluginTypeController indicates that Nomad should only use the plugin for + // performing Controller RPCs against the provided plugin. + CSIPluginTypeController CSIPluginType = "controller" + + // CSIPluginTypeMonolith indicates that Nomad can use the provided plugin for + // both controller and node rpcs. + CSIPluginTypeMonolith CSIPluginType = "monolith" +) + +// CSIPluginTypeIsValid validates the given CSIPluginType string and returns +// true only when a correct plugin type is specified. +func CSIPluginTypeIsValid(pt CSIPluginType) bool { + switch pt { + case CSIPluginTypeNode, CSIPluginTypeController, CSIPluginTypeMonolith: + return true + default: + return false + } +} + +// TaskCSIPluginConfig contains the data that is required to setup a task as a +// CSI plugin. This will be used by the csi_plugin_supervisor_hook to configure +// mounts for the plugin and initiate the connection to the plugin catalog. +type TaskCSIPluginConfig struct { + // ID is the identifier of the plugin. + // Ideally this should be the FQDN of the plugin. + ID string + + // Type instructs Nomad on how to handle processing a plugin + Type CSIPluginType + + // MountDir is the destination that nomad should mount in its CSI + // directory for the plugin. It will then expect a file called CSISocketName + // to be created by the plugin, and will provide references into + // "MountDir/CSIIntermediaryDirname/{VolumeName}/{AllocID} for mounts. + MountDir string +} + +func (t *TaskCSIPluginConfig) Copy() *TaskCSIPluginConfig { + if t == nil { + return nil + } + + nt := new(TaskCSIPluginConfig) + *nt = *t + + return nt +} diff --git a/nomad/structs/node.go b/nomad/structs/node.go index 76758fb8e..7143c42e2 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -1,11 +1,177 @@ package structs import ( + "reflect" "time" "github.com/hashicorp/nomad/helper" ) +// CSITopology is a map of topological domains to topological segments. +// A topological domain is a sub-division of a cluster, like "region", +// "zone", "rack", etc. +// +// According to CSI, there are a few requirements for the keys within this map: +// - Valid keys have two segments: an OPTIONAL prefix and name, separated +// by a slash (/), for example: "com.company.example/zone". +// - The key name segment is REQUIRED. The prefix is OPTIONAL. +// - The key name MUST be 63 characters or less, begin and end with an +// alphanumeric character ([a-z0-9A-Z]), and contain only dashes (-), +// underscores (_), dots (.), or alphanumerics in between, for example +// "zone". +// - The key prefix MUST be 63 characters or less, begin and end with a +// lower-case alphanumeric character ([a-z0-9]), contain only +// dashes (-), dots (.), or lower-case alphanumerics in between, and +// follow domain name notation format +// (https://tools.ietf.org/html/rfc1035#section-2.3.1). +// - The key prefix SHOULD include the plugin's host company name and/or +// the plugin name, to minimize the possibility of collisions with keys +// from other plugins. +// - If a key prefix is specified, it MUST be identical across all +// topology keys returned by the SP (across all RPCs). +// - Keys MUST be case-insensitive. Meaning the keys "Zone" and "zone" +// MUST not both exist. +// - Each value (topological segment) MUST contain 1 or more strings. +// - Each string MUST be 63 characters or less and begin and end with an +// alphanumeric character with '-', '_', '.', or alphanumerics in +// between. +// +// However, Nomad applies lighter restrictions to these, as they are already +// only referenced by plugin within the scheduler and as such collisions and +// related concerns are less of an issue. We may implement these restrictions +// in the future. +type CSITopology struct { + Segments map[string]string +} + +func (t *CSITopology) Copy() *CSITopology { + if t == nil { + return nil + } + + return &CSITopology{ + Segments: helper.CopyMapStringString(t.Segments), + } +} + +// CSINodeInfo is the fingerprinted data from a CSI Plugin that is specific to +// the Node API. +type CSINodeInfo struct { + // ID is the identity of a given nomad client as observed by the storage + // provider. + ID string + + // MaxVolumes is the maximum number of volumes that can be attached to the + // current host via this provider. + // If 0 then unlimited volumes may be attached. + MaxVolumes int64 + + // AccessibleTopology specifies where (regions, zones, racks, etc.) the node is + // accessible from within the storage provider. + // + // A plugin that returns this field MUST also set the `RequiresTopologies` + // property. + // + // This field is OPTIONAL. If it is not specified, then we assume that the + // the node is not subject to any topological constraint, and MAY + // schedule workloads that reference any volume V, such that there are + // no topological constraints declared for V. + // + // Example 1: + // accessible_topology = + // {"region": "R1", "zone": "Z2"} + // Indicates the node exists within the "region" "R1" and the "zone" + // "Z2" within the storage provider. + AccessibleTopology *CSITopology +} + +func (n *CSINodeInfo) Copy() *CSINodeInfo { + if n == nil { + return nil + } + + nc := new(CSINodeInfo) + *nc = *n + nc.AccessibleTopology = n.AccessibleTopology.Copy() + + return nc +} + +// CSIControllerInfo is the fingerprinted data from a CSI Plugin that is specific to +// the Controller API. +type CSIControllerInfo struct { + // Currently empty +} + +func (c *CSIControllerInfo) Copy() *CSIControllerInfo { + if c == nil { + return nil + } + + nc := new(CSIControllerInfo) + *nc = *c + + return nc +} + +// CSIInfo is the current state of a single CSI Plugin. This is updated regularly +// as plugin health changes on the node. +type CSIInfo struct { + PluginID string + Healthy bool + HealthDescription string + UpdateTime time.Time + + // RequiresControllerPlugin is set when the CSI Plugin returns the + // CONTROLLER_SERVICE capability. When this is true, the volumes should not be + // scheduled on this client until a matching controller plugin is available. + RequiresControllerPlugin bool + + // RequiresTopologies is set when the CSI Plugin returns the + // VOLUME_ACCESSIBLE_CONSTRAINTS capability. When this is true, we must + // respect the Volume and Node Topology information. + RequiresTopologies bool + + // CSI Specific metadata + ControllerInfo *CSIControllerInfo `json:",omitempty"` + NodeInfo *CSINodeInfo `json:",omitempty"` +} + +func (c *CSIInfo) Copy() *CSIInfo { + if c == nil { + return nil + } + + nc := new(CSIInfo) + *nc = *c + nc.ControllerInfo = c.ControllerInfo.Copy() + nc.NodeInfo = c.NodeInfo.Copy() + + return nc +} + +func (c *CSIInfo) SetHealthy(hs bool) { + c.Healthy = hs + if hs { + c.HealthDescription = "healthy" + } else { + c.HealthDescription = "unhealthy" + } +} + +func (c *CSIInfo) Equal(o *CSIInfo) bool { + if c == nil && o == nil { + return c == o + } + + nc := *c + nc.UpdateTime = time.Time{} + no := *o + no.UpdateTime = time.Time{} + + return reflect.DeepEqual(nc, no) +} + // DriverInfo is the current state of a single driver. This is updated // regularly as driver health changes on the node. type DriverInfo struct { diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 28942c1b7..ee475bccb 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -1659,6 +1659,11 @@ type Node struct { // Drivers is a map of driver names to current driver information Drivers map[string]*DriverInfo + // CSIControllerPlugins is a map of plugin names to current CSI Plugin info + CSIControllerPlugins map[string]*CSIInfo + // CSINodePlugins is a map of plugin names to current CSI Plugin info + CSINodePlugins map[string]*CSIInfo + // HostVolumes is a map of host volume names to their configuration HostVolumes map[string]*ClientHostVolumeConfig @@ -1705,6 +1710,8 @@ func (n *Node) Copy() *Node { nn.Meta = helper.CopyMapStringString(nn.Meta) nn.Events = copyNodeEvents(n.Events) nn.DrainStrategy = nn.DrainStrategy.Copy() + nn.CSIControllerPlugins = copyNodeCSI(nn.CSIControllerPlugins) + nn.CSINodePlugins = copyNodeCSI(nn.CSINodePlugins) nn.Drivers = copyNodeDrivers(n.Drivers) nn.HostVolumes = copyNodeHostVolumes(n.HostVolumes) return nn @@ -1724,6 +1731,21 @@ func copyNodeEvents(events []*NodeEvent) []*NodeEvent { return c } +// copyNodeCSI is a helper to copy a map of CSIInfo +func copyNodeCSI(plugins map[string]*CSIInfo) map[string]*CSIInfo { + l := len(plugins) + if l == 0 { + return nil + } + + c := make(map[string]*CSIInfo, l) + for plugin, info := range plugins { + c[plugin] = info.Copy() + } + + return c +} + // copyNodeDrivers is a helper to copy a map of DriverInfo func copyNodeDrivers(drivers map[string]*DriverInfo) map[string]*DriverInfo { l := len(drivers) @@ -5556,6 +5578,9 @@ type Task struct { // Used internally to manage tasks according to their TaskKind. Initial use case // is for Consul Connect Kind TaskKind + + // CSIPluginConfig is used to configure the plugin supervisor for the task. + CSIPluginConfig *TaskCSIPluginConfig } // UsesConnect is for conveniently detecting if the Task is able to make use @@ -5593,6 +5618,7 @@ func (t *Task) Copy() *Task { nt.Constraints = CopySliceConstraints(nt.Constraints) nt.Affinities = CopySliceAffinities(nt.Affinities) nt.VolumeMounts = CopySliceVolumeMount(nt.VolumeMounts) + nt.CSIPluginConfig = nt.CSIPluginConfig.Copy() nt.Vault = nt.Vault.Copy() nt.Resources = nt.Resources.Copy() @@ -5811,6 +5837,19 @@ func (t *Task) Validate(ephemeralDisk *EphemeralDisk, jobType string, tgServices } } + // Validate CSI Plugin Config + if t.CSIPluginConfig != nil { + if t.CSIPluginConfig.ID == "" { + mErr.Errors = append(mErr.Errors, fmt.Errorf("CSIPluginConfig must have a non-empty PluginID")) + } + + if !CSIPluginTypeIsValid(t.CSIPluginConfig.Type) { + mErr.Errors = append(mErr.Errors, fmt.Errorf("CSIPluginConfig PluginType must be one of 'node', 'controller', or 'monolith', got: \"%s\"", t.CSIPluginConfig.Type)) + } + + // TODO: Investigate validation of the PluginMountDir. Not much we can do apart from check IsAbs until after we understand its execution environment though :( + } + return mErr.ErrorOrNil() } @@ -6336,6 +6375,12 @@ const ( // TaskRestoreFailed indicates Nomad was unable to reattach to a // restored task. TaskRestoreFailed = "Failed Restoring Task" + + // TaskPluginUnhealthy indicates that a plugin managed by Nomad became unhealthy + TaskPluginUnhealthy = "Plugin became unhealthy" + + // TaskPluginHealthy indicates that a plugin managed by Nomad became healthy + TaskPluginHealthy = "Plugin became healthy" ) // TaskEvent is an event that effects the state of a task and contains meta-data diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index f43ebf527..cba53774d 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -1781,6 +1781,55 @@ func TestTask_Validate_LogConfig(t *testing.T) { } } +func TestTask_Validate_CSIPluginConfig(t *testing.T) { + table := []struct { + name string + pc *TaskCSIPluginConfig + expectedErr string + }{ + { + name: "no errors when not specified", + pc: nil, + }, + { + name: "requires non-empty plugin id", + pc: &TaskCSIPluginConfig{}, + expectedErr: "CSIPluginConfig must have a non-empty PluginID", + }, + { + name: "requires valid plugin type", + pc: &TaskCSIPluginConfig{ + ID: "com.hashicorp.csi", + Type: "nonsense", + }, + expectedErr: "CSIPluginConfig PluginType must be one of 'node', 'controller', or 'monolith', got: \"nonsense\"", + }, + } + + for _, tt := range table { + t.Run(tt.name, func(t *testing.T) { + task := &Task{ + CSIPluginConfig: tt.pc, + } + ephemeralDisk := &EphemeralDisk{ + SizeMB: 1, + } + + err := task.Validate(ephemeralDisk, JobTypeService, nil) + mErr := err.(*multierror.Error) + if tt.expectedErr != "" { + if !strings.Contains(mErr.Errors[4].Error(), tt.expectedErr) { + t.Fatalf("err: %s", err) + } + } else { + if len(mErr.Errors) != 4 { + t.Fatalf("unexpected err: %s", mErr.Errors[4]) + } + } + }) + } +} + func TestTask_Validate_Template(t *testing.T) { bad := &Template{} diff --git a/plugins/csi/client.go b/plugins/csi/client.go new file mode 100644 index 000000000..5647eeba7 --- /dev/null +++ b/plugins/csi/client.go @@ -0,0 +1,210 @@ +package csi + +import ( + "context" + "fmt" + "net" + "time" + + csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/hashicorp/nomad/plugins/base" + "github.com/hashicorp/nomad/plugins/shared/hclspec" + "google.golang.org/grpc" +) + +type NodeGetInfoResponse struct { + NodeID string + MaxVolumes int64 + AccessibleTopology *Topology +} + +// Topology is a map of topological domains to topological segments. +// A topological domain is a sub-division of a cluster, like "region", +// "zone", "rack", etc. +// +// According to CSI, there are a few requirements for the keys within this map: +// - Valid keys have two segments: an OPTIONAL prefix and name, separated +// by a slash (/), for example: "com.company.example/zone". +// - The key name segment is REQUIRED. The prefix is OPTIONAL. +// - The key name MUST be 63 characters or less, begin and end with an +// alphanumeric character ([a-z0-9A-Z]), and contain only dashes (-), +// underscores (_), dots (.), or alphanumerics in between, for example +// "zone". +// - The key prefix MUST be 63 characters or less, begin and end with a +// lower-case alphanumeric character ([a-z0-9]), contain only +// dashes (-), dots (.), or lower-case alphanumerics in between, and +// follow domain name notation format +// (https://tools.ietf.org/html/rfc1035#section-2.3.1). +// - The key prefix SHOULD include the plugin's host company name and/or +// the plugin name, to minimize the possibility of collisions with keys +// from other plugins. +// - If a key prefix is specified, it MUST be identical across all +// topology keys returned by the SP (across all RPCs). +// - Keys MUST be case-insensitive. Meaning the keys "Zone" and "zone" +// MUST not both exist. +// - Each value (topological segment) MUST contain 1 or more strings. +// - Each string MUST be 63 characters or less and begin and end with an +// alphanumeric character with '-', '_', '.', or alphanumerics in +// between. +type Topology struct { + Segments map[string]string +} + +type client struct { + conn *grpc.ClientConn + identityClient csipbv1.IdentityClient + controllerClient csipbv1.ControllerClient + nodeClient csipbv1.NodeClient +} + +func (c *client) Close() error { + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +func NewClient(addr string) (CSIPlugin, error) { + if addr == "" { + return nil, fmt.Errorf("address is empty") + } + + conn, err := newGrpcConn(addr) + if err != nil { + return nil, err + } + + return &client{ + conn: conn, + identityClient: csipbv1.NewIdentityClient(conn), + controllerClient: csipbv1.NewControllerClient(conn), + nodeClient: csipbv1.NewNodeClient(conn), + }, nil +} + +func newGrpcConn(addr string) (*grpc.ClientConn, error) { + conn, err := grpc.Dial( + addr, + grpc.WithInsecure(), + grpc.WithDialer(func(target string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", target, timeout) + }), + ) + + if err != nil { + return nil, fmt.Errorf("failed to open grpc connection to addr: %s, err: %v", addr, err) + } + + return conn, nil +} + +// PluginInfo describes the type and version of a plugin as required by the nomad +// base.BasePlugin interface. +func (c *client) PluginInfo() (*base.PluginInfoResponse, error) { + name, err := c.PluginGetInfo(context.TODO()) + if err != nil { + return nil, err + } + + return &base.PluginInfoResponse{ + Type: "csi", + PluginApiVersions: []string{"1.0.0"}, // TODO: fingerprint csi version + PluginVersion: "1.0.0", // TODO: get plugin version from somewhere?! + Name: name, + }, nil +} + +// ConfigSchema returns the schema for parsing the plugins configuration as +// required by the base.BasePlugin interface. It will always return nil. +func (c *client) ConfigSchema() (*hclspec.Spec, error) { + return nil, nil +} + +// SetConfig is used to set the configuration by passing a MessagePack +// encoding of it. +func (c *client) SetConfig(_ *base.Config) error { + return fmt.Errorf("unsupported") +} + +func (c *client) PluginProbe(ctx context.Context) (bool, error) { + req, err := c.identityClient.Probe(ctx, &csipbv1.ProbeRequest{}) + if err != nil { + return false, err + } + + wrapper := req.GetReady() + + // wrapper.GetValue() protects against wrapper being `nil`, and returns false. + ready := wrapper.GetValue() + + if wrapper == nil { + // If the plugin returns a nil value for ready, then it should be + // interpreted as the plugin is ready for compatibility with plugins that + // do not do health checks. + ready = true + } + + return ready, nil +} + +func (c *client) PluginGetInfo(ctx context.Context) (string, error) { + if c == nil { + return "", fmt.Errorf("Client not initialized") + } + if c.identityClient == nil { + return "", fmt.Errorf("Client not initialized") + } + + req, err := c.identityClient.GetPluginInfo(ctx, &csipbv1.GetPluginInfoRequest{}) + if err != nil { + return "", err + } + + name := req.GetName() + if name == "" { + return "", fmt.Errorf("PluginGetInfo: plugin returned empty name field") + } + + return name, nil +} + +func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySet, error) { + if c == nil { + return nil, fmt.Errorf("Client not initialized") + } + if c.identityClient == nil { + return nil, fmt.Errorf("Client not initialized") + } + + resp, err := c.identityClient.GetPluginCapabilities(ctx, &csipbv1.GetPluginCapabilitiesRequest{}) + if err != nil { + return nil, err + } + + return NewPluginCapabilitySet(resp), nil +} + +func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) { + if c == nil { + return nil, fmt.Errorf("Client not initialized") + } + if c.nodeClient == nil { + return nil, fmt.Errorf("Client not initialized") + } + + result := &NodeGetInfoResponse{} + + resp, err := c.nodeClient.NodeGetInfo(ctx, &csipbv1.NodeGetInfoRequest{}) + if err != nil { + return nil, err + } + + if resp.GetNodeId() == "" { + return nil, fmt.Errorf("plugin failed to return nodeid") + } + + result.NodeID = resp.GetNodeId() + result.MaxVolumes = resp.GetMaxVolumesPerNode() + + return result, nil +} diff --git a/plugins/csi/client_test.go b/plugins/csi/client_test.go new file mode 100644 index 000000000..882eacfad --- /dev/null +++ b/plugins/csi/client_test.go @@ -0,0 +1,191 @@ +package csi + +import ( + "context" + "fmt" + "testing" + + csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/protobuf/ptypes/wrappers" + fake "github.com/hashicorp/nomad/plugins/csi/testing" + "github.com/stretchr/testify/require" +) + +func newTestClient() (*fake.IdentityClient, CSIPlugin) { + ic := &fake.IdentityClient{} + client := &client{ + identityClient: ic, + } + + return ic, client +} + +func TestClient_RPC_PluginProbe(t *testing.T) { + cases := []struct { + Name string + ResponseErr error + ProbeResponse *csipbv1.ProbeResponse + ExpectedResponse bool + ExpectedErr error + }{ + { + Name: "handles underlying grpc errors", + ResponseErr: fmt.Errorf("some grpc error"), + ExpectedErr: fmt.Errorf("some grpc error"), + }, + { + Name: "returns false for ready when the provider returns false", + ProbeResponse: &csipbv1.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + ExpectedResponse: false, + }, + { + Name: "returns true for ready when the provider returns true", + ProbeResponse: &csipbv1.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + ExpectedResponse: true, + }, + { + /* When a SP does not return a ready value, a CO MAY treat this as ready. + We do so because example plugins rely on this behaviour. We may + re-evaluate this decision in the future. */ + Name: "returns true for ready when the provider returns a nil wrapper", + ProbeResponse: &csipbv1.ProbeResponse{ + Ready: nil, + }, + ExpectedResponse: true, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + ic, client := newTestClient() + defer client.Close() + + ic.NextErr = c.ResponseErr + ic.NextPluginProbe = c.ProbeResponse + + resp, err := client.PluginProbe(context.TODO()) + if c.ExpectedErr != nil { + require.Error(t, c.ExpectedErr, err) + } + + require.Equal(t, c.ExpectedResponse, resp) + }) + } + +} + +func TestClient_RPC_PluginInfo(t *testing.T) { + cases := []struct { + Name string + ResponseErr error + InfoResponse *csipbv1.GetPluginInfoResponse + ExpectedResponse string + ExpectedErr error + }{ + { + Name: "handles underlying grpc errors", + ResponseErr: fmt.Errorf("some grpc error"), + ExpectedErr: fmt.Errorf("some grpc error"), + }, + { + Name: "returns an error if we receive an empty `name`", + InfoResponse: &csipbv1.GetPluginInfoResponse{ + Name: "", + }, + ExpectedErr: fmt.Errorf("PluginGetInfo: plugin returned empty name field"), + }, + { + Name: "returns the name when successfully retrieved and not empty", + InfoResponse: &csipbv1.GetPluginInfoResponse{ + Name: "com.hashicorp.storage", + }, + ExpectedResponse: "com.hashicorp.storage", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + ic, client := newTestClient() + defer client.Close() + + ic.NextErr = c.ResponseErr + ic.NextPluginInfo = c.InfoResponse + + resp, err := client.PluginGetInfo(context.TODO()) + if c.ExpectedErr != nil { + require.Error(t, c.ExpectedErr, err) + } + + require.Equal(t, c.ExpectedResponse, resp) + }) + } + +} + +func TestClient_RPC_PluginGetCapabilities(t *testing.T) { + cases := []struct { + Name string + ResponseErr error + Response *csipbv1.GetPluginCapabilitiesResponse + ExpectedResponse *PluginCapabilitySet + ExpectedErr error + }{ + { + Name: "handles underlying grpc errors", + ResponseErr: fmt.Errorf("some grpc error"), + ExpectedErr: fmt.Errorf("some grpc error"), + }, + { + Name: "HasControllerService is true when it's part of the response", + Response: &csipbv1.GetPluginCapabilitiesResponse{ + Capabilities: []*csipbv1.PluginCapability{ + { + Type: &csipbv1.PluginCapability_Service_{ + Service: &csipbv1.PluginCapability_Service{ + Type: csipbv1.PluginCapability_Service_CONTROLLER_SERVICE, + }, + }, + }, + }, + }, + ExpectedResponse: &PluginCapabilitySet{hasControllerService: true}, + }, + { + Name: "HasTopologies is true when it's part of the response", + Response: &csipbv1.GetPluginCapabilitiesResponse{ + Capabilities: []*csipbv1.PluginCapability{ + { + Type: &csipbv1.PluginCapability_Service_{ + Service: &csipbv1.PluginCapability_Service{ + Type: csipbv1.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS, + }, + }, + }, + }, + }, + ExpectedResponse: &PluginCapabilitySet{hasTopologies: true}, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + ic, client := newTestClient() + defer client.Close() + + ic.NextErr = c.ResponseErr + ic.NextPluginCapabilities = c.Response + + resp, err := client.PluginGetCapabilities(context.TODO()) + if c.ExpectedErr != nil { + require.Error(t, c.ExpectedErr, err) + } + + require.Equal(t, c.ExpectedResponse, resp) + }) + } + +} diff --git a/plugins/csi/fake/client.go b/plugins/csi/fake/client.go new file mode 100644 index 000000000..dc8477363 --- /dev/null +++ b/plugins/csi/fake/client.go @@ -0,0 +1,112 @@ +// fake is a package that includes fake implementations of public interfaces +// from the CSI package for testing. +package fake + +import ( + "context" + "errors" + "sync" + + "github.com/hashicorp/nomad/plugins/base" + "github.com/hashicorp/nomad/plugins/csi" + "github.com/hashicorp/nomad/plugins/shared/hclspec" +) + +var _ csi.CSIPlugin = &Client{} + +// Client is a mock implementation of the csi.CSIPlugin interface for use in testing +// external components +type Client struct { + Mu sync.RWMutex + + NextPluginInfoResponse *base.PluginInfoResponse + NextPluginInfoErr error + PluginInfoCallCount int64 + + NextPluginProbeResponse bool + NextPluginProbeErr error + PluginProbeCallCount int64 + + NextPluginGetInfoResponse string + NextPluginGetInfoErr error + PluginGetInfoCallCount int64 + + NextPluginGetCapabilitiesResponse *csi.PluginCapabilitySet + NextPluginGetCapabilitiesErr error + PluginGetCapabilitiesCallCount int64 + + NextNodeGetInfoResponse *csi.NodeGetInfoResponse + NextNodeGetInfoErr error + NodeGetInfoCallCount int64 +} + +// PluginInfo describes the type and version of a plugin. +func (c *Client) PluginInfo() (*base.PluginInfoResponse, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.PluginInfoCallCount++ + + return c.NextPluginInfoResponse, c.NextPluginInfoErr +} + +// ConfigSchema returns the schema for parsing the plugins configuration. +func (c *Client) ConfigSchema() (*hclspec.Spec, error) { + return nil, errors.New("Unsupported") +} + +// SetConfig is used to set the configuration by passing a MessagePack +// encoding of it. +func (c *Client) SetConfig(a *base.Config) error { + return errors.New("Unsupported") +} + +// PluginProbe is used to verify that the plugin is in a healthy state +func (c *Client) PluginProbe(ctx context.Context) (bool, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.PluginProbeCallCount++ + + return c.NextPluginProbeResponse, c.NextPluginProbeErr +} + +// PluginGetInfo is used to return semantic data about the plugin. +// Response: +// - string: name, the name of the plugin in domain notation format. +func (c *Client) PluginGetInfo(ctx context.Context) (string, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.PluginGetInfoCallCount++ + + return c.NextPluginGetInfoResponse, c.NextPluginGetInfoErr +} + +// PluginGetCapabilities is used to return the available capabilities from the +// identity service. This currently only looks for the CONTROLLER_SERVICE and +// Accessible Topology Support +func (c *Client) PluginGetCapabilities(ctx context.Context) (*csi.PluginCapabilitySet, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.PluginGetCapabilitiesCallCount++ + + return c.NextPluginGetCapabilitiesResponse, c.NextPluginGetCapabilitiesErr +} + +// NodeGetInfo is used to return semantic data about the current node in +// respect to the SP. +func (c *Client) NodeGetInfo(ctx context.Context) (*csi.NodeGetInfoResponse, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.NodeGetInfoCallCount++ + + return c.NextNodeGetInfoResponse, c.NextNodeGetInfoErr +} + +// Shutdown the client and ensure any connections are cleaned up. +func (c *Client) Close() error { + return nil +} diff --git a/plugins/csi/plugin.go b/plugins/csi/plugin.go new file mode 100644 index 000000000..646c0b5f9 --- /dev/null +++ b/plugins/csi/plugin.go @@ -0,0 +1,85 @@ +package csi + +import ( + "context" + + csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/hashicorp/nomad/plugins/base" +) + +// CSIPlugin implements a lightweight abstraction layer around a CSI Plugin. +// It validates that responses from storage providers (SP's), correctly conform +// to the specification before returning response data or erroring. +type CSIPlugin interface { + base.BasePlugin + + // PluginProbe is used to verify that the plugin is in a healthy state + PluginProbe(ctx context.Context) (bool, error) + + // PluginGetInfo is used to return semantic data about the plugin. + // Response: + // - string: name, the name of the plugin in domain notation format. + PluginGetInfo(ctx context.Context) (string, error) + + // PluginGetCapabilities is used to return the available capabilities from the + // identity service. This currently only looks for the CONTROLLER_SERVICE and + // Accessible Topology Support + PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySet, error) + + // NodeGetInfo is used to return semantic data about the current node in + // respect to the SP. + NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) + + // Shutdown the client and ensure any connections are cleaned up. + Close() error +} + +type PluginCapabilitySet struct { + hasControllerService bool + hasTopologies bool +} + +func (p *PluginCapabilitySet) HasControllerService() bool { + return p.hasControllerService +} + +// HasTopologies indicates whether the volumes for this plugin are equally +// accessible by all nodes in the cluster. +// If true, we MUST use the topology information when scheduling workloads. +func (p *PluginCapabilitySet) HasToplogies() bool { + return p.hasTopologies +} + +func (p *PluginCapabilitySet) IsEqual(o *PluginCapabilitySet) bool { + return p.hasControllerService == o.hasControllerService && p.hasTopologies == o.hasTopologies +} + +func NewTestPluginCapabilitySet(topologies, controller bool) *PluginCapabilitySet { + return &PluginCapabilitySet{ + hasTopologies: topologies, + hasControllerService: controller, + } +} + +func NewPluginCapabilitySet(capabilities *csipbv1.GetPluginCapabilitiesResponse) *PluginCapabilitySet { + cs := &PluginCapabilitySet{} + + pluginCapabilities := capabilities.GetCapabilities() + + for _, pcap := range pluginCapabilities { + if svcCap := pcap.GetService(); svcCap != nil { + switch svcCap.Type { + case csipbv1.PluginCapability_Service_UNKNOWN: + continue + case csipbv1.PluginCapability_Service_CONTROLLER_SERVICE: + cs.hasControllerService = true + case csipbv1.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS: + cs.hasTopologies = true + default: + continue + } + } + } + + return cs +} diff --git a/plugins/csi/testing/client.go b/plugins/csi/testing/client.go new file mode 100644 index 000000000..a84841be4 --- /dev/null +++ b/plugins/csi/testing/client.go @@ -0,0 +1,43 @@ +package testing + +import ( + "context" + + csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" +) + +// IdentityClient is a CSI identity client used for testing +type IdentityClient struct { + NextErr error + NextPluginInfo *csipbv1.GetPluginInfoResponse + NextPluginCapabilities *csipbv1.GetPluginCapabilitiesResponse + NextPluginProbe *csipbv1.ProbeResponse +} + +// NewIdentityClient returns a new IdentityClient +func NewIdentityClient() *IdentityClient { + return &IdentityClient{} +} + +func (f *IdentityClient) Reset() { + f.NextErr = nil + f.NextPluginInfo = nil + f.NextPluginCapabilities = nil + f.NextPluginProbe = nil +} + +// GetPluginInfo returns plugin info +func (f *IdentityClient) GetPluginInfo(ctx context.Context, in *csipbv1.GetPluginInfoRequest, opts ...grpc.CallOption) (*csipbv1.GetPluginInfoResponse, error) { + return f.NextPluginInfo, f.NextErr +} + +// GetPluginCapabilities implements csi method +func (f *IdentityClient) GetPluginCapabilities(ctx context.Context, in *csipbv1.GetPluginCapabilitiesRequest, opts ...grpc.CallOption) (*csipbv1.GetPluginCapabilitiesResponse, error) { + return f.NextPluginCapabilities, f.NextErr +} + +// Probe implements csi method +func (f *IdentityClient) Probe(ctx context.Context, in *csipbv1.ProbeRequest, opts ...grpc.CallOption) (*csipbv1.ProbeResponse, error) { + return f.NextPluginProbe, f.NextErr +}